mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
Register models on serving session
This commit is contained in:
parent
3bddccbaef
commit
78a03cc166
@ -9,7 +9,6 @@ import itertools
|
||||
import threading
|
||||
from multiprocessing import Lock
|
||||
import asyncio
|
||||
from numpy import isin
|
||||
from numpy.random import choice
|
||||
|
||||
from clearml import Task, Model
|
||||
@ -288,6 +287,9 @@ class ModelRequestProcessor(object):
|
||||
name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True)
|
||||
endpoint.preprocess_artifact = preprocess_artifact_name
|
||||
|
||||
# register the model
|
||||
self._add_registered_input_model(endpoint_url=endpoint.serving_url, model_id=endpoint.model_id)
|
||||
|
||||
self._endpoints[url] = endpoint
|
||||
return url
|
||||
|
||||
@ -348,6 +350,7 @@ class ModelRequestProcessor(object):
|
||||
if endpoint_url not in self._endpoints:
|
||||
return False
|
||||
self._endpoints.pop(endpoint_url, None)
|
||||
self._remove_registered_input_model(endpoint_url)
|
||||
return True
|
||||
|
||||
def add_canary_endpoint(
|
||||
@ -688,17 +691,14 @@ class ModelRequestProcessor(object):
|
||||
if not model:
|
||||
# this should never happen
|
||||
continue
|
||||
ep = ModelEndpoint(
|
||||
engine_type=model.engine_type,
|
||||
serving_url=serving_base_url,
|
||||
model_id=model_id,
|
||||
version=str(version),
|
||||
preprocess_artifact=model.preprocess_artifact,
|
||||
input_size=model.input_size,
|
||||
input_type=model.input_type,
|
||||
output_size=model.output_size,
|
||||
output_type=model.output_type
|
||||
)
|
||||
model_endpoint_config = {
|
||||
i: j for i, j in model.as_dict(remove_null_entries=True).items()
|
||||
if hasattr(ModelEndpoint.__attrs_attrs__, i)
|
||||
}
|
||||
model_endpoint_config["serving_url"] = serving_base_url
|
||||
model_endpoint_config["model_id"] = model_id
|
||||
model_endpoint_config["version"] = str(version)
|
||||
ep = ModelEndpoint(**model_endpoint_config)
|
||||
self._model_monitoring_endpoints[url] = ep
|
||||
dirty = True
|
||||
|
||||
@ -706,6 +706,7 @@ class ModelRequestProcessor(object):
|
||||
for ep_url in list(self._model_monitoring_endpoints.keys()):
|
||||
if not any(True for url in self._model_monitoring_versions if ep_url.startswith(url+"/")):
|
||||
self._model_monitoring_endpoints.pop(ep_url, None)
|
||||
self._remove_registered_input_model(ep_url)
|
||||
dirty = True
|
||||
|
||||
# reset flag
|
||||
@ -714,6 +715,9 @@ class ModelRequestProcessor(object):
|
||||
if dirty:
|
||||
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring_endpoints.items()}
|
||||
self._task.set_configuration_object(name='model_monitoring_eps', config_dict=config_dict)
|
||||
for m in self._model_monitoring_endpoints.values():
|
||||
# log us on the main task
|
||||
self._add_registered_input_model(endpoint_url=m.serving_url, model_id=m.model_id)
|
||||
|
||||
return dirty
|
||||
|
||||
@ -1299,3 +1303,37 @@ class ModelRequestProcessor(object):
|
||||
if not endpoint.auxiliary_cfg and missing:
|
||||
raise ValueError("Triton engine requires input description - missing values in {}".format(missing))
|
||||
return True
|
||||
|
||||
def _add_registered_input_model(self, endpoint_url: str, model_id: str) -> bool:
|
||||
"""
|
||||
Add registered endpoint url, return True if successful
|
||||
"""
|
||||
if not self._task or not model_id or not endpoint_url:
|
||||
return False
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._task.set_input_model(model_id=model_id, name=endpoint_url.strip("/"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _remove_registered_input_model(self, endpoint_url: str) -> bool:
|
||||
"""
|
||||
Remove registered endpoint url, return True if successful
|
||||
"""
|
||||
if not self._task or not endpoint_url:
|
||||
return False
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# we assume we have the API version ot support it
|
||||
from clearml.backend_api.services import tasks
|
||||
self._task.send(tasks.DeleteModelsRequest(
|
||||
task=self._task.id, models=[dict(name=endpoint_url.strip("/"), type=tasks.ModelTypeEnum.input)]
|
||||
))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
Loading…
Reference in New Issue
Block a user