mirror of
https://github.com/clearml/clearml-serving
synced 2025-01-31 02:46:54 +00:00
Optimize async processing for increased speed
This commit is contained in:
parent
f4aaf095a3
commit
395a547c04
@ -1,6 +1,7 @@
|
||||
clearml >= 1.3.1
|
||||
clearml-serving
|
||||
tritonclient[grpc]>=2.22.3,<2.23
|
||||
tritonclient[grpc]>=2.25,<2.26
|
||||
starlette
|
||||
grpcio
|
||||
Pillow>=9.0.1,<10
|
||||
pathlib2
|
||||
pathlib2
|
||||
|
33
clearml_serving/serving/init.py
Normal file
33
clearml_serving/serving/init.py
Normal file
@ -0,0 +1,33 @@
|
||||
import os
|
||||
from clearml import Task
|
||||
from clearml_serving.serving.model_request_processor import ModelRequestProcessor
|
||||
from clearml_serving.serving.preprocess_service import BasePreprocessRequest
|
||||
|
||||
|
||||
def setup_task(force_threaded_logging=None):
|
||||
serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None)
|
||||
|
||||
# always use background thread, it requires less memory
|
||||
if force_threaded_logging or os.environ.get("CLEARML_BKG_THREAD_REPORT") in ("1", "Y", "y", "true"):
|
||||
os.environ["CLEARML_BKG_THREAD_REPORT"] = "1"
|
||||
Task._report_subprocess_enabled = False
|
||||
|
||||
# get the serving controller task
|
||||
# noinspection PyProtectedMember
|
||||
serving_task = ModelRequestProcessor._get_control_plane_task(task_id=serving_service_task_id)
|
||||
# set to running (because we are here)
|
||||
if serving_task.status != "in_progress":
|
||||
serving_task.started(force=True)
|
||||
|
||||
# create a new serving instance (for visibility and monitoring)
|
||||
instance_task = Task.init(
|
||||
project_name=serving_task.get_project_name(),
|
||||
task_name="{} - serve instance".format(serving_task.name),
|
||||
task_type="inference", # noqa
|
||||
)
|
||||
instance_task.set_system_tags(["service"])
|
||||
|
||||
# preload modules into memory before forking
|
||||
BasePreprocessRequest.load_modules()
|
||||
|
||||
return serving_service_task_id
|
@ -7,10 +7,9 @@ from fastapi.routing import APIRoute
|
||||
|
||||
from typing import Optional, Dict, Any, Callable, Union
|
||||
|
||||
from clearml import Task
|
||||
from clearml_serving.version import __version__
|
||||
from clearml_serving.serving.init import setup_task
|
||||
from clearml_serving.serving.model_request_processor import ModelRequestProcessor
|
||||
from clearml_serving.serving.preprocess_service import BasePreprocessRequest
|
||||
|
||||
|
||||
class GzipRequest(Request):
|
||||
@ -35,31 +34,20 @@ class GzipRoute(APIRoute):
|
||||
|
||||
|
||||
# process Lock, so that we can have only a single process doing the model reloading at a time
|
||||
singleton_sync_lock = Lock()
|
||||
singleton_sync_lock = None # Lock()
|
||||
# shared Model processor object
|
||||
processor = None # type: Optional[ModelRequestProcessor]
|
||||
|
||||
serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None)
|
||||
# create clearml Task and load models
|
||||
serving_service_task_id = setup_task()
|
||||
# polling frequency
|
||||
model_sync_frequency_secs = 5
|
||||
try:
|
||||
model_sync_frequency_secs = float(os.environ.get("CLEARML_SERVING_POLL_FREQ", model_sync_frequency_secs))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# get the serving controller task
|
||||
# noinspection PyProtectedMember
|
||||
serving_task = ModelRequestProcessor._get_control_plane_task(task_id=serving_service_task_id)
|
||||
# set to running (because we are here)
|
||||
if serving_task.status != "in_progress":
|
||||
serving_task.started(force=True)
|
||||
# create a new serving instance (for visibility and monitoring)
|
||||
instance_task = Task.init(
|
||||
project_name=serving_task.get_project_name(),
|
||||
task_name="{} - serve instance".format(serving_task.name),
|
||||
task_type="inference",
|
||||
)
|
||||
instance_task.set_system_tags(["service"])
|
||||
processor = None # type: Optional[ModelRequestProcessor]
|
||||
# preload modules into memory before forking
|
||||
BasePreprocessRequest.load_modules()
|
||||
|
||||
# start FastAPI app
|
||||
app = FastAPI(title="ClearML Serving Service", version=__version__, description="ClearML Service Service router")
|
||||
|
||||
@ -67,12 +55,18 @@ app = FastAPI(title="ClearML Serving Service", version=__version__, description=
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global processor
|
||||
print("Starting up ModelRequestProcessor [pid={}] [service_id={}]".format(os.getpid(), serving_service_task_id))
|
||||
processor = ModelRequestProcessor(
|
||||
task_id=serving_service_task_id, update_lock_guard=singleton_sync_lock,
|
||||
)
|
||||
print("ModelRequestProcessor [id={}] loaded".format(processor.get_id()))
|
||||
processor.launch(poll_frequency_sec=model_sync_frequency_secs*60)
|
||||
|
||||
if processor:
|
||||
print("ModelRequestProcessor already initialized [pid={}] [service_id={}]".format(
|
||||
os.getpid(), serving_service_task_id))
|
||||
else:
|
||||
print("Starting up ModelRequestProcessor [pid={}] [service_id={}]".format(
|
||||
os.getpid(), serving_service_task_id))
|
||||
processor = ModelRequestProcessor(
|
||||
task_id=serving_service_task_id, update_lock_guard=singleton_sync_lock,
|
||||
)
|
||||
print("ModelRequestProcessor [id={}] loaded".format(processor.get_id()))
|
||||
processor.launch(poll_frequency_sec=model_sync_frequency_secs*60)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
@ -89,13 +83,15 @@ router = APIRouter(
|
||||
@router.post("/{model_id}")
|
||||
async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None):
|
||||
try:
|
||||
return_value = processor.process_request(
|
||||
return_value = await processor.process_request(
|
||||
base_url=model_id,
|
||||
version=version,
|
||||
request_body=request
|
||||
)
|
||||
except Exception as ex:
|
||||
except ValueError as ex:
|
||||
raise HTTPException(status_code=404, detail="Error processing request: {}".format(ex))
|
||||
except Exception as ex:
|
||||
raise HTTPException(status_code=500, detail="Error processing request: {}".format(ex))
|
||||
return return_value
|
||||
|
||||
|
||||
|
@ -2,13 +2,13 @@ import json
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
# from queue import Queue
|
||||
from random import random
|
||||
from time import sleep, time
|
||||
from typing import Optional, Union, Dict, List
|
||||
import itertools
|
||||
import threading
|
||||
from multiprocessing import Lock
|
||||
import asyncio
|
||||
from numpy import isin
|
||||
from numpy.random import choice
|
||||
|
||||
@ -124,7 +124,7 @@ class ModelRequestProcessor(object):
|
||||
self._serving_base_url = None
|
||||
self._metric_log_freq = None
|
||||
|
||||
def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict:
|
||||
async def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
|
||||
"""
|
||||
Process request coming in,
|
||||
Raise Value error if url does not match existing endpoints
|
||||
@ -134,9 +134,9 @@ class ModelRequestProcessor(object):
|
||||
if self._update_lock_flag:
|
||||
self._request_processing_state.dec()
|
||||
while self._update_lock_flag:
|
||||
sleep(0.5+random())
|
||||
await asyncio.sleep(0.5+random())
|
||||
# retry to process
|
||||
return self.process_request(base_url=base_url, version=version, request_body=request_body)
|
||||
return await self.process_request(base_url=base_url, version=version, request_body=request_body)
|
||||
|
||||
try:
|
||||
# normalize url and version
|
||||
@ -157,7 +157,7 @@ class ModelRequestProcessor(object):
|
||||
processor = processor_cls(model_endpoint=ep, task=self._task)
|
||||
self._engine_processor_lookup[url] = processor
|
||||
|
||||
return_value = self._process_request(processor=processor, url=url, body=request_body)
|
||||
return_value = await self._process_request(processor=processor, url=url, body=request_body)
|
||||
finally:
|
||||
self._request_processing_state.dec()
|
||||
|
||||
@ -271,7 +271,7 @@ class ModelRequestProcessor(object):
|
||||
)
|
||||
models = Model.query_models(max_results=2, **model_query)
|
||||
if not models:
|
||||
raise ValueError("Could not fine any Model to serve {}".format(model_query))
|
||||
raise ValueError("Could not find any Model to serve {}".format(model_query))
|
||||
if len(models) > 1:
|
||||
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
|
||||
endpoint.model_id = models[0].id
|
||||
@ -1133,7 +1133,7 @@ class ModelRequestProcessor(object):
|
||||
# update preprocessing classes
|
||||
BasePreprocessRequest.set_server_config(self._configuration)
|
||||
|
||||
def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict:
|
||||
async def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
|
||||
# collect statistics for this request
|
||||
stats_collect_fn = None
|
||||
collect_stats = False
|
||||
@ -1151,9 +1151,18 @@ class ModelRequestProcessor(object):
|
||||
|
||||
tic = time()
|
||||
state = dict()
|
||||
preprocessed = processor.preprocess(body, state, stats_collect_fn)
|
||||
processed = processor.process(preprocessed, state, stats_collect_fn)
|
||||
return_value = processor.postprocess(processed, state, stats_collect_fn)
|
||||
# noinspection PyUnresolvedReferences
|
||||
preprocessed = await processor.preprocess(body, state, stats_collect_fn) \
|
||||
if processor.is_preprocess_async \
|
||||
else processor.preprocess(body, state, stats_collect_fn)
|
||||
# noinspection PyUnresolvedReferences
|
||||
processed = await processor.process(preprocessed, state, stats_collect_fn) \
|
||||
if processor.is_process_async \
|
||||
else processor.process(preprocessed, state, stats_collect_fn)
|
||||
# noinspection PyUnresolvedReferences
|
||||
return_value = await processor.postprocess(processed, state, stats_collect_fn) \
|
||||
if processor.is_postprocess_async \
|
||||
else processor.postprocess(processed, state, stats_collect_fn)
|
||||
tic = time() - tic
|
||||
if collect_stats:
|
||||
stats = dict(
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, Callable, List
|
||||
|
||||
@ -18,6 +19,9 @@ class BasePreprocessRequest(object):
|
||||
_default_serving_base_url = "http://127.0.0.1:8080/serve/"
|
||||
_server_config = {} # externally configured by the serving inference service
|
||||
_timeout = None # timeout in seconds for the entire request, set in __init__
|
||||
is_preprocess_async = False
|
||||
is_process_async = False
|
||||
is_postprocess_async = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -259,6 +263,9 @@ class TritonPreprocessRequest(BasePreprocessRequest):
|
||||
_ext_np_to_triton_dtype = None
|
||||
_ext_service_pb2 = None
|
||||
_ext_service_pb2_grpc = None
|
||||
is_preprocess_async = False
|
||||
is_process_async = True
|
||||
is_postprocess_async = False
|
||||
|
||||
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
|
||||
super(TritonPreprocessRequest, self).__init__(
|
||||
@ -266,7 +273,7 @@ class TritonPreprocessRequest(BasePreprocessRequest):
|
||||
|
||||
# load Triton Module
|
||||
if self._ext_grpc is None:
|
||||
import grpc # noqa
|
||||
from tritonclient.grpc import grpc # noqa
|
||||
self._ext_grpc = grpc
|
||||
|
||||
if self._ext_np_to_triton_dtype is None:
|
||||
@ -274,11 +281,13 @@ class TritonPreprocessRequest(BasePreprocessRequest):
|
||||
self._ext_np_to_triton_dtype = np_to_triton_dtype
|
||||
|
||||
if self._ext_service_pb2 is None:
|
||||
from tritonclient.grpc import service_pb2, service_pb2_grpc # noqa
|
||||
from tritonclient.grpc.aio import service_pb2, service_pb2_grpc # noqa
|
||||
self._ext_service_pb2 = service_pb2
|
||||
self._ext_service_pb2_grpc = service_pb2_grpc
|
||||
|
||||
def process(
|
||||
self._grpc_stub = {}
|
||||
|
||||
async def process(
|
||||
self,
|
||||
data: Any,
|
||||
state: dict,
|
||||
@ -312,11 +321,17 @@ class TritonPreprocessRequest(BasePreprocessRequest):
|
||||
triton_server_address = self._server_config.get("triton_grpc_server") or self._default_grpc_address
|
||||
if not triton_server_address:
|
||||
raise ValueError("External Triton gRPC server is not configured!")
|
||||
try:
|
||||
channel = self._ext_grpc.insecure_channel(triton_server_address)
|
||||
grpc_stub = self._ext_service_pb2_grpc.GRPCInferenceServiceStub(channel)
|
||||
except Exception as ex:
|
||||
raise ValueError("External Triton gRPC server misconfigured [{}]: {}".format(triton_server_address, ex))
|
||||
|
||||
tid = threading.get_ident()
|
||||
if self._grpc_stub.get(tid):
|
||||
grpc_stub = self._grpc_stub.get(tid)
|
||||
else:
|
||||
try:
|
||||
channel = self._ext_grpc.aio.insecure_channel(triton_server_address)
|
||||
grpc_stub = self._ext_service_pb2_grpc.GRPCInferenceServiceStub(channel)
|
||||
self._grpc_stub[tid] = grpc_stub
|
||||
except Exception as ex:
|
||||
raise ValueError("External Triton gRPC server misconfigured [{}]: {}".format(triton_server_address, ex))
|
||||
|
||||
use_compression = self._server_config.get("triton_grpc_compression", self._default_grpc_compression)
|
||||
|
||||
@ -364,15 +379,11 @@ class TritonPreprocessRequest(BasePreprocessRequest):
|
||||
try:
|
||||
compression = self._ext_grpc.Compression.Gzip if use_compression \
|
||||
else self._ext_grpc.Compression.NoCompression
|
||||
response = grpc_stub.ModelInfer(
|
||||
request,
|
||||
compression=compression,
|
||||
timeout=self._timeout
|
||||
)
|
||||
except Exception:
|
||||
response = await grpc_stub.ModelInfer(request, compression=compression, timeout=self._timeout)
|
||||
except Exception as ex:
|
||||
print("Exception calling Triton RPC function: "
|
||||
"request_inputs={}, ".format([(r.name, r.shape, r.datatype) for r in (request.inputs or [])]) +
|
||||
f"triton_address={triton_server_address}, compression={compression}, timeout={self._timeout}")
|
||||
f"triton_address={triton_server_address}, compression={compression}, timeout={self._timeout}:\n{ex}")
|
||||
raise
|
||||
|
||||
# process result
|
||||
@ -464,3 +475,83 @@ class CustomPreprocessRequest(BasePreprocessRequest):
|
||||
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
|
||||
return self._preprocess.process(data, state, collect_custom_statistics_fn)
|
||||
return None
|
||||
|
||||
|
||||
@BasePreprocessRequest.register_engine("custom_async")
|
||||
class CustomAsyncPreprocessRequest(BasePreprocessRequest):
|
||||
is_preprocess_async = True
|
||||
is_process_async = True
|
||||
is_postprocess_async = True
|
||||
|
||||
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
|
||||
super(CustomAsyncPreprocessRequest, self).__init__(
|
||||
model_endpoint=model_endpoint, task=task)
|
||||
|
||||
async def preprocess(
|
||||
self,
|
||||
request: dict,
|
||||
state: dict,
|
||||
collect_custom_statistics_fn: Callable[[dict], None] = None,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Raise exception to report an error
|
||||
Return value will be passed to serving engine
|
||||
|
||||
:param request: dictionary as recieved from the RestAPI
|
||||
:param state: Use state dict to store data passed to the post-processing function call.
|
||||
Usage example:
|
||||
>>> def preprocess(..., state):
|
||||
state['preprocess_aux_data'] = [1,2,3]
|
||||
>>> def postprocess(..., state):
|
||||
print(state['preprocess_aux_data'])
|
||||
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
|
||||
to the statictics collector servicd
|
||||
|
||||
Usage example:
|
||||
>>> print(request)
|
||||
{"x0": 1, "x1": 2}
|
||||
>>> collect_custom_statistics_fn({"x0": 1, "x1": 2})
|
||||
|
||||
:return: Object to be passed directly to the model inference
|
||||
"""
|
||||
if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'):
|
||||
return await self._preprocess.preprocess(request, state, collect_custom_statistics_fn)
|
||||
return request
|
||||
|
||||
async def postprocess(
|
||||
self,
|
||||
data: Any,
|
||||
state: dict,
|
||||
collect_custom_statistics_fn: Callable[[dict], None] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Raise exception to report an error
|
||||
Return value will be passed to serving engine
|
||||
|
||||
:param data: object as recieved from the inference model function
|
||||
:param state: Use state dict to store data passed to the post-processing function call.
|
||||
Usage example:
|
||||
>>> def preprocess(..., state):
|
||||
state['preprocess_aux_data'] = [1,2,3]
|
||||
>>> def postprocess(..., state):
|
||||
print(state['preprocess_aux_data'])
|
||||
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
|
||||
to the statictics collector servicd
|
||||
|
||||
Usage example:
|
||||
>>> collect_custom_statistics_fn({"y": 1})
|
||||
|
||||
:return: Dictionary passed directly as the returned result of the RestAPI
|
||||
"""
|
||||
if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'):
|
||||
return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn)
|
||||
return data
|
||||
|
||||
async def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
||||
"""
|
||||
The actual processing function.
|
||||
We run the process in this context
|
||||
"""
|
||||
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
|
||||
return await self._preprocess.process(data, state, collect_custom_statistics_fn)
|
||||
return None
|
||||
|
@ -5,7 +5,8 @@ uvicorn[standard]
|
||||
gunicorn>=20.1.0,<20.2
|
||||
asyncio>=3.4.3,<3.5
|
||||
aiocache>=0.11.1,<0.12
|
||||
tritonclient[grpc]>=2.22.3,<2.23
|
||||
tritonclient[grpc]>=2.25,<2.26
|
||||
starlette
|
||||
numpy>=1.20,<1.24
|
||||
scikit-learn>=1.0.2,<1.1
|
||||
pandas>=1.0.5,<1.5
|
||||
|
6
clearml_serving/serving/uvicorn_mp_entrypoint.py
Normal file
6
clearml_serving/serving/uvicorn_mp_entrypoint.py
Normal file
@ -0,0 +1,6 @@
|
||||
import uvicorn
|
||||
from clearml_serving.serving.init import setup_task
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_task(force_threaded_logging=True)
|
||||
uvicorn.main()
|
Loading…
Reference in New Issue
Block a user