Optimize request serving statistics reporting

This commit is contained in:
allegroai 2022-06-07 00:20:33 +03:00
parent 4a55c10366
commit 48f720ac91
3 changed files with 127 additions and 62 deletions

View File

@ -87,7 +87,7 @@ router = APIRouter(
@router.post("/{model_id}/{version}")
@router.post("/{model_id}/")
@router.post("/{model_id}")
def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
async def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
try:
return_value = processor.process_request(
base_url=model_id,

View File

@ -1,7 +1,8 @@
import json
import os
from collections import deque
from pathlib import Path
from queue import Queue
# from queue import Queue
from random import random
from time import sleep, time
from typing import Optional, Union, Dict, List
@ -32,6 +33,37 @@ class FastWriteCounter(object):
return next(self._counter_inc) - next(self._counter_dec)
class FastSimpleQueue(object):
_default_wait_timeout = 10
def __init__(self):
self._deque = deque()
# Notify not_empty whenever an item is added to the queue; a
# thread waiting to get is notified then.
self._not_empty = threading.Event()
self._last_notify = time()
def put(self, a_object, block=True):
self._deque.append(a_object)
if time() - self._last_notify > self._default_wait_timeout:
self._not_empty.set()
self._last_notify = time()
def get(self, block=True):
while True:
try:
return self._deque.popleft()
except IndexError:
if not block:
return None
# wait until signaled
try:
if self._not_empty.wait(timeout=self._default_wait_timeout):
self._not_empty.clear()
except Exception as ex: # noqa
pass
class ModelRequestProcessor(object):
_system_tag = "serving-control-plane"
_kafka_topic = "clearml_inference_stats"
@ -75,7 +107,7 @@ class ModelRequestProcessor(object):
self._last_update_hash = None
self._sync_daemon_thread = None
self._stats_sending_thread = None
self._stats_queue = Queue()
self._stats_queue = FastSimpleQueue()
# this is used for Fast locking mechanisms (so we do not actually need to use Locks)
self._update_lock_flag = False
self._request_processing_state = FastWriteCounter()
@ -99,7 +131,7 @@ class ModelRequestProcessor(object):
if self._update_lock_flag:
self._request_processing_state.dec()
while self._update_lock_flag:
sleep(1)
sleep(0.5+random())
# retry to process
return self.process_request(base_url=base_url, version=version, request_body=request_body)
@ -820,6 +852,7 @@ class ModelRequestProcessor(object):
print("Starting Kafka Statistics reporting: {}".format(self._kafka_stats_url))
from kafka import KafkaProducer # noqa
import kafka.errors as Errors # noqa
while True:
try:
@ -836,16 +869,35 @@ class ModelRequestProcessor(object):
while True:
try:
stats_dict = self._stats_queue.get(block=True)
stats_list_dict = [self._stats_queue.get(block=True)]
while True:
v = self._stats_queue.get(block=False)
if v is None:
break
stats_list_dict.append(v)
except Exception as ex:
print("Warning: Statistics thread exception: {}".format(ex))
break
# send into kafka service
try:
producer.send(self._kafka_topic, value=stats_dict).get()
except Exception as ex:
print("Warning: Failed to send statistics packet to Kafka service: {}".format(ex))
pass
left_overs = []
while stats_list_dict or left_overs:
if not stats_list_dict:
stats_list_dict = left_overs
left_overs = []
# send into kafka service
try:
producer.send(self._kafka_topic, value=stats_list_dict).get()
stats_list_dict = []
except Errors.MessageSizeTooLargeError:
# log.debug("Splitting Kafka message in half [{}]".format(len(stats_list_dict)))
# split in half - message is too long for kafka to send
left_overs += stats_list_dict[len(stats_list_dict)//2:]
stats_list_dict = stats_list_dict[:len(stats_list_dict)//2]
continue
except Exception as ex:
print("Warning: Failed to send statistics packet to Kafka service: {}".format(ex))
break
def get_id(self) -> str:
return self._task.id
@ -1046,9 +1098,9 @@ class ModelRequestProcessor(object):
def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
# collect statistics for this request
stats = {}
stats_collect_fn = None
collect_stats = False
custom_stats = dict()
freq = 1
# decide if we are collecting the stats
metric_endpoint = self._metric_logging.get(url)
@ -1056,8 +1108,8 @@ class ModelRequestProcessor(object):
freq = metric_endpoint.log_frequency if metric_endpoint and metric_endpoint.log_frequency is not None \
else self._metric_log_freq
if freq and random() <= freq:
stats_collect_fn = stats.update
if freq and (freq >= 1 or random() <= freq):
stats_collect_fn = custom_stats.update
collect_stats = True
tic = time()
@ -1067,21 +1119,25 @@ class ModelRequestProcessor(object):
return_value = processor.postprocess(processed, state, stats_collect_fn)
tic = time() - tic
if collect_stats:
# 10th of a millisecond should be enough
stats['_latency'] = round(tic, 4)
stats['_count'] = int(1.0/freq)
stats['_url'] = url
stats = dict(
_latency=round(tic, 4), # 10th of a millisecond should be enough
_count=int(1.0/freq),
_url=url
)
# collect inputs
if metric_endpoint and body:
for k, v in body.items():
if k in metric_endpoint.metrics:
stats[k] = v
# collect outputs
if metric_endpoint and return_value:
for k, v in return_value.items():
if k in metric_endpoint.metrics:
stats[k] = v
if custom_stats:
stats.update(custom_stats)
if metric_endpoint:
metric_keys = set(metric_endpoint.metrics.keys())
# collect inputs
if body:
keys = set(body.keys()) & metric_keys
stats.update({k: body[k] for k in keys})
# collect outputs
if return_value:
keys = set(return_value.keys()) & metric_keys
stats.update({k: return_value[k] for k in keys})
# send stats in background, push it into a thread queue
# noinspection PyBroadException

View File

@ -7,7 +7,7 @@ from threading import Event, Thread
from time import time, sleep
from clearml import Task
from typing import Optional, Dict, Any, Iterable
from typing import Optional, Dict, Any, Iterable, Set
from prometheus_client import Histogram, Enum, Gauge, Counter, values
from kafka import KafkaConsumer
@ -204,6 +204,7 @@ class StatisticsController(object):
self._poll_frequency_min = float(poll_frequency_min)
self._serving_service = None # type: Optional[ModelRequestProcessor]
self._current_endpoints = {} # type: Optional[Dict[str, EndpointMetricLogging]]
self._auto_added_endpoints = set() # type: Set[str]
self._prometheus_metrics = {} # type: Optional[Dict[str, Dict[str, MetricWrapperBase]]]
self._timestamp = time()
self._sync_thread = None
@ -242,45 +243,47 @@ class StatisticsController(object):
for message in consumer:
# noinspection PyBroadException
try:
data = json.loads(message.value.decode("utf-8"))
list_data = json.loads(message.value.decode("utf-8"))
except Exception:
print("Warning: failed to decode kafka stats message")
continue
try:
url = data.pop("_url", None)
if not url:
# should not happen
continue
endpoint_metric = self._current_endpoints.get(url)
if not endpoint_metric:
# add default one, we will just log the reserved valued:
endpoint_metric = dict()
self._current_endpoints[url] = EndpointMetricLogging(endpoint=url)
# we should sync,
if time()-self._last_sync_time > self._sync_threshold_sec:
self._last_sync_time = time()
self._sync_event.set()
for data in list_data:
try:
url = data.pop("_url", None)
if not url:
# should not happen
continue
endpoint_metric = self._current_endpoints.get(url)
if not endpoint_metric:
# add default one, we will just log the reserved valued:
endpoint_metric = dict()
self._current_endpoints[url] = EndpointMetricLogging(endpoint=url)
self._auto_added_endpoints.add(url)
# we should sync,
if time()-self._last_sync_time > self._sync_threshold_sec:
self._last_sync_time = time()
self._sync_event.set()
metric_url_log = self._prometheus_metrics.get(url)
if not metric_url_log:
# create a new one
metric_url_log = dict()
self._prometheus_metrics[url] = metric_url_log
metric_url_log = self._prometheus_metrics.get(url)
if not metric_url_log:
# create a new one
metric_url_log = dict()
self._prometheus_metrics[url] = metric_url_log
# check if we have the prometheus_logger
for k, v in data.items():
prometheus_logger = metric_url_log.get(k)
if not prometheus_logger:
prometheus_logger = self._create_prometheus_logger_class(url, k, endpoint_metric)
# check if we have the prometheus_logger
for k, v in data.items():
prometheus_logger = metric_url_log.get(k)
if not prometheus_logger:
continue
metric_url_log[k] = prometheus_logger
prometheus_logger = self._create_prometheus_logger_class(url, k, endpoint_metric)
if not prometheus_logger:
continue
metric_url_log[k] = prometheus_logger
self._report_value(prometheus_logger, v)
self._report_value(prometheus_logger, v)
except Exception as ex:
print("Warning: failed to report stat to Prometheus: {}".format(ex))
continue
except Exception as ex:
print("Warning: failed to report stat to Prometheus: {}".format(ex))
continue
@staticmethod
def _report_value(prometheus_logger: Optional[MetricWrapperBase], v: Any) -> bool:
@ -341,14 +344,20 @@ class StatisticsController(object):
self._serving_service.reload()
endpoint_metrics = self._serving_service.list_endpoint_logging()
self._last_sync_time = time()
if self._current_endpoints == endpoint_metrics:
# we might have added new urls (auto metric logging), we need to compare only configured keys
current_endpoints = {
k: v for k, v in self._current_endpoints.items()
if k not in self._auto_added_endpoints}
if current_endpoints == endpoint_metrics:
self._sync_event.wait(timeout=poll_freq_sec)
self._sync_event.clear()
continue
# update metrics:
self._dirty = True
self._current_endpoints = deepcopy(endpoint_metrics)
self._auto_added_endpoints -= set(endpoint_metrics.keys())
# merge top level configuration (we might have auto logged url endpoints)
self._current_endpoints.update(deepcopy(endpoint_metrics))
print("New configuration synced")
except Exception as ex:
print("Warning: failed to sync state from serving service Task: {}".format(ex))