mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
Optimize async processing for increased speed
This commit is contained in:
@@ -2,13 +2,13 @@ import json
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
# from queue import Queue
|
||||
from random import random
|
||||
from time import sleep, time
|
||||
from typing import Optional, Union, Dict, List
|
||||
import itertools
|
||||
import threading
|
||||
from multiprocessing import Lock
|
||||
import asyncio
|
||||
from numpy import isin
|
||||
from numpy.random import choice
|
||||
|
||||
@@ -124,7 +124,7 @@ class ModelRequestProcessor(object):
|
||||
self._serving_base_url = None
|
||||
self._metric_log_freq = None
|
||||
|
||||
def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict:
|
||||
async def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
|
||||
"""
|
||||
Process request coming in,
|
||||
Raise Value error if url does not match existing endpoints
|
||||
@@ -134,9 +134,9 @@ class ModelRequestProcessor(object):
|
||||
if self._update_lock_flag:
|
||||
self._request_processing_state.dec()
|
||||
while self._update_lock_flag:
|
||||
sleep(0.5+random())
|
||||
await asyncio.sleep(0.5+random())
|
||||
# retry to process
|
||||
return self.process_request(base_url=base_url, version=version, request_body=request_body)
|
||||
return await self.process_request(base_url=base_url, version=version, request_body=request_body)
|
||||
|
||||
try:
|
||||
# normalize url and version
|
||||
@@ -157,7 +157,7 @@ class ModelRequestProcessor(object):
|
||||
processor = processor_cls(model_endpoint=ep, task=self._task)
|
||||
self._engine_processor_lookup[url] = processor
|
||||
|
||||
return_value = self._process_request(processor=processor, url=url, body=request_body)
|
||||
return_value = await self._process_request(processor=processor, url=url, body=request_body)
|
||||
finally:
|
||||
self._request_processing_state.dec()
|
||||
|
||||
@@ -271,7 +271,7 @@ class ModelRequestProcessor(object):
|
||||
)
|
||||
models = Model.query_models(max_results=2, **model_query)
|
||||
if not models:
|
||||
raise ValueError("Could not fine any Model to serve {}".format(model_query))
|
||||
raise ValueError("Could not find any Model to serve {}".format(model_query))
|
||||
if len(models) > 1:
|
||||
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
|
||||
endpoint.model_id = models[0].id
|
||||
@@ -1133,7 +1133,7 @@ class ModelRequestProcessor(object):
|
||||
# update preprocessing classes
|
||||
BasePreprocessRequest.set_server_config(self._configuration)
|
||||
|
||||
def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict:
|
||||
async def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
|
||||
# collect statistics for this request
|
||||
stats_collect_fn = None
|
||||
collect_stats = False
|
||||
@@ -1151,9 +1151,18 @@ class ModelRequestProcessor(object):
|
||||
|
||||
tic = time()
|
||||
state = dict()
|
||||
preprocessed = processor.preprocess(body, state, stats_collect_fn)
|
||||
processed = processor.process(preprocessed, state, stats_collect_fn)
|
||||
return_value = processor.postprocess(processed, state, stats_collect_fn)
|
||||
# noinspection PyUnresolvedReferences
|
||||
preprocessed = await processor.preprocess(body, state, stats_collect_fn) \
|
||||
if processor.is_preprocess_async \
|
||||
else processor.preprocess(body, state, stats_collect_fn)
|
||||
# noinspection PyUnresolvedReferences
|
||||
processed = await processor.process(preprocessed, state, stats_collect_fn) \
|
||||
if processor.is_process_async \
|
||||
else processor.process(preprocessed, state, stats_collect_fn)
|
||||
# noinspection PyUnresolvedReferences
|
||||
return_value = await processor.postprocess(processed, state, stats_collect_fn) \
|
||||
if processor.is_postprocess_async \
|
||||
else processor.postprocess(processed, state, stats_collect_fn)
|
||||
tic = time() - tic
|
||||
if collect_stats:
|
||||
stats = dict(
|
||||
|
||||
Reference in New Issue
Block a user