Optimize async processing for increased speed

This commit is contained in:
allegroai
2022-10-08 02:11:57 +03:00
parent f4aaf095a3
commit 395a547c04
7 changed files with 193 additions and 56 deletions

View File

@@ -2,13 +2,13 @@ import json
import os
from collections import deque
from pathlib import Path
# from queue import Queue
from random import random
from time import sleep, time
from typing import Optional, Union, Dict, List
import itertools
import threading
from multiprocessing import Lock
import asyncio
from numpy import isin
from numpy.random import choice
@@ -124,7 +124,7 @@ class ModelRequestProcessor(object):
self._serving_base_url = None
self._metric_log_freq = None
def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict:
async def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
"""
Process request coming in,
Raise Value error if url does not match existing endpoints
@@ -134,9 +134,9 @@ class ModelRequestProcessor(object):
if self._update_lock_flag:
self._request_processing_state.dec()
while self._update_lock_flag:
sleep(0.5+random())
await asyncio.sleep(0.5+random())
# retry to process
return self.process_request(base_url=base_url, version=version, request_body=request_body)
return await self.process_request(base_url=base_url, version=version, request_body=request_body)
try:
# normalize url and version
@@ -157,7 +157,7 @@ class ModelRequestProcessor(object):
processor = processor_cls(model_endpoint=ep, task=self._task)
self._engine_processor_lookup[url] = processor
return_value = self._process_request(processor=processor, url=url, body=request_body)
return_value = await self._process_request(processor=processor, url=url, body=request_body)
finally:
self._request_processing_state.dec()
@@ -271,7 +271,7 @@ class ModelRequestProcessor(object):
)
models = Model.query_models(max_results=2, **model_query)
if not models:
raise ValueError("Could not fine any Model to serve {}".format(model_query))
raise ValueError("Could not find any Model to serve {}".format(model_query))
if len(models) > 1:
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
endpoint.model_id = models[0].id
@@ -1133,7 +1133,7 @@ class ModelRequestProcessor(object):
# update preprocessing classes
BasePreprocessRequest.set_server_config(self._configuration)
def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict:
async def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
# collect statistics for this request
stats_collect_fn = None
collect_stats = False
@@ -1151,9 +1151,18 @@ class ModelRequestProcessor(object):
tic = time()
state = dict()
preprocessed = processor.preprocess(body, state, stats_collect_fn)
processed = processor.process(preprocessed, state, stats_collect_fn)
return_value = processor.postprocess(processed, state, stats_collect_fn)
# noinspection PyUnresolvedReferences
preprocessed = await processor.preprocess(body, state, stats_collect_fn) \
if processor.is_preprocess_async \
else processor.preprocess(body, state, stats_collect_fn)
# noinspection PyUnresolvedReferences
processed = await processor.process(preprocessed, state, stats_collect_fn) \
if processor.is_process_async \
else processor.process(preprocessed, state, stats_collect_fn)
# noinspection PyUnresolvedReferences
return_value = await processor.postprocess(processed, state, stats_collect_fn) \
if processor.is_postprocess_async \
else processor.postprocess(processed, state, stats_collect_fn)
tic = time() - tic
if collect_stats:
stats = dict(