Add pre/post processing callnack state dict, for safe per request state storage

This commit is contained in:
allegroai 2022-06-05 16:10:16 +03:00
parent 45d0877c71
commit 8778f723e6
10 changed files with 118 additions and 37 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, List, Callable from typing import Any, Optional, Callable
# Preprocess class Must be named "Preprocess" # Preprocess class Must be named "Preprocess"
@ -24,12 +24,24 @@ class Preprocess(object):
""" """
pass pass
def preprocess(self, body: dict, collect_custom_statistics_fn: Optional[Callable[[dict], None]]) -> Any: # noqa def preprocess(
self,
body: dict,
state: dict,
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
) -> Any: # noqa
""" """
Optional: do something with the request data, return any type of object. Optional: do something with the request data, return any type of object.
The returned object will be passed as is to the inference engine The returned object will be passed as is to the inference engine
:param body: dictionary as recieved from the RestAPI :param body: dictionary as recieved from the RestAPI
:param state: Use state dict to store data passed to the post-processing function call.
This is a per-request state dict (meaning a new empty dict will be passed per request)
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values :param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values
to the statictics collector servicd. to the statictics collector servicd.
None is passed if statiscs collector is not configured, or if the current request should not be collected None is passed if statiscs collector is not configured, or if the current request should not be collected
@ -44,12 +56,24 @@ class Preprocess(object):
""" """
return body return body
def postprocess(self, data: Any, collect_custom_statistics_fn: Optional[Callable[[dict], None]]) -> dict: # noqa def postprocess(
self,
data: Any,
state: dict,
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
) -> dict: # noqa
""" """
Optional: post process the data returned from the model inference engine Optional: post process the data returned from the model inference engine
returned dict will be passed back as the request result as is. returned dict will be passed back as the request result as is.
:param data: object as recieved from the inference model function :param data: object as recieved from the inference model function
:param state: Use state dict to store data passed to the post-processing function call.
This is a per-request state dict (meaning a dict instance per request)
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values :param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values
to the statictics collector servicd. to the statictics collector servicd.
None is passed if statiscs collector is not configured, or if the current request should not be collected None is passed if statiscs collector is not configured, or if the current request should not be collected
@ -62,12 +86,24 @@ class Preprocess(object):
""" """
return data return data
def process(self, data: Any, collect_custom_statistics_fn: Optional[Callable[[dict], None]]) -> Any: # noqa def process(
self,
data: Any,
state: dict,
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
) -> Any: # noqa
""" """
Optional: do something with the actual data, return any type of object. Optional: do something with the actual data, return any type of object.
The returned object will be passed as is to the postprocess function engine The returned object will be passed as is to the postprocess function engine
:param data: object as recieved from the preprocessing function :param data: object as recieved from the preprocessing function
:param state: Use state dict to store data passed to the post-processing function call.
This is a per-request state dict (meaning a dict instance per request)
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values :param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values
to the statictics collector servicd. to the statictics collector servicd.
None is passed if statiscs collector is not configured, or if the current request should not be collected None is passed if statiscs collector is not configured, or if the current request should not be collected

View File

@ -1029,9 +1029,10 @@ class ModelRequestProcessor(object):
collect_stats = True collect_stats = True
tic = time() tic = time()
preprocessed = processor.preprocess(body, stats_collect_fn) state = dict()
processed = processor.process(preprocessed, stats_collect_fn) preprocessed = processor.preprocess(body, state, stats_collect_fn)
return_value = processor.postprocess(processed, stats_collect_fn) processed = processor.process(preprocessed, state, stats_collect_fn)
return_value = processor.postprocess(processed, state, stats_collect_fn)
tic = time() - tic tic = time() - tic
if collect_stats: if collect_stats:
# 10th of a millisecond should be enough # 10th of a millisecond should be enough

View File

