Optimize async processing for increased speed

This commit is contained in:
allegroai 2022-10-08 02:11:57 +03:00
parent f4aaf095a3
commit 395a547c04
7 changed files with 193 additions and 56 deletions

View File

@ -1,6 +1,7 @@
clearml >= 1.3.1
clearml-serving
tritonclient[grpc]>=2.22.3,<2.23
tritonclient[grpc]>=2.25,<2.26
starlette
grpcio
Pillow>=9.0.1,<10
pathlib2
pathlib2

View File

@ -0,0 +1,33 @@
import os
from clearml import Task
from clearml_serving.serving.model_request_processor import ModelRequestProcessor
from clearml_serving.serving.preprocess_service import BasePreprocessRequest
def setup_task(force_threaded_logging=None):
serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None)
# always use background thread, it requires less memory
if force_threaded_logging or os.environ.get("CLEARML_BKG_THREAD_REPORT") in ("1", "Y", "y", "true"):
os.environ["CLEARML_BKG_THREAD_REPORT"] = "1"
Task._report_subprocess_enabled = False
# get the serving controller task
# noinspection PyProtectedMember
serving_task = ModelRequestProcessor._get_control_plane_task(task_id=serving_service_task_id)
# set to running (because we are here)
if serving_task.status != "in_progress":
serving_task.started(force=True)
# create a new serving instance (for visibility and monitoring)
instance_task = Task.init(
project_name=serving_task.get_project_name(),
task_name="{} - serve instance".format(serving_task.name),
task_type="inference", # noqa
)
instance_task.set_system_tags(["service"])
# preload modules into memory before forking
BasePreprocessRequest.load_modules()
return serving_service_task_id

View File

