mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
Register models on serving session
This commit is contained in:
parent
3bddccbaef
commit
78a03cc166
@ -9,7 +9,6 @@ import itertools
|
|||||||
import threading
|
import threading
|
||||||
from multiprocessing import Lock
|
from multiprocessing import Lock
|
||||||
import asyncio
|
import asyncio
|
||||||
from numpy import isin
|
|
||||||
from numpy.random import choice
|
from numpy.random import choice
|
||||||
|
|
||||||
from clearml import Task, Model
|
from clearml import Task, Model
|
||||||
@ -288,6 +287,9 @@ class ModelRequestProcessor(object):
|
|||||||
name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True)
|
name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True)
|
||||||
endpoint.preprocess_artifact = preprocess_artifact_name
|
endpoint.preprocess_artifact = preprocess_artifact_name
|
||||||
|
|
||||||
|
# register the model
|
||||||
|
self._add_registered_input_model(endpoint_url=endpoint.serving_url, model_id=endpoint.model_id)
|
||||||
|
|
||||||
self._endpoints[url] = endpoint
|
self._endpoints[url] = endpoint
|
||||||
return url
|
return url
|
||||||
|
|
||||||
@ -348,6 +350,7 @@ class ModelRequestProcessor(object):
|
|||||||
if endpoint_url not in self._endpoints:
|
if endpoint_url not in self._endpoints:
|
||||||
return False
|
return False
|
||||||
self._endpoints.pop(endpoint_url, None)
|
self._endpoints.pop(endpoint_url, None)
|
||||||
|
self._remove_registered_input_model(endpoint_url)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_canary_endpoint(
|
def add_canary_endpoint(
|
||||||
@ -688,17 +691,14 @@ class ModelRequestProcessor(object):
|
|||||||
if not model:
|
if not model:
|
||||||
# this should never happen
|
# this should never happen
|
||||||
continue
|
continue
|
||||||
ep = ModelEndpoint(
|
model_endpoint_config = {
|
||||||
engine_type=model.engine_type,
|
i: j for i, j in model.as_dict(remove_null_entries=True).items()
|
||||||
serving_url=serving_base_url,
|
if hasattr(ModelEndpoint.__attrs_attrs__, i)
|
||||||
model_id=model_id,
|
}
|
||||||
version=str(version),
|
model_endpoint_config["serving_url"] = serving_base_url
|
||||||
preprocess_artifact=model.preprocess_artifact,
|
model_endpoint_config["model_id"] = model_id
|
||||||
input_size=model.input_size,
|
model_endpoint_config["version"] = str(version)
|
||||||
input_type=model.input_type,
|
ep = ModelEndpoint(**model_endpoint_config)
|
||||||
output_size=model.output_size,
|
|
||||||
output_type=model.output_type
|
|
||||||
)
|
|
||||||
self._model_monitoring_endpoints[url] = ep
|
self._model_monitoring_endpoints[url] = ep
|
||||||
dirty = True
|
dirty = True
|
||||||
|
|
||||||
@ -706,6 +706,7 @@ class ModelRequestProcessor(object):
|
|||||||
for ep_url in list(self._model_monitoring_endpoints.keys()):
|
for ep_url in list(self._model_monitoring_endpoints.keys()):
|
||||||
if not any(True for url in self._model_monitoring_versions if ep_url.startswith(url+"/")):
|
if not any(True for url in self._model_monitoring_versions if ep_url.startswith(url+"/")):
|
||||||
self._model_monitoring_endpoints.pop(ep_url, None)
|
self._model_monitoring_endpoints.pop(ep_url, None)
|
||||||
|
self._remove_registered_input_model(ep_url)
|
||||||
dirty = True
|
dirty = True
|
||||||
|
|
||||||
# reset flag
|
# reset flag
|
||||||
@ -714,6 +715,9 @@ class ModelRequestProcessor(object):
|
|||||||
if dirty:
|
if dirty:
|
||||||
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring_endpoints.items()}
|
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring_endpoints.items()}
|
||||||
self._task.set_configuration_object(name='model_monitoring_eps', config_dict=config_dict)
|
self._task.set_configuration_object(name='model_monitoring_eps', config_dict=config_dict)
|
||||||
|
for m in self._model_monitoring_endpoints.values():
|
||||||
|
# log us on the main task
|
||||||
|
self._add_registered_input_model(endpoint_url=m.serving_url, model_id=m.model_id)
|
||||||
|
|
||||||
return dirty
|
return dirty
|
||||||
|
|
||||||
@ -1299,3 +1303,37 @@ class ModelRequestProcessor(object):
|
|||||||
if not endpoint.auxiliary_cfg and missing:
|
if not endpoint.auxiliary_cfg and missing:
|
||||||
raise ValueError("Triton engine requires input description - missing values in {}".format(missing))
|
raise ValueError("Triton engine requires input description - missing values in {}".format(missing))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _add_registered_input_model(self, endpoint_url: str, model_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Add registered endpoint url, return True if successful
|
||||||
|
"""
|
||||||
|
if not self._task or not model_id or not endpoint_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
self._task.set_input_model(model_id=model_id, name=endpoint_url.strip("/"))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _remove_registered_input_model(self, endpoint_url: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove registered endpoint url, return True if successful
|
||||||
|
"""
|
||||||
|
if not self._task or not endpoint_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
# we assume we have the API version ot support it
|
||||||
|
from clearml.backend_api.services import tasks
|
||||||
|
self._task.send(tasks.DeleteModelsRequest(
|
||||||
|
task=self._task.id, models=[dict(name=endpoint_url.strip("/"), type=tasks.ModelTypeEnum.input)]
|
||||||
|
))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
Loading…
Reference in New Issue
Block a user