Optimize async processing for increased speed

This commit is contained in:
allegroai 2022-10-08 02:11:57 +03:00
parent f4aaf095a3
commit 395a547c04
7 changed files with 193 additions and 56 deletions

View File

@ -1,6 +1,7 @@
clearml >= 1.3.1 clearml >= 1.3.1
clearml-serving clearml-serving
tritonclient[grpc]>=2.22.3,<2.23 tritonclient[grpc]>=2.25,<2.26
starlette
grpcio grpcio
Pillow>=9.0.1,<10 Pillow>=9.0.1,<10
pathlib2 pathlib2

View File

@ -0,0 +1,33 @@
import os
from clearml import Task
from clearml_serving.serving.model_request_processor import ModelRequestProcessor
from clearml_serving.serving.preprocess_service import BasePreprocessRequest
def setup_task(force_threaded_logging=None):
serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None)
# always use background thread, it requires less memory
if force_threaded_logging or os.environ.get("CLEARML_BKG_THREAD_REPORT") in ("1", "Y", "y", "true"):
os.environ["CLEARML_BKG_THREAD_REPORT"] = "1"
Task._report_subprocess_enabled = False
# get the serving controller task
# noinspection PyProtectedMember
serving_task = ModelRequestProcessor._get_control_plane_task(task_id=serving_service_task_id)
# set to running (because we are here)
if serving_task.status != "in_progress":
serving_task.started(force=True)
# create a new serving instance (for visibility and monitoring)
instance_task = Task.init(
project_name=serving_task.get_project_name(),
task_name="{} - serve instance".format(serving_task.name),
task_type="inference", # noqa
)
instance_task.set_system_tags(["service"])
# preload modules into memory before forking
BasePreprocessRequest.load_modules()
return serving_service_task_id

View File

@ -7,10 +7,9 @@ from fastapi.routing import APIRoute
from typing import Optional, Dict, Any, Callable, Union from typing import Optional, Dict, Any, Callable, Union
from clearml import Task
from clearml_serving.version import __version__ from clearml_serving.version import __version__
from clearml_serving.serving.init import setup_task
from clearml_serving.serving.model_request_processor import ModelRequestProcessor from clearml_serving.serving.model_request_processor import ModelRequestProcessor
from clearml_serving.serving.preprocess_service import BasePreprocessRequest
class GzipRequest(Request): class GzipRequest(Request):
@ -35,31 +34,20 @@ class GzipRoute(APIRoute):
# process Lock, so that we can have only a single process doing the model reloading at a time # process Lock, so that we can have only a single process doing the model reloading at a time
singleton_sync_lock = Lock() singleton_sync_lock = None # Lock()
# shared Model processor object
processor = None # type: Optional[ModelRequestProcessor]
serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None) # create clearml Task and load models
serving_service_task_id = setup_task()
# polling frequency
model_sync_frequency_secs = 5 model_sync_frequency_secs = 5
try: try:
model_sync_frequency_secs = float(os.environ.get("CLEARML_SERVING_POLL_FREQ", model_sync_frequency_secs)) model_sync_frequency_secs = float(os.environ.get("CLEARML_SERVING_POLL_FREQ", model_sync_frequency_secs))
except (ValueError, TypeError): except (ValueError, TypeError):
pass pass
# get the serving controller task
# noinspection PyProtectedMember
serving_task = ModelRequestProcessor._get_control_plane_task(task_id=serving_service_task_id)
# set to running (because we are here)
if serving_task.status != "in_progress":
serving_task.started(force=True)
# create a new serving instance (for visibility and monitoring)
instance_task = Task.init(
project_name=serving_task.get_project_name(),
task_name="{} - serve instance".format(serving_task.name),
task_type="inference",
)
instance_task.set_system_tags(["service"])
processor = None # type: Optional[ModelRequestProcessor]
# preload modules into memory before forking
BasePreprocessRequest.load_modules()
# start FastAPI app # start FastAPI app
app = FastAPI(title="ClearML Serving Service", version=__version__, description="ClearML Service Service router") app = FastAPI(title="ClearML Serving Service", version=__version__, description="ClearML Service Service router")
@ -67,7 +55,13 @@ app = FastAPI(title="ClearML Serving Service", version=__version__, description=
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
global processor global processor
print("Starting up ModelRequestProcessor [pid={}] [service_id={}]".format(os.getpid(), serving_service_task_id))
if processor:
print("ModelRequestProcessor already initialized [pid={}] [service_id={}]".format(
os.getpid(), serving_service_task_id))
else:
print("Starting up ModelRequestProcessor [pid={}] [service_id={}]".format(
os.getpid(), serving_service_task_id))
processor = ModelRequestProcessor( processor = ModelRequestProcessor(
task_id=serving_service_task_id, update_lock_guard=singleton_sync_lock, task_id=serving_service_task_id, update_lock_guard=singleton_sync_lock,
) )
@ -89,13 +83,15 @@ router = APIRouter(
@router.post("/{model_id}") @router.post("/{model_id}")
async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None): async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None):
try: try:
return_value = processor.process_request( return_value = await processor.process_request(
base_url=model_id, base_url=model_id,
version=version, version=version,
request_body=request request_body=request
) )
except Exception as ex: except ValueError as ex:
raise HTTPException(status_code=404, detail="Error processing request: {}".format(ex)) raise HTTPException(status_code=404, detail="Error processing request: {}".format(ex))
except Exception as ex:
raise HTTPException(status_code=500, detail="Error processing request: {}".format(ex))
return return_value return return_value