@ -46,7 +46,7 @@ class BasePreprocessRequest(object):
def _instantiate_custom_preprocess_cls(self, task: Task) -> None: def _instantiate_custom_preprocess_cls(self, task: Task) -> None:
path = task.artifacts[self.model_endpoint.preprocess_artifact].get_local_copy() path = task.artifacts[self.model_endpoint.preprocess_artifact].get_local_copy()
# check file content hash, should only happens once?! # check file content hash, should only happen once?!
# noinspection PyProtectedMember # noinspection PyProtectedMember
file_hash, _ = sha256sum(path, block_size=Artifacts._hash_block_size) file_hash, _ = sha256sum(path, block_size=Artifacts._hash_block_size)
if file_hash != task.artifacts[self.model_endpoint.preprocess_artifact].hash: if file_hash != task.artifacts[self.model_endpoint.preprocess_artifact].hash:
@ -77,12 +77,23 @@ class BasePreprocessRequest(object):
if callable(getattr(self._preprocess, 'load', None)): if callable(getattr(self._preprocess, 'load', None)):
self._model = self._preprocess.load(self._get_local_model_file()) self._model = self._preprocess.load(self._get_local_model_file())
def preprocess(self, request: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Optional[Any]: def preprocess(
self,
request: dict,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None,
) -> Optional[Any]:
""" """
Raise exception to report an error Raise exception to report an error
Return value will be passed to serving engine Return value will be passed to serving engine
:param request: dictionary as recieved from the RestAPI :param request: dictionary as recieved from the RestAPI
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd to the statictics collector servicd
@ -94,15 +105,26 @@ class BasePreprocessRequest(object):
:return: Object to be passed directly to the model inference :return: Object to be passed directly to the model inference
""" """
if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'): if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'):
return self._preprocess.preprocess(request, collect_custom_statistics_fn) return self._preprocess.preprocess(request, state, collect_custom_statistics_fn)
return request return request
def postprocess(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Optional[dict]: def postprocess(
self,
data: Any,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None
) -> Optional[dict]:
""" """
Raise exception to report an error Raise exception to report an error
Return value will be passed to serving engine Return value will be passed to serving engine
:param data: object as recieved from the inference model function :param data: object as recieved from the inference model function
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd to the statictics collector servicd
@ -112,14 +134,25 @@ class BasePreprocessRequest(object):
:return: Dictionary passed directly as the returned result of the RestAPI :return: Dictionary passed directly as the returned result of the RestAPI
""" """
if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'): if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'):
return self._preprocess.postprocess(data, collect_custom_statistics_fn) return self._preprocess.postprocess(data, state, collect_custom_statistics_fn)
return data return data
def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: def process(
self,
data: Any,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None
) -> Any:
""" """
The actual processing function. Can be send to external service The actual processing function. Can be sent to external service
:param data: object as recieved from the preprocessing function :param data: object as recieved from the preprocessing function
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd to the statictics collector servicd
@ -178,7 +211,7 @@ class BasePreprocessRequest(object):
pass pass
@staticmethod @staticmethod
def _preprocess_send_request(self, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]: def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]:
endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/") endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/")
base_url = BasePreprocessRequest.get_server_config().get("base_serving_url") base_url = BasePreprocessRequest.get_server_config().get("base_serving_url")
base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/") base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/")
@ -226,12 +259,23 @@ class TritonPreprocessRequest(BasePreprocessRequest):
self._ext_service_pb2 = service_pb2 self._ext_service_pb2 = service_pb2
self._ext_service_pb2_grpc = service_pb2_grpc self._ext_service_pb2_grpc = service_pb2_grpc
def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: def process(
self,
data: Any,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None
) -> Any:
""" """
The actual processing function. The actual processing function.
Detect gRPC server and send the request to it Detect gRPC server and send the request to it
:param data: object as recieved from the preprocessing function :param data: object as recieved from the preprocessing function
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd to the statictics collector servicd
@ -240,9 +284,9 @@ class TritonPreprocessRequest(BasePreprocessRequest):
:return: Object to be passed tp the post-processing function :return: Object to be passed tp the post-processing function
""" """
# allow to override bt preprocessing class # allow overriding the process method
if self._preprocess is not None and hasattr(self._preprocess, "process"): if self._preprocess is not None and hasattr(self._preprocess, "process"):
return self._preprocess.process(data, collect_custom_statistics_fn) return self._preprocess.process(data, state, collect_custom_statistics_fn)
# Create gRPC stub for communicating with the server # Create gRPC stub for communicating with the server
triton_server_address = self._server_config.get("triton_grpc_server") or self._default_grpc_address triton_server_address = self._server_config.get("triton_grpc_server") or self._default_grpc_address
@ -316,7 +360,7 @@ class SKLearnPreprocessRequest(BasePreprocessRequest):
import joblib # noqa import joblib # noqa
self._model = joblib.load(filename=self._get_local_model_file()) self._model = joblib.load(filename=self._get_local_model_file())
def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
""" """
The actual processing function. The actual processing function.
We run the model in this context We run the model in this context
@ -335,7 +379,7 @@ class XGBoostPreprocessRequest(BasePreprocessRequest):
self._model = xgboost.Booster() self._model = xgboost.Booster()
self._model.load_model(self._get_local_model_file()) self._model.load_model(self._get_local_model_file())
def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
""" """
The actual processing function. The actual processing function.
We run the model in this context We run the model in this context
@ -353,7 +397,7 @@ class LightGBMPreprocessRequest(BasePreprocessRequest):
import lightgbm # noqa import lightgbm # noqa
self._model = lightgbm.Booster(model_file=self._get_local_model_file()) self._model = lightgbm.Booster(model_file=self._get_local_model_file())
def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
""" """
The actual processing function. The actual processing function.
We run the model in this context We run the model in this context
@ -367,11 +411,11 @@ class CustomPreprocessRequest(BasePreprocessRequest):
super(CustomPreprocessRequest, self).__init__( super(CustomPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task) model_endpoint=model_endpoint, task=task)
def process(self, data: Any, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
""" """
The actual processing function. The actual processing function.
We run the process in this context We run the process in this context
""" """
if self._preprocess is not None and hasattr(self._preprocess, 'process'): if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return self._preprocess.process(data, collect_custom_statistics_fn) return self._preprocess.process(data, state, collect_custom_statistics_fn)
return None return None

View File

@ -9,11 +9,11 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request) # set internal state, this will be called only once. (i.e. not per request)
pass pass
def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get two valid on the dict x0, and x1 # we expect to get two valid on the dict x0, and x1
return [[body.get("x0", None), body.get("x1", None)], ] return [[body.get("x0", None), body.get("x1", None)], ]
def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine # post process the data returned from the model inference engine
# data is the return value from model.predict we will put is inside a return value as Y # data is the return value from model.predict we will put is inside a return value as Y
return dict(y=data.tolist() if isinstance(data, np.ndarray) else data) return dict(y=data.tolist() if isinstance(data, np.ndarray) else data)

