Register models on serving session

This commit is contained in:
allegroai 2023-04-12 23:34:49 +03:00
parent 3bddccbaef
commit 78a03cc166

View File

@ -9,7 +9,6 @@ import itertools
import threading import threading
from multiprocessing import Lock from multiprocessing import Lock
import asyncio import asyncio
from numpy import isin
from numpy.random import choice from numpy.random import choice
from clearml import Task, Model from clearml import Task, Model
@ -288,6 +287,9 @@ class ModelRequestProcessor(object):
name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True) name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True)
endpoint.preprocess_artifact = preprocess_artifact_name endpoint.preprocess_artifact = preprocess_artifact_name
# register the model
self._add_registered_input_model(endpoint_url=endpoint.serving_url, model_id=endpoint.model_id)
self._endpoints[url] = endpoint self._endpoints[url] = endpoint
return url return url
@ -348,6 +350,7 @@ class ModelRequestProcessor(object):
if endpoint_url not in self._endpoints: if endpoint_url not in self._endpoints:
return False return False
self._endpoints.pop(endpoint_url, None) self._endpoints.pop(endpoint_url, None)
self._remove_registered_input_model(endpoint_url)
return True return True
def add_canary_endpoint( def add_canary_endpoint(
@ -688,17 +691,14 @@ class ModelRequestProcessor(object):
if not model: if not model:
# this should never happen # this should never happen
continue continue
ep = ModelEndpoint( model_endpoint_config = {
engine_type=model.engine_type, i: j for i, j in model.as_dict(remove_null_entries=True).items()
serving_url=serving_base_url, if hasattr(ModelEndpoint.__attrs_attrs__, i)
model_id=model_id, }
version=str(version), model_endpoint_config["serving_url"] = serving_base_url
preprocess_artifact=model.preprocess_artifact, model_endpoint_config["model_id"] = model_id
input_size=model.input_size, model_endpoint_config["version"] = str(version)
input_type=model.input_type, ep = ModelEndpoint(**model_endpoint_config)
output_size=model.output_size,
output_type=model.output_type
)
self._model_monitoring_endpoints[url] = ep self._model_monitoring_endpoints[url] = ep
dirty = True dirty = True
@ -706,6 +706,7 @@ class ModelRequestProcessor(object):
for ep_url in list(self._model_monitoring_endpoints.keys()): for ep_url in list(self._model_monitoring_endpoints.keys()):
if not any(True for url in self._model_monitoring_versions if ep_url.startswith(url+"/")): if not any(True for url in self._model_monitoring_versions if ep_url.startswith(url+"/")):
self._model_monitoring_endpoints.pop(ep_url, None) self._model_monitoring_endpoints.pop(ep_url, None)
self._remove_registered_input_model(ep_url)
dirty = True dirty = True
# reset flag # reset flag
@ -714,6 +715,9 @@ class ModelRequestProcessor(object):
if dirty: if dirty:
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring_endpoints.items()} config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring_endpoints.items()}
self._task.set_configuration_object(name='model_monitoring_eps', config_dict=config_dict) self._task.set_configuration_object(name='model_monitoring_eps', config_dict=config_dict)
for m in self._model_monitoring_endpoints.values():
# log us on the main task
self._add_registered_input_model(endpoint_url=m.serving_url, model_id=m.model_id)
return dirty return dirty
@ -1299,3 +1303,37 @@ class ModelRequestProcessor(object):
if not endpoint.auxiliary_cfg and missing: if not endpoint.auxiliary_cfg and missing:
raise ValueError("Triton engine requires input description - missing values in {}".format(missing)) raise ValueError("Triton engine requires input description - missing values in {}".format(missing))
return True return True
def _add_registered_input_model(self, endpoint_url: str, model_id: str) -> bool:
"""
Add registered endpoint url, return True if successful
"""
if not self._task or not model_id or not endpoint_url:
return False
# noinspection PyBroadException
try:
self._task.set_input_model(model_id=model_id, name=endpoint_url.strip("/"))
except Exception:
return False
return True
def _remove_registered_input_model(self, endpoint_url: str) -> bool:
"""
Remove registered endpoint url, return True if successful
"""
if not self._task or not endpoint_url:
return False
# noinspection PyBroadException
try:
# we assume we have the API version ot support it
from clearml.backend_api.services import tasks
self._task.send(tasks.DeleteModelsRequest(
task=self._task.id, models=[dict(name=endpoint_url.strip("/"), type=tasks.ModelTypeEnum.input)]
))
except Exception:
return False
return True