View File

@ -2,13 +2,13 @@ import json
import os import os
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
# from queue import Queue
from random import random from random import random
from time import sleep, time from time import sleep, time
from typing import Optional, Union, Dict, List from typing import Optional, Union, Dict, List
import itertools import itertools
import threading import threading
from multiprocessing import Lock from multiprocessing import Lock
import asyncio
from numpy import isin from numpy import isin
from numpy.random import choice from numpy.random import choice
@ -124,7 +124,7 @@ class ModelRequestProcessor(object):
self._serving_base_url = None self._serving_base_url = None
self._metric_log_freq = None self._metric_log_freq = None
def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict: async def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
""" """
Process request coming in, Process request coming in,
Raise Value error if url does not match existing endpoints Raise Value error if url does not match existing endpoints
@ -134,9 +134,9 @@ class ModelRequestProcessor(object):
if self._update_lock_flag: if self._update_lock_flag:
self._request_processing_state.dec() self._request_processing_state.dec()
while self._update_lock_flag: while self._update_lock_flag:
sleep(0.5+random()) await asyncio.sleep(0.5+random())
# retry to process # retry to process
return self.process_request(base_url=base_url, version=version, request_body=request_body) return await self.process_request(base_url=base_url, version=version, request_body=request_body)
try: try:
# normalize url and version # normalize url and version
@ -157,7 +157,7 @@ class ModelRequestProcessor(object):
processor = processor_cls(model_endpoint=ep, task=self._task) processor = processor_cls(model_endpoint=ep, task=self._task)
self._engine_processor_lookup[url] = processor self._engine_processor_lookup[url] = processor
return_value = self._process_request(processor=processor, url=url, body=request_body) return_value = await self._process_request(processor=processor, url=url, body=request_body)
finally: finally:
self._request_processing_state.dec() self._request_processing_state.dec()
@ -271,7 +271,7 @@ class ModelRequestProcessor(object):
) )
models = Model.query_models(max_results=2, **model_query) models = Model.query_models(max_results=2, **model_query)
if not models: if not models:
raise ValueError("Could not fine any Model to serve {}".format(model_query)) raise ValueError("Could not find any Model to serve {}".format(model_query))
if len(models) > 1: if len(models) > 1:
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id)) print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
endpoint.model_id = models[0].id endpoint.model_id = models[0].id
@ -1133,7 +1133,7 @@ class ModelRequestProcessor(object):
# update preprocessing classes # update preprocessing classes
BasePreprocessRequest.set_server_config(self._configuration) BasePreprocessRequest.set_server_config(self._configuration)
def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict: async def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
# collect statistics for this request # collect statistics for this request
stats_collect_fn = None stats_collect_fn = None
collect_stats = False collect_stats = False
@ -1151,9 +1151,18 @@ class ModelRequestProcessor(object):
tic = time() tic = time()
state = dict() state = dict()
preprocessed = processor.preprocess(body, state, stats_collect_fn) # noinspection PyUnresolvedReferences
processed = processor.process(preprocessed, state, stats_collect_fn) preprocessed = await processor.preprocess(body, state, stats_collect_fn) \
return_value = processor.postprocess(processed, state, stats_collect_fn) if processor.is_preprocess_async \
else processor.preprocess(body, state, stats_collect_fn)
# noinspection PyUnresolvedReferences
processed = await processor.process(preprocessed, state, stats_collect_fn) \
if processor.is_process_async \
else processor.process(preprocessed, state, stats_collect_fn)
# noinspection PyUnresolvedReferences
return_value = await processor.postprocess(processed, state, stats_collect_fn) \
if processor.is_postprocess_async \
else processor.postprocess(processed, state, stats_collect_fn)
tic = time() - tic tic = time() - tic
if collect_stats: if collect_stats:
stats = dict( stats = dict(

View File

@ -1,5 +1,6 @@
import os import os
import sys import sys
import threading
from pathlib import Path from pathlib import Path
from typing import Optional, Any, Callable, List from typing import Optional, Any, Callable, List
@ -18,6 +19,9 @@ class BasePreprocessRequest(object):
_default_serving_base_url = "http://127.0.0.1:8080/serve/" _default_serving_base_url = "http://127.0.0.1:8080/serve/"
_server_config = {} # externally configured by the serving inference service _server_config = {} # externally configured by the serving inference service
_timeout = None # timeout in seconds for the entire request, set in __init__ _timeout = None # timeout in seconds for the entire request, set in __init__
is_preprocess_async = False
is_process_async = False
is_postprocess_async = False
def __init__( def __init__(
self, self,
@ -259,6 +263,9 @@ class TritonPreprocessRequest(BasePreprocessRequest):
_ext_np_to_triton_dtype = None _ext_np_to_triton_dtype = None
_ext_service_pb2 = None _ext_service_pb2 = None
_ext_service_pb2_grpc = None _ext_service_pb2_grpc = None
is_preprocess_async = False
is_process_async = True
is_postprocess_async = False
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None): def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(TritonPreprocessRequest, self).__init__( super(TritonPreprocessRequest, self).__init__(
@ -266,7 +273,7 @@ class TritonPreprocessRequest(BasePreprocessRequest):
# load Triton Module # load Triton Module
if self._ext_grpc is None: if self._ext_grpc is None:
import grpc # noqa from tritonclient.grpc import grpc # noqa
self._ext_grpc = grpc self._ext_grpc = grpc
if self._ext_np_to_triton_dtype is None: if self._ext_np_to_triton_dtype is None:
@ -274,11 +281,13 @@ class TritonPreprocessRequest(BasePreprocessRequest):
self._ext_np_to_triton_dtype = np_to_triton_dtype self._ext_np_to_triton_dtype = np_to_triton_dtype
if self._ext_service_pb2 is None: if self._ext_service_pb2 is None:
from tritonclient.grpc import service_pb2, service_pb2_grpc # noqa from tritonclient.grpc.aio import service_pb2, service_pb2_grpc # noqa
self._ext_service_pb2 = service_pb2 self._ext_service_pb2 = service_pb2
self._ext_service_pb2_grpc = service_pb2_grpc self._ext_service_pb2_grpc = service_pb2_grpc
def process( self._grpc_stub = {}
async def process(
self, self,
data: Any, data: Any,
state: dict, state: dict,
@ -312,9 +321,15 @@ class TritonPreprocessRequest(BasePreprocessRequest):
triton_server_address = self._server_config.get("triton_grpc_server") or self._default_grpc_address triton_server_address = self._server_config.get("triton_grpc_server") or self._default_grpc_address
if not triton_server_address: if not triton_server_address:
raise ValueError("External Triton gRPC server is not configured!") raise ValueError("External Triton gRPC server is not configured!")
tid = threading.get_ident()
if self._grpc_stub.get(tid):
grpc_stub = self._grpc_stub.get(tid)
else:
try: try:
channel = self._ext_grpc.insecure_channel(triton_server_address) channel = self._ext_grpc.aio.insecure_channel(triton_server_address)
grpc_stub = self._ext_service_pb2_grpc.GRPCInferenceServiceStub(channel) grpc_stub = self._ext_service_pb2_grpc.GRPCInferenceServiceStub(channel)
self._grpc_stub[tid] = grpc_stub
except Exception as ex: except Exception as ex:
raise ValueError("External Triton gRPC server misconfigured [{}]: {}".format(triton_server_address, ex)) raise ValueError("External Triton gRPC server misconfigured [{}]: {}".format(triton_server_address, ex))
@ -364,15 +379,11 @@ class TritonPreprocessRequest(BasePreprocessRequest):
try: try:
compression = self._ext_grpc.Compression.Gzip if use_compression \ compression = self._ext_grpc.Compression.Gzip if use_compression \
else self._ext_grpc.Compression.NoCompression else self._ext_grpc.Compression.NoCompression
response = grpc_stub.ModelInfer( response = await grpc_stub.ModelInfer(request, compression=compression, timeout=self._timeout)
request, except Exception as ex:
compression=compression,
timeout=self._timeout
)
except Exception:
print("Exception calling Triton RPC function: " print("Exception calling Triton RPC function: "
"request_inputs={}, ".format([(r.name, r.shape, r.datatype) for r in (request.inputs or [])]) + "request_inputs={}, ".format([(r.name, r.shape, r.datatype) for r in (request.inputs or [])]) +
f"triton_address={triton_server_address}, compression={compression}, timeout={self._timeout}") f"triton_address={triton_server_address}, compression={compression}, timeout={self._timeout}:\n{ex}")
raise raise
# process result # process result
@ -464,3 +475,83 @@ class CustomPreprocessRequest(BasePreprocessRequest):
if self._preprocess is not None and hasattr(self._preprocess, 'process'): if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return self._preprocess.process(data, state, collect_custom_statistics_fn) return self._preprocess.process(data, state, collect_custom_statistics_fn)
return None return None
@BasePreprocessRequest.register_engine("custom_async")
class CustomAsyncPreprocessRequest(BasePreprocessRequest):
is_preprocess_async = True
is_process_async = True
is_postprocess_async = True
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(CustomAsyncPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task)
async def preprocess(
self,
request: dict,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None,
) -> Optional[Any]:
"""
Raise exception to report an error
Return value will be passed to serving engine
:param request: dictionary as recieved from the RestAPI
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd
Usage example:
>>> print(request)
{"x0": 1, "x1": 2}
>>> collect_custom_statistics_fn({"x0": 1, "x1": 2})
:return: Object to be passed directly to the model inference
"""
if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'):
return await self._preprocess.preprocess(request, state, collect_custom_statistics_fn)
return request
async def postprocess(
self,
data: Any,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None
) -> Optional[dict]:
"""
Raise exception to report an error
Return value will be passed to serving engine
:param data: object as recieved from the inference model function
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd
Usage example:
>>> collect_custom_statistics_fn({"y": 1})
:return: Dictionary passed directly as the returned result of the RestAPI
"""
if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'):
return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn)
return data
async def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the process in this context
"""
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return await self._preprocess.process(data, state, collect_custom_statistics_fn)
return None

View File

@ -5,7 +5,8 @@ uvicorn[standard]
gunicorn>=20.1.0,<20.2 gunicorn>=20.1.0,<20.2
asyncio>=3.4.3,<3.5 asyncio>=3.4.3,<3.5
aiocache>=0.11.1,<0.12 aiocache>=0.11.1,<0.12
tritonclient[grpc]>=2.22.3,<2.23 tritonclient[grpc]>=2.25,<2.26
starlette
numpy>=1.20,<1.24 numpy>=1.20,<1.24
scikit-learn>=1.0.2,<1.1 scikit-learn>=1.0.2,<1.1
pandas>=1.0.5,<1.5 pandas>=1.0.5,<1.5

View File

@ -0,0 +1,6 @@
import uvicorn
from clearml_serving.serving.init import setup_task
if __name__ == "__main__":
setup_task(force_threaded_logging=True)
uvicorn.main()