Add Triton support for variable length requests, adds support for HuggingFace Transformers

Add triton_grpc_compression=False (default) for grpc connection compression control
This commit is contained in:
allegroai 2022-09-02 23:41:49 +03:00
parent c6c40c9a36
commit f4eed33f10
7 changed files with 297 additions and 138 deletions

View File

@ -1,5 +1,6 @@
import json
import os.path
import sys
from argparse import ArgumentParser
from pathlib import Path
@ -7,6 +8,33 @@ from clearml_serving.serving.model_request_processor import ModelRequestProcesso
from clearml_serving.serving.endpoints import ModelMonitoring, ModelEndpoint, EndpointMetricLogging
verbosity = False
answer_yes = False
def verify_session_version(request_processor):
from clearml_serving.version import __version__
current_v = float('.'.join(str(__version__).split(".")[:2]))
stored_v = float('.'.join(str(request_processor.get_version()).split(".")[:2]))
if stored_v != current_v:
print(
"WARNING: You are about to edit clearml-serving session ID={}\n"
"It was created with a different version ({}), you are currently using version {}".format(
request_processor.get_id(), stored_v, current_v))
print("Do you want to continue [n]/y? ", end="")
if answer_yes:
print("y")
else:
should_continue = input().lower()
if should_continue not in ("y", "yes"):
print("INFO: If you wish to downgrade your clearml-serving CLI use "
"`pip3 install clearml-serving=={}`".format(request_processor.get_version()))
sys.exit(0)
def safe_ModelRequestProcessor(*args, **kwargs):
request_processor = ModelRequestProcessor(*args, **kwargs)
verify_session_version(request_processor)
return request_processor
def func_metric_ls(args):
@ -18,9 +46,10 @@ def func_metric_ls(args):
def func_metric_rm(args):
request_processor = ModelRequestProcessor(task_id=args.id)
request_processor = safe_ModelRequestProcessor(task_id=args.id)
print("Serving service Task {}, Removing metrics from endpoint={}".format(
request_processor.get_id(), args.endpoint))
verify_session_version(request_processor)
request_processor.deserialize(skip_sync=True)
if not args.variable:
if request_processor.remove_metric_logging(endpoint=args.endpoint):
@ -38,7 +67,7 @@ def func_metric_rm(args):
def func_metric_add(args):
request_processor = ModelRequestProcessor(task_id=args.id)
request_processor = safe_ModelRequestProcessor(task_id=args.id)
print("Serving service Task {}, Adding metric logging endpoint \'/{}/\'".format(
request_processor.get_id(), args.endpoint))
request_processor.deserialize(skip_sync=True)
@ -104,6 +133,7 @@ def func_model_upload(args):
else:
print("Registering model file \'{}\'".format(args.url))
model.update_weights(weights_filename=args.path, register_uri=args.url, auto_delete_file=False)
t.flush(wait_for_uploads=True)
if args.project:
# noinspection PyProtectedMember
model._base_model.update(
@ -130,11 +160,12 @@ def func_model_ls(args):
def func_create_service(args):
request_processor = ModelRequestProcessor(
force_create=True, name=args.name, project=args.project, tags=args.tags or None)
request_processor.serialize()
print("New Serving Service created: id={}".format(request_processor.get_id()))
def func_config_service(args):
request_processor = ModelRequestProcessor(task_id=args.id)
request_processor = safe_ModelRequestProcessor(task_id=args.id)
print("Configure serving service id={}".format(request_processor.get_id()))
request_processor.deserialize(skip_sync=True)
if args.base_serving_url:
@ -162,7 +193,7 @@ def func_list_services(_):
def func_model_remove(args):
request_processor = ModelRequestProcessor(task_id=args.id)
request_processor = safe_ModelRequestProcessor(task_id=args.id)
print("Serving service Task {}, Removing Model endpoint={}".format(request_processor.get_id(), args.endpoint))
request_processor.deserialize(skip_sync=True)
if request_processor.remove_endpoint(endpoint_url=args.endpoint):
@ -179,7 +210,7 @@ def func_model_remove(args):
def func_canary_add(args):
request_processor = ModelRequestProcessor(task_id=args.id)
request_processor = safe_ModelRequestProcessor(task_id=args.id)
print("Serving service Task {}, Adding canary endpoint \'/{}/\'".format(
request_processor.get_id(), args.endpoint))
request_processor.deserialize(skip_sync=True)
@ -198,7 +229,7 @@ def func_canary_add(args):
def func_model_auto_update_add(args):
request_processor = ModelRequestProcessor(task_id=args.id)
request_processor = safe_ModelRequestProcessor(task_id=args.id)
print("Serving service Task {}, Adding Model monitoring endpoint: \'/{}/\'".format(
request_processor.get_id(), args.endpoint))
@ -207,7 +238,9 @@ def func_model_auto_update_add(args):
aux_config = Path(args.aux_config[0]).read_text()
else:
from clearml.utilities.pyhocon import ConfigFactory
aux_config = ConfigFactory.parse_string('\n'.join(args.aux_config)).as_plain_ordered_dict()
aux_config = ConfigFactory.parse_string(
'\n'.join(args.aux_config).replace("\"", "\\\"").replace("'", "\\\'")
).as_plain_ordered_dict()
else:
aux_config = None
@ -238,7 +271,7 @@ def func_model_auto_update_add(args):
def func_model_endpoint_add(args):
request_processor = ModelRequestProcessor(task_id=args.id)
request_processor = safe_ModelRequestProcessor(task_id=args.id)
print("Serving service Task {}, Adding Model endpoint \'/{}/\'".format(
request_processor.get_id(), args.endpoint))
request_processor.deserialize(skip_sync=True)
@ -248,7 +281,9 @@ def func_model_endpoint_add(args):
aux_config = Path(args.aux_config[0]).read_text()
else:
from clearml.utilities.pyhocon import ConfigFactory
aux_config = ConfigFactory.parse_string('\n'.join(args.aux_config)).as_plain_ordered_dict()
aux_config = ConfigFactory.parse_string(
'\n'.join(args.aux_config).replace("\"", "\\\"").replace("'", "\\\'")
).as_plain_ordered_dict()
else:
aux_config = None
@ -283,6 +318,7 @@ def cli():
print(title)
parser = ArgumentParser(prog='clearml-serving', description=title)
parser.add_argument('--debug', action='store_true', help='Print debug messages')
parser.add_argument('--yes', action='store_true', help='Always answer YES on interactive inputs')
parser.add_argument(
'--id', type=str,
help='Control plane Task ID to configure '
@ -447,34 +483,44 @@ def cli():
'- this should hold for all the models'
)
parser_model_monitor.add_argument(
'--input-size', type=int, nargs='+',
help='Optional: Specify the model matrix input size [Rows x Columns X Channels etc ...]'
'--input-size', nargs='+', type=json.loads,
help='Optional: Specify the model matrix input size [Rows x Columns X Channels etc ...] '
'if multiple inputs are required specify using json notation e.g.: '
'\"[dim0, dim1, dim2, ...]\" \"[dim0, dim1, dim2, ...]\"'
)
parser_model_monitor.add_argument(
'--input-type', type=str,
help='Optional: Specify the model matrix input type, examples: uint8, float32, int16, float16 etc.'
'--input-type', nargs='+',
help='Optional: Specify the model matrix input type, examples: uint8, float32, int16, float16 etc. '
'if multiple inputs are required pass multiple values: float32, float32,'
)
parser_model_monitor.add_argument(
'--input-name', type=str,
help='Optional: Specify the model layer pushing input into, examples: layer_0'
'--input-name', nargs='+',
help='Optional: Specify the model layer pushing input into, examples: layer_0 '
'if multiple inputs are required pass multiple values: layer_0, layer_1,'
)
parser_model_monitor.add_argument(
'--output-size', type=int, nargs='+',
help='Optional: Specify the model matrix output size [Rows x Columns X Channels etc ...]'
'--output-size', nargs='+', type=json.loads,
help='Optional: Specify the model matrix output size [Rows x Columns X Channels etc ...] '
'if multiple outputs are required specify using json notation e.g.: '
'\"[dim0, dim1, dim2, ...]\" \"[dim0, dim1, dim2, ...]\"'
)
parser_model_monitor.add_argument(
'--output_type', type=str,
help='Optional: Specify the model matrix output type, examples: uint8, float32, int16, float16 etc.'
'--output-type', nargs='+',
help='Optional: Specify the model matrix output type, examples: uint8, float32, int16, float16 etc. '
'if multiple outputs are required pass multiple values: float32, float32,'
)
parser_model_monitor.add_argument(
'--output-name', type=str,
help='Optional: Specify the model layer pulling results from, examples: layer_99'
'--output-name', nargs='+',
help='Optional: Specify the model layer pulling results from, examples: layer_99 '
'if multiple outputs are required pass multiple values: layer_98, layer_99,'
)
parser_model_monitor.add_argument(
'--aux-config', nargs='+',
help='Specify additional engine specific auxiliary configuration in the form of key=value. '
'Example: platform=onnxruntime_onnx response_cache.enable=true max_batch_size=8 '
'Notice: you can also pass full configuration file (e.g. Triton "config.pbtxt")'
'Examples: platform=\\"onnxruntime_onnx\\" response_cache.enable=true max_batch_size=8 '
'input.0.format=FORMAT_NCHW output.0.format=FORMAT_NCHW '
'Remarks: (1) string must be quoted (e.g. key=\\"a_string\\") '
'(2) instead of key/value pairs, you can also pass a full configuration file (e.g. "./config.pbtxt")'
)
parser_model_monitor.set_defaults(func=func_model_auto_update_add)
@ -496,34 +542,44 @@ def cli():
help='Specify Pre/Post processing code to be used with the model (point to local file / folder)'
)
parser_model_add.add_argument(
'--input-size', type=int, nargs='+',
help='Optional: Specify the model matrix input size [Rows x Columns X Channels etc ...]'
'--input-size', nargs='+', type=json.loads,
help='Optional: Specify the model matrix input size [Rows x Columns X Channels etc ...] '
'if multiple inputs are required specify using json notation e.g.: '
'\"[dim0, dim1, dim2, ...]\" \"[dim0, dim1, dim2, ...]\"'
)
parser_model_add.add_argument(
'--input-type', type=str,
help='Optional: Specify the model matrix input type, examples: uint8, float32, int16, float16 etc.'
'--input-type', nargs='+',
help='Optional: Specify the model matrix input type, examples: uint8, float32, int16, float16 etc. '
'if multiple inputs are required pass multiple values: float32, float32,'
)
parser_model_add.add_argument(
'--input-name', type=str,
help='Optional: Specify the model layer pushing input into, examples: layer_0'
'--input-name', nargs='+',
help='Optional: Specify the model layer pushing input into, examples: layer_0 '
'if multiple inputs are required pass multiple values: layer_0, layer_1,'
)
parser_model_add.add_argument(
'--output-size', type=int, nargs='+',
help='Optional: Specify the model matrix output size [Rows x Columns X Channels etc ...]'
'--output-size', nargs='+', type=json.loads,
help='Optional: Specify the model matrix output size [Rows x Columns X Channels etc ...] '
'if multiple outputs are required specify using json notation e.g.: '
'\"[dim0, dim1, dim2, ...]\" \"[dim0, dim1, dim2, ...]\"'
)
parser_model_add.add_argument(
'--output-type', type=str,
help='Specify the model matrix output type, examples: uint8, float32, int16, float16 etc.'
'--output-type', nargs='+',
help='Optional: Specify the model matrix output type, examples: uint8, float32, int16, float16 etc. '
'if multiple outputs are required pass multiple values: float32, float32,'
)
parser_model_add.add_argument(
'--output-name', type=str,
help='Optional: Specify the model layer pulling results from, examples: layer_99'
'--output-name', nargs='+',
help='Optional: Specify the model layer pulling results from, examples: layer_99 '
'if multiple outputs are required pass multiple values: layer_98, layer_99,'
)
parser_model_add.add_argument(
'--aux-config', type=int, nargs='+',
'--aux-config', nargs='+',
help='Specify additional engine specific auxiliary configuration in the form of key=value. '
'Example: platform=onnxruntime_onnx response_cache.enable=true max_batch_size=8 '
'Notice: you can also pass full configuration file (e.g. Triton "config.pbtxt")'
'Examples: platform=\\"onnxruntime_onnx\\" response_cache.enable=true max_batch_size=8 '
'input.0.format=FORMAT_NCHW output.0.format=FORMAT_NCHW '
'Remarks: (1) string must be quoted (e.g. key=\\"a_string\\") '
'(2) instead of key/value pairs, you can also pass a full configuration file (e.g. "./config.pbtxt")'
)
parser_model_add.add_argument(
'--name', type=str,
@ -540,8 +596,9 @@ def cli():
parser_model_add.set_defaults(func=func_model_endpoint_add)
args = parser.parse_args()
global verbosity
global verbosity, answer_yes
verbosity = args.debug
answer_yes = args.yes
if args.command:
if args.command not in ("create", "list") and not args.id:

View File

@ -110,6 +110,9 @@ class TritonHelper(object):
for url, endpoint in active_endpoints.items():
# Triton model folder structure reference:
# https://github.com/triton-inference-server/server/blob/r22.07/docs/model_repository.md#model-repository
# skip if there is no change
if url in self._current_endpoints and self._current_endpoints.get(url) == endpoint:
continue
@ -138,7 +141,8 @@ class TritonHelper(object):
local_path = model.get_local_copy()
except Exception:
local_path = None
if not local_path:
if not local_path or not model:
print("Error retrieving model ID {} []".format(model_id, model.url if model else ''))
continue
@ -152,42 +156,34 @@ class TritonHelper(object):
if verbose:
print('Update model v{} in {}'.format(version, model_folder))
framework = str(model.framework).lower()
# if this is a folder copy every and delete the temp folder
if local_path.is_dir() and model and (
str(model.framework).lower().startswith("tensorflow") or
str(model.framework).lower().startswith("keras")
):
if local_path.is_dir() and model and ("tensorflow" in framework or "keras" in framework):
# we assume we have a `tensorflow.savedmodel` folder
model_folder /= 'model.savedmodel'
model_folder.mkdir(parents=True, exist_ok=True)
# rename to old
old_folder = None
if model_folder.exists():
old_folder = model_folder.parent / '.old.{}'.format(model_folder.name)
model_folder.replace(old_folder)
if verbose:
print('copy model into {}'.format(model_folder))
shutil.copytree(
local_path.as_posix(), model_folder.as_posix(), symlinks=False,
)
if old_folder:
shutil.rmtree(path=old_folder.as_posix())
# delete temp folder
shutil.rmtree(local_path.as_posix())
else:
self._extract_folder(local_path, model_folder, verbose, remove_existing=True)
elif "torch" in framework and local_path.is_file():
# single file should be moved
if model and str(model.framework).lower().startswith("pytorch"):
target_path = model_folder / "model.pt"
else:
target_path = model_folder / local_path.name
old_file = None
if target_path.exists():
old_file = target_path.parent / '.old.{}'.format(target_path.name)
target_path.replace(old_file)
shutil.move(local_path.as_posix(), target_path.as_posix())
if old_file:
old_file.unlink()
self._extract_single_file(local_path, model_folder / "model.pt", verbose)
elif "onnx" in framework and local_path.is_dir():
# just unzip both model.bin & model.xml into the model folder
self._extract_folder(local_path, model_folder, verbose)
elif ("tensorflow" in framework or "keras" in framework) and local_path.is_file():
# just rename the single file to "model.graphdef"
self._extract_single_file(local_path, model_folder / "model.graphdef", verbose)
elif "tensorrt" in framework and local_path.is_file():
# just rename the single file to "model.plan"
self._extract_single_file(local_path, model_folder / "model.plan", verbose)
elif local_path.is_file():
# generic model will be stored as 'model.bin'
self._extract_single_file(local_path, model_folder / "model.bin", verbose)
elif local_path.is_dir():
# generic model will be stored into the model folder
self._extract_folder(local_path, model_folder, verbose)
else:
print("Model type could not be inferred skipping", model.id, model.framework, model.name)
continue
# todo: trigger triton model reloading (instead of relaying on current poll mechanism)
# based on the model endpoint changes
@ -197,6 +193,36 @@ class TritonHelper(object):
return True
@staticmethod
def _extract_single_file(local_path, target_path, verbose):
old_file = None
if target_path.exists():
old_file = target_path.parent / '.old.{}'.format(target_path.name)
target_path.replace(old_file)
if verbose:
print('copy model into {}'.format(target_path))
shutil.move(local_path.as_posix(), target_path.as_posix())
if old_file:
old_file.unlink()
@staticmethod
def _extract_folder(local_path, model_folder, verbose, remove_existing=False):
model_folder.mkdir(parents=True, exist_ok=True)
# rename to old
old_folder = None
if remove_existing and model_folder.exists():
old_folder = model_folder.parent / '.old.{}'.format(model_folder.name)
model_folder.replace(old_folder)
if verbose:
print('copy model into {}'.format(model_folder))
shutil.copytree(
local_path.as_posix(), model_folder.as_posix(), symlinks=False, dirs_exist_ok=True
)
if old_folder:
shutil.rmtree(path=old_folder.as_posix())
# delete temp folder
shutil.rmtree(local_path.as_posix())
def maintenance_daemon(
self,
local_model_repo='/models', # type: str
@ -313,36 +339,34 @@ class TritonHelper(object):
# replace ": [{" with ": [{" (currently not needed)
# pattern = re.compile(r"(?P<key>\w+)(?P<space>\s+)(?P<bracket>(\[)|({))")
if endpoint.input_size:
config_dict.put("input.0.dims", endpoint.input_size)
for i, s in enumerate(endpoint.input_size or []):
config_dict.put("input.{}.dims".format(i), s)
if endpoint.output_size:
config_dict.put("output.0.dims", endpoint.output_size)
for i, s in enumerate(endpoint.output_size or []):
config_dict.put("output.{}.dims".format(i), s)
input_type = None
if endpoint.input_type:
input_type = "TYPE_" + cls.np_to_triton_dtype(np.dtype(endpoint.input_type))
config_dict.put("input.0.data_type", input_type)
for i, s in enumerate(endpoint.input_type or []):
input_type = "TYPE_" + cls.np_to_triton_dtype(np.dtype(s))
config_dict.put("input.{}.data_type".format(i), input_type)
output_type = None
if endpoint.output_type:
output_type = "TYPE_" + cls.np_to_triton_dtype(np.dtype(endpoint.output_type))
config_dict.put("output.0.data_type", output_type)
for i, s in enumerate(endpoint.output_type or []):
output_type = "TYPE_" + cls.np_to_triton_dtype(np.dtype(s))
config_dict.put("output.{}.data_type".format(i), output_type)
if endpoint.input_name:
config_dict.put("input.0.name", endpoint.input_name)
for i, s in enumerate(endpoint.input_name or []):
config_dict.put("input.{}.name".format(i), "\"{}\"".format(s))
if endpoint.output_name:
config_dict.put("output.0.name", endpoint.output_name)
for i, s in enumerate(endpoint.output_name or []):
config_dict.put("output.{}.name".format(i), "\"{}\"".format(s))
if platform and not config_dict.get("platform", None) and not config_dict.get("backend", None):
platform = str(platform).lower()
if platform.startswith("tensorflow") or platform.startswith("keras"):
config_dict["platform"] = "tensorflow_savedmodel"
config_dict["platform"] = "\"tensorflow_savedmodel\""
elif platform.startswith("pytorch") or platform.startswith("caffe"):
config_dict["backend"] = "pytorch"
config_dict["backend"] = "\"pytorch\""
elif platform.startswith("onnx"):
config_dict["platform"] = "onnxruntime_onnx"
config_dict["platform"] = "\"onnxruntime_onnx\""
# convert to lists anything that we can:
if config_dict:
@ -350,13 +374,11 @@ class TritonHelper(object):
# Convert HOCON standard to predefined message format
config_pbtxt = "\n" + HOCONConverter.to_hocon(config_dict). \
replace("=", ":").replace(" : ", ": ")
# conform types (remove string quotes)
if input_type:
config_pbtxt = config_pbtxt.replace(f"\"{input_type}\"", f"{input_type}")
if output_type:
config_pbtxt = config_pbtxt.replace(f"\"{output_type}\"", f"{output_type}")
# conform types (remove string quotes)
config_pbtxt = config_pbtxt.replace("\"KIND_CPU\"", "KIND_CPU").replace("\"KIND_GPU\"", "KIND_GPU")
config_pbtxt = config_pbtxt.replace("\\\"", "<DQUOTE>").\
replace("\\\'", "<QUOTE>").replace("\"", "").replace("\'", "").\
replace("<DQUOTE>", "\"").replace("<QUOTE>", "\'")
else:
config_pbtxt = ""

View File

@ -9,10 +9,30 @@ def _engine_validator(inst, attr, value): # noqa
def _matrix_type_validator(inst, attr, value): # noqa
if value and not np.dtype(value):
if isinstance(value, (tuple, list)):
for v in value:
if v and not np.dtype(v):
raise TypeError("{} not supported matrix type".format(v))
elif value and not np.dtype(value):
raise TypeError("{} not supported matrix type".format(value))
def _list_type_convertor(inst): # noqa
if inst is None:
return None
return inst if isinstance(inst, (tuple, list)) else [inst]
def _nested_list_type_convertor(inst): # noqa
if inst is None:
return None
if isinstance(inst, (tuple, list)) and all(not isinstance(i, (tuple, list)) for i in inst):
return [inst]
inst = inst if isinstance(inst, (tuple, list)) else [inst]
return inst
@attrs
class BaseStruct(object):
def as_dict(self, remove_null_entries=False):
@ -30,12 +50,12 @@ class ModelMonitoring(BaseStruct):
monitor_tags = attrib(type=list, default=[]) # monitor model tag (for model auto update)
only_published = attrib(type=bool, default=False) # only select published models
max_versions = attrib(type=int, default=None) # Maximum number of models to keep serving (latest X models)
input_size = attrib(type=list, default=None) # optional, model matrix size
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
output_size = attrib(type=list, default=None) # optional, model matrix size
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
input_size = attrib(type=list, default=None, converter=_nested_list_type_convertor) # optional, model matrix size
input_type = attrib(type=list, default=None, validator=_matrix_type_validator, converter=_list_type_convertor)
input_name = attrib(type=list, default=None, converter=_list_type_convertor) # optional, input layer names
output_size = attrib(type=list, default=None, converter=_nested_list_type_convertor) # optional, model matrix size
output_type = attrib(type=list, default=None, validator=_matrix_type_validator, converter=_list_type_convertor)
output_name = attrib(type=list, default=None, converter=_list_type_convertor) # optional, output layer names
preprocess_artifact = attrib(
type=str, default=None) # optional artifact name storing the model preprocessing code
auxiliary_cfg = attrib(type=dict, default=None) # Auxiliary configuration (e.g. triton conf), Union[str, dict]
@ -49,12 +69,12 @@ class ModelEndpoint(BaseStruct):
version = attrib(type=str, default="") # key (version string), default no version
preprocess_artifact = attrib(
type=str, default=None) # optional artifact name storing the model preprocessing code
input_size = attrib(type=list, default=None) # optional, model matrix size
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
output_size = attrib(type=list, default=None) # optional, model matrix size
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
input_size = attrib(type=list, default=None, converter=_nested_list_type_convertor) # optional, model matrix size
input_type = attrib(type=list, default=None, validator=_matrix_type_validator, converter=_list_type_convertor)
input_name = attrib(type=list, default=None, converter=_list_type_convertor) # optional, input layer names
output_size = attrib(type=list, default=None, converter=_nested_list_type_convertor) # optional, model matrix size
output_type = attrib(type=list, default=None, validator=_matrix_type_validator, converter=_list_type_convertor)
output_name = attrib(type=list, default=None, converter=_list_type_convertor) # optional, output layer names
auxiliary_cfg = attrib(type=dict, default=None) # Optional: Auxiliary configuration (e.g. triton conf), [str, dict]

View File

@ -12,7 +12,7 @@ from multiprocessing import Lock
from numpy.random import choice
from clearml import Task, Model
from clearml.utilities.dicts import merge_dicts
from clearml.utilities.dicts import merge_dicts, cast_str_to_bool
from clearml.storage.util import hash_dict
from .preprocess_service import BasePreprocessRequest
from .endpoints import ModelEndpoint, ModelMonitoring, CanaryEP, EndpointMetricLogging
@ -69,6 +69,7 @@ class ModelRequestProcessor(object):
_kafka_topic = "clearml_inference_stats"
_config_key_serving_base_url = "serving_base_url"
_config_key_triton_grpc = "triton_grpc_server"
_config_key_triton_compression = "triton_grpc_compression"
_config_key_kafka_stats = "kafka_service_server"
_config_key_def_metric_freq = "metric_logging_freq"
@ -118,6 +119,7 @@ class ModelRequestProcessor(object):
# deserialized values go here
self._kafka_stats_url = None
self._triton_grpc = None
self._triton_grpc_compression = None
self._serving_base_url = None
self._metric_log_freq = None
@ -173,6 +175,7 @@ class ModelRequestProcessor(object):
self,
external_serving_base_url: Optional[str] = None,
external_triton_grpc_server: Optional[str] = None,
external_triton_grpc_compression: Optional[bool] = None,
external_kafka_service_server: Optional[str] = None,
default_metric_log_freq: Optional[float] = None,
):
@ -184,6 +187,7 @@ class ModelRequestProcessor(object):
allowing it to concatenate and combine multiple model requests into one
:param external_triton_grpc_server: set the external grpc tcp port of the Nvidia Triton clearml container.
Used by the clearml triton engine class to send inference requests
:param external_triton_grpc_compression: set gRPC compression (default: False, no compression)
:param external_kafka_service_server: Optional, Kafka endpoint for the statistics controller collection.
:param default_metric_log_freq: Default request metric logging (0 to 1.0, 1. means 100% of requests are logged)
"""
@ -201,6 +205,13 @@ class ModelRequestProcessor(object):
value_type="str",
description="external grpc tcp port of the Nvidia Triton ClearML container running"
)
if external_triton_grpc_compression is not None:
self._task.set_parameter(
name="General/{}".format(self._config_key_triton_compression),
value=str(external_triton_grpc_compression),
value_type="bool",
description="use external grpc tcp compression"
)
if external_kafka_service_server is not None:
self._task.set_parameter(
name="General/{}".format(self._config_key_kafka_stats),
@ -587,6 +598,22 @@ class ModelRequestProcessor(object):
task.set_configuration_object(name='model_monitoring', config_dict=config_dict)
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._metric_logging.items()}
task.set_configuration_object(name='metric_logging', config_dict=config_dict)
# store our version
from ..version import __version__
# noinspection PyProtectedMember
if task._get_runtime_properties().get("version") != str(__version__):
# noinspection PyProtectedMember
task._set_runtime_properties(runtime_properties=dict(version=str(__version__)))
def get_version(self) -> str:
"""
:return: version number (string) of the ModelRequestProcessor (clearml-serving session)
"""
default_version = "1.0.0"
if not self._task:
return default_version
# noinspection PyProtectedMember
return self._task._get_runtime_properties().get("version", default_version)
def _update_canary_lookup(self):
canary_route = {}
@ -1086,6 +1113,10 @@ class ModelRequestProcessor(object):
self._triton_grpc = \
configuration.get(self._config_key_triton_grpc) or \
os.environ.get("CLEARML_DEFAULT_TRITON_GRPC_ADDR")
self._triton_grpc_compression = \
cast_str_to_bool(str(configuration.get(
self._config_key_triton_compression, os.environ.get("CLEARML_DEFAULT_TRITON_GRPC_COMPRESSION", '0')
)))
self._serving_base_url = \
configuration.get(self._config_key_serving_base_url) or \
os.environ.get("CLEARML_DEFAULT_BASE_SERVE_URL")
@ -1095,6 +1126,7 @@ class ModelRequestProcessor(object):
# update back configuration
self._configuration[self._config_key_kafka_stats] = self._kafka_stats_url
self._configuration[self._config_key_triton_grpc] = self._triton_grpc
self._configuration[self._config_key_triton_compression] = self._triton_grpc_compression
self._configuration[self._config_key_serving_base_url] = self._serving_base_url
self._configuration[self._config_key_def_metric_freq] = self._metric_log_freq
# update preprocessing classes

View File

@ -247,12 +247,14 @@ class TritonPreprocessRequest(BasePreprocessRequest):
np.int64: 'int64_contents',
np.uint64: 'uint64_contents',
np.int: 'int_contents',
np.int32: 'int_contents',
np.uint: 'uint_contents',
np.bool: 'bool_contents',
np.float32: 'fp32_contents',
np.float64: 'fp64_contents',
}
_default_grpc_address = "127.0.0.1:8001"
_default_grpc_compression = False
_ext_grpc = None
_ext_np_to_triton_dtype = None
_ext_service_pb2 = None
@ -287,6 +289,7 @@ class TritonPreprocessRequest(BasePreprocessRequest):
Detect gRPC server and send the request to it
:param data: object as recieved from the preprocessing function
If multiple inputs are needed, data is a list of numpy array
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
@ -315,51 +318,76 @@ class TritonPreprocessRequest(BasePreprocessRequest):
except Exception as ex:
raise ValueError("External Triton gRPC server misconfigured [{}]: {}".format(triton_server_address, ex))
use_compression = self._server_config.get("triton_grpc_compression", self._default_grpc_compression)
# Generate the request
request = self._ext_service_pb2.ModelInferRequest()
request.model_name = "{}/{}".format(self.model_endpoint.serving_url, self.model_endpoint.version).strip("/")
# we do not use the Triton model versions, we just assume a single version per endpoint
request.model_version = "1"
# take the input data
input_data = np.array(data, dtype=self.model_endpoint.input_type)
# make sure that if we have only one input we maintain backwards compatibility
list_data = [data] if len(self.model_endpoint.input_name) == 1 else data
# Populate the inputs in inference request
input0 = request.InferInputTensor()
input0.name = self.model_endpoint.input_name
input_dtype = np.dtype(self.model_endpoint.input_type).type
input0.datatype = self._ext_np_to_triton_dtype(input_dtype)
input0.shape.extend(self.model_endpoint.input_size)
for i_data, m_name, m_type, m_size in zip(
list_data, self.model_endpoint.input_name,
self.model_endpoint.input_type, self.model_endpoint.input_size
):
# take the input data
input_data = np.array(i_data, dtype=m_type)
# to be inferred
input_func = self._content_lookup.get(input_dtype)
if not input_func:
raise ValueError("Input type nt supported {}".format(input_dtype))
input_func = getattr(input0.contents, input_func)
input_func[:] = input_data.flatten()
input0 = request.InferInputTensor()
input0.name = m_name
input_dtype = np.dtype(m_type).type
input0.datatype = self._ext_np_to_triton_dtype(input_dtype)
input0.shape.extend(input_data.shape)
# push into request
request.inputs.extend([input0])
# to be inferred
input_func = self._content_lookup.get(input_dtype)
if not input_func:
raise ValueError("Input type nt supported {}".format(input_dtype))
input_func = getattr(input0.contents, input_func)
input_func[:] = input_data.flatten()
# push into request
request.inputs.extend([input0])
# Populate the outputs in the inference request
output0 = request.InferRequestedOutputTensor()
output0.name = self.model_endpoint.output_name
for m_name in self.model_endpoint.output_name:
output0 = request.InferRequestedOutputTensor()
output0.name = m_name
request.outputs.extend([output0])
request.outputs.extend([output0])
response = grpc_stub.ModelInfer(
request,
compression=self._ext_grpc.Compression.Gzip,
timeout=self._timeout
)
# send infer request over gRPC
compression = None
try:
compression = self._ext_grpc.Compression.Gzip if use_compression \
else self._ext_grpc.Compression.NoCompression
response = grpc_stub.ModelInfer(
request,
compression=compression,
timeout=self._timeout
)
except Exception:
print("Exception calling Triton RPC function: "
"request_inputs={}, ".format([(r.name, r.shape, r.datatype) for r in (request.inputs or [])]) +
f"triton_address={triton_server_address}, compression={compression}, timeout={self._timeout}")
raise
# process result
output_results = []
index = 0
for output in response.outputs:
for i, output in enumerate(response.outputs):
shape = []
for value in output.shape:
shape.append(value)
output_results.append(
np.frombuffer(response.raw_output_contents[index], dtype=self.model_endpoint.output_type))
np.frombuffer(
response.raw_output_contents[index],
dtype=self.model_endpoint.output_type[min(i, len(self.model_endpoint.output_type)-1)]
)
)
output_results[-1] = np.resize(output_results[-1], shape)
index += 1

View File

@ -23,7 +23,7 @@ class Preprocess(object):
image = Image.open(local_file)
image = ImageOps.grayscale(image).resize((28, 28))
return np.array(image).flatten()
return np.array([np.array(image).flatten()])
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine

View File

@ -22,7 +22,7 @@ class Preprocess(object):
local_file = StorageManager.get_local_copy(remote_url=url)
image = Image.open(local_file)
image = ImageOps.grayscale(image).resize((28, 28))
return np.array(image).flatten()
return np.array([np.array(image)])
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine