Merge pull request #37 from codechem/feature/bytes-payload

Bytes request payload
This commit is contained in:
Allegro AI 2022-10-08 00:01:01 +03:00 committed by GitHub
commit 1569f08d1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 26 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, Callable from typing import Any, Optional, Callable, Union
# Preprocess class Must be named "Preprocess" # Preprocess class Must be named "Preprocess"
@ -12,8 +12,8 @@ class Preprocess(object):
Notice the execution flows is synchronous as follows: Notice the execution flows is synchronous as follows:
1. RestAPI(...) -> body: dict 1. RestAPI(...) -> body: Union[bytes, dict]
2. preprocess(body: dict, ...) -> data: Any 2. preprocess(body: Union[bytes, dict], ...) -> data: Any
3. process(data: Any, ...) -> data: Any 3. process(data: Any, ...) -> data: Any
4. postprocess(data: Any, ...) -> result: dict 4. postprocess(data: Any, ...) -> result: dict
5. RestAPI(result: dict) -> returned request 5. RestAPI(result: dict) -> returned request
@ -35,7 +35,7 @@ class Preprocess(object):
def preprocess( def preprocess(
self, self,
body: dict, body: Union[bytes, dict],
state: dict, state: dict,
collect_custom_statistics_fn: Optional[Callable[[dict], None]], collect_custom_statistics_fn: Optional[Callable[[dict], None]],
) -> Any: # noqa ) -> Any: # noqa
@ -43,7 +43,7 @@ class Preprocess(object):
Optional: do something with the request data, return any type of object. Optional: do something with the request data, return any type of object.
The returned object will be passed as is to the inference engine The returned object will be passed as is to the inference engine
:param body: dictionary as recieved from the RestAPI :param body: dictionary or bytes as recieved from the RestAPI
:param state: Use state dict to store data passed to the post-processing function call. :param state: Use state dict to store data passed to the post-processing function call.
This is a per-request state dict (meaning a new empty dict will be passed per request) This is a per-request state dict (meaning a new empty dict will be passed per request)
Usage example: Usage example:

View File

@ -5,7 +5,7 @@ import gzip
from fastapi import FastAPI, Request, Response, APIRouter, HTTPException from fastapi import FastAPI, Request, Response, APIRouter, HTTPException
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from typing import Optional, Dict, Any, Callable from typing import Optional, Dict, Any, Callable, Union
from clearml import Task from clearml import Task
from clearml_serving.version import __version__ from clearml_serving.version import __version__
@ -87,7 +87,7 @@ router = APIRouter(
@router.post("/{model_id}/{version}") @router.post("/{model_id}/{version}")
@router.post("/{model_id}/") @router.post("/{model_id}/")
@router.post("/{model_id}") @router.post("/{model_id}")
async def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None): async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None):
try: try:
return_value = processor.process_request( return_value = processor.process_request(
base_url=model_id, base_url=model_id,

View File

@ -9,6 +9,7 @@ from typing import Optional, Union, Dict, List
import itertools import itertools
import threading import threading
from multiprocessing import Lock from multiprocessing import Lock
from numpy import isin
from numpy.random import choice from numpy.random import choice
from clearml import Task, Model from clearml import Task, Model
@ -123,7 +124,7 @@ class ModelRequestProcessor(object):
self._serving_base_url = None self._serving_base_url = None
self._metric_log_freq = None self._metric_log_freq = None
def process_request(self, base_url: str, version: str, request_body: dict) -> dict: def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict:
""" """
Process request coming in, Process request coming in,
Raise Value error if url does not match existing endpoints Raise Value error if url does not match existing endpoints
@ -1132,7 +1133,7 @@ class ModelRequestProcessor(object):
# update preprocessing classes # update preprocessing classes
BasePreprocessRequest.set_server_config(self._configuration) BasePreprocessRequest.set_server_config(self._configuration)
def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict: def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict:
# collect statistics for this request # collect statistics for this request
stats_collect_fn = None stats_collect_fn = None
collect_stats = False collect_stats = False
@ -1167,7 +1168,7 @@ class ModelRequestProcessor(object):
if metric_endpoint: if metric_endpoint:
metric_keys = set(metric_endpoint.metrics.keys()) metric_keys = set(metric_endpoint.metrics.keys())
# collect inputs # collect inputs
if body: if body and isinstance(body, dict):
keys = set(body.keys()) & metric_keys keys = set(body.keys()) & metric_keys
stats.update({k: body[k] for k in keys}) stats.update({k: body[k] for k in keys})
# collect outputs # collect outputs

View File

@ -1,4 +1,5 @@
from typing import Any import io
from typing import Any, Union
import numpy as np import numpy as np
from PIL import Image, ImageOps from PIL import Image, ImageOps
@ -13,16 +14,22 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request) # set internal state, this will be called only once. (i.e. not per request)
pass pass
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get two valid on the dict x0, and x1 # we expect to get two valid on the dict x0, and x1
url = body.get("url") if isinstance(body, bytes):
if not url: # we expect to get a stream of encoded image bytes
raise ValueError("'url' entry not provided, expected http/s link to image") try:
image = Image.open(io.BytesIO(body)).convert("RGB")
except Exception:
raise ValueError("Image could not be decoded")
local_file = StorageManager.get_local_copy(remote_url=url) if isinstance(body, dict) and "url" in body.keys():
image = Image.open(local_file) # image is given as url, and is fetched
url = body.get("url")
local_file = StorageManager.get_local_copy(remote_url=url)
image = Image.open(local_file)
image = ImageOps.grayscale(image).resize((28, 28)) image = ImageOps.grayscale(image).resize((28, 28))
return np.array([np.array(image).flatten()]) return np.array([np.array(image).flatten()])
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:

View File

@ -1,4 +1,5 @@
from typing import Any import io
from typing import Any, Union
import numpy as np import numpy as np
from PIL import Image, ImageOps from PIL import Image, ImageOps
@ -13,16 +14,23 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request) # set internal state, this will be called only once. (i.e. not per request)
pass pass
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get two valid on the dict x0, and x1 # we expect to get two valid on the dict x0, and x1
url = body.get("url") if isinstance(body, bytes):
if not url: # we expect to get a stream of encoded image bytes
raise ValueError("'url' entry not provided, expected http/s link to image") try:
image = Image.open(io.BytesIO(body)).convert("RGB")
except Exception:
raise ValueError("Image could not be decoded")
local_file = StorageManager.get_local_copy(remote_url=url) if isinstance(body, dict) and "url" in body.keys():
image = Image.open(local_file) # image is given as url, and is fetched
url = body.get("url")
local_file = StorageManager.get_local_copy(remote_url=url)
image = Image.open(local_file)
image = ImageOps.grayscale(image).resize((28, 28)) image = ImageOps.grayscale(image).resize((28, 28))
return np.array([np.array(image)]) return np.array([np.array(image).flatten()])
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine # post process the data returned from the model inference engine