From 48f720ac91506de08387ee6e1692511353c69c87 Mon Sep 17 00:00:00 2001 From: allegroai Date: Tue, 7 Jun 2022 00:20:33 +0300 Subject: [PATCH] Optimize request serving statistics reporting --- clearml_serving/serving/main.py | 2 +- .../serving/model_request_processor.py | 110 +++++++++++++----- clearml_serving/statistics/metrics.py | 77 ++++++------ 3 files changed, 127 insertions(+), 62 deletions(-) diff --git a/clearml_serving/serving/main.py b/clearml_serving/serving/main.py index f3473e7..369aa4b 100644 --- a/clearml_serving/serving/main.py +++ b/clearml_serving/serving/main.py @@ -87,7 +87,7 @@ router = APIRouter( @router.post("/{model_id}/{version}") @router.post("/{model_id}/") @router.post("/{model_id}") -def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None): +async def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None): try: return_value = processor.process_request( base_url=model_id, diff --git a/clearml_serving/serving/model_request_processor.py b/clearml_serving/serving/model_request_processor.py index acd1877..42635f6 100644 --- a/clearml_serving/serving/model_request_processor.py +++ b/clearml_serving/serving/model_request_processor.py @@ -1,7 +1,8 @@ import json import os +from collections import deque from pathlib import Path -from queue import Queue +# from queue import Queue from random import random from time import sleep, time from typing import Optional, Union, Dict, List @@ -32,6 +33,37 @@ class FastWriteCounter(object): return next(self._counter_inc) - next(self._counter_dec) +class FastSimpleQueue(object): + _default_wait_timeout = 10 + + def __init__(self): + self._deque = deque() + # Notify not_empty whenever an item is added to the queue; a + # thread waiting to get is notified then. + self._not_empty = threading.Event() + self._last_notify = time() + + def put(self, a_object, block=True): + self._deque.append(a_object) + if time() - self._last_notify > self._default_wait_timeout: + self._not_empty.set() + self._last_notify = time() + + def get(self, block=True): + while True: + try: + return self._deque.popleft() + except IndexError: + if not block: + return None + # wait until signaled + try: + if self._not_empty.wait(timeout=self._default_wait_timeout): + self._not_empty.clear() + except Exception as ex: # noqa + pass + + class ModelRequestProcessor(object): _system_tag = "serving-control-plane" _kafka_topic = "clearml_inference_stats" @@ -75,7 +107,7 @@ class ModelRequestProcessor(object): self._last_update_hash = None self._sync_daemon_thread = None self._stats_sending_thread = None - self._stats_queue = Queue() + self._stats_queue = FastSimpleQueue() # this is used for Fast locking mechanisms (so we do not actually need to use Locks) self._update_lock_flag = False self._request_processing_state = FastWriteCounter() @@ -99,7 +131,7 @@ class ModelRequestProcessor(object): if self._update_lock_flag: self._request_processing_state.dec() while self._update_lock_flag: - sleep(1) + sleep(0.5+random()) # retry to process return self.process_request(base_url=base_url, version=version, request_body=request_body) @@ -820,6 +852,7 @@ class ModelRequestProcessor(object): print("Starting Kafka Statistics reporting: {}".format(self._kafka_stats_url)) from kafka import KafkaProducer # noqa + import kafka.errors as Errors # noqa while True: try: @@ -836,16 +869,35 @@ class ModelRequestProcessor(object): while True: try: - stats_dict = self._stats_queue.get(block=True) + stats_list_dict = [self._stats_queue.get(block=True)] + while True: + v = self._stats_queue.get(block=False) + if v is None: + break + stats_list_dict.append(v) except Exception as ex: print("Warning: Statistics thread exception: {}".format(ex)) break - # send into kafka service - try: - producer.send(self._kafka_topic, value=stats_dict).get() - except Exception as ex: - print("Warning: Failed to send statistics packet to Kafka service: {}".format(ex)) - pass + + left_overs = [] + while stats_list_dict or left_overs: + if not stats_list_dict: + stats_list_dict = left_overs + left_overs = [] + + # send into kafka service + try: + producer.send(self._kafka_topic, value=stats_list_dict).get() + stats_list_dict = [] + except Errors.MessageSizeTooLargeError: + # log.debug("Splitting Kafka message in half [{}]".format(len(stats_list_dict))) + # split in half - message is too long for kafka to send + left_overs += stats_list_dict[len(stats_list_dict)//2:] + stats_list_dict = stats_list_dict[:len(stats_list_dict)//2] + continue + except Exception as ex: + print("Warning: Failed to send statistics packet to Kafka service: {}".format(ex)) + break def get_id(self) -> str: return self._task.id @@ -1046,9 +1098,9 @@ class ModelRequestProcessor(object): def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict: # collect statistics for this request - stats = {} stats_collect_fn = None collect_stats = False + custom_stats = dict() freq = 1 # decide if we are collecting the stats metric_endpoint = self._metric_logging.get(url) @@ -1056,8 +1108,8 @@ class ModelRequestProcessor(object): freq = metric_endpoint.log_frequency if metric_endpoint and metric_endpoint.log_frequency is not None \ else self._metric_log_freq - if freq and random() <= freq: - stats_collect_fn = stats.update + if freq and (freq >= 1 or random() <= freq): + stats_collect_fn = custom_stats.update collect_stats = True tic = time() @@ -1067,21 +1119,25 @@ class ModelRequestProcessor(object): return_value = processor.postprocess(processed, state, stats_collect_fn) tic = time() - tic if collect_stats: - # 10th of a millisecond should be enough - stats['_latency'] = round(tic, 4) - stats['_count'] = int(1.0/freq) - stats['_url'] = url + stats = dict( + _latency=round(tic, 4), # 10th of a millisecond should be enough + _count=int(1.0/freq), + _url=url + ) - # collect inputs - if metric_endpoint and body: - for k, v in body.items(): - if k in metric_endpoint.metrics: - stats[k] = v - # collect outputs - if metric_endpoint and return_value: - for k, v in return_value.items(): - if k in metric_endpoint.metrics: - stats[k] = v + if custom_stats: + stats.update(custom_stats) + + if metric_endpoint: + metric_keys = set(metric_endpoint.metrics.keys()) + # collect inputs + if body: + keys = set(body.keys()) & metric_keys + stats.update({k: body[k] for k in keys}) + # collect outputs + if return_value: + keys = set(return_value.keys()) & metric_keys + stats.update({k: return_value[k] for k in keys}) # send stats in background, push it into a thread queue # noinspection PyBroadException diff --git a/clearml_serving/statistics/metrics.py b/clearml_serving/statistics/metrics.py index 1b6f60f..afdd095 100644 --- a/clearml_serving/statistics/metrics.py +++ b/clearml_serving/statistics/metrics.py @@ -7,7 +7,7 @@ from threading import Event, Thread from time import time, sleep from clearml import Task -from typing import Optional, Dict, Any, Iterable +from typing import Optional, Dict, Any, Iterable, Set from prometheus_client import Histogram, Enum, Gauge, Counter, values from kafka import KafkaConsumer @@ -204,6 +204,7 @@ class StatisticsController(object): self._poll_frequency_min = float(poll_frequency_min) self._serving_service = None # type: Optional[ModelRequestProcessor] self._current_endpoints = {} # type: Optional[Dict[str, EndpointMetricLogging]] + self._auto_added_endpoints = set() # type: Set[str] self._prometheus_metrics = {} # type: Optional[Dict[str, Dict[str, MetricWrapperBase]]] self._timestamp = time() self._sync_thread = None @@ -242,45 +243,47 @@ class StatisticsController(object): for message in consumer: # noinspection PyBroadException try: - data = json.loads(message.value.decode("utf-8")) + list_data = json.loads(message.value.decode("utf-8")) except Exception: print("Warning: failed to decode kafka stats message") continue - try: - url = data.pop("_url", None) - if not url: - # should not happen - continue - endpoint_metric = self._current_endpoints.get(url) - if not endpoint_metric: - # add default one, we will just log the reserved valued: - endpoint_metric = dict() - self._current_endpoints[url] = EndpointMetricLogging(endpoint=url) - # we should sync, - if time()-self._last_sync_time > self._sync_threshold_sec: - self._last_sync_time = time() - self._sync_event.set() + for data in list_data: + try: + url = data.pop("_url", None) + if not url: + # should not happen + continue + endpoint_metric = self._current_endpoints.get(url) + if not endpoint_metric: + # add default one, we will just log the reserved valued: + endpoint_metric = dict() + self._current_endpoints[url] = EndpointMetricLogging(endpoint=url) + self._auto_added_endpoints.add(url) + # we should sync, + if time()-self._last_sync_time > self._sync_threshold_sec: + self._last_sync_time = time() + self._sync_event.set() - metric_url_log = self._prometheus_metrics.get(url) - if not metric_url_log: - # create a new one - metric_url_log = dict() - self._prometheus_metrics[url] = metric_url_log + metric_url_log = self._prometheus_metrics.get(url) + if not metric_url_log: + # create a new one + metric_url_log = dict() + self._prometheus_metrics[url] = metric_url_log - # check if we have the prometheus_logger - for k, v in data.items(): - prometheus_logger = metric_url_log.get(k) - if not prometheus_logger: - prometheus_logger = self._create_prometheus_logger_class(url, k, endpoint_metric) + # check if we have the prometheus_logger + for k, v in data.items(): + prometheus_logger = metric_url_log.get(k) if not prometheus_logger: - continue - metric_url_log[k] = prometheus_logger + prometheus_logger = self._create_prometheus_logger_class(url, k, endpoint_metric) + if not prometheus_logger: + continue + metric_url_log[k] = prometheus_logger - self._report_value(prometheus_logger, v) + self._report_value(prometheus_logger, v) - except Exception as ex: - print("Warning: failed to report stat to Prometheus: {}".format(ex)) - continue + except Exception as ex: + print("Warning: failed to report stat to Prometheus: {}".format(ex)) + continue @staticmethod def _report_value(prometheus_logger: Optional[MetricWrapperBase], v: Any) -> bool: @@ -341,14 +344,20 @@ class StatisticsController(object): self._serving_service.reload() endpoint_metrics = self._serving_service.list_endpoint_logging() self._last_sync_time = time() - if self._current_endpoints == endpoint_metrics: + # we might have added new urls (auto metric logging), we need to compare only configured keys + current_endpoints = { + k: v for k, v in self._current_endpoints.items() + if k not in self._auto_added_endpoints} + if current_endpoints == endpoint_metrics: self._sync_event.wait(timeout=poll_freq_sec) self._sync_event.clear() continue # update metrics: self._dirty = True - self._current_endpoints = deepcopy(endpoint_metrics) + self._auto_added_endpoints -= set(endpoint_metrics.keys()) + # merge top level configuration (we might have auto logged url endpoints) + self._current_endpoints.update(deepcopy(endpoint_metrics)) print("New configuration synced") except Exception as ex: print("Warning: failed to sync state from serving service Task: {}".format(ex))