diff --git a/clearml_serving/engines/triton/requirements.txt b/clearml_serving/engines/triton/requirements.txt index d1f103b..284d834 100644 --- a/clearml_serving/engines/triton/requirements.txt +++ b/clearml_serving/engines/triton/requirements.txt @@ -1,6 +1,7 @@ clearml >= 1.3.1 clearml-serving -tritonclient[grpc]>=2.22.3,<2.23 +tritonclient[grpc]>=2.25,<2.26 +starlette grpcio Pillow>=9.0.1,<10 -pathlib2 \ No newline at end of file +pathlib2 diff --git a/clearml_serving/serving/init.py b/clearml_serving/serving/init.py new file mode 100644 index 0000000..34e3da9 --- /dev/null +++ b/clearml_serving/serving/init.py @@ -0,0 +1,33 @@ +import os +from clearml import Task +from clearml_serving.serving.model_request_processor import ModelRequestProcessor +from clearml_serving.serving.preprocess_service import BasePreprocessRequest + + +def setup_task(force_threaded_logging=None): + serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None) + + # always use background thread, it requires less memory + if force_threaded_logging or os.environ.get("CLEARML_BKG_THREAD_REPORT") in ("1", "Y", "y", "true"): + os.environ["CLEARML_BKG_THREAD_REPORT"] = "1" + Task._report_subprocess_enabled = False + + # get the serving controller task + # noinspection PyProtectedMember + serving_task = ModelRequestProcessor._get_control_plane_task(task_id=serving_service_task_id) + # set to running (because we are here) + if serving_task.status != "in_progress": + serving_task.started(force=True) + + # create a new serving instance (for visibility and monitoring) + instance_task = Task.init( + project_name=serving_task.get_project_name(), + task_name="{} - serve instance".format(serving_task.name), + task_type="inference", # noqa + ) + instance_task.set_system_tags(["service"]) + + # preload modules into memory before forking + BasePreprocessRequest.load_modules() + + return serving_service_task_id diff --git a/clearml_serving/serving/main.py b/clearml_serving/serving/main.py index 88b9ddc..e6847f5 100644 --- a/clearml_serving/serving/main.py +++ b/clearml_serving/serving/main.py @@ -7,10 +7,9 @@ from fastapi.routing import APIRoute from typing import Optional, Dict, Any, Callable, Union -from clearml import Task from clearml_serving.version import __version__ +from clearml_serving.serving.init import setup_task from clearml_serving.serving.model_request_processor import ModelRequestProcessor -from clearml_serving.serving.preprocess_service import BasePreprocessRequest class GzipRequest(Request): @@ -35,31 +34,20 @@ class GzipRoute(APIRoute): # process Lock, so that we can have only a single process doing the model reloading at a time -singleton_sync_lock = Lock() +singleton_sync_lock = None # Lock() +# shared Model processor object +processor = None # type: Optional[ModelRequestProcessor] -serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None) +# create clearml Task and load models +serving_service_task_id = setup_task() +# polling frequency model_sync_frequency_secs = 5 try: model_sync_frequency_secs = float(os.environ.get("CLEARML_SERVING_POLL_FREQ", model_sync_frequency_secs)) except (ValueError, TypeError): pass -# get the serving controller task -# noinspection PyProtectedMember -serving_task = ModelRequestProcessor._get_control_plane_task(task_id=serving_service_task_id) -# set to running (because we are here) -if serving_task.status != "in_progress": - serving_task.started(force=True) -# create a new serving instance (for visibility and monitoring) -instance_task = Task.init( - project_name=serving_task.get_project_name(), - task_name="{} - serve instance".format(serving_task.name), - task_type="inference", -) -instance_task.set_system_tags(["service"]) -processor = None # type: Optional[ModelRequestProcessor] -# preload modules into memory before forking -BasePreprocessRequest.load_modules() + # start FastAPI app app = FastAPI(title="ClearML Serving Service", version=__version__, description="ClearML Service Service router") @@ -67,12 +55,18 @@ app = FastAPI(title="ClearML Serving Service", version=__version__, description= @app.on_event("startup") async def startup_event(): global processor - print("Starting up ModelRequestProcessor [pid={}] [service_id={}]".format(os.getpid(), serving_service_task_id)) - processor = ModelRequestProcessor( - task_id=serving_service_task_id, update_lock_guard=singleton_sync_lock, - ) - print("ModelRequestProcessor [id={}] loaded".format(processor.get_id())) - processor.launch(poll_frequency_sec=model_sync_frequency_secs*60) + + if processor: + print("ModelRequestProcessor already initialized [pid={}] [service_id={}]".format( + os.getpid(), serving_service_task_id)) + else: + print("Starting up ModelRequestProcessor [pid={}] [service_id={}]".format( + os.getpid(), serving_service_task_id)) + processor = ModelRequestProcessor( + task_id=serving_service_task_id, update_lock_guard=singleton_sync_lock, + ) + print("ModelRequestProcessor [id={}] loaded".format(processor.get_id())) + processor.launch(poll_frequency_sec=model_sync_frequency_secs*60) router = APIRouter( @@ -89,13 +83,15 @@ router = APIRouter( @router.post("/{model_id}") async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None): try: - return_value = processor.process_request( + return_value = await processor.process_request( base_url=model_id, version=version, request_body=request ) - except Exception as ex: + except ValueError as ex: raise HTTPException(status_code=404, detail="Error processing request: {}".format(ex)) + except Exception as ex: + raise HTTPException(status_code=500, detail="Error processing request: {}".format(ex)) return return_value diff --git a/clearml_serving/serving/model_request_processor.py b/clearml_serving/serving/model_request_processor.py index 3e932ec..22c7414 100644 --- a/clearml_serving/serving/model_request_processor.py +++ b/clearml_serving/serving/model_request_processor.py @@ -2,13 +2,13 @@ import json import os from collections import deque from pathlib import Path -# from queue import Queue from random import random from time import sleep, time from typing import Optional, Union, Dict, List import itertools import threading from multiprocessing import Lock +import asyncio from numpy import isin from numpy.random import choice @@ -124,7 +124,7 @@ class ModelRequestProcessor(object): self._serving_base_url = None self._metric_log_freq = None - def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict: + async def process_request(self, base_url: str, version: str, request_body: dict) -> dict: """ Process request coming in, Raise Value error if url does not match existing endpoints @@ -134,9 +134,9 @@ class ModelRequestProcessor(object): if self._update_lock_flag: self._request_processing_state.dec() while self._update_lock_flag: - sleep(0.5+random()) + await asyncio.sleep(0.5+random()) # retry to process - return self.process_request(base_url=base_url, version=version, request_body=request_body) + return await self.process_request(base_url=base_url, version=version, request_body=request_body) try: # normalize url and version @@ -157,7 +157,7 @@ class ModelRequestProcessor(object): processor = processor_cls(model_endpoint=ep, task=self._task) self._engine_processor_lookup[url] = processor - return_value = self._process_request(processor=processor, url=url, body=request_body) + return_value = await self._process_request(processor=processor, url=url, body=request_body) finally: self._request_processing_state.dec() @@ -271,7 +271,7 @@ class ModelRequestProcessor(object): ) models = Model.query_models(max_results=2, **model_query) if not models: - raise ValueError("Could not fine any Model to serve {}".format(model_query)) + raise ValueError("Could not find any Model to serve {}".format(model_query)) if len(models) > 1: print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id)) endpoint.model_id = models[0].id @@ -1133,7 +1133,7 @@ class ModelRequestProcessor(object): # update preprocessing classes BasePreprocessRequest.set_server_config(self._configuration) - def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict: + async def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict: # collect statistics for this request stats_collect_fn = None collect_stats = False @@ -1151,9 +1151,18 @@ class ModelRequestProcessor(object): tic = time() state = dict() - preprocessed = processor.preprocess(body, state, stats_collect_fn) - processed = processor.process(preprocessed, state, stats_collect_fn) - return_value = processor.postprocess(processed, state, stats_collect_fn) + # noinspection PyUnresolvedReferences + preprocessed = await processor.preprocess(body, state, stats_collect_fn) \ + if processor.is_preprocess_async \ + else processor.preprocess(body, state, stats_collect_fn) + # noinspection PyUnresolvedReferences + processed = await processor.process(preprocessed, state, stats_collect_fn) \ + if processor.is_process_async \ + else processor.process(preprocessed, state, stats_collect_fn) + # noinspection PyUnresolvedReferences + return_value = await processor.postprocess(processed, state, stats_collect_fn) \ + if processor.is_postprocess_async \ + else processor.postprocess(processed, state, stats_collect_fn) tic = time() - tic if collect_stats: stats = dict( diff --git a/clearml_serving/serving/preprocess_service.py b/clearml_serving/serving/preprocess_service.py index ee50e2f..d17d6c8 100644 --- a/clearml_serving/serving/preprocess_service.py +++ b/clearml_serving/serving/preprocess_service.py @@ -1,5 +1,6 @@ import os import sys +import threading from pathlib import Path from typing import Optional, Any, Callable, List @@ -18,6 +19,9 @@ class BasePreprocessRequest(object): _default_serving_base_url = "http://127.0.0.1:8080/serve/" _server_config = {} # externally configured by the serving inference service _timeout = None # timeout in seconds for the entire request, set in __init__ + is_preprocess_async = False + is_process_async = False + is_postprocess_async = False def __init__( self, @@ -259,6 +263,9 @@ class TritonPreprocessRequest(BasePreprocessRequest): _ext_np_to_triton_dtype = None _ext_service_pb2 = None _ext_service_pb2_grpc = None + is_preprocess_async = False + is_process_async = True + is_postprocess_async = False def __init__(self, model_endpoint: ModelEndpoint, task: Task = None): super(TritonPreprocessRequest, self).__init__( @@ -266,7 +273,7 @@ class TritonPreprocessRequest(BasePreprocessRequest): # load Triton Module if self._ext_grpc is None: - import grpc # noqa + from tritonclient.grpc import grpc # noqa self._ext_grpc = grpc if self._ext_np_to_triton_dtype is None: @@ -274,11 +281,13 @@ class TritonPreprocessRequest(BasePreprocessRequest): self._ext_np_to_triton_dtype = np_to_triton_dtype if self._ext_service_pb2 is None: - from tritonclient.grpc import service_pb2, service_pb2_grpc # noqa + from tritonclient.grpc.aio import service_pb2, service_pb2_grpc # noqa self._ext_service_pb2 = service_pb2 self._ext_service_pb2_grpc = service_pb2_grpc - def process( + self._grpc_stub = {} + + async def process( self, data: Any, state: dict, @@ -312,11 +321,17 @@ class TritonPreprocessRequest(BasePreprocessRequest): triton_server_address = self._server_config.get("triton_grpc_server") or self._default_grpc_address if not triton_server_address: raise ValueError("External Triton gRPC server is not configured!") - try: - channel = self._ext_grpc.insecure_channel(triton_server_address) - grpc_stub = self._ext_service_pb2_grpc.GRPCInferenceServiceStub(channel) - except Exception as ex: - raise ValueError("External Triton gRPC server misconfigured [{}]: {}".format(triton_server_address, ex)) + + tid = threading.get_ident() + if self._grpc_stub.get(tid): + grpc_stub = self._grpc_stub.get(tid) + else: + try: + channel = self._ext_grpc.aio.insecure_channel(triton_server_address) + grpc_stub = self._ext_service_pb2_grpc.GRPCInferenceServiceStub(channel) + self._grpc_stub[tid] = grpc_stub + except Exception as ex: + raise ValueError("External Triton gRPC server misconfigured [{}]: {}".format(triton_server_address, ex)) use_compression = self._server_config.get("triton_grpc_compression", self._default_grpc_compression) @@ -364,15 +379,11 @@ class TritonPreprocessRequest(BasePreprocessRequest): try: compression = self._ext_grpc.Compression.Gzip if use_compression \ else self._ext_grpc.Compression.NoCompression - response = grpc_stub.ModelInfer( - request, - compression=compression, - timeout=self._timeout - ) - except Exception: + response = await grpc_stub.ModelInfer(request, compression=compression, timeout=self._timeout) + except Exception as ex: print("Exception calling Triton RPC function: " "request_inputs={}, ".format([(r.name, r.shape, r.datatype) for r in (request.inputs or [])]) + - f"triton_address={triton_server_address}, compression={compression}, timeout={self._timeout}") + f"triton_address={triton_server_address}, compression={compression}, timeout={self._timeout}:\n{ex}") raise # process result @@ -464,3 +475,83 @@ class CustomPreprocessRequest(BasePreprocessRequest): if self._preprocess is not None and hasattr(self._preprocess, 'process'): return self._preprocess.process(data, state, collect_custom_statistics_fn) return None + + +@BasePreprocessRequest.register_engine("custom_async") +class CustomAsyncPreprocessRequest(BasePreprocessRequest): + is_preprocess_async = True + is_process_async = True + is_postprocess_async = True + + def __init__(self, model_endpoint: ModelEndpoint, task: Task = None): + super(CustomAsyncPreprocessRequest, self).__init__( + model_endpoint=model_endpoint, task=task) + + async def preprocess( + self, + request: dict, + state: dict, + collect_custom_statistics_fn: Callable[[dict], None] = None, + ) -> Optional[Any]: + """ + Raise exception to report an error + Return value will be passed to serving engine + + :param request: dictionary as recieved from the RestAPI + :param state: Use state dict to store data passed to the post-processing function call. + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) + :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values + to the statictics collector servicd + + Usage example: + >>> print(request) + {"x0": 1, "x1": 2} + >>> collect_custom_statistics_fn({"x0": 1, "x1": 2}) + + :return: Object to be passed directly to the model inference + """ + if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'): + return await self._preprocess.preprocess(request, state, collect_custom_statistics_fn) + return request + + async def postprocess( + self, + data: Any, + state: dict, + collect_custom_statistics_fn: Callable[[dict], None] = None + ) -> Optional[dict]: + """ + Raise exception to report an error + Return value will be passed to serving engine + + :param data: object as recieved from the inference model function + :param state: Use state dict to store data passed to the post-processing function call. + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) + :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values + to the statictics collector servicd + + Usage example: + >>> collect_custom_statistics_fn({"y": 1}) + + :return: Dictionary passed directly as the returned result of the RestAPI + """ + if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'): + return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn) + return data + + async def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + """ + The actual processing function. + We run the process in this context + """ + if self._preprocess is not None and hasattr(self._preprocess, 'process'): + return await self._preprocess.process(data, state, collect_custom_statistics_fn) + return None diff --git a/clearml_serving/serving/requirements.txt b/clearml_serving/serving/requirements.txt index 276d5e6..4335e83 100644 --- a/clearml_serving/serving/requirements.txt +++ b/clearml_serving/serving/requirements.txt @@ -5,7 +5,8 @@ uvicorn[standard] gunicorn>=20.1.0,<20.2 asyncio>=3.4.3,<3.5 aiocache>=0.11.1,<0.12 -tritonclient[grpc]>=2.22.3,<2.23 +tritonclient[grpc]>=2.25,<2.26 +starlette numpy>=1.20,<1.24 scikit-learn>=1.0.2,<1.1 pandas>=1.0.5,<1.5 diff --git a/clearml_serving/serving/uvicorn_mp_entrypoint.py b/clearml_serving/serving/uvicorn_mp_entrypoint.py new file mode 100644 index 0000000..8c250f2 --- /dev/null +++ b/clearml_serving/serving/uvicorn_mp_entrypoint.py @@ -0,0 +1,6 @@ +import uvicorn +from clearml_serving.serving.init import setup_task + +if __name__ == "__main__": + setup_task(force_threaded_logging=True) + uvicorn.main()