View File

@ -13,7 +13,7 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request) # set internal state, this will be called only once. (i.e. not per request)
pass pass
def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get two valid on the dict x0, and x1 # we expect to get two valid on the dict x0, and x1
url = body.get("url") url = body.get("url")
if not url: if not url:
@ -25,7 +25,7 @@ class Preprocess(object):
return np.array(image).flatten() return np.array(image).flatten()
def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine # post process the data returned from the model inference engine
# data is the return value from model.predict we will put is inside a return value as Y # data is the return value from model.predict we will put is inside a return value as Y
if not isinstance(data, np.ndarray): if not isinstance(data, np.ndarray):

View File

@ -9,14 +9,14 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request) # set internal state, this will be called only once. (i.e. not per request)
pass pass
def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get four valid numbers on the dict: x0, x1, x2, x3 # we expect to get four valid numbers on the dict: x0, x1, x2, x3
return np.array( return np.array(
[[body.get("x0", None), body.get("x1", None), body.get("x2", None), body.get("x3", None)], ], [[body.get("x0", None), body.get("x1", None), body.get("x2", None), body.get("x3", None)], ],
dtype=np.float32 dtype=np.float32
) )
def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine # post process the data returned from the model inference engine
# data is the return value from model.predict we will put is inside a return value as Y # data is the return value from model.predict we will put is inside a return value as Y
# we pick the most probably class and return the class index (argmax) # we pick the most probably class and return the class index (argmax)

