mirror of
https://github.com/clearml/clearml-serving
synced 2025-01-31 10:56:52 +00:00
1162 lines
48 KiB
Python
1162 lines
48 KiB
Python
import json
|
|
import os
|
|
from pathlib import Path
|
|
from queue import Queue
|
|
from random import random
|
|
from time import sleep, time
|
|
from typing import Optional, Union, Dict, List
|
|
import itertools
|
|
import threading
|
|
from multiprocessing import Lock
|
|
from numpy.random import choice
|
|
|
|
from clearml import Task, Model
|
|
from clearml.storage.util import hash_dict
|
|
from .preprocess_service import BasePreprocessRequest
|
|
from .endpoints import ModelEndpoint, ModelMonitoring, CanaryEP, EndpointMetricLogging
|
|
|
|
|
|
class FastWriteCounter(object):
|
|
def __init__(self):
|
|
self._counter_inc = itertools.count()
|
|
self._counter_dec = itertools.count()
|
|
|
|
def inc(self):
|
|
next(self._counter_inc)
|
|
|
|
def dec(self):
|
|
next(self._counter_dec)
|
|
|
|
def value(self):
|
|
return next(self._counter_inc) - next(self._counter_dec)
|
|
|
|
|
|
class ModelRequestProcessor(object):
|
|
_system_tag = "serving-control-plane"
|
|
_kafka_topic = "clearml_inference_stats"
|
|
_config_key_serving_base_url = "serving_base_url"
|
|
_config_key_triton_grpc = "triton_grpc_server"
|
|
_config_key_kafka_stats = "kafka_service_server"
|
|
_config_key_def_metric_freq = "metric_logging_freq"
|
|
|
|
def __init__(
|
|
self,
|
|
task_id: Optional[str] = None,
|
|
update_lock_guard: Optional[Lock] = None,
|
|
name: Optional[str] = None,
|
|
project: Optional[str] = None,
|
|
tags: Optional[List[str]] = None,
|
|
force_create: bool = False,
|
|
) -> None:
|
|
"""
|
|
:param task_id: Optional specify existing Task ID of the ServingService
|
|
:param update_lock_guard: If provided use external (usually multi-process) lock guard for updates
|
|
:param name: Optional name current serving service
|
|
:param project: Optional select project for the current serving service
|
|
:param tags: Optional add tags to the serving service
|
|
:param force_create: force_create if provided, ignore task_id and create a new serving Task
|
|
"""
|
|
self._task = self._create_task(name=name, project=project, tags=tags) \
|
|
if force_create else self._get_control_plane_task(task_id=task_id, name=name, project=project, tags=tags)
|
|
self._endpoints = dict() # type: Dict[str, ModelEndpoint]
|
|
self._model_monitoring = dict() # type: Dict[str, ModelMonitoring]
|
|
self._model_monitoring_versions = dict() # type: Dict[str, Dict[int, str]]
|
|
self._model_monitoring_endpoints = dict() # type: Dict[str, ModelEndpoint]
|
|
self._model_monitoring_update_request = True
|
|
# Dict[base_serve_url, Dict[version, model_id]]
|
|
self._canary_endpoints = dict() # type: Dict[str, CanaryEP]
|
|
self._canary_route = dict() # type: Dict[str, dict]
|
|
self._engine_processor_lookup = dict() # type: Dict[str, BasePreprocessRequest]
|
|
self._metric_logging = dict() # type: Dict[str, EndpointMetricLogging]
|
|
self._endpoint_metric_logging = dict() # type: Dict[str, EndpointMetricLogging]
|
|
self._last_update_hash = None
|
|
self._sync_daemon_thread = None
|
|
self._stats_sending_thread = None
|
|
self._stats_queue = Queue()
|
|
# this is used for Fast locking mechanisms (so we do not actually need to use Locks)
|
|
self._update_lock_flag = False
|
|
self._request_processing_state = FastWriteCounter()
|
|
self._update_lock_guard = update_lock_guard or threading.Lock()
|
|
self._instance_task = None
|
|
# serving server config
|
|
self._configuration = {}
|
|
# deserialized values go here
|
|
self._kafka_stats_url = None
|
|
self._triton_grpc = None
|
|
self._serving_base_url = None
|
|
self._metric_log_freq = None
|
|
|
|
def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
|
|
"""
|
|
Process request coming in,
|
|
Raise Value error if url does not match existing endpoints
|
|
"""
|
|
self._request_processing_state.inc()
|
|
# check if we need to stall
|
|
if self._update_lock_flag:
|
|
self._request_processing_state.dec()
|
|
while self._update_lock_flag:
|
|
sleep(1)
|
|
# retry to process
|
|
return self.process_request(base_url=base_url, version=version, request_body=request_body)
|
|
|
|
try:
|
|
# normalize url and version
|
|
url = self._normalize_endpoint_url(base_url, version)
|
|
|
|
# check canary
|
|
canary_url = self._process_canary(base_url=url)
|
|
if canary_url:
|
|
url = canary_url
|
|
|
|
ep = self._endpoints.get(url, None) or self._model_monitoring_endpoints.get(url, None)
|
|
if not ep:
|
|
raise ValueError("Model inference endpoint '{}' not found".format(url))
|
|
|
|
processor = self._engine_processor_lookup.get(url)
|
|
if not processor:
|
|
processor_cls = BasePreprocessRequest.get_engine_cls(ep.engine_type)
|
|
processor = processor_cls(model_endpoint=ep, task=self._task)
|
|
self._engine_processor_lookup[url] = processor
|
|
|
|
return_value = self._process_request(processor=processor, url=url, body=request_body)
|
|
finally:
|
|
self._request_processing_state.dec()
|
|
|
|
return return_value
|
|
|
|
def _process_canary(self, base_url: str) -> Optional[dict]:
|
|
canary = self._canary_route.get(base_url)
|
|
if not canary:
|
|
return None
|
|
# random choice
|
|
draw = choice(canary['endpoints'], 1, p=canary['weights'])
|
|
# the new endpoint to use
|
|
return draw[0]
|
|
|
|
def configure(
|
|
self,
|
|
external_serving_base_url: Optional[str] = None,
|
|
external_triton_grpc_server: Optional[str] = None,
|
|
external_kafka_service_server: Optional[str] = None,
|
|
default_metric_log_freq: Optional[float] = None,
|
|
):
|
|
"""
|
|
Set ModelRequestProcessor configuration arguments.
|
|
|
|
:param external_serving_base_url: Set the external base http endpoint for the serving service
|
|
This URL will be passed to user custom preprocess class,
|
|
allowing it to concatenate and combine multiple model requests into one
|
|
:param external_triton_grpc_server: set the external grpc tcp port of the Nvidia Triton clearml container.
|
|
Used by the clearml triton engine class to send inference requests
|
|
:param external_kafka_service_server: Optional, Kafka endpoint for the statistics controller collection.
|
|
:param default_metric_log_freq: Default request metric logging (0 to 1.0, 1. means 100% of requests are logged)
|
|
"""
|
|
if external_serving_base_url is not None:
|
|
self._task.set_parameter(
|
|
name="General/{}".format(self._config_key_serving_base_url),
|
|
value=str(external_serving_base_url),
|
|
value_type="str",
|
|
description="external base http endpoint for the serving service"
|
|
)
|
|
if external_triton_grpc_server is not None:
|
|
self._task.set_parameter(
|
|
name="General/{}".format(self._config_key_triton_grpc),
|
|
value=str(external_triton_grpc_server),
|
|
value_type="str",
|
|
description="external grpc tcp port of the Nvidia Triton ClearML container running"
|
|
)
|
|
if external_kafka_service_server is not None:
|
|
self._task.set_parameter(
|
|
name="General/{}".format(self._config_key_kafka_stats),
|
|
value=str(external_kafka_service_server),
|
|
value_type="str",
|
|
description="external Kafka service url for the statistics controller server"
|
|
)
|
|
if default_metric_log_freq is not None:
|
|
self._task.set_parameter(
|
|
name="General/{}".format(self._config_key_def_metric_freq),
|
|
value=str(default_metric_log_freq),
|
|
value_type="float",
|
|
description="Request metric logging frequency"
|
|
)
|
|
|
|
def get_configuration(self) -> dict:
|
|
return dict(**self._configuration)
|
|
|
|
def add_endpoint(
|
|
self,
|
|
endpoint: Union[ModelEndpoint, dict],
|
|
preprocess_code: Optional[str] = None,
|
|
model_name: Optional[str] = None,
|
|
model_project: Optional[str] = None,
|
|
model_tags: Optional[List[str]] = None,
|
|
model_published: Optional[bool] = None,
|
|
) -> str:
|
|
"""
|
|
Return the unique name of the endpoint (endpoint + version)
|
|
Overwrite existing endpoint if already exists (outputs a warning)
|
|
|
|
:param endpoint: New endpoint to register (overwrite existing endpoint if exists)
|
|
:param preprocess_code: If provided upload local code as artifact
|
|
:param model_name: If model-id not provided on, search based on model name
|
|
:param model_project: If model-id not provided on, search based on model project
|
|
:param model_tags: If model-id not provided on, search based on model tags
|
|
:param model_published: If model-id not provided on, search based on model published state
|
|
"""
|
|
if not isinstance(endpoint, ModelEndpoint):
|
|
endpoint = ModelEndpoint(**endpoint)
|
|
|
|
# make sure we have everything configured
|
|
self._validate_model(endpoint)
|
|
|
|
url = self._normalize_endpoint_url(endpoint.serving_url, endpoint.version)
|
|
if url in self._endpoints:
|
|
print("Warning: Model endpoint \'{}\' overwritten".format(url))
|
|
|
|
if not endpoint.model_id and any([model_project, model_name, model_tags]):
|
|
model_query = dict(
|
|
project_name=model_project,
|
|
model_name=model_name,
|
|
tags=model_tags,
|
|
only_published=bool(model_published),
|
|
include_archived=False,
|
|
)
|
|
models = Model.query_models(max_results=2, **model_query)
|
|
if not models:
|
|
raise ValueError("Could not fine any Model to serve {}".format(model_query))
|
|
if len(models) > 1:
|
|
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
|
|
endpoint.model_id = models[0].id
|
|
elif not endpoint.model_id:
|
|
print("Warning: No Model provided for \'{}\'".format(url))
|
|
|
|
# upload as new artifact
|
|
if preprocess_code:
|
|
if not Path(preprocess_code).exists():
|
|
raise ValueError("Preprocessing code \'{}\' could not be found".format(preprocess_code))
|
|
preprocess_artifact_name = "py_code_{}".format(url.replace("/", "_"))
|
|
self._task.upload_artifact(
|
|
name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True)
|
|
endpoint.preprocess_artifact = preprocess_artifact_name
|
|
|
|
self._endpoints[url] = endpoint
|
|
return url
|
|
|
|
def add_model_monitoring(
|
|
self,
|
|
monitoring: Union[ModelMonitoring, dict],
|
|
preprocess_code: Optional[str] = None,
|
|
) -> str:
|
|
"""
|
|
Return the unique name of the endpoint (endpoint + version)
|
|
Overwrite existing endpoint if already exists (outputs a warning)
|
|
|
|
:param monitoring: Model endpoint monitor (overwrite existing endpoint if exists)
|
|
:param preprocess_code: If provided upload local code as artifact
|
|
:return: Unique model monitoring ID (base_model_url)
|
|
"""
|
|
if not isinstance(monitoring, ModelMonitoring):
|
|
monitoring = ModelMonitoring(**monitoring)
|
|
|
|
# make sure we actually have something to monitor
|
|
if not any([monitoring.monitor_project, monitoring.monitor_name, monitoring.monitor_tags]):
|
|
raise ValueError("Model monitoring requires at least a "
|
|
"project / name / tag to monitor, none were provided.")
|
|
|
|
# make sure we have everything configured
|
|
self._validate_model(monitoring)
|
|
|
|
name = monitoring.base_serving_url
|
|
if name in self._model_monitoring:
|
|
print("Warning: Model monitoring \'{}\' overwritten".format(name))
|
|
|
|
# upload as new artifact
|
|
if preprocess_code:
|
|
if not Path(preprocess_code).exists():
|
|
raise ValueError("Preprocessing code \'{}\' could not be found".format(preprocess_code))
|
|
preprocess_artifact_name = "py_code_{}".format(name.replace("/", "_"))
|
|
self._task.upload_artifact(
|
|
name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True)
|
|
monitoring.preprocess_artifact = preprocess_artifact_name
|
|
|
|
self._model_monitoring[name] = monitoring
|
|
return name
|
|
|
|
def remove_model_monitoring(self, model_base_url: str) -> bool:
|
|
"""
|
|
Remove model monitoring, use base_model_url as unique identifier
|
|
"""
|
|
if model_base_url not in self._model_monitoring:
|
|
return False
|
|
self._model_monitoring.pop(model_base_url, None)
|
|
return True
|
|
|
|
def remove_endpoint(self, endpoint_url: str, version: Optional[str] = None) -> bool:
|
|
"""
|
|
Remove specific model endpoint, use base_model_url as unique identifier
|
|
"""
|
|
endpoint_url = self._normalize_endpoint_url(endpoint_url, version)
|
|
if endpoint_url not in self._endpoints:
|
|
return False
|
|
self._endpoints.pop(endpoint_url, None)
|
|
return True
|
|
|
|
def add_canary_endpoint(
|
|
self,
|
|
canary: Union[CanaryEP, dict],
|
|
) -> str:
|
|
"""
|
|
Return the unique name of the endpoint (endpoint + version)
|
|
Overwrite existing endpoint if already exists (outputs a warning)
|
|
|
|
:param canary: Canary endpoint router (overwrite existing endpoint if exists)
|
|
:return: Unique canary ID (base_model_url)
|
|
"""
|
|
if not isinstance(canary, CanaryEP):
|
|
canary = CanaryEP(**canary)
|
|
if canary.load_endpoints and canary.load_endpoint_prefix:
|
|
raise ValueError(
|
|
"Could not add canary endpoint with both "
|
|
"prefix ({}) and fixed set of endpoints ({})".format(
|
|
canary.load_endpoints, canary.load_endpoint_prefix))
|
|
name = canary.endpoint
|
|
if name in self._canary_endpoints:
|
|
print("Warning: Model monitoring \'{}\' overwritten".format(name))
|
|
|
|
self._canary_endpoints[name] = canary
|
|
return name
|
|
|
|
def remove_canary_endpoint(self, endpoint_url: str) -> bool:
|
|
"""
|
|
Remove specific canary model endpoint, use base_model_url as unique identifier
|
|
"""
|
|
if endpoint_url not in self._canary_endpoints:
|
|
return False
|
|
self._canary_endpoints.pop(endpoint_url, None)
|
|
return True
|
|
|
|
def add_metric_logging(self, metric: Union[EndpointMetricLogging, dict]) -> bool:
|
|
"""
|
|
Add metric logging to a specific endpoint
|
|
Valid metric variable are any variables on the request or response dictionary,
|
|
or a custom preprocess reported variable
|
|
|
|
When overwriting and existing monitored variable, output a warning.
|
|
|
|
:param metric: Metric variable to monitor
|
|
:return: True if successful
|
|
"""
|
|
if not isinstance(metric, EndpointMetricLogging):
|
|
metric = EndpointMetricLogging(**metric)
|
|
|
|
name = str(metric.endpoint).strip("/")
|
|
metric.endpoint = name
|
|
|
|
if name not in self._endpoints and not name.endswith('*'):
|
|
raise ValueError("Metric logging \'{}\' references a nonexistent endpoint".format(name))
|
|
|
|
if name in self._metric_logging:
|
|
print("Warning: Metric logging \'{}\' overwritten".format(name))
|
|
|
|
self._metric_logging[name] = metric
|
|
return True
|
|
|
|
def remove_metric_logging(
|
|
self,
|
|
endpoint: str,
|
|
variable_name: str = None,
|
|
) -> bool:
|
|
"""
|
|
Remove existing logged metric variable. Use variable name and endpoint as unique identifier
|
|
|
|
:param endpoint: Endpoint name (including version, e.g. "model/1" or "model/*")
|
|
:param variable_name: Variable name (str), pass None to remove the entire endpoint logging
|
|
|
|
:return: True if successful
|
|
"""
|
|
|
|
name = str(endpoint).strip("/")
|
|
|
|
if name not in self._metric_logging or \
|
|
(variable_name and variable_name not in self._metric_logging[name].metrics):
|
|
return False
|
|
|
|
if not variable_name:
|
|
self._metric_logging.pop(name, None)
|
|
else:
|
|
self._metric_logging[name].metrics.pop(variable_name, None)
|
|
|
|
return True
|
|
|
|
def list_metric_logging(self) -> Dict[str, EndpointMetricLogging]:
|
|
"""
|
|
List existing logged metric variables.
|
|
|
|
:return: Dictionary, key='endpoint/version' value=EndpointMetricLogging
|
|
"""
|
|
|
|
return dict(**self._metric_logging)
|
|
|
|
def list_endpoint_logging(self) -> Dict[str, EndpointMetricLogging]:
|
|
"""
|
|
List endpoints (fully synced) current metric logging state.
|
|
|
|
:return: Dictionary, key='endpoint/version' value=EndpointMetricLogging
|
|
"""
|
|
|
|
return dict(**self._endpoint_metric_logging)
|
|
|
|
def deserialize(
|
|
self,
|
|
task: Task = None,
|
|
prefetch_artifacts: bool = False,
|
|
skip_sync: bool = False,
|
|
update_current_task: bool = True
|
|
) -> bool:
|
|
"""
|
|
Restore ModelRequestProcessor state from Task
|
|
return True if actually needed serialization, False nothing changed
|
|
:param task: Load data from Task
|
|
:param prefetch_artifacts: If True prefetch artifacts requested by the endpoints
|
|
:param skip_sync: If True do not update the canary/monitoring state
|
|
:param update_current_task: is not skip_sync, and is True,
|
|
update the current Task with the configuration synced from the serving service Task
|
|
"""
|
|
if not task:
|
|
task = self._task
|
|
configuration = task.get_parameters_as_dict().get("General") or {}
|
|
endpoints = task.get_configuration_object_as_dict(name='endpoints') or {}
|
|
canary_ep = task.get_configuration_object_as_dict(name='canary') or {}
|
|
model_monitoring = task.get_configuration_object_as_dict(name='model_monitoring') or {}
|
|
metric_logging = task.get_configuration_object_as_dict(name='metric_logging') or {}
|
|
|
|
hashed_conf = hash_dict(
|
|
dict(endpoints=endpoints,
|
|
canary_ep=canary_ep,
|
|
model_monitoring=model_monitoring,
|
|
metric_logging=metric_logging,
|
|
configuration=configuration)
|
|
)
|
|
if self._last_update_hash == hashed_conf and not self._model_monitoring_update_request:
|
|
return False
|
|
print("Info: syncing model endpoint configuration, state hash={}".format(hashed_conf))
|
|
self._last_update_hash = hashed_conf
|
|
|
|
endpoints = {
|
|
k: ModelEndpoint(**{i: j for i, j in v.items() if hasattr(ModelEndpoint.__attrs_attrs__, i)})
|
|
for k, v in endpoints.items()
|
|
}
|
|
model_monitoring = {
|
|
k: ModelMonitoring(**{i: j for i, j in v.items() if hasattr(ModelMonitoring.__attrs_attrs__, i)})
|
|
for k, v in model_monitoring.items()
|
|
}
|
|
canary_endpoints = {
|
|
k: CanaryEP(**{i: j for i, j in v.items() if hasattr(CanaryEP.__attrs_attrs__, i)})
|
|
for k, v in canary_ep.items()
|
|
}
|
|
metric_logging = {
|
|
k: EndpointMetricLogging(**{i: j for i, j in v.items() if hasattr(EndpointMetricLogging.__attrs_attrs__, i)})
|
|
for k, v in metric_logging.items()
|
|
}
|
|
|
|
# if there is no need to sync Canary and Models we can just leave
|
|
if skip_sync:
|
|
self._endpoints = endpoints
|
|
self._model_monitoring = model_monitoring
|
|
self._canary_endpoints = canary_endpoints
|
|
self._metric_logging = metric_logging
|
|
self._deserialize_conf_dict(configuration)
|
|
return True
|
|
|
|
# make sure we only have one stall request at any given moment
|
|
with self._update_lock_guard:
|
|
# download artifacts
|
|
# todo: separate into two, download before lock, and overwrite inside lock
|
|
if prefetch_artifacts:
|
|
for item in list(endpoints.values()) + list(model_monitoring.values()):
|
|
if item.preprocess_artifact:
|
|
# noinspection PyBroadException
|
|
try:
|
|
self._task.artifacts[item.preprocess_artifact].get_local_copy(
|
|
extract_archive=True,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
# stall all requests
|
|
self._update_lock_flag = True
|
|
# wait until we have no request processed
|
|
while self._request_processing_state.value() != 0:
|
|
sleep(1)
|
|
|
|
self._endpoints = endpoints
|
|
self._model_monitoring = model_monitoring
|
|
self._canary_endpoints = canary_endpoints
|
|
self._metric_logging = metric_logging
|
|
self._deserialize_conf_dict(configuration)
|
|
|
|
# if we have models we need to sync, now is the time
|
|
self._sync_monitored_models()
|
|
|
|
self._update_canary_lookup()
|
|
|
|
self._sync_metric_logging()
|
|
|
|
# release stall lock
|
|
self._update_lock_flag = False
|
|
|
|
# update the state on the inference task
|
|
if update_current_task and Task.current_task() and Task.current_task().id != self._task.id:
|
|
self.serialize(task=Task.current_task())
|
|
|
|
return True
|
|
|
|
def serialize(self, task: Optional[Task] = None) -> None:
|
|
"""
|
|
Store ModelRequestProcessor state into Task
|
|
"""
|
|
if not task:
|
|
task = self._task
|
|
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._endpoints.items()}
|
|
task.set_configuration_object(name='endpoints', config_dict=config_dict)
|
|
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._canary_endpoints.items()}
|
|
task.set_configuration_object(name='canary', config_dict=config_dict)
|
|
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring.items()}
|
|
task.set_configuration_object(name='model_monitoring', config_dict=config_dict)
|
|
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._metric_logging.items()}
|
|
task.set_configuration_object(name='metric_logging', config_dict=config_dict)
|
|
|
|
def _update_canary_lookup(self):
|
|
canary_route = {}
|
|
for k, v in self._canary_endpoints.items():
|
|
if v.load_endpoint_prefix and v.load_endpoints:
|
|
print("Warning: Canary has both prefix and fixed endpoints, ignoring canary endpoint")
|
|
continue
|
|
if v.load_endpoints:
|
|
if len(v.load_endpoints) != len(v.weights):
|
|
print("Warning: Canary \'{}\' weights [{}] do not match number of endpoints [{}], skipping!".format(
|
|
k, v.weights, v.load_endpoints))
|
|
continue
|
|
endpoints = []
|
|
weights = []
|
|
for w, ep in zip(v.weights, v.load_endpoints):
|
|
if ep not in self._endpoints and ep not in self._model_monitoring_endpoints:
|
|
print("Warning: Canary \'{}\' endpoint \'{}\' could not be found, skipping".format(k, ep))
|
|
continue
|
|
endpoints.append(ep)
|
|
weights.append(float(w))
|
|
# normalize weights
|
|
sum_weights = sum(weights)
|
|
weights = [w/sum_weights for w in weights]
|
|
canary_route[k] = dict(endpoints=endpoints, weights=weights)
|
|
elif v.load_endpoint_prefix:
|
|
endpoints = [ep for ep in list(self._endpoints.keys()) + list(self._model_monitoring_endpoints.keys())
|
|
if str(ep).startswith(v.load_endpoint_prefix)]
|
|
endpoints = sorted(
|
|
endpoints,
|
|
reverse=True,
|
|
key=lambda x: '{}/{:0>9}'.format('/'.join(x.split('/')[:-1]), x.split('/')[-1]) if '/' in x else x
|
|
)
|
|
endpoints = endpoints[:len(v.weights)]
|
|
weights = v.weights[:len(endpoints)]
|
|
# normalize weights
|
|
sum_weights = sum(weights)
|
|
weights = [w/sum_weights for w in weights]
|
|
canary_route[k] = dict(endpoints=endpoints, weights=weights)
|
|
self._report_text(
|
|
"Info: Canary endpoint \'{}\' selected [{}]".format(k, canary_route[k])
|
|
)
|
|
|
|
# update back
|
|
self._canary_route = canary_route
|
|
|
|
def _sync_monitored_models(self, force: bool = False) -> bool:
|
|
if not force and not self._model_monitoring_update_request:
|
|
return False
|
|
dirty = False
|
|
|
|
for serving_base_url, versions_model_id_dict in self._model_monitoring_versions.items():
|
|
# find existing endpoint versions
|
|
for ep_base_url in list(self._model_monitoring_endpoints.keys()):
|
|
# skip over endpoints that are not our own
|
|
if not ep_base_url.startswith(serving_base_url+"/"):
|
|
continue
|
|
# find endpoint version
|
|
_, version = ep_base_url.split("/", 1)
|
|
if int(version) not in versions_model_id_dict:
|
|
# remove old endpoint
|
|
self._model_monitoring_endpoints.pop(ep_base_url, None)
|
|
dirty = True
|
|
continue
|
|
|
|
# add new endpoint
|
|
for version, model_id in versions_model_id_dict.items():
|
|
url = "{}/{}".format(serving_base_url, version)
|
|
if url in self._model_monitoring_endpoints:
|
|
continue
|
|
model = self._model_monitoring.get(serving_base_url)
|
|
if not model:
|
|
# this should never happen
|
|
continue
|
|
ep = ModelEndpoint(
|
|
engine_type=model.engine_type,
|
|
serving_url=serving_base_url,
|
|
model_id=model_id,
|
|
version=str(version),
|
|
preprocess_artifact=model.preprocess_artifact,
|
|
input_size=model.input_size,
|
|
input_type=model.input_type,
|
|
output_size=model.output_size,
|
|
output_type=model.output_type
|
|
)
|
|
self._model_monitoring_endpoints[url] = ep
|
|
dirty = True
|
|
|
|
# filter out old model monitoring endpoints
|
|
for ep_url in list(self._model_monitoring_endpoints.keys()):
|
|
if not any(True for url in self._model_monitoring_versions if ep_url.startswith(url+"/")):
|
|
self._model_monitoring_endpoints.pop(ep_url, None)
|
|
dirty = True
|
|
|
|
# reset flag
|
|
self._model_monitoring_update_request = False
|
|
|
|
if dirty:
|
|
config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring_endpoints.items()}
|
|
self._task.set_configuration_object(name='model_monitoring_eps', config_dict=config_dict)
|
|
|
|
return dirty
|
|
|
|
def _update_monitored_models(self):
|
|
for model in self._model_monitoring.values():
|
|
current_served_models = self._model_monitoring_versions.get(model.base_serving_url, {})
|
|
# To Do: sort by updated time ?
|
|
models = Model.query_models(
|
|
project_name=model.monitor_project or None,
|
|
model_name=model.monitor_name or None,
|
|
tags=model.monitor_tags or None,
|
|
only_published=model.only_published,
|
|
max_results=model.max_versions,
|
|
include_archived=False,
|
|
)
|
|
|
|
# check what we already have:
|
|
current_model_id_version_lookup = dict(
|
|
zip(list(current_served_models.values()), list(current_served_models.keys()))
|
|
)
|
|
versions = sorted(current_served_models.keys(), reverse=True)
|
|
|
|
# notice, most updated model first
|
|
# first select only the new models
|
|
model_ids = [m.id for m in models]
|
|
|
|
# we want last updated model to be last (so it gets the highest version number)
|
|
max_v = 1 + (versions[0] if versions else 0)
|
|
versions_model_ids = []
|
|
for m_id in reversed(model_ids):
|
|
v = current_model_id_version_lookup.get(m_id)
|
|
if v is None:
|
|
v = max_v
|
|
max_v += 1
|
|
versions_model_ids.append((v, m_id))
|
|
|
|
# remove extra entries (old models)
|
|
versions_model_ids_dict = dict(versions_model_ids[:model.max_versions])
|
|
|
|
# mark dirty if something changed:
|
|
if versions_model_ids_dict != current_served_models:
|
|
self._model_monitoring_update_request = True
|
|
|
|
# update model serving state
|
|
self._model_monitoring_versions[model.base_serving_url] = versions_model_ids_dict
|
|
|
|
if not self._model_monitoring_update_request:
|
|
return False
|
|
|
|
self._report_text("INFO: Monitored Models updated: {}".format(
|
|
json.dumps(self._model_monitoring_versions, indent=2))
|
|
)
|
|
return True
|
|
|
|
def _sync_metric_logging(self, force: bool = False) -> bool:
|
|
if not force and not self._metric_logging:
|
|
return False
|
|
|
|
fixed_metric_endpoint = {
|
|
k: v for k, v in self._metric_logging.items() if "*/" not in k
|
|
}
|
|
prefix_metric_endpoint = {k.split("*/")[0]: v for k, v in self._metric_logging.items() if "*/" in k}
|
|
|
|
endpoint_metric_logging = {}
|
|
for k, ep in list(self._endpoints.items()) + list(self._model_monitoring_endpoints.items()):
|
|
if k in fixed_metric_endpoint:
|
|
if k not in endpoint_metric_logging:
|
|
endpoint_metric_logging[k] = fixed_metric_endpoint[k]
|
|
|
|
continue
|
|
for p, v in prefix_metric_endpoint.items():
|
|
if k.startswith(p):
|
|
if k not in endpoint_metric_logging:
|
|
endpoint_metric_logging[k] = v
|
|
|
|
break
|
|
|
|
self._endpoint_metric_logging = endpoint_metric_logging
|
|
return True
|
|
|
|
def launch(self, poll_frequency_sec=300):
|
|
"""
|
|
Launch the background synchronization thread and monitoring thread
|
|
(updating runtime process based on changes on the Task, and monitoring model changes in the system)
|
|
:param poll_frequency_sec: Sync every X seconds (default 300 seconds)
|
|
"""
|
|
if self._sync_daemon_thread:
|
|
return
|
|
|
|
# read state
|
|
self.deserialize(self._task, prefetch_artifacts=True)
|
|
# model monitoring sync
|
|
if self._update_monitored_models():
|
|
# update endpoints
|
|
self.deserialize(self._task, prefetch_artifacts=True)
|
|
|
|
# get the serving instance (for visibility and monitoring)
|
|
self._instance_task = Task.current_task()
|
|
|
|
# start the background thread
|
|
with self._update_lock_guard:
|
|
if self._sync_daemon_thread:
|
|
return
|
|
self._sync_daemon_thread = threading.Thread(
|
|
target=self._sync_daemon, args=(poll_frequency_sec, ), daemon=True)
|
|
self._stats_sending_thread = threading.Thread(
|
|
target=self._stats_send_loop, daemon=True)
|
|
|
|
self._sync_daemon_thread.start()
|
|
self._stats_sending_thread.start()
|
|
|
|
# we return immediately
|
|
|
|
def _sync_daemon(self, poll_frequency_sec: float = 300) -> None:
|
|
"""
|
|
Background thread, syncing model changes into request service.
|
|
"""
|
|
poll_frequency_sec = float(poll_frequency_sec)
|
|
# force mark started on the main serving service task
|
|
self._task.mark_started(force=True)
|
|
self._report_text("Launching - configuration sync every {} sec".format(poll_frequency_sec))
|
|
cleanup = False
|
|
self._update_serving_plot()
|
|
while True:
|
|
try:
|
|
# this should be the only place where we call deserialize
|
|
self._task.reload()
|
|
if self.deserialize(self._task):
|
|
self._report_text("New configuration updated")
|
|
# mark clean up for next round
|
|
cleanup = True
|
|
# model monitoring sync
|
|
if self._update_monitored_models():
|
|
self._report_text("Model monitoring synced")
|
|
# update endpoints
|
|
self.deserialize(self._task)
|
|
# mark clean up for next round
|
|
cleanup = True
|
|
# update serving layout plot
|
|
if cleanup:
|
|
self._update_serving_plot()
|
|
except Exception as ex:
|
|
print("Exception occurred in monitoring thread: {}".format(ex))
|
|
sleep(poll_frequency_sec)
|
|
try:
|
|
# we assume that by now all old deleted endpoints requests already returned
|
|
if cleanup:
|
|
cleanup = False
|
|
for k in list(self._engine_processor_lookup.keys()):
|
|
if k not in self._endpoints:
|
|
# atomic
|
|
self._engine_processor_lookup.pop(k, None)
|
|
except Exception as ex:
|
|
print("Exception occurred in monitoring thread: {}".format(ex))
|
|
|
|
def _stats_send_loop(self) -> None:
|
|
"""
|
|
Background thread for sending stats to Kafka service
|
|
"""
|
|
if not self._kafka_stats_url:
|
|
print("No Kafka Statistics service configured, shutting down statistics report")
|
|
return
|
|
|
|
print("Starting Kafka Statistics reporting: {}".format(self._kafka_stats_url))
|
|
|
|
from kafka import KafkaProducer # noqa
|
|
|
|
while True:
|
|
try:
|
|
producer = KafkaProducer(
|
|
bootstrap_servers=self._kafka_stats_url, # ['localhost:9092'],
|
|
value_serializer=lambda x: json.dumps(x).encode('utf-8'),
|
|
compression_type='lz4', # requires python lz4 package
|
|
)
|
|
break
|
|
except Exception as ex:
|
|
print("Error: failed opening Kafka consumer [{}]: {}".format(self._kafka_stats_url, ex))
|
|
print("Retrying in 30 seconds")
|
|
sleep(30)
|
|
|
|
while True:
|
|
try:
|
|
stats_dict = self._stats_queue.get(block=True)
|
|
except Exception as ex:
|
|
print("Warning: Statistics thread exception: {}".format(ex))
|
|
break
|
|
# send into kafka service
|
|
try:
|
|
producer.send(self._kafka_topic, value=stats_dict).get()
|
|
except Exception as ex:
|
|
print("Warning: Failed to send statistics packet to Kafka service: {}".format(ex))
|
|
pass
|
|
|
|
def get_id(self) -> str:
|
|
return self._task.id
|
|
|
|
def get_endpoints(self) -> Dict[str, ModelEndpoint]:
|
|
endpoints = dict(**self._endpoints)
|
|
endpoints.update(**self._model_monitoring_endpoints)
|
|
return endpoints
|
|
|
|
def get_synced_endpoints(self) -> Dict[str, ModelEndpoint]:
|
|
self._task.reload()
|
|
_endpoints = self._task.get_configuration_object_as_dict(name='endpoints') or {}
|
|
_monitor_endpoints = self._task.get_configuration_object_as_dict(name='model_monitoring_eps') or {}
|
|
endpoints = {
|
|
k: ModelEndpoint(**{i: j for i, j in v.items() if hasattr(ModelEndpoint.__attrs_attrs__, i)})
|
|
for k, v in _endpoints.items()}
|
|
endpoints.update({
|
|
k: ModelEndpoint(**{i: j for i, j in v.items() if hasattr(ModelEndpoint.__attrs_attrs__, i)})
|
|
for k, v in _monitor_endpoints.items()
|
|
})
|
|
return endpoints
|
|
|
|
def get_canary_endpoints(self) -> dict:
|
|
return self._canary_endpoints
|
|
|
|
def get_model_monitoring(self) -> dict:
|
|
return self._model_monitoring
|
|
|
|
def _get_instance_id(self) -> Optional[str]:
|
|
return self._instance_task.id if self._instance_task else None
|
|
|
|
def _report_text(self, text) -> Optional[str]:
|
|
return self._task.get_logger().report_text("Instance [{}, pid={}]: {}".format(
|
|
self._get_instance_id(), os.getpid(), text))
|
|
|
|
def _update_serving_plot(self) -> None:
|
|
"""
|
|
Update the endpoint serving graph on the serving instance Task
|
|
"""
|
|
if not self._instance_task:
|
|
return
|
|
|
|
# Generate configuration table and details
|
|
endpoints = list(self._endpoints.values()) + list(self._model_monitoring_endpoints.values())
|
|
if not endpoints:
|
|
# clear plot if we had any
|
|
return
|
|
|
|
endpoints = [e.as_dict() for e in endpoints]
|
|
table_values = [list(endpoints[0].keys())]
|
|
table_values += [[e[c] or "" for c in table_values[0]] for e in endpoints]
|
|
self._instance_task.get_logger().report_table(
|
|
title='Serving Endpoint Configuration', series='Details', iteration=0, table_plot=table_values,
|
|
extra_layout={"title": "Model Endpoints Details"})
|
|
|
|
# generate current endpoint view
|
|
sankey_node = dict(
|
|
label=[],
|
|
color=[],
|
|
customdata=[],
|
|
hovertemplate='%{customdata}<extra></extra>',
|
|
hoverlabel={"align": "left"},
|
|
)
|
|
sankey_link = dict(
|
|
source=[],
|
|
target=[],
|
|
value=[],
|
|
hovertemplate='<extra></extra>',
|
|
)
|
|
# root
|
|
sankey_node['color'].append("mediumpurple")
|
|
sankey_node['label'].append('{}'.format('external'))
|
|
sankey_node['customdata'].append("")
|
|
|
|
sankey_node_idx = {}
|
|
|
|
# base_url = self._task._get_app_server() + '/projects/*/models/{model_id}/general'
|
|
|
|
# draw all static endpoints
|
|
# noinspection PyProtectedMember
|
|
for i, ep in enumerate(endpoints):
|
|
serve_url = ep['serving_url']
|
|
full_url = '{}/{}'.format(serve_url, ep['version'] or "")
|
|
sankey_node['color'].append("blue")
|
|
sankey_node['label'].append("/{}/".format(full_url.strip("/")))
|
|
sankey_node['customdata'].append(
|
|
"model id: {}".format(ep['model_id'])
|
|
)
|
|
sankey_link['source'].append(0)
|
|
sankey_link['target'].append(i + 1)
|
|
sankey_link['value'].append(1. / len(self._endpoints))
|
|
sankey_node_idx[full_url] = i + 1
|
|
|
|
# draw all model monitoring
|
|
sankey_node['color'].append("mediumpurple")
|
|
sankey_node['label'].append('{}'.format('monitoring models'))
|
|
sankey_node['customdata'].append("")
|
|
monitoring_root_idx = len(sankey_node['customdata']) - 1
|
|
|
|
for i, m in enumerate(self._model_monitoring.values()):
|
|
serve_url = m.base_serving_url
|
|
sankey_node['color'].append("purple")
|
|
sankey_node['label'].append('{}'.format(serve_url))
|
|
sankey_node['customdata'].append(
|
|
"project: {}<br />name: {}<br />tags: {}".format(
|
|
m.monitor_project or '', m.monitor_name or '', m.monitor_tags or '')
|
|
)
|
|
sankey_link['source'].append(monitoring_root_idx)
|
|
sankey_link['target'].append(monitoring_root_idx + i + 1)
|
|
sankey_link['value'].append(1. / len(self._model_monitoring))
|
|
|
|
# add links to the current models
|
|
serve_url = serve_url.rstrip("/") + "/"
|
|
for k in sankey_node_idx:
|
|
if k.startswith(serve_url):
|
|
sankey_link['source'].append(monitoring_root_idx + i + 1)
|
|
sankey_link['target'].append(sankey_node_idx[k])
|
|
sankey_link['value'].append(1.0 / m.max_versions)
|
|
|
|
# add canary endpoints
|
|
# sankey_node['color'].append("mediumpurple")
|
|
# sankey_node['label'].append('{}'.format('Canary endpoints'))
|
|
# sankey_node['customdata'].append("")
|
|
canary_root_idx = len(sankey_node['customdata']) - 1
|
|
|
|
# sankey_link['source'].append(0)
|
|
# sankey_link['target'].append(canary_root_idx)
|
|
# sankey_link['value'].append(1.)
|
|
|
|
for i, c in enumerate(self._canary_endpoints.values()):
|
|
serve_url = c.endpoint
|
|
sankey_node['color'].append("green")
|
|
sankey_node['label'].append('CANARY: /{}/'.format(serve_url.strip("/")))
|
|
sankey_node['customdata'].append(
|
|
"outputs: {}".format(
|
|
c.load_endpoints or c.load_endpoint_prefix)
|
|
)
|
|
sankey_link['source'].append(0)
|
|
sankey_link['target'].append(canary_root_idx + i + 1)
|
|
sankey_link['value'].append(1. / len(self._canary_endpoints))
|
|
|
|
# add links to the current models
|
|
if serve_url not in self._canary_route:
|
|
continue
|
|
for ep, w in zip(self._canary_route[serve_url]['endpoints'], self._canary_route[serve_url]['weights']):
|
|
idx = sankey_node_idx.get(ep)
|
|
if idx is None:
|
|
continue
|
|
sankey_link['source'].append(canary_root_idx + i + 1)
|
|
sankey_link['target'].append(idx)
|
|
sankey_link['value'].append(w)
|
|
|
|
# create the sankey graph
|
|
dag_flow = dict(
|
|
link=sankey_link,
|
|
node=sankey_node,
|
|
textfont=dict(color='rgba(0,0,0,255)', size=10),
|
|
type='sankey',
|
|
orientation='h'
|
|
)
|
|
fig = dict(data=[dag_flow], layout={'xaxis': {'visible': False}, 'yaxis': {'visible': False}})
|
|
|
|
self._instance_task.get_logger().report_plotly(
|
|
title='Serving Endpoints Layout', series='', iteration=0, figure=fig)
|
|
|
|
def _deserialize_conf_dict(self, configuration: dict) -> None:
|
|
self._configuration = configuration
|
|
|
|
# deserialized values go here
|
|
self._kafka_stats_url = \
|
|
configuration.get(self._config_key_kafka_stats) or \
|
|
os.environ.get("CLEARML_DEFAULT_KAFKA_SERVE_URL")
|
|
self._triton_grpc = \
|
|
configuration.get(self._config_key_triton_grpc) or \
|
|
os.environ.get("CLEARML_DEFAULT_TRITON_GRPC_ADDR")
|
|
self._serving_base_url = \
|
|
configuration.get(self._config_key_serving_base_url) or \
|
|
os.environ.get("CLEARML_DEFAULT_BASE_SERVE_URL")
|
|
self._metric_log_freq = \
|
|
float(configuration.get(self._config_key_def_metric_freq,
|
|
os.environ.get("CLEARML_DEFAULT_METRIC_LOG_FREQ", 1.0)))
|
|
# update back configuration
|
|
self._configuration[self._config_key_kafka_stats] = self._kafka_stats_url
|
|
self._configuration[self._config_key_triton_grpc] = self._triton_grpc
|
|
self._configuration[self._config_key_serving_base_url] = self._serving_base_url
|
|
self._configuration[self._config_key_def_metric_freq] = self._metric_log_freq
|
|
# update preprocessing classes
|
|
BasePreprocessRequest.set_server_config(self._configuration)
|
|
|
|
def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
|
|
# collect statistics for this request
|
|
stats = {}
|
|
stats_collect_fn = None
|
|
collect_stats = False
|
|
freq = 1
|
|
# decide if we are collecting the stats
|
|
metric_endpoint = self._metric_logging.get(url)
|
|
if self._kafka_stats_url:
|
|
freq = metric_endpoint.log_frequency if metric_endpoint and metric_endpoint.log_frequency is not None \
|
|
else self._metric_log_freq
|
|
|
|
if freq and random() <= freq:
|
|
stats_collect_fn = stats.update
|
|
collect_stats = True
|
|
|
|
tic = time()
|
|
preprocessed = processor.preprocess(body, stats_collect_fn)
|
|
processed = processor.process(preprocessed, stats_collect_fn)
|
|
return_value = processor.postprocess(processed, stats_collect_fn)
|
|
tic = time() - tic
|
|
if collect_stats:
|
|
# 10th of a millisecond should be enough
|
|
stats['_latency'] = round(tic, 4)
|
|
stats['_count'] = int(1.0/freq)
|
|
stats['_url'] = url
|
|
|
|
# collect inputs
|
|
if metric_endpoint and body:
|
|
for k, v in body.items():
|
|
if k in metric_endpoint.metrics:
|
|
stats[k] = v
|
|
# collect outputs
|
|
if metric_endpoint and return_value:
|
|
for k, v in return_value.items():
|
|
if k in metric_endpoint.metrics:
|
|
stats[k] = v
|
|
|
|
# send stats in background, push it into a thread queue
|
|
# noinspection PyBroadException
|
|
try:
|
|
self._stats_queue.put(stats, block=False)
|
|
except Exception:
|
|
pass
|
|
|
|
return return_value
|
|
|
|
@classmethod
|
|
def list_control_plane_tasks(
|
|
cls,
|
|
task_id: Optional[str] = None,
|
|
name: Optional[str] = None,
|
|
project: Optional[str] = None,
|
|
tags: Optional[List[str]] = None
|
|
) -> List[dict]:
|
|
|
|
# noinspection PyProtectedMember
|
|
tasks = Task.query_tasks(
|
|
task_name=name or None,
|
|
project_name=project or None,
|
|
tags=tags or None,
|
|
additional_return_fields=["id", "name", "project", "tags"],
|
|
task_filter={'type': ['service'],
|
|
'status': ["created", "in_progress"],
|
|
'system_tags': [cls._system_tag]}
|
|
) # type: List[dict]
|
|
if not tasks:
|
|
return []
|
|
|
|
for t in tasks:
|
|
# noinspection PyProtectedMember
|
|
t['project'] = Task._get_project_name(t['project'])
|
|
|
|
return tasks
|
|
|
|
@classmethod
|
|
def _get_control_plane_task(
|
|
cls,
|
|
task_id: Optional[str] = None,
|
|
name: Optional[str] = None,
|
|
project: Optional[str] = None,
|
|
tags: Optional[List[str]] = None,
|
|
disable_change_state: bool = False,
|
|
) -> Task:
|
|
if task_id:
|
|
task = Task.get_task(task_id=task_id)
|
|
if not task:
|
|
raise ValueError("Could not find Control Task ID={}".format(task_id))
|
|
task_status = task.status
|
|
if task_status not in ("created", "in_progress",):
|
|
if disable_change_state:
|
|
raise ValueError(
|
|
"Could Control Task ID={} status [{}] "
|
|
"is not valid (only 'draft', 'running' are supported)".format(task_id, task_status))
|
|
else:
|
|
task.mark_started(force=True)
|
|
return task
|
|
|
|
# noinspection PyProtectedMember
|
|
tasks = Task.query_tasks(
|
|
task_name=name or None,
|
|
project_name=project or None,
|
|
tags=tags or None,
|
|
task_filter={'type': ['service'],
|
|
'status': ["created", "in_progress"],
|
|
'system_tags': [cls._system_tag]}
|
|
)
|
|
if not tasks:
|
|
raise ValueError("Could not find any valid Control Tasks")
|
|
|
|
if len(tasks) > 1:
|
|
print("Warning: more than one valid Controller Tasks found, using Task ID={}".format(tasks[0]))
|
|
|
|
return Task.get_task(task_id=tasks[0])
|
|
|
|
@classmethod
|
|
def _create_task(
|
|
cls,
|
|
name: Optional[str] = None,
|
|
project: Optional[str] = None,
|
|
tags: Optional[List[str]] = None
|
|
) -> Task:
|
|
task = Task.create(
|
|
project_name=project or "DevOps",
|
|
task_name=name or "Serving Service",
|
|
task_type="service",
|
|
)
|
|
task.set_system_tags([cls._system_tag])
|
|
if tags:
|
|
task.set_tags(tags)
|
|
return task
|
|
|
|
@classmethod
|
|
def _normalize_endpoint_url(cls, endpoint: str, version: Optional[str] = None) -> str:
|
|
return "{}/{}".format(endpoint.rstrip("/"), version or "").rstrip("/")
|
|
|
|
@classmethod
|
|
def _validate_model(cls, endpoint: Union[ModelEndpoint, ModelMonitoring]) -> bool:
|
|
"""
|
|
Raise exception if validation fails, otherwise return True
|
|
"""
|
|
if endpoint.engine_type in ("triton", ):
|
|
# verify we have all the info we need
|
|
d = endpoint.as_dict()
|
|
missing = [
|
|
k for k in [
|
|
'input_type', 'input_size', 'input_name',
|
|
'output_type', 'output_size', 'output_name',
|
|
] if not d.get(k)
|
|
]
|
|
if not endpoint.auxiliary_cfg and missing:
|
|
raise ValueError("Triton engine requires input description - missing values in {}".format(missing))
|
|
return True
|