Optimize request serving statistics reporting

This commit is contained in:
allegroai 2022-06-07 00:20:33 +03:00
parent 4a55c10366
commit 48f720ac91
3 changed files with 127 additions and 62 deletions

View File

@ -87,7 +87,7 @@ router = APIRouter(
@router.post("/{model_id}/{version}") @router.post("/{model_id}/{version}")
@router.post("/{model_id}/") @router.post("/{model_id}/")
@router.post("/{model_id}") @router.post("/{model_id}")
def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None): async def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
try: try:
return_value = processor.process_request( return_value = processor.process_request(
base_url=model_id, base_url=model_id,

View File

@ -1,7 +1,8 @@
import json import json
import os import os
from collections import deque
from pathlib import Path from pathlib import Path
from queue import Queue # from queue import Queue
from random import random from random import random
from time import sleep, time from time import sleep, time
from typing import Optional, Union, Dict, List from typing import Optional, Union, Dict, List
@ -32,6 +33,37 @@ class FastWriteCounter(object):
return next(self._counter_inc) - next(self._counter_dec) return next(self._counter_inc) - next(self._counter_dec)
class FastSimpleQueue(object):
_default_wait_timeout = 10
def __init__(self):
self._deque = deque()
# Notify not_empty whenever an item is added to the queue; a
# thread waiting to get is notified then.
self._not_empty = threading.Event()
self._last_notify = time()
def put(self, a_object, block=True):
self._deque.append(a_object)
if time() - self._last_notify > self._default_wait_timeout:
self._not_empty.set()
self._last_notify = time()
def get(self, block=True):
while True:
try:
return self._deque.popleft()
except IndexError:
if not block:
return None
# wait until signaled
try:
if self._not_empty.wait(timeout=self._default_wait_timeout):
self._not_empty.clear()
except Exception as ex: # noqa
pass
class ModelRequestProcessor(object): class ModelRequestProcessor(object):
_system_tag = "serving-control-plane" _system_tag = "serving-control-plane"
_kafka_topic = "clearml_inference_stats" _kafka_topic = "clearml_inference_stats"
@ -75,7 +107,7 @@ class ModelRequestProcessor(object):
self._last_update_hash = None self._last_update_hash = None
self._sync_daemon_thread = None self._sync_daemon_thread = None
self._stats_sending_thread = None self._stats_sending_thread = None
self._stats_queue = Queue() self._stats_queue = FastSimpleQueue()
# this is used for Fast locking mechanisms (so we do not actually need to use Locks) # this is used for Fast locking mechanisms (so we do not actually need to use Locks)
self._update_lock_flag = False self._update_lock_flag = False
self._request_processing_state = FastWriteCounter() self._request_processing_state = FastWriteCounter()
@ -99,7 +131,7 @@ class ModelRequestProcessor(object):
if self._update_lock_flag: if self._update_lock_flag:
self._request_processing_state.dec() self._request_processing_state.dec()
while self._update_lock_flag: while self._update_lock_flag:
sleep(1) sleep(0.5+random())
# retry to process # retry to process
return self.process_request(base_url=base_url, version=version, request_body=request_body) return self.process_request(base_url=base_url, version=version, request_body=request_body)
@ -820,6 +852,7 @@ class ModelRequestProcessor(object):
print("Starting Kafka Statistics reporting: {}".format(self._kafka_stats_url)) print("Starting Kafka Statistics reporting: {}".format(self._kafka_stats_url))
from kafka import KafkaProducer # noqa from kafka import KafkaProducer # noqa
import kafka.errors as Errors # noqa
while True: while True:
try: try:
@ -836,16 +869,35 @@ class ModelRequestProcessor(object):
while True: while True:
try: try:
stats_dict = self._stats_queue.get(block=True) stats_list_dict = [self._stats_queue.get(block=True)]
while True:
v = self._stats_queue.get(block=False)
if v is None:
break
stats_list_dict.append(v)
except Exception as ex: except Exception as ex:
print("Warning: Statistics thread exception: {}".format(ex)) print("Warning: Statistics thread exception: {}".format(ex))
break break
# send into kafka service
try: left_overs = []
producer.send(self._kafka_topic, value=stats_dict).get() while stats_list_dict or left_overs:
except Exception as ex: if not stats_list_dict:
print("Warning: Failed to send statistics packet to Kafka service: {}".format(ex)) stats_list_dict = left_overs
pass left_overs = []
# send into kafka service
try:
producer.send(self._kafka_topic, value=stats_list_dict).get()
stats_list_dict = []
except Errors.MessageSizeTooLargeError:
# log.debug("Splitting Kafka message in half [{}]".format(len(stats_list_dict)))
# split in half - message is too long for kafka to send
left_overs += stats_list_dict[len(stats_list_dict)//2:]
stats_list_dict = stats_list_dict[:len(stats_list_dict)//2]
continue
except Exception as ex:
print("Warning: Failed to send statistics packet to Kafka service: {}".format(ex))
break
def get_id(self) -> str: def get_id(self) -> str:
return self._task.id return self._task.id
@ -1046,9 +1098,9 @@ class ModelRequestProcessor(object):
def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict: def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
# collect statistics for this request # collect statistics for this request
stats = {}
stats_collect_fn = None stats_collect_fn = None
collect_stats = False collect_stats = False
custom_stats = dict()
freq = 1 freq = 1
# decide if we are collecting the stats # decide if we are collecting the stats
metric_endpoint = self._metric_logging.get(url) metric_endpoint = self._metric_logging.get(url)
@ -1056,8 +1108,8 @@ class ModelRequestProcessor(object):
freq = metric_endpoint.log_frequency if metric_endpoint and metric_endpoint.log_frequency is not None \ freq = metric_endpoint.log_frequency if metric_endpoint and metric_endpoint.log_frequency is not None \
else self._metric_log_freq else self._metric_log_freq
if freq and random() <= freq: if freq and (freq >= 1 or random() <= freq):
stats_collect_fn = stats.update stats_collect_fn = custom_stats.update
collect_stats = True collect_stats = True
tic = time() tic = time()
@ -1067,21 +1119,25 @@ class ModelRequestProcessor(object):
return_value = processor.postprocess(processed, state, stats_collect_fn) return_value = processor.postprocess(processed, state, stats_collect_fn)
tic = time() - tic tic = time() - tic
if collect_stats: if collect_stats:
# 10th of a millisecond should be enough stats = dict(
stats['_latency'] = round(tic, 4) _latency=round(tic, 4), # 10th of a millisecond should be enough
stats['_count'] = int(1.0/freq) _count=int(1.0/freq),
stats['_url'] = url _url=url
)
# collect inputs if custom_stats:
if metric_endpoint and body: stats.update(custom_stats)
for k, v in body.items():
if k in metric_endpoint.metrics: if metric_endpoint:
stats[k] = v metric_keys = set(metric_endpoint.metrics.keys())
# collect outputs # collect inputs
if metric_endpoint and return_value: if body:
for k, v in return_value.items(): keys = set(body.keys()) & metric_keys
if k in metric_endpoint.metrics: stats.update({k: body[k] for k in keys})
stats[k] = v # collect outputs
if return_value:
keys = set(return_value.keys()) & metric_keys
stats.update({k: return_value[k] for k in keys})
# send stats in background, push it into a thread queue # send stats in background, push it into a thread queue
# noinspection PyBroadException # noinspection PyBroadException

View File

@ -7,7 +7,7 @@ from threading import Event, Thread
from time import time, sleep from time import time, sleep
from clearml import Task from clearml import Task
from typing import Optional, Dict, Any, Iterable from typing import Optional, Dict, Any, Iterable, Set
from prometheus_client import Histogram, Enum, Gauge, Counter, values from prometheus_client import Histogram, Enum, Gauge, Counter, values
from kafka import KafkaConsumer from kafka import KafkaConsumer
@ -204,6 +204,7 @@ class StatisticsController(object):
self._poll_frequency_min = float(poll_frequency_min) self._poll_frequency_min = float(poll_frequency_min)
self._serving_service = None # type: Optional[ModelRequestProcessor] self._serving_service = None # type: Optional[ModelRequestProcessor]
self._current_endpoints = {} # type: Optional[Dict[str, EndpointMetricLogging]] self._current_endpoints = {} # type: Optional[Dict[str, EndpointMetricLogging]]
self._auto_added_endpoints = set() # type: Set[str]
self._prometheus_metrics = {} # type: Optional[Dict[str, Dict[str, MetricWrapperBase]]] self._prometheus_metrics = {} # type: Optional[Dict[str, Dict[str, MetricWrapperBase]]]
self._timestamp = time() self._timestamp = time()
self._sync_thread = None self._sync_thread = None
@ -242,45 +243,47 @@ class StatisticsController(object):
for message in consumer: for message in consumer:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
data = json.loads(message.value.decode("utf-8")) list_data = json.loads(message.value.decode("utf-8"))
except Exception: except Exception:
print("Warning: failed to decode kafka stats message") print("Warning: failed to decode kafka stats message")
continue continue
try: for data in list_data:
url = data.pop("_url", None) try:
if not url: url = data.pop("_url", None)
# should not happen if not url:
continue # should not happen
endpoint_metric = self._current_endpoints.get(url) continue
if not endpoint_metric: endpoint_metric = self._current_endpoints.get(url)
# add default one, we will just log the reserved valued: if not endpoint_metric:
endpoint_metric = dict() # add default one, we will just log the reserved valued:
self._current_endpoints[url] = EndpointMetricLogging(endpoint=url) endpoint_metric = dict()
# we should sync, self._current_endpoints[url] = EndpointMetricLogging(endpoint=url)
if time()-self._last_sync_time > self._sync_threshold_sec: self._auto_added_endpoints.add(url)
self._last_sync_time = time() # we should sync,
self._sync_event.set() if time()-self._last_sync_time > self._sync_threshold_sec:
self._last_sync_time = time()
self._sync_event.set()
metric_url_log = self._prometheus_metrics.get(url) metric_url_log = self._prometheus_metrics.get(url)
if not metric_url_log: if not metric_url_log:
# create a new one # create a new one
metric_url_log = dict() metric_url_log = dict()
self._prometheus_metrics[url] = metric_url_log self._prometheus_metrics[url] = metric_url_log
# check if we have the prometheus_logger # check if we have the prometheus_logger
for k, v in data.items(): for k, v in data.items():
prometheus_logger = metric_url_log.get(k) prometheus_logger = metric_url_log.get(k)
if not prometheus_logger:
prometheus_logger = self._create_prometheus_logger_class(url, k, endpoint_metric)
if not prometheus_logger: if not prometheus_logger:
continue prometheus_logger = self._create_prometheus_logger_class(url, k, endpoint_metric)
metric_url_log[k] = prometheus_logger if not prometheus_logger:
continue
metric_url_log[k] = prometheus_logger
self._report_value(prometheus_logger, v) self._report_value(prometheus_logger, v)
except Exception as ex: except Exception as ex:
print("Warning: failed to report stat to Prometheus: {}".format(ex)) print("Warning: failed to report stat to Prometheus: {}".format(ex))
continue continue
@staticmethod @staticmethod
def _report_value(prometheus_logger: Optional[MetricWrapperBase], v: Any) -> bool: def _report_value(prometheus_logger: Optional[MetricWrapperBase], v: Any) -> bool:
@ -341,14 +344,20 @@ class StatisticsController(object):
self._serving_service.reload() self._serving_service.reload()
endpoint_metrics = self._serving_service.list_endpoint_logging() endpoint_metrics = self._serving_service.list_endpoint_logging()
self._last_sync_time = time() self._last_sync_time = time()
if self._current_endpoints == endpoint_metrics: # we might have added new urls (auto metric logging), we need to compare only configured keys
current_endpoints = {
k: v for k, v in self._current_endpoints.items()
if k not in self._auto_added_endpoints}
if current_endpoints == endpoint_metrics:
self._sync_event.wait(timeout=poll_freq_sec) self._sync_event.wait(timeout=poll_freq_sec)
self._sync_event.clear() self._sync_event.clear()
continue continue
# update metrics: # update metrics:
self._dirty = True self._dirty = True
self._current_endpoints = deepcopy(endpoint_metrics) self._auto_added_endpoints -= set(endpoint_metrics.keys())
# merge top level configuration (we might have auto logged url endpoints)
self._current_endpoints.update(deepcopy(endpoint_metrics))
print("New configuration synced") print("New configuration synced")
except Exception as ex: except Exception as ex:
print("Warning: failed to sync state from serving service Task: {}".format(ex)) print("Warning: failed to sync state from serving service Task: {}".format(ex))