From 8778f723e66b33e86aebb4fa6602da644a2527d6 Mon Sep 17 00:00:00 2001 From: allegroai Date: Sun, 5 Jun 2022 16:10:16 +0300 Subject: [PATCH] Add pre/post processing callnack state dict, for safe per request state storage --- .../preprocess/preprocess_template.py | 44 ++++++++++- .../serving/model_request_processor.py | 7 +- clearml_serving/serving/preprocess_service.py | 76 +++++++++++++++---- examples/ensemble/preprocess.py | 4 +- examples/keras/preprocess.py | 4 +- examples/lightgbm/preprocess.py | 4 +- examples/pipeline/preprocess.py | 4 +- examples/pytorch/preprocess.py | 4 +- examples/sklearn/preprocess.py | 4 +- examples/xgboost/preprocess.py | 4 +- 10 files changed, 118 insertions(+), 37 deletions(-) diff --git a/clearml_serving/preprocess/preprocess_template.py b/clearml_serving/preprocess/preprocess_template.py index 7fda339..0c82756 100644 --- a/clearml_serving/preprocess/preprocess_template.py +++ b/clearml_serving/preprocess/preprocess_template.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, List, Callable +from typing import Any, Optional, Callable # Preprocess class Must be named "Preprocess" @@ -24,12 +24,24 @@ class Preprocess(object): """ pass - def preprocess(self, body: dict, collect_custom_statistics_fn: Optional[Callable[[dict], None]]) -> Any: # noqa + def preprocess( + self, + body: dict, + state: dict, + collect_custom_statistics_fn: Optional[Callable[[dict], None]], + ) -> Any: # noqa """ Optional: do something with the request data, return any type of object. The returned object will be passed as is to the inference engine :param body: dictionary as recieved from the RestAPI + :param state: Use state dict to store data passed to the post-processing function call. + This is a per-request state dict (meaning a new empty dict will be passed per request) + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) :param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values to the statictics collector servicd. None is passed if statiscs collector is not configured, or if the current request should not be collected @@ -44,12 +56,24 @@ class Preprocess(object): """ return body - def postprocess(self, data: Any, collect_custom_statistics_fn: Optional[Callable[[dict], None]]) -> dict: # noqa + def postprocess( + self, + data: Any, + state: dict, + collect_custom_statistics_fn: Optional[Callable[[dict], None]], + ) -> dict: # noqa """ Optional: post process the data returned from the model inference engine returned dict will be passed back as the request result as is. :param data: object as recieved from the inference model function + :param state: Use state dict to store data passed to the post-processing function call. + This is a per-request state dict (meaning a dict instance per request) + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) :param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values to the statictics collector servicd. None is passed if statiscs collector is not configured, or if the current request should not be collected @@ -62,12 +86,24 @@ class Preprocess(object): """ return data - def process(self, data: Any, collect_custom_statistics_fn: Optional[Callable[[dict], None]]) -> Any: # noqa + def process( + self, + data: Any, + state: dict, + collect_custom_statistics_fn: Optional[Callable[[dict], None]], + ) -> Any: # noqa """ Optional: do something with the actual data, return any type of object. The returned object will be passed as is to the postprocess function engine :param data: object as recieved from the preprocessing function + :param state: Use state dict to store data passed to the post-processing function call. + This is a per-request state dict (meaning a dict instance per request) + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) :param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values to the statictics collector servicd. None is passed if statiscs collector is not configured, or if the current request should not be collected diff --git a/clearml_serving/serving/model_request_processor.py b/clearml_serving/serving/model_request_processor.py index 7145982..d52d7f6 100644 --- a/clearml_serving/serving/model_request_processor.py +++ b/clearml_serving/serving/model_request_processor.py @@ -1029,9 +1029,10 @@ class ModelRequestProcessor(object): collect_stats = True tic = time() - preprocessed = processor.preprocess(body, stats_collect_fn) - processed = processor.process(preprocessed, stats_collect_fn) - return_value = processor.postprocess(processed, stats_collect_fn) + state = dict() + preprocessed = processor.preprocess(body, state, stats_collect_fn) + processed = processor.process(preprocessed, state, stats_collect_fn) + return_value = processor.postprocess(processed, state, stats_collect_fn) tic = time() - tic if collect_stats: # 10th of a millisecond should be enough diff --git a/clearml_serving/serving/preprocess_service.py b/clearml_serving/serving/preprocess_service.py index 5c08054..66e7679 100644 --- a/clearml_serving/serving/preprocess_service.py +++ b/clearml_serving/serving/preprocess_service.py @@ -46,7 +46,7 @@ class BasePreprocessRequest(object): def _instantiate_custom_preprocess_cls(self, task: Task) -> None: path = task.artifacts[self.model_endpoint.preprocess_artifact].get_local_copy() - # check file content hash, should only happens once?! + # check file content hash, should only happen once?! # noinspection PyProtectedMember file_hash, _ = sha256sum(path, block_size=Artifacts._hash_block_size) if file_hash != task.artifacts[self.model_endpoint.preprocess_artifact].hash: @@ -77,12 +77,23 @@ class BasePreprocessRequest(object): if callable(getattr(self._preprocess, 'load', None)): self._model = self._preprocess.load(self._get_local_model_file()) - def preprocess(self, request: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Optional[Any]: + def preprocess( + self, + request: dict, + state: dict, + collect_custom_statistics_fn: Callable[[dict], None] = None, + ) -> Optional[Any]: """ Raise exception to report an error Return value will be passed to serving engine :param request: dictionary as recieved from the RestAPI + :param state: Use state dict to store data passed to the post-processing function call. + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values to the statictics collector servicd @@ -94,15 +105,26 @@ class BasePreprocessRequest(object): :return: Object to be passed directly to the model inference """ if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'): - return self._preprocess.preprocess(request, collect_custom_statistics_fn) + return self._preprocess.preprocess(request, state, collect_custom_statistics_fn) return request - def postprocess(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Optional[dict]: + def postprocess( + self, + data: Any, + state: dict, + collect_custom_statistics_fn: Callable[[dict], None] = None + ) -> Optional[dict]: """ Raise exception to report an error Return value will be passed to serving engine :param data: object as recieved from the inference model function + :param state: Use state dict to store data passed to the post-processing function call. + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values to the statictics collector servicd @@ -112,14 +134,25 @@ class BasePreprocessRequest(object): :return: Dictionary passed directly as the returned result of the RestAPI """ if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'): - return self._preprocess.postprocess(data, collect_custom_statistics_fn) + return self._preprocess.postprocess(data, state, collect_custom_statistics_fn) return data - def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + def process( + self, + data: Any, + state: dict, + collect_custom_statistics_fn: Callable[[dict], None] = None + ) -> Any: """ - The actual processing function. Can be send to external service + The actual processing function. Can be sent to external service :param data: object as recieved from the preprocessing function + :param state: Use state dict to store data passed to the post-processing function call. + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values to the statictics collector servicd @@ -178,7 +211,7 @@ class BasePreprocessRequest(object): pass @staticmethod - def _preprocess_send_request(self, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]: + def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]: endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/") base_url = BasePreprocessRequest.get_server_config().get("base_serving_url") base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/") @@ -226,12 +259,23 @@ class TritonPreprocessRequest(BasePreprocessRequest): self._ext_service_pb2 = service_pb2 self._ext_service_pb2_grpc = service_pb2_grpc - def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + def process( + self, + data: Any, + state: dict, + collect_custom_statistics_fn: Callable[[dict], None] = None + ) -> Any: """ The actual processing function. Detect gRPC server and send the request to it :param data: object as recieved from the preprocessing function + :param state: Use state dict to store data passed to the post-processing function call. + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values to the statictics collector servicd @@ -240,9 +284,9 @@ class TritonPreprocessRequest(BasePreprocessRequest): :return: Object to be passed tp the post-processing function """ - # allow to override bt preprocessing class + # allow overriding the process method if self._preprocess is not None and hasattr(self._preprocess, "process"): - return self._preprocess.process(data, collect_custom_statistics_fn) + return self._preprocess.process(data, state, collect_custom_statistics_fn) # Create gRPC stub for communicating with the server triton_server_address = self._server_config.get("triton_grpc_server") or self._default_grpc_address @@ -316,7 +360,7 @@ class SKLearnPreprocessRequest(BasePreprocessRequest): import joblib # noqa self._model = joblib.load(filename=self._get_local_model_file()) - def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: """ The actual processing function. We run the model in this context @@ -335,7 +379,7 @@ class XGBoostPreprocessRequest(BasePreprocessRequest): self._model = xgboost.Booster() self._model.load_model(self._get_local_model_file()) - def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: """ The actual processing function. We run the model in this context @@ -353,7 +397,7 @@ class LightGBMPreprocessRequest(BasePreprocessRequest): import lightgbm # noqa self._model = lightgbm.Booster(model_file=self._get_local_model_file()) - def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: """ The actual processing function. We run the model in this context @@ -367,11 +411,11 @@ class CustomPreprocessRequest(BasePreprocessRequest): super(CustomPreprocessRequest, self).__init__( model_endpoint=model_endpoint, task=task) - def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: """ The actual processing function. We run the process in this context """ if self._preprocess is not None and hasattr(self._preprocess, 'process'): - return self._preprocess.process(data, collect_custom_statistics_fn) + return self._preprocess.process(data, state, collect_custom_statistics_fn) return None diff --git a/examples/ensemble/preprocess.py b/examples/ensemble/preprocess.py index 6ba648c..8ea4128 100644 --- a/examples/ensemble/preprocess.py +++ b/examples/ensemble/preprocess.py @@ -9,11 +9,11 @@ class Preprocess(object): # set internal state, this will be called only once. (i.e. not per request) pass - def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: + def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: # we expect to get two valid on the dict x0, and x1 return [[body.get("x0", None), body.get("x1", None)], ] - def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: + def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: # post process the data returned from the model inference engine # data is the return value from model.predict we will put is inside a return value as Y return dict(y=data.tolist() if isinstance(data, np.ndarray) else data) diff --git a/examples/keras/preprocess.py b/examples/keras/preprocess.py index 27d039c..cfe7691 100644 --- a/examples/keras/preprocess.py +++ b/examples/keras/preprocess.py @@ -13,7 +13,7 @@ class Preprocess(object): # set internal state, this will be called only once. (i.e. not per request) pass - def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: + def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: # we expect to get two valid on the dict x0, and x1 url = body.get("url") if not url: @@ -25,7 +25,7 @@ class Preprocess(object): return np.array(image).flatten() - def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: + def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: # post process the data returned from the model inference engine # data is the return value from model.predict we will put is inside a return value as Y if not isinstance(data, np.ndarray): diff --git a/examples/lightgbm/preprocess.py b/examples/lightgbm/preprocess.py index e89f563..f6632c3 100644 --- a/examples/lightgbm/preprocess.py +++ b/examples/lightgbm/preprocess.py @@ -9,14 +9,14 @@ class Preprocess(object): # set internal state, this will be called only once. (i.e. not per request) pass - def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: + def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: # we expect to get four valid numbers on the dict: x0, x1, x2, x3 return np.array( [[body.get("x0", None), body.get("x1", None), body.get("x2", None), body.get("x3", None)], ], dtype=np.float32 ) - def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: + def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: # post process the data returned from the model inference engine # data is the return value from model.predict we will put is inside a return value as Y # we pick the most probably class and return the class index (argmax) diff --git a/examples/pipeline/preprocess.py b/examples/pipeline/preprocess.py index a521482..e5860d8 100644 --- a/examples/pipeline/preprocess.py +++ b/examples/pipeline/preprocess.py @@ -8,14 +8,14 @@ class Preprocess(object): # set internal state, this will be called only once. (i.e. not per request) self.executor = ThreadPoolExecutor(max_workers=32) - def postprocess(self, data: List[dict], collect_custom_statistics_fn=None) -> dict: + def postprocess(self, data: List[dict], state: dict, collect_custom_statistics_fn=None) -> dict: # we will here average the results and return the new value # assume data is a list of dicts greater than 1 # average result return dict(y=0.5 * data[0]['y'][0] + 0.5 * data[1]['y'][0]) - def process(self, data: Any, collect_custom_statistics_fn=None) -> Any: + def process(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> Any: """ do something with the actual data, return any type of object. The returned object will be passed as is to the postprocess function engine diff --git a/examples/pytorch/preprocess.py b/examples/pytorch/preprocess.py index 2803f37..395dc69 100644 --- a/examples/pytorch/preprocess.py +++ b/examples/pytorch/preprocess.py @@ -13,7 +13,7 @@ class Preprocess(object): # set internal state, this will be called only once. (i.e. not per request) pass - def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: + def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: # we expect to get two valid on the dict x0, and x1 url = body.get("url") if not url: @@ -24,7 +24,7 @@ class Preprocess(object): image = ImageOps.grayscale(image).resize((28, 28)) return np.array(image).flatten() - def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: + def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: # post process the data returned from the model inference engine # data is the return value from model.predict we will put is inside a return value as Y if not isinstance(data, np.ndarray): diff --git a/examples/sklearn/preprocess.py b/examples/sklearn/preprocess.py index 6ba648c..8ea4128 100644 --- a/examples/sklearn/preprocess.py +++ b/examples/sklearn/preprocess.py @@ -9,11 +9,11 @@ class Preprocess(object): # set internal state, this will be called only once. (i.e. not per request) pass - def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: + def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: # we expect to get two valid on the dict x0, and x1 return [[body.get("x0", None), body.get("x1", None)], ] - def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: + def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: # post process the data returned from the model inference engine # data is the return value from model.predict we will put is inside a return value as Y return dict(y=data.tolist() if isinstance(data, np.ndarray) else data) diff --git a/examples/xgboost/preprocess.py b/examples/xgboost/preprocess.py index e3a1771..0be8c0c 100644 --- a/examples/xgboost/preprocess.py +++ b/examples/xgboost/preprocess.py @@ -10,12 +10,12 @@ class Preprocess(object): # set internal state, this will be called only once. (i.e. not per request) pass - def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: + def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: # we expect to get four valid numbers on the dict: x0, x1, x2, x3 return xgb.DMatrix( [[body.get("x0", None), body.get("x1", None), body.get("x2", None), body.get("x3", None)]]) - def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: + def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: # post process the data returned from the model inference engine # data is the return value from model.predict we will put is inside a return value as Y return dict(y=data.tolist() if isinstance(data, np.ndarray) else data)