import json import os from pathlib import Path from queue import Queue from random import random from time import sleep, time from typing import Optional, Union, Dict, List import itertools import threading from multiprocessing import Lock from numpy.random import choice from clearml import Task, Model from clearml.storage.util import hash_dict from .preprocess_service import BasePreprocessRequest from .endpoints import ModelEndpoint, ModelMonitoring, CanaryEP, EndpointMetricLogging class FastWriteCounter(object): def __init__(self): self._counter_inc = itertools.count() self._counter_dec = itertools.count() def inc(self): next(self._counter_inc) def dec(self): next(self._counter_dec) def value(self): return next(self._counter_inc) - next(self._counter_dec) class ModelRequestProcessor(object): _system_tag = "serving-control-plane" _kafka_topic = "clearml_inference_stats" _config_key_serving_base_url = "serving_base_url" _config_key_triton_grpc = "triton_grpc_server" _config_key_kafka_stats = "kafka_service_server" _config_key_def_metric_freq = "metric_logging_freq" def __init__( self, task_id: Optional[str] = None, update_lock_guard: Optional[Lock] = None, name: Optional[str] = None, project: Optional[str] = None, tags: Optional[List[str]] = None, force_create: bool = False, ) -> None: """ :param task_id: Optional specify existing Task ID of the ServingService :param update_lock_guard: If provided use external (usually multi-process) lock guard for updates :param name: Optional name current serving service :param project: Optional select project for the current serving service :param tags: Optional add tags to the serving service :param force_create: force_create if provided, ignore task_id and create a new serving Task """ self._task = self._create_task(name=name, project=project, tags=tags) \ if force_create else self._get_control_plane_task(task_id=task_id, name=name, project=project, tags=tags) self._endpoints = dict() # type: Dict[str, ModelEndpoint] self._model_monitoring = dict() # type: Dict[str, ModelMonitoring] self._model_monitoring_versions = dict() # type: Dict[str, Dict[int, str]] self._model_monitoring_endpoints = dict() # type: Dict[str, ModelEndpoint] self._model_monitoring_update_request = True # Dict[base_serve_url, Dict[version, model_id]] self._canary_endpoints = dict() # type: Dict[str, CanaryEP] self._canary_route = dict() # type: Dict[str, dict] self._engine_processor_lookup = dict() # type: Dict[str, BasePreprocessRequest] self._metric_logging = dict() # type: Dict[str, EndpointMetricLogging] self._endpoint_metric_logging = dict() # type: Dict[str, EndpointMetricLogging] self._last_update_hash = None self._sync_daemon_thread = None self._stats_sending_thread = None self._stats_queue = Queue() # this is used for Fast locking mechanisms (so we do not actually need to use Locks) self._update_lock_flag = False self._request_processing_state = FastWriteCounter() self._update_lock_guard = update_lock_guard or threading.Lock() self._instance_task = None # serving server config self._configuration = {} # deserialized values go here self._kafka_stats_url = None self._triton_grpc = None self._serving_base_url = None self._metric_log_freq = None def process_request(self, base_url: str, version: str, request_body: dict) -> dict: """ Process request coming in, Raise Value error if url does not match existing endpoints """ self._request_processing_state.inc() # check if we need to stall if self._update_lock_flag: self._request_processing_state.dec() while self._update_lock_flag: sleep(1) # retry to process return self.process_request(base_url=base_url, version=version, request_body=request_body) try: # normalize url and version url = self._normalize_endpoint_url(base_url, version) # check canary canary_url = self._process_canary(base_url=url) if canary_url: url = canary_url ep = self._endpoints.get(url, None) or self._model_monitoring_endpoints.get(url, None) if not ep: raise ValueError("Model inference endpoint '{}' not found".format(url)) processor = self._engine_processor_lookup.get(url) if not processor: processor_cls = BasePreprocessRequest.get_engine_cls(ep.engine_type) processor = processor_cls(model_endpoint=ep, task=self._task) self._engine_processor_lookup[url] = processor return_value = self._process_request(processor=processor, url=url, body=request_body) finally: self._request_processing_state.dec() return return_value def _process_canary(self, base_url: str) -> Optional[dict]: canary = self._canary_route.get(base_url) if not canary: return None # random choice draw = choice(canary['endpoints'], 1, p=canary['weights']) # the new endpoint to use return draw[0] def configure( self, external_serving_base_url: Optional[str] = None, external_triton_grpc_server: Optional[str] = None, external_kafka_service_server: Optional[str] = None, default_metric_log_freq: Optional[float] = None, ): """ Set ModelRequestProcessor configuration arguments. :param external_serving_base_url: Set the external base http endpoint for the serving service This URL will be passed to user custom preprocess class, allowing it to concatenate and combine multiple model requests into one :param external_triton_grpc_server: set the external grpc tcp port of the Nvidia Triton clearml container. Used by the clearml triton engine class to send inference requests :param external_kafka_service_server: Optional, Kafka endpoint for the statistics controller collection. :param default_metric_log_freq: Default request metric logging (0 to 1.0, 1. means 100% of requests are logged) """ if external_serving_base_url is not None: self._task.set_parameter( name="General/{}".format(self._config_key_serving_base_url), value=str(external_serving_base_url), value_type="str", description="external base http endpoint for the serving service" ) if external_triton_grpc_server is not None: self._task.set_parameter( name="General/{}".format(self._config_key_triton_grpc), value=str(external_triton_grpc_server), value_type="str", description="external grpc tcp port of the Nvidia Triton ClearML container running" ) if external_kafka_service_server is not None: self._task.set_parameter( name="General/{}".format(self._config_key_kafka_stats), value=str(external_kafka_service_server), value_type="str", description="external Kafka service url for the statistics controller server" ) if default_metric_log_freq is not None: self._task.set_parameter( name="General/{}".format(self._config_key_def_metric_freq), value=str(default_metric_log_freq), value_type="float", description="Request metric logging frequency" ) def get_configuration(self) -> dict: return dict(**self._configuration) def add_endpoint( self, endpoint: Union[ModelEndpoint, dict], preprocess_code: Optional[str] = None, model_name: Optional[str] = None, model_project: Optional[str] = None, model_tags: Optional[List[str]] = None, model_published: Optional[bool] = None, ) -> str: """ Return the unique name of the endpoint (endpoint + version) Overwrite existing endpoint if already exists (outputs a warning) :param endpoint: New endpoint to register (overwrite existing endpoint if exists) :param preprocess_code: If provided upload local code as artifact :param model_name: If model-id not provided on, search based on model name :param model_project: If model-id not provided on, search based on model project :param model_tags: If model-id not provided on, search based on model tags :param model_published: If model-id not provided on, search based on model published state """ if not isinstance(endpoint, ModelEndpoint): endpoint = ModelEndpoint(**endpoint) # make sure we have everything configured self._validate_model(endpoint) url = self._normalize_endpoint_url(endpoint.serving_url, endpoint.version) if url in self._endpoints: print("Warning: Model endpoint \'{}\' overwritten".format(url)) if not endpoint.model_id and any([model_project, model_name, model_tags]): model_query = dict( project_name=model_project, model_name=model_name, tags=model_tags, only_published=bool(model_published), include_archived=False, ) models = Model.query_models(max_results=2, **model_query) if not models: raise ValueError("Could not fine any Model to serve {}".format(model_query)) if len(models) > 1: print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id)) endpoint.model_id = models[0].id elif not endpoint.model_id: print("Warning: No Model provided for \'{}\'".format(url)) # upload as new artifact if preprocess_code: if not Path(preprocess_code).exists(): raise ValueError("Preprocessing code \'{}\' could not be found".format(preprocess_code)) preprocess_artifact_name = "py_code_{}".format(url.replace("/", "_")) self._task.upload_artifact( name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True) endpoint.preprocess_artifact = preprocess_artifact_name self._endpoints[url] = endpoint return url def add_model_monitoring( self, monitoring: Union[ModelMonitoring, dict], preprocess_code: Optional[str] = None, ) -> str: """ Return the unique name of the endpoint (endpoint + version) Overwrite existing endpoint if already exists (outputs a warning) :param monitoring: Model endpoint monitor (overwrite existing endpoint if exists) :param preprocess_code: If provided upload local code as artifact :return: Unique model monitoring ID (base_model_url) """ if not isinstance(monitoring, ModelMonitoring): monitoring = ModelMonitoring(**monitoring) # make sure we actually have something to monitor if not any([monitoring.monitor_project, monitoring.monitor_name, monitoring.monitor_tags]): raise ValueError("Model monitoring requires at least a " "project / name / tag to monitor, none were provided.") # make sure we have everything configured self._validate_model(monitoring) name = monitoring.base_serving_url if name in self._model_monitoring: print("Warning: Model monitoring \'{}\' overwritten".format(name)) # upload as new artifact if preprocess_code: if not Path(preprocess_code).exists(): raise ValueError("Preprocessing code \'{}\' could not be found".format(preprocess_code)) preprocess_artifact_name = "py_code_{}".format(name.replace("/", "_")) self._task.upload_artifact( name=preprocess_artifact_name, artifact_object=Path(preprocess_code), wait_on_upload=True) monitoring.preprocess_artifact = preprocess_artifact_name self._model_monitoring[name] = monitoring return name def remove_model_monitoring(self, model_base_url: str) -> bool: """ Remove model monitoring, use base_model_url as unique identifier """ if model_base_url not in self._model_monitoring: return False self._model_monitoring.pop(model_base_url, None) return True def remove_endpoint(self, endpoint_url: str, version: Optional[str] = None) -> bool: """ Remove specific model endpoint, use base_model_url as unique identifier """ endpoint_url = self._normalize_endpoint_url(endpoint_url, version) if endpoint_url not in self._endpoints: return False self._endpoints.pop(endpoint_url, None) return True def add_canary_endpoint( self, canary: Union[CanaryEP, dict], ) -> str: """ Return the unique name of the endpoint (endpoint + version) Overwrite existing endpoint if already exists (outputs a warning) :param canary: Canary endpoint router (overwrite existing endpoint if exists) :return: Unique canary ID (base_model_url) """ if not isinstance(canary, CanaryEP): canary = CanaryEP(**canary) if canary.load_endpoints and canary.load_endpoint_prefix: raise ValueError( "Could not add canary endpoint with both " "prefix ({}) and fixed set of endpoints ({})".format( canary.load_endpoints, canary.load_endpoint_prefix)) name = canary.endpoint if name in self._canary_endpoints: print("Warning: Model monitoring \'{}\' overwritten".format(name)) self._canary_endpoints[name] = canary return name def remove_canary_endpoint(self, endpoint_url: str) -> bool: """ Remove specific canary model endpoint, use base_model_url as unique identifier """ if endpoint_url not in self._canary_endpoints: return False self._canary_endpoints.pop(endpoint_url, None) return True def add_metric_logging(self, metric: Union[EndpointMetricLogging, dict]) -> bool: """ Add metric logging to a specific endpoint Valid metric variable are any variables on the request or response dictionary, or a custom preprocess reported variable When overwriting and existing monitored variable, output a warning. :param metric: Metric variable to monitor :return: True if successful """ if not isinstance(metric, EndpointMetricLogging): metric = EndpointMetricLogging(**metric) name = str(metric.endpoint).strip("/") metric.endpoint = name if name not in self._endpoints and not name.endswith('*'): raise ValueError("Metric logging \'{}\' references a nonexistent endpoint".format(name)) if name in self._metric_logging: print("Warning: Metric logging \'{}\' overwritten".format(name)) self._metric_logging[name] = metric return True def remove_metric_logging( self, endpoint: str, variable_name: str = None, ) -> bool: """ Remove existing logged metric variable. Use variable name and endpoint as unique identifier :param endpoint: Endpoint name (including version, e.g. "model/1" or "model/*") :param variable_name: Variable name (str), pass None to remove the entire endpoint logging :return: True if successful """ name = str(endpoint).strip("/") if name not in self._metric_logging or \ (variable_name and variable_name not in self._metric_logging[name].metrics): return False if not variable_name: self._metric_logging.pop(name, None) else: self._metric_logging[name].metrics.pop(variable_name, None) return True def list_metric_logging(self) -> Dict[str, EndpointMetricLogging]: """ List existing logged metric variables. :return: Dictionary, key='endpoint/version' value=EndpointMetricLogging """ return dict(**self._metric_logging) def list_endpoint_logging(self) -> Dict[str, EndpointMetricLogging]: """ List endpoints (fully synced) current metric logging state. :return: Dictionary, key='endpoint/version' value=EndpointMetricLogging """ return dict(**self._endpoint_metric_logging) def deserialize( self, task: Task = None, prefetch_artifacts: bool = False, skip_sync: bool = False, update_current_task: bool = True ) -> bool: """ Restore ModelRequestProcessor state from Task return True if actually needed serialization, False nothing changed :param task: Load data from Task :param prefetch_artifacts: If True prefetch artifacts requested by the endpoints :param skip_sync: If True do not update the canary/monitoring state :param update_current_task: is not skip_sync, and is True, update the current Task with the configuration synced from the serving service Task """ if not task: task = self._task configuration = task.get_parameters_as_dict().get("General") or {} endpoints = task.get_configuration_object_as_dict(name='endpoints') or {} canary_ep = task.get_configuration_object_as_dict(name='canary') or {} model_monitoring = task.get_configuration_object_as_dict(name='model_monitoring') or {} metric_logging = task.get_configuration_object_as_dict(name='metric_logging') or {} hashed_conf = hash_dict( dict(endpoints=endpoints, canary_ep=canary_ep, model_monitoring=model_monitoring, metric_logging=metric_logging, configuration=configuration) ) if self._last_update_hash == hashed_conf and not self._model_monitoring_update_request: return False print("Info: syncing model endpoint configuration, state hash={}".format(hashed_conf)) self._last_update_hash = hashed_conf endpoints = { k: ModelEndpoint(**{i: j for i, j in v.items() if hasattr(ModelEndpoint.__attrs_attrs__, i)}) for k, v in endpoints.items() } model_monitoring = { k: ModelMonitoring(**{i: j for i, j in v.items() if hasattr(ModelMonitoring.__attrs_attrs__, i)}) for k, v in model_monitoring.items() } canary_endpoints = { k: CanaryEP(**{i: j for i, j in v.items() if hasattr(CanaryEP.__attrs_attrs__, i)}) for k, v in canary_ep.items() } metric_logging = { k: EndpointMetricLogging(**{i: j for i, j in v.items() if hasattr(EndpointMetricLogging.__attrs_attrs__, i)}) for k, v in metric_logging.items() } # if there is no need to sync Canary and Models we can just leave if skip_sync: self._endpoints = endpoints self._model_monitoring = model_monitoring self._canary_endpoints = canary_endpoints self._metric_logging = metric_logging self._deserialize_conf_dict(configuration) return True # make sure we only have one stall request at any given moment with self._update_lock_guard: # download artifacts # todo: separate into two, download before lock, and overwrite inside lock if prefetch_artifacts: for item in list(endpoints.values()) + list(model_monitoring.values()): if item.preprocess_artifact: # noinspection PyBroadException try: self._task.artifacts[item.preprocess_artifact].get_local_copy( extract_archive=True, ) except Exception: pass # stall all requests self._update_lock_flag = True # wait until we have no request processed while self._request_processing_state.value() != 0: sleep(1) self._endpoints = endpoints self._model_monitoring = model_monitoring self._canary_endpoints = canary_endpoints self._metric_logging = metric_logging self._deserialize_conf_dict(configuration) # if we have models we need to sync, now is the time self._sync_monitored_models() self._update_canary_lookup() self._sync_metric_logging() # release stall lock self._update_lock_flag = False # update the state on the inference task if update_current_task and Task.current_task() and Task.current_task().id != self._task.id: self.serialize(task=Task.current_task()) return True def serialize(self, task: Optional[Task] = None) -> None: """ Store ModelRequestProcessor state into Task """ if not task: task = self._task config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._endpoints.items()} task.set_configuration_object(name='endpoints', config_dict=config_dict) config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._canary_endpoints.items()} task.set_configuration_object(name='canary', config_dict=config_dict) config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring.items()} task.set_configuration_object(name='model_monitoring', config_dict=config_dict) config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._metric_logging.items()} task.set_configuration_object(name='metric_logging', config_dict=config_dict) def _update_canary_lookup(self): canary_route = {} for k, v in self._canary_endpoints.items(): if v.load_endpoint_prefix and v.load_endpoints: print("Warning: Canary has both prefix and fixed endpoints, ignoring canary endpoint") continue if v.load_endpoints: if len(v.load_endpoints) != len(v.weights): print("Warning: Canary \'{}\' weights [{}] do not match number of endpoints [{}], skipping!".format( k, v.weights, v.load_endpoints)) continue endpoints = [] weights = [] for w, ep in zip(v.weights, v.load_endpoints): if ep not in self._endpoints and ep not in self._model_monitoring_endpoints: print("Warning: Canary \'{}\' endpoint \'{}\' could not be found, skipping".format(k, ep)) continue endpoints.append(ep) weights.append(float(w)) # normalize weights sum_weights = sum(weights) weights = [w/sum_weights for w in weights] canary_route[k] = dict(endpoints=endpoints, weights=weights) elif v.load_endpoint_prefix: endpoints = [ep for ep in list(self._endpoints.keys()) + list(self._model_monitoring_endpoints.keys()) if str(ep).startswith(v.load_endpoint_prefix)] endpoints = sorted( endpoints, reverse=True, key=lambda x: '{}/{:0>9}'.format('/'.join(x.split('/')[:-1]), x.split('/')[-1]) if '/' in x else x ) endpoints = endpoints[:len(v.weights)] weights = v.weights[:len(endpoints)] # normalize weights sum_weights = sum(weights) weights = [w/sum_weights for w in weights] canary_route[k] = dict(endpoints=endpoints, weights=weights) self._report_text( "Info: Canary endpoint \'{}\' selected [{}]".format(k, canary_route[k]) ) # update back self._canary_route = canary_route def _sync_monitored_models(self, force: bool = False) -> bool: if not force and not self._model_monitoring_update_request: return False dirty = False for serving_base_url, versions_model_id_dict in self._model_monitoring_versions.items(): # find existing endpoint versions for ep_base_url in list(self._model_monitoring_endpoints.keys()): # skip over endpoints that are not our own if not ep_base_url.startswith(serving_base_url+"/"): continue # find endpoint version _, version = ep_base_url.split("/", 1) if int(version) not in versions_model_id_dict: # remove old endpoint self._model_monitoring_endpoints.pop(ep_base_url, None) dirty = True continue # add new endpoint for version, model_id in versions_model_id_dict.items(): url = "{}/{}".format(serving_base_url, version) if url in self._model_monitoring_endpoints: continue model = self._model_monitoring.get(serving_base_url) if not model: # this should never happen continue ep = ModelEndpoint( engine_type=model.engine_type, serving_url=serving_base_url, model_id=model_id, version=str(version), preprocess_artifact=model.preprocess_artifact, input_size=model.input_size, input_type=model.input_type, output_size=model.output_size, output_type=model.output_type ) self._model_monitoring_endpoints[url] = ep dirty = True # filter out old model monitoring endpoints for ep_url in list(self._model_monitoring_endpoints.keys()): if not any(True for url in self._model_monitoring_versions if ep_url.startswith(url+"/")): self._model_monitoring_endpoints.pop(ep_url, None) dirty = True # reset flag self._model_monitoring_update_request = False if dirty: config_dict = {k: v.as_dict(remove_null_entries=True) for k, v in self._model_monitoring_endpoints.items()} self._task.set_configuration_object(name='model_monitoring_eps', config_dict=config_dict) return dirty def _update_monitored_models(self): for model in self._model_monitoring.values(): current_served_models = self._model_monitoring_versions.get(model.base_serving_url, {}) # To Do: sort by updated time ? models = Model.query_models( project_name=model.monitor_project or None, model_name=model.monitor_name or None, tags=model.monitor_tags or None, only_published=model.only_published, max_results=model.max_versions, include_archived=False, ) # check what we already have: current_model_id_version_lookup = dict( zip(list(current_served_models.values()), list(current_served_models.keys())) ) versions = sorted(current_served_models.keys(), reverse=True) # notice, most updated model first # first select only the new models model_ids = [m.id for m in models] # we want last updated model to be last (so it gets the highest version number) max_v = 1 + (versions[0] if versions else 0) versions_model_ids = [] for m_id in reversed(model_ids): v = current_model_id_version_lookup.get(m_id) if v is None: v = max_v max_v += 1 versions_model_ids.append((v, m_id)) # remove extra entries (old models) versions_model_ids_dict = dict(versions_model_ids[:model.max_versions]) # mark dirty if something changed: if versions_model_ids_dict != current_served_models: self._model_monitoring_update_request = True # update model serving state self._model_monitoring_versions[model.base_serving_url] = versions_model_ids_dict if not self._model_monitoring_update_request: return False self._report_text("INFO: Monitored Models updated: {}".format( json.dumps(self._model_monitoring_versions, indent=2)) ) return True def _sync_metric_logging(self, force: bool = False) -> bool: if not force and not self._metric_logging: return False fixed_metric_endpoint = { k: v for k, v in self._metric_logging.items() if "*/" not in k } prefix_metric_endpoint = {k.split("*/")[0]: v for k, v in self._metric_logging.items() if "*/" in k} endpoint_metric_logging = {} for k, ep in list(self._endpoints.items()) + list(self._model_monitoring_endpoints.items()): if k in fixed_metric_endpoint: if k not in endpoint_metric_logging: endpoint_metric_logging[k] = fixed_metric_endpoint[k] continue for p, v in prefix_metric_endpoint.items(): if k.startswith(p): if k not in endpoint_metric_logging: endpoint_metric_logging[k] = v break self._endpoint_metric_logging = endpoint_metric_logging return True def launch(self, poll_frequency_sec=300): """ Launch the background synchronization thread and monitoring thread (updating runtime process based on changes on the Task, and monitoring model changes in the system) :param poll_frequency_sec: Sync every X seconds (default 300 seconds) """ if self._sync_daemon_thread: return # read state self.deserialize(self._task, prefetch_artifacts=True) # model monitoring sync if self._update_monitored_models(): # update endpoints self.deserialize(self._task, prefetch_artifacts=True) # get the serving instance (for visibility and monitoring) self._instance_task = Task.current_task() # start the background thread with self._update_lock_guard: if self._sync_daemon_thread: return self._sync_daemon_thread = threading.Thread( target=self._sync_daemon, args=(poll_frequency_sec, ), daemon=True) self._stats_sending_thread = threading.Thread( target=self._stats_send_loop, daemon=True) self._sync_daemon_thread.start() self._stats_sending_thread.start() # we return immediately def _sync_daemon(self, poll_frequency_sec: float = 300) -> None: """ Background thread, syncing model changes into request service. """ poll_frequency_sec = float(poll_frequency_sec) # force mark started on the main serving service task self._task.mark_started(force=True) self._report_text("Launching - configuration sync every {} sec".format(poll_frequency_sec)) cleanup = False self._update_serving_plot() while True: try: # this should be the only place where we call deserialize self._task.reload() if self.deserialize(self._task): self._report_text("New configuration updated") # mark clean up for next round cleanup = True # model monitoring sync if self._update_monitored_models(): self._report_text("Model monitoring synced") # update endpoints self.deserialize(self._task) # mark clean up for next round cleanup = True # update serving layout plot if cleanup: self._update_serving_plot() except Exception as ex: print("Exception occurred in monitoring thread: {}".format(ex)) sleep(poll_frequency_sec) try: # we assume that by now all old deleted endpoints requests already returned if cleanup: cleanup = False for k in list(self._engine_processor_lookup.keys()): if k not in self._endpoints: # atomic self._engine_processor_lookup.pop(k, None) except Exception as ex: print("Exception occurred in monitoring thread: {}".format(ex)) def _stats_send_loop(self) -> None: """ Background thread for sending stats to Kafka service """ if not self._kafka_stats_url: print("No Kafka Statistics service configured, shutting down statistics report") return print("Starting Kafka Statistics reporting: {}".format(self._kafka_stats_url)) from kafka import KafkaProducer # noqa while True: try: producer = KafkaProducer( bootstrap_servers=self._kafka_stats_url, # ['localhost:9092'], value_serializer=lambda x: json.dumps(x).encode('utf-8'), compression_type='lz4', # requires python lz4 package ) break except Exception as ex: print("Error: failed opening Kafka consumer [{}]: {}".format(self._kafka_stats_url, ex)) print("Retrying in 30 seconds") sleep(30) while True: try: stats_dict = self._stats_queue.get(block=True) except Exception as ex: print("Warning: Statistics thread exception: {}".format(ex)) break # send into kafka service try: producer.send(self._kafka_topic, value=stats_dict).get() except Exception as ex: print("Warning: Failed to send statistics packet to Kafka service: {}".format(ex)) pass def get_id(self) -> str: return self._task.id def get_endpoints(self) -> Dict[str, ModelEndpoint]: endpoints = dict(**self._endpoints) endpoints.update(**self._model_monitoring_endpoints) return endpoints def get_synced_endpoints(self) -> Dict[str, ModelEndpoint]: self._task.reload() _endpoints = self._task.get_configuration_object_as_dict(name='endpoints') or {} _monitor_endpoints = self._task.get_configuration_object_as_dict(name='model_monitoring_eps') or {} endpoints = { k: ModelEndpoint(**{i: j for i, j in v.items() if hasattr(ModelEndpoint.__attrs_attrs__, i)}) for k, v in _endpoints.items()} endpoints.update({ k: ModelEndpoint(**{i: j for i, j in v.items() if hasattr(ModelEndpoint.__attrs_attrs__, i)}) for k, v in _monitor_endpoints.items() }) return endpoints def get_canary_endpoints(self) -> dict: return self._canary_endpoints def get_model_monitoring(self) -> dict: return self._model_monitoring def _get_instance_id(self) -> Optional[str]: return self._instance_task.id if self._instance_task else None def _report_text(self, text) -> Optional[str]: return self._task.get_logger().report_text("Instance [{}, pid={}]: {}".format( self._get_instance_id(), os.getpid(), text)) def _update_serving_plot(self) -> None: """ Update the endpoint serving graph on the serving instance Task """ if not self._instance_task: return # Generate configuration table and details endpoints = list(self._endpoints.values()) + list(self._model_monitoring_endpoints.values()) if not endpoints: # clear plot if we had any return endpoints = [e.as_dict() for e in endpoints] table_values = [list(endpoints[0].keys())] table_values += [[e[c] or "" for c in table_values[0]] for e in endpoints] self._instance_task.get_logger().report_table( title='Serving Endpoint Configuration', series='Details', iteration=0, table_plot=table_values, extra_layout={"title": "Model Endpoints Details"}) # generate current endpoint view sankey_node = dict( label=[], color=[], customdata=[], hovertemplate='%{customdata}', hoverlabel={"align": "left"}, ) sankey_link = dict( source=[], target=[], value=[], hovertemplate='', ) # root sankey_node['color'].append("mediumpurple") sankey_node['label'].append('{}'.format('external')) sankey_node['customdata'].append("") sankey_node_idx = {} # base_url = self._task._get_app_server() + '/projects/*/models/{model_id}/general' # draw all static endpoints # noinspection PyProtectedMember for i, ep in enumerate(endpoints): serve_url = ep['serving_url'] full_url = '{}/{}'.format(serve_url, ep['version'] or "") sankey_node['color'].append("blue") sankey_node['label'].append("/{}/".format(full_url.strip("/"))) sankey_node['customdata'].append( "model id: {}".format(ep['model_id']) ) sankey_link['source'].append(0) sankey_link['target'].append(i + 1) sankey_link['value'].append(1. / len(self._endpoints)) sankey_node_idx[full_url] = i + 1 # draw all model monitoring sankey_node['color'].append("mediumpurple") sankey_node['label'].append('{}'.format('monitoring models')) sankey_node['customdata'].append("") monitoring_root_idx = len(sankey_node['customdata']) - 1 for i, m in enumerate(self._model_monitoring.values()): serve_url = m.base_serving_url sankey_node['color'].append("purple") sankey_node['label'].append('{}'.format(serve_url)) sankey_node['customdata'].append( "project: {}
name: {}
tags: {}".format( m.monitor_project or '', m.monitor_name or '', m.monitor_tags or '') ) sankey_link['source'].append(monitoring_root_idx) sankey_link['target'].append(monitoring_root_idx + i + 1) sankey_link['value'].append(1. / len(self._model_monitoring)) # add links to the current models serve_url = serve_url.rstrip("/") + "/" for k in sankey_node_idx: if k.startswith(serve_url): sankey_link['source'].append(monitoring_root_idx + i + 1) sankey_link['target'].append(sankey_node_idx[k]) sankey_link['value'].append(1.0 / m.max_versions) # add canary endpoints # sankey_node['color'].append("mediumpurple") # sankey_node['label'].append('{}'.format('Canary endpoints')) # sankey_node['customdata'].append("") canary_root_idx = len(sankey_node['customdata']) - 1 # sankey_link['source'].append(0) # sankey_link['target'].append(canary_root_idx) # sankey_link['value'].append(1.) for i, c in enumerate(self._canary_endpoints.values()): serve_url = c.endpoint sankey_node['color'].append("green") sankey_node['label'].append('CANARY: /{}/'.format(serve_url.strip("/"))) sankey_node['customdata'].append( "outputs: {}".format( c.load_endpoints or c.load_endpoint_prefix) ) sankey_link['source'].append(0) sankey_link['target'].append(canary_root_idx + i + 1) sankey_link['value'].append(1. / len(self._canary_endpoints)) # add links to the current models if serve_url not in self._canary_route: continue for ep, w in zip(self._canary_route[serve_url]['endpoints'], self._canary_route[serve_url]['weights']): idx = sankey_node_idx.get(ep) if idx is None: continue sankey_link['source'].append(canary_root_idx + i + 1) sankey_link['target'].append(idx) sankey_link['value'].append(w) # create the sankey graph dag_flow = dict( link=sankey_link, node=sankey_node, textfont=dict(color='rgba(0,0,0,255)', size=10), type='sankey', orientation='h' ) fig = dict(data=[dag_flow], layout={'xaxis': {'visible': False}, 'yaxis': {'visible': False}}) self._instance_task.get_logger().report_plotly( title='Serving Endpoints Layout', series='', iteration=0, figure=fig) def _deserialize_conf_dict(self, configuration: dict) -> None: self._configuration = configuration # deserialized values go here self._kafka_stats_url = \ configuration.get(self._config_key_kafka_stats) or \ os.environ.get("CLEARML_DEFAULT_KAFKA_SERVE_URL") self._triton_grpc = \ configuration.get(self._config_key_triton_grpc) or \ os.environ.get("CLEARML_DEFAULT_TRITON_GRPC_ADDR") self._serving_base_url = \ configuration.get(self._config_key_serving_base_url) or \ os.environ.get("CLEARML_DEFAULT_BASE_SERVE_URL") self._metric_log_freq = \ float(configuration.get(self._config_key_def_metric_freq, os.environ.get("CLEARML_DEFAULT_METRIC_LOG_FREQ", 1.0))) # update back configuration self._configuration[self._config_key_kafka_stats] = self._kafka_stats_url self._configuration[self._config_key_triton_grpc] = self._triton_grpc self._configuration[self._config_key_serving_base_url] = self._serving_base_url self._configuration[self._config_key_def_metric_freq] = self._metric_log_freq # update preprocessing classes BasePreprocessRequest.set_server_config(self._configuration) def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict: # collect statistics for this request stats = {} stats_collect_fn = None collect_stats = False freq = 1 # decide if we are collecting the stats metric_endpoint = self._metric_logging.get(url) if self._kafka_stats_url: freq = metric_endpoint.log_frequency if metric_endpoint and metric_endpoint.log_frequency is not None \ else self._metric_log_freq if freq and random() <= freq: stats_collect_fn = stats.update collect_stats = True tic = time() preprocessed = processor.preprocess(body, stats_collect_fn) processed = processor.process(preprocessed, stats_collect_fn) return_value = processor.postprocess(processed, stats_collect_fn) tic = time() - tic if collect_stats: # 10th of a millisecond should be enough stats['_latency'] = round(tic, 4) stats['_count'] = int(1.0/freq) stats['_url'] = url # collect inputs if metric_endpoint and body: for k, v in body.items(): if k in metric_endpoint.metrics: stats[k] = v # collect outputs if metric_endpoint and return_value: for k, v in return_value.items(): if k in metric_endpoint.metrics: stats[k] = v # send stats in background, push it into a thread queue # noinspection PyBroadException try: self._stats_queue.put(stats, block=False) except Exception: pass return return_value @classmethod def list_control_plane_tasks( cls, task_id: Optional[str] = None, name: Optional[str] = None, project: Optional[str] = None, tags: Optional[List[str]] = None ) -> List[dict]: # noinspection PyProtectedMember tasks = Task.query_tasks( task_name=name or None, project_name=project or None, tags=tags or None, additional_return_fields=["id", "name", "project", "tags"], task_filter={'type': ['service'], 'status': ["created", "in_progress"], 'system_tags': [cls._system_tag]} ) # type: List[dict] if not tasks: return [] for t in tasks: # noinspection PyProtectedMember t['project'] = Task._get_project_name(t['project']) return tasks @classmethod def _get_control_plane_task( cls, task_id: Optional[str] = None, name: Optional[str] = None, project: Optional[str] = None, tags: Optional[List[str]] = None, disable_change_state: bool = False, ) -> Task: if task_id: task = Task.get_task(task_id=task_id) if not task: raise ValueError("Could not find Control Task ID={}".format(task_id)) task_status = task.status if task_status not in ("created", "in_progress",): if disable_change_state: raise ValueError( "Could Control Task ID={} status [{}] " "is not valid (only 'draft', 'running' are supported)".format(task_id, task_status)) else: task.mark_started(force=True) return task # noinspection PyProtectedMember tasks = Task.query_tasks( task_name=name or None, project_name=project or None, tags=tags or None, task_filter={'type': ['service'], 'status': ["created", "in_progress"], 'system_tags': [cls._system_tag]} ) if not tasks: raise ValueError("Could not find any valid Control Tasks") if len(tasks) > 1: print("Warning: more than one valid Controller Tasks found, using Task ID={}".format(tasks[0])) return Task.get_task(task_id=tasks[0]) @classmethod def _create_task( cls, name: Optional[str] = None, project: Optional[str] = None, tags: Optional[List[str]] = None ) -> Task: task = Task.create( project_name=project or "DevOps", task_name=name or "Serving Service", task_type="service", ) task.set_system_tags([cls._system_tag]) if tags: task.set_tags(tags) return task @classmethod def _normalize_endpoint_url(cls, endpoint: str, version: Optional[str] = None) -> str: return "{}/{}".format(endpoint.rstrip("/"), version or "").rstrip("/") @classmethod def _validate_model(cls, endpoint: Union[ModelEndpoint, ModelMonitoring]) -> bool: """ Raise exception if validation fails, otherwise return True """ if endpoint.engine_type in ("triton", ): # verify we have all the info we need d = endpoint.as_dict() missing = [ k for k in [ 'input_type', 'input_size', 'input_name', 'output_type', 'output_size', 'output_name', ] if not d.get(k) ] if not endpoint.auxiliary_cfg and missing: raise ValueError("Triton engine requires input description - missing values in {}".format(missing)) return True