[DEV] feature/bytes-payload | Add typing

This commit is contained in:
Aleksandar Ivanovski 2022-10-06 16:01:31 +02:00
parent 2aa91a3d43
commit d89d1370d8
2 changed files with 7 additions and 7 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, Callable
from typing import Any, Optional, Callable, Union
# Preprocess class Must be named "Preprocess"
@ -12,8 +12,8 @@ class Preprocess(object):
Notice the execution flows is synchronous as follows:
1. RestAPI(...) -> body: dict
2. preprocess(body: dict, ...) -> data: Any
1. RestAPI(...) -> body: Union[bytes, dict]
2. preprocess(body: Union[bytes, dict], ...) -> data: Any
3. process(data: Any, ...) -> data: Any
4. postprocess(data: Any, ...) -> result: dict
5. RestAPI(result: dict) -> returned request
@ -35,7 +35,7 @@ class Preprocess(object):
def preprocess(
self,
body: dict,
body: Union[bytes, dict],
state: dict,
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
) -> Any: # noqa
@ -43,7 +43,7 @@ class Preprocess(object):
Optional: do something with the request data, return any type of object.
The returned object will be passed as is to the inference engine
:param body: dictionary as recieved from the RestAPI
:param body: dictionary or bytes as recieved from the RestAPI
:param state: Use state dict to store data passed to the post-processing function call.
This is a per-request state dict (meaning a new empty dict will be passed per request)
Usage example:

View File

@ -124,7 +124,7 @@ class ModelRequestProcessor(object):
self._serving_base_url = None
self._metric_log_freq = None
def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict:
"""
Process request coming in,
Raise Value error if url does not match existing endpoints
@ -1133,7 +1133,7 @@ class ModelRequestProcessor(object):
# update preprocessing classes
BasePreprocessRequest.set_server_config(self._configuration)
def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict:
# collect statistics for this request
stats_collect_fn = None
collect_stats = False