View File

@ -8,14 +8,14 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request) # set internal state, this will be called only once. (i.e. not per request)
self.executor = ThreadPoolExecutor(max_workers=32) self.executor = ThreadPoolExecutor(max_workers=32)
def postprocess(self, data: List[dict], collect_custom_statistics_fn=None) -> dict: def postprocess(self, data: List[dict], state: dict, collect_custom_statistics_fn=None) -> dict:
# we will here average the results and return the new value # we will here average the results and return the new value
# assume data is a list of dicts greater than 1 # assume data is a list of dicts greater than 1
# average result # average result
return dict(y=0.5 * data[0]['y'][0] + 0.5 * data[1]['y'][0]) return dict(y=0.5 * data[0]['y'][0] + 0.5 * data[1]['y'][0])
def process(self, data: Any, collect_custom_statistics_fn=None) -> Any: def process(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> Any:
""" """
do something with the actual data, return any type of object. do something with the actual data, return any type of object.
The returned object will be passed as is to the postprocess function engine The returned object will be passed as is to the postprocess function engine

View File

@ -13,7 +13,7 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request) # set internal state, this will be called only once. (i.e. not per request)
pass pass
def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get two valid on the dict x0, and x1 # we expect to get two valid on the dict x0, and x1
url = body.get("url") url = body.get("url")
if not url: if not url:
@ -24,7 +24,7 @@ class Preprocess(object):
image = ImageOps.grayscale(image).resize((28, 28)) image = ImageOps.grayscale(image).resize((28, 28))
return np.array(image).flatten() return np.array(image).flatten()
def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine # post process the data returned from the model inference engine
# data is the return value from model.predict we will put is inside a return value as Y # data is the return value from model.predict we will put is inside a return value as Y
if not isinstance(data, np.ndarray): if not isinstance(data, np.ndarray):

View File

@ -9,11 +9,11 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request) # set internal state, this will be called only once. (i.e. not per request)
pass pass
def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get two valid on the dict x0, and x1 # we expect to get two valid on the dict x0, and x1
return [[body.get("x0", None), body.get("x1", None)], ] return [[body.get("x0", None), body.get("x1", None)], ]
def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine # post process the data returned from the model inference engine
# data is the return value from model.predict we will put is inside a return value as Y # data is the return value from model.predict we will put is inside a return value as Y
return dict(y=data.tolist() if isinstance(data, np.ndarray) else data) return dict(y=data.tolist() if isinstance(data, np.ndarray) else data)

View File

@ -10,12 +10,12 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request) # set internal state, this will be called only once. (i.e. not per request)
pass pass
def preprocess(self, body: dict, collect_custom_statistics_fn=None) -> Any: def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get four valid numbers on the dict: x0, x1, x2, x3 # we expect to get four valid numbers on the dict: x0, x1, x2, x3
return xgb.DMatrix( return xgb.DMatrix(
[[body.get("x0", None), body.get("x1", None), body.get("x2", None), body.get("x3", None)]]) [[body.get("x0", None), body.get("x1", None), body.get("x2", None), body.get("x3", None)]])
def postprocess(self, data: Any, collect_custom_statistics_fn=None) -> dict: def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine # post process the data returned from the model inference engine
# data is the return value from model.predict we will put is inside a return value as Y # data is the return value from model.predict we will put is inside a return value as Y
return dict(y=data.tolist() if isinstance(data, np.ndarray) else data) return dict(y=data.tolist() if isinstance(data, np.ndarray) else data)