mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
Fix check triton config.pbtxt for missing values or colliding specifications (#62)
This commit is contained in:
parent
96b335e3c2
commit
82ade1e24a
@ -444,6 +444,36 @@ class TritonHelper(object):
|
|||||||
return "BYTES"
|
return "BYTES"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def triton_to_np_dtype(dtype):
|
||||||
|
if dtype == "BOOL":
|
||||||
|
return bool
|
||||||
|
elif dtype == "INT8":
|
||||||
|
return np.int8
|
||||||
|
elif dtype == "INT16":
|
||||||
|
return np.int16
|
||||||
|
elif dtype == "INT32":
|
||||||
|
return np.int32
|
||||||
|
elif dtype == "INT64":
|
||||||
|
return np.int64
|
||||||
|
elif dtype == "UINT8":
|
||||||
|
return np.uint8
|
||||||
|
elif dtype == "UINT16":
|
||||||
|
return np.uint16
|
||||||
|
elif dtype == "UINT32":
|
||||||
|
return np.uint32
|
||||||
|
elif dtype == "UINT64":
|
||||||
|
return np.uint64
|
||||||
|
elif dtype == "FP16":
|
||||||
|
return np.float16
|
||||||
|
elif dtype == "FP32":
|
||||||
|
return np.float32
|
||||||
|
elif dtype == "FP64":
|
||||||
|
return np.float64
|
||||||
|
elif dtype == "BYTES":
|
||||||
|
return np.object_
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
title = 'clearml-serving - Nvidia Triton Engine Controller'
|
title = 'clearml-serving - Nvidia Triton Engine Controller'
|
||||||
|
@ -1292,6 +1292,63 @@ class ModelRequestProcessor(object):
|
|||||||
Raise exception if validation fails, otherwise return True
|
Raise exception if validation fails, otherwise return True
|
||||||
"""
|
"""
|
||||||
if endpoint.engine_type in ("triton", ):
|
if endpoint.engine_type in ("triton", ):
|
||||||
|
if endpoint.auxiliary_cfg:
|
||||||
|
aux_config_dict = {}
|
||||||
|
|
||||||
|
if isinstance(endpoint.auxiliary_cfg, dict):
|
||||||
|
aux_config_dict = endpoint.auxiliary_cfg
|
||||||
|
elif isinstance(endpoint.auxiliary_cfg, str):
|
||||||
|
from clearml.utilities.pyhocon import ConfigFactory
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
aux_config_dict = ConfigFactory.parse_string(endpoint.auxiliary_cfg)
|
||||||
|
except Exception:
|
||||||
|
# we failed parsing the auxiliary pbtxt
|
||||||
|
aux_config_dict = {}
|
||||||
|
|
||||||
|
if aux_config_dict.get("input", None) or aux_config_dict.get("output", None):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
from ..engines.triton.triton_helper import TritonHelper
|
||||||
|
|
||||||
|
suggested_cli_in = {"name": [], "dims": [], "data_type": []}
|
||||||
|
suggested_cli_out = {"name": [], "dims": [], "data_type": []}
|
||||||
|
for layer in aux_config_dict.get("input", None) or []:
|
||||||
|
suggested_cli_in["name"] += ['"{}"'.format(layer["name"])]
|
||||||
|
suggested_cli_in["data_type"] += [
|
||||||
|
TritonHelper.triton_to_np_dtype(layer["data_type"].replace("TYPE_", "", 1)).__name__]
|
||||||
|
suggested_cli_in["dims"] += ['"{}"'.format(layer["dims"])]
|
||||||
|
|
||||||
|
for layer in aux_config_dict.get("output", None) or []:
|
||||||
|
suggested_cli_out["name"] += ['"{}"'.format(layer["name"])]
|
||||||
|
suggested_cli_out["data_type"] += [
|
||||||
|
TritonHelper.triton_to_np_dtype(layer["data_type"].replace("TYPE_", "", 1)).__name__]
|
||||||
|
suggested_cli_out["dims"] += ['"{}"'.format(layer["dims"])]
|
||||||
|
|
||||||
|
suggested_cli = "Add to your command line: "\
|
||||||
|
"--input-name {} --input-type {} --input-size {} " \
|
||||||
|
"--output-name {} --output-type {} --output-size {} ".format(
|
||||||
|
" ".join(suggested_cli_in["name"]),
|
||||||
|
" ".join(suggested_cli_in["data_type"]),
|
||||||
|
" ".join(suggested_cli_in["dims"]),
|
||||||
|
" ".join(suggested_cli_out["name"]),
|
||||||
|
" ".join(suggested_cli_out["data_type"]),
|
||||||
|
" ".join(suggested_cli_out["dims"]),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
suggested_cli = "?"
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Triton engine requires *manual* input/output specification, "
|
||||||
|
"You input/output in your pbtxt, please remove them and specify manually.\n"
|
||||||
|
"{}".format(suggested_cli)
|
||||||
|
)
|
||||||
|
|
||||||
|
if aux_config_dict.get("default_model_filename", None):
|
||||||
|
raise ValueError("ERROR: You have `default_model_filename` in your config pbtxt, "
|
||||||
|
"please remove it. It will be added automatically by the system.")
|
||||||
|
|
||||||
# verify we have all the info we need
|
# verify we have all the info we need
|
||||||
d = endpoint.as_dict()
|
d = endpoint.as_dict()
|
||||||
missing = [
|
missing = [
|
||||||
@ -1300,7 +1357,8 @@ class ModelRequestProcessor(object):
|
|||||||
'output_type', 'output_size', 'output_name',
|
'output_type', 'output_size', 'output_name',
|
||||||
] if not d.get(k)
|
] if not d.get(k)
|
||||||
]
|
]
|
||||||
if not endpoint.auxiliary_cfg and missing:
|
|
||||||
|
if missing:
|
||||||
raise ValueError("Triton engine requires input description - missing values in {}".format(missing))
|
raise ValueError("Triton engine requires input description - missing values in {}".format(missing))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user