Register models on serving session

This commit is contained in:
allegroai 2023-04-12 23:34:49 +03:00
parent 3bddccbaef
commit 78a03cc166

View File

@ -9,7 +9,6 @@ import itertools
import threading
from multiprocessing import Lock
import asyncio
from numpy import isin
from numpy.random import choice
from clearml import Task, Model
@ -288,6 +287,9 @@ class ModelRequestProcessor(object):
name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True)
endpoint.preprocess_artifact = preprocess_artifact_name
# register the model
self._add_registered_input_model(endpoint_url=endpoint.serving_url, model_id=endpoint.model_id)
self._endpoints[url] = endpoint
return url
@ -348,6 +350,7 @@ class ModelRequestProcessor(object):
if endpoint_url not in self._endpoints:
return False
self._endpoints.pop(endpoint_url, None)
self._remove_registered_input_model(endpoint_url)
return True
def add_canary_endpoint(
@ -688,17 +691,14 @@ class ModelRequestProcessor(object):
if not model:
# this should never happen
continue
ep = ModelEndpoint(
engine_type=model.engine_type,
serving_url=serving_base_url,
model_id=model_id,
version=str(version),
preprocess_artifact=model.preprocess_artifact,
input_size=model.input_size,
input_type=model.input_type,
output_size=model.output_size,
output_type=model.output_type
)
model_endpoint_config = {
i: j for i, j in model.as_dict(remove_null_entries=True).items()
if hasattr(ModelEndpoint.__attrs_attrs__, i)
}
model_endpoint_config["serving_url"] = serving_base_url
model_endpoint_config["model_id"] = model_id
model_endpoint_config["version"] = str(version)
ep = ModelEndpoint(**model_endpoint_config)
self._model_monitoring_endpoints[url] = ep
dirty = True
@ -706,6 +706,7 @@ class ModelRequestProcessor(object):
for ep_url in list(self._model_monitoring_endpoints.keys()):
if not any(True for url in self._model_monitoring_versions if ep_url.startswith(url+"/")):
self._model_monitoring_endpoints.pop(ep_url, None)
self._remove_registered_input_model(ep_url)
dirty = True
# reset flag
@ -714,6 +715,9 @@ class ModelRequestProcessor(object):
if dirty:
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring_endpoints.items()}
self._task.set_configuration_object(name='model_monitoring_eps', config_dict=config_dict)
for m in self._model_monitoring_endpoints.values():
# log us on the main task
self._add_registered_input_model(endpoint_url=m.serving_url, model_id=m.model_id)
return dirty
@ -1299,3 +1303,37 @@ class ModelRequestProcessor(object):
if not endpoint.auxiliary_cfg and missing:
raise ValueError("Triton engine requires input description - missing values in {}".format(missing))
return True
def _add_registered_input_model(self, endpoint_url: str, model_id: str) -> bool:
"""
Add registered endpoint url, return True if successful
"""
if not self._task or not model_id or not endpoint_url:
return False
# noinspection PyBroadException
try:
self._task.set_input_model(model_id=model_id, name=endpoint_url.strip("/"))
except Exception:
return False
return True
def _remove_registered_input_model(self, endpoint_url: str) -> bool:
"""
Remove registered endpoint url, return True if successful
"""
if not self._task or not endpoint_url:
return False
# noinspection PyBroadException
try:
# we assume we have the API version ot support it
from clearml.backend_api.services import tasks
self._task.send(tasks.DeleteModelsRequest(
task=self._task.id, models=[dict(name=endpoint_url.strip("/"), type=tasks.ModelTypeEnum.input)]
))
except Exception:
return False
return True