diff --git a/clearml_serving/serving/model_request_processor.py b/clearml_serving/serving/model_request_processor.py index 22c7414..2bac29f 100644 --- a/clearml_serving/serving/model_request_processor.py +++ b/clearml_serving/serving/model_request_processor.py @@ -9,7 +9,6 @@ import itertools import threading from multiprocessing import Lock import asyncio -from numpy import isin from numpy.random import choice from clearml import Task, Model @@ -288,6 +287,9 @@ class ModelRequestProcessor(object): name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True) endpoint.preprocess_artifact = preprocess_artifact_name + # register the model + self._add_registered_input_model(endpoint_url=endpoint.serving_url, model_id=endpoint.model_id) + self._endpoints[url] = endpoint return url @@ -348,6 +350,7 @@ class ModelRequestProcessor(object): if endpoint_url not in self._endpoints: return False self._endpoints.pop(endpoint_url, None) + self._remove_registered_input_model(endpoint_url) return True def add_canary_endpoint( @@ -688,17 +691,14 @@ class ModelRequestProcessor(object): if not model: # this should never happen continue - ep = ModelEndpoint( - engine_type=model.engine_type, - serving_url=serving_base_url, - model_id=model_id, - version=str(version), - preprocess_artifact=model.preprocess_artifact, - input_size=model.input_size, - input_type=model.input_type, - output_size=model.output_size, - output_type=model.output_type - ) + model_endpoint_config = { + i: j for i, j in model.as_dict(remove_null_entries=True).items() + if hasattr(ModelEndpoint.__attrs_attrs__, i) + } + model_endpoint_config["serving_url"] = serving_base_url + model_endpoint_config["model_id"] = model_id + model_endpoint_config["version"] = str(version) + ep = ModelEndpoint(**model_endpoint_config) self._model_monitoring_endpoints[url] = ep dirty = True @@ -706,6 +706,7 @@ class ModelRequestProcessor(object): for ep_url in list(self._model_monitoring_endpoints.keys()): if not any(True for url in self._model_monitoring_versions if ep_url.startswith(url+"/")): self._model_monitoring_endpoints.pop(ep_url, None) + self._remove_registered_input_model(ep_url) dirty = True # reset flag @@ -714,6 +715,9 @@ class ModelRequestProcessor(object): if dirty: config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring_endpoints.items()} self._task.set_configuration_object(name='model_monitoring_eps', config_dict=config_dict) + for m in self._model_monitoring_endpoints.values(): + # log us on the main task + self._add_registered_input_model(endpoint_url=m.serving_url, model_id=m.model_id) return dirty @@ -1299,3 +1303,37 @@ class ModelRequestProcessor(object): if not endpoint.auxiliary_cfg and missing: raise ValueError("Triton engine requires input description - missing values in {}".format(missing)) return True + + def _add_registered_input_model(self, endpoint_url: str, model_id: str) -> bool: + """ + Add registered endpoint url, return True if successful + """ + if not self._task or not model_id or not endpoint_url: + return False + + # noinspection PyBroadException + try: + self._task.set_input_model(model_id=model_id, name=endpoint_url.strip("/")) + except Exception: + return False + + return True + + def _remove_registered_input_model(self, endpoint_url: str) -> bool: + """ + Remove registered endpoint url, return True if successful + """ + if not self._task or not endpoint_url: + return False + + # noinspection PyBroadException + try: + # we assume we have the API version ot support it + from clearml.backend_api.services import tasks + self._task.send(tasks.DeleteModelsRequest( + task=self._task.id, models=[dict(name=endpoint_url.strip("/"), type=tasks.ModelTypeEnum.input)] + )) + except Exception: + return False + + return True