@ -7,10 +7,9 @@ from fastapi.routing import APIRoute
from typing import Optional, Dict, Any, Callable, Union
from clearml import Task
from clearml_serving.version import __version__
from clearml_serving.serving.init import setup_task
from clearml_serving.serving.model_request_processor import ModelRequestProcessor
from clearml_serving.serving.preprocess_service import BasePreprocessRequest
class GzipRequest(Request):
@ -35,31 +34,20 @@ class GzipRoute(APIRoute):
# process Lock, so that we can have only a single process doing the model reloading at a time
singleton_sync_lock = Lock()
singleton_sync_lock = None # Lock()
# shared Model processor object
processor = None # type: Optional[ModelRequestProcessor]
serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None)
# create clearml Task and load models
serving_service_task_id = setup_task()
# polling frequency
model_sync_frequency_secs = 5
try:
model_sync_frequency_secs = float(os.environ.get("CLEARML_SERVING_POLL_FREQ", model_sync_frequency_secs))
except (ValueError, TypeError):
pass
# get the serving controller task
# noinspection PyProtectedMember
serving_task = ModelRequestProcessor._get_control_plane_task(task_id=serving_service_task_id)
# set to running (because we are here)
if serving_task.status != "in_progress":
serving_task.started(force=True)
# create a new serving instance (for visibility and monitoring)
instance_task = Task.init(
project_name=serving_task.get_project_name(),
task_name="{} - serve instance".format(serving_task.name),
task_type="inference",
)
instance_task.set_system_tags(["service"])
processor = None # type: Optional[ModelRequestProcessor]
# preload modules into memory before forking
BasePreprocessRequest.load_modules()
# start FastAPI app
app = FastAPI(title="ClearML Serving Service", version=__version__, description="ClearML Service Service router")
@ -67,12 +55,18 @@ app = FastAPI(title="ClearML Serving Service", version=__version__, description=
@app.on_event("startup")
async def startup_event():
global processor
print("Starting up ModelRequestProcessor [pid={}] [service_id={}]".format(os.getpid(), serving_service_task_id))
processor = ModelRequestProcessor(
task_id=serving_service_task_id, update_lock_guard=singleton_sync_lock,
)
print("ModelRequestProcessor [id={}] loaded".format(processor.get_id()))
processor.launch(poll_frequency_sec=model_sync_frequency_secs*60)
if processor:
print("ModelRequestProcessor already initialized [pid={}] [service_id={}]".format(
os.getpid(), serving_service_task_id))
else:
print("Starting up ModelRequestProcessor [pid={}] [service_id={}]".format(
os.getpid(), serving_service_task_id))
processor = ModelRequestProcessor(
task_id=serving_service_task_id, update_lock_guard=singleton_sync_lock,
)
print("ModelRequestProcessor [id={}] loaded".format(processor.get_id()))
processor.launch(poll_frequency_sec=model_sync_frequency_secs*60)
router = APIRouter(
@ -89,13 +83,15 @@ router = APIRouter(
@router.post("/{model_id}")
async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None):
try:
return_value = processor.process_request(
return_value = await processor.process_request(
base_url=model_id,
version=version,
request_body=request
)
except Exception as ex:
except ValueError as ex:
raise HTTPException(status_code=404, detail="Error processing request: {}".format(ex))
except Exception as ex:
raise HTTPException(status_code=500, detail="Error processing request: {}".format(ex))
return return_value

View File

@ -2,13 +2,13 @@ import json
import os
from collections import deque
from pathlib import Path
# from queue import Queue
from random import random
from time import sleep, time
from typing import Optional, Union, Dict, List
import itertools
import threading
from multiprocessing import Lock
import asyncio
from numpy import isin
from numpy.random import choice
@ -124,7 +124,7 @@ class ModelRequestProcessor(object):
self._serving_base_url = None
self._metric_log_freq = None
def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict:
async def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
"""
Process request coming in,
Raise Value error if url does not match existing endpoints
@ -134,9 +134,9 @@ class ModelRequestProcessor(object):
if self._update_lock_flag:
self._request_processing_state.dec()
while self._update_lock_flag:
sleep(0.5+random())
await asyncio.sleep(0.5+random())
# retry to process
return self.process_request(base_url=base_url, version=version, request_body=request_body)
return await self.process_request(base_url=base_url, version=version, request_body=request_body)
try:
# normalize url and version
@ -157,7 +157,7 @@ class ModelRequestProcessor(object):
processor = processor_cls(model_endpoint=ep, task=self._task)
self._engine_processor_lookup[url] = processor
return_value = self._process_request(processor=processor, url=url, body=request_body)
return_value = await self._process_request(processor=processor, url=url, body=request_body)
finally:
self._request_processing_state.dec()
@ -271,7 +271,7 @@ class ModelRequestProcessor(object):
)
models = Model.query_models(max_results=2, **model_query)
if not models:
raise ValueError("Could not fine any Model to serve {}".format(model_query))
raise ValueError("Could not find any Model to serve {}".format(model_query))
if len(models) > 1:
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
endpoint.model_id = models[0].id
@ -1133,7 +1133,7 @@ class ModelRequestProcessor(object):
# update preprocessing classes
BasePreprocessRequest.set_server_config(self._configuration)
def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict:
async def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
# collect statistics for this request
stats_collect_fn = None
collect_stats = False
@ -1151,9 +1151,18 @@ class ModelRequestProcessor(object):
tic = time()
state = dict()
preprocessed = processor.preprocess(body, state, stats_collect_fn)
processed = processor.process(preprocessed, state, stats_collect_fn)
return_value = processor.postprocess(processed, state, stats_collect_fn)
# noinspection PyUnresolvedReferences
preprocessed = await processor.preprocess(body, state, stats_collect_fn) \
if processor.is_preprocess_async \
else processor.preprocess(body, state, stats_collect_fn)
# noinspection PyUnresolvedReferences
processed = await processor.process(preprocessed, state, stats_collect_fn) \
if processor.is_process_async \
else processor.process(preprocessed, state, stats_collect_fn)
# noinspection PyUnresolvedReferences
return_value = await processor.postprocess(processed, state, stats_collect_fn) \
if processor.is_postprocess_async \
else processor.postprocess(processed, state, stats_collect_fn)
tic = time() - tic
if collect_stats:
stats = dict(

View File

@ -1,5 +1,6 @@
import os
import sys
import threading
from pathlib import Path
from typing import Optional, Any, Callable, List
@ -18,6 +19,9 @@ class BasePreprocessRequest(object):
_default_serving_base_url = "http://127.0.0.1:8080/serve/"
_server_config = {} # externally configured by the serving inference service
_timeout = None # timeout in seconds for the entire request, set in __init__
is_preprocess_async = False
is_process_async = False
is_postprocess_async = False
def __init__(
self,
@ -259,6 +263,9 @@ class TritonPreprocessRequest(BasePreprocessRequest):
_ext_np_to_triton_dtype = None
_ext_service_pb2 = None
_ext_service_pb2_grpc = None
is_preprocess_async = False
is_process_async = True
is_postprocess_async = False
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(TritonPreprocessRequest, self).__init__(
@ -266,7 +273,7 @@ class TritonPreprocessRequest(BasePreprocessRequest):
# load Triton Module
if self._ext_grpc is None:
import grpc # noqa
from tritonclient.grpc import grpc # noqa
self._ext_grpc = grpc
if self._ext_np_to_triton_dtype is None:
@ -274,11 +281,13 @@ class TritonPreprocessRequest(BasePreprocessRequest):
self._ext_np_to_triton_dtype = np_to_triton_dtype
if self._ext_service_pb2 is None:
from tritonclient.grpc import service_pb2, service_pb2_grpc # noqa
from tritonclient.grpc.aio import service_pb2, service_pb2_grpc # noqa
self._ext_service_pb2 = service_pb2
self._ext_service_pb2_grpc = service_pb2_grpc
def process(
self._grpc_stub = {}
async def process(
self,
data: Any,
state: dict,
@ -312,11 +321,17 @@ class TritonPreprocessRequest(BasePreprocessRequest):
triton_server_address = self._server_config.get("triton_grpc_server") or self._default_grpc_address
if not triton_server_address:
raise ValueError("External Triton gRPC server is not configured!")
try:
channel = self._ext_grpc.insecure_channel(triton_server_address)
grpc_stub = self._ext_service_pb2_grpc.GRPCInferenceServiceStub(channel)
except Exception as ex:
raise ValueError("External Triton gRPC server misconfigured [{}]: {}".format(triton_server_address, ex))
tid = threading.get_ident()
if self._grpc_stub.get(tid):
grpc_stub = self._grpc_stub.get(tid)
else:
try:
channel = self._ext_grpc.aio.insecure_channel(triton_server_address)
grpc_stub = self._ext_service_pb2_grpc.GRPCInferenceServiceStub(channel)
self._grpc_stub[tid] = grpc_stub
except Exception as ex:
raise ValueError("External Triton gRPC server misconfigured [{}]: {}".format(triton_server_address, ex))
use_compression = self._server_config.get("triton_grpc_compression", self._default_grpc_compression)
@ -364,15 +379,11 @@ class TritonPreprocessRequest(BasePreprocessRequest):
try:
compression = self._ext_grpc.Compression.Gzip if use_compression \
else self._ext_grpc.Compression.NoCompression
response = grpc_stub.ModelInfer(
request,
compression=compression,
timeout=self._timeout
)
except Exception:
response = await grpc_stub.ModelInfer(request, compression=compression, timeout=self._timeout)
except Exception as ex:
print("Exception calling Triton RPC function: "
"request_inputs={}, ".format([(r.name, r.shape, r.datatype) for r in (request.inputs or [])]) +
f"triton_address={triton_server_address}, compression={compression}, timeout={self._timeout}")
f"triton_address={triton_server_address}, compression={compression}, timeout={self._timeout}:\n{ex}")
raise
# process result
@ -464,3 +475,83 @@ class CustomPreprocessRequest(BasePreprocessRequest):
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return self._preprocess.process(data, state, collect_custom_statistics_fn)
return None
@BasePreprocessRequest.register_engine("custom_async")
class CustomAsyncPreprocessRequest(BasePreprocessRequest):
is_preprocess_async = True
is_process_async = True
is_postprocess_async = True
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(CustomAsyncPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task)
async def preprocess(
self,
request: dict,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None,
) -> Optional[Any]:
"""
Raise exception to report an error
Return value will be passed to serving engine
:param request: dictionary as recieved from the RestAPI
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd
Usage example:
>>> print(request)
{"x0": 1, "x1": 2}
>>> collect_custom_statistics_fn({"x0": 1, "x1": 2})
:return: Object to be passed directly to the model inference
"""
if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'):
return await self._preprocess.preprocess(request, state, collect_custom_statistics_fn)
return request
async def postprocess(
self,
data: Any,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None
) -> Optional[dict]:
"""
Raise exception to report an error
Return value will be passed to serving engine
:param data: object as recieved from the inference model function
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd
Usage example:
>>> collect_custom_statistics_fn({"y": 1})
:return: Dictionary passed directly as the returned result of the RestAPI
"""
if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'):
return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn)
return data
async def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the process in this context
"""
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return await self._preprocess.process(data, state, collect_custom_statistics_fn)
return None

View File

@ -5,7 +5,8 @@ uvicorn[standard]
gunicorn>=20.1.0,<20.2
asyncio>=3.4.3,<3.5
aiocache>=0.11.1,<0.12
tritonclient[grpc]>=2.22.3,<2.23
tritonclient[grpc]>=2.25,<2.26
starlette
numpy>=1.20,<1.24
scikit-learn>=1.0.2,<1.1
pandas>=1.0.5,<1.5

View File

@ -0,0 +1,6 @@
import uvicorn
from clearml_serving.serving.init import setup_task
if __name__ == "__main__":
setup_task(force_threaded_logging=True)
uvicorn.main()