Merge pull request #37 from codechem/feature/bytes-payload

Bytes request payload
This commit is contained in:
Allegro AI 2022-10-08 00:01:01 +03:00 committed by GitHub
commit 1569f08d1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 26 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, Callable
from typing import Any, Optional, Callable, Union
# Preprocess class Must be named "Preprocess"
@ -12,8 +12,8 @@ class Preprocess(object):
Notice the execution flows is synchronous as follows:
1. RestAPI(...) -> body: dict
2. preprocess(body: dict, ...) -> data: Any
1. RestAPI(...) -> body: Union[bytes, dict]
2. preprocess(body: Union[bytes, dict], ...) -> data: Any
3. process(data: Any, ...) -> data: Any
4. postprocess(data: Any, ...) -> result: dict
5. RestAPI(result: dict) -> returned request
@ -35,7 +35,7 @@ class Preprocess(object):
def preprocess(
self,
body: dict,
body: Union[bytes, dict],
state: dict,
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
) -> Any: # noqa
@ -43,7 +43,7 @@ class Preprocess(object):
Optional: do something with the request data, return any type of object.
The returned object will be passed as is to the inference engine
:param body: dictionary as recieved from the RestAPI
:param body: dictionary or bytes as recieved from the RestAPI
:param state: Use state dict to store data passed to the post-processing function call.
This is a per-request state dict (meaning a new empty dict will be passed per request)
Usage example:

View File

@ -5,7 +5,7 @@ import gzip
from fastapi import FastAPI, Request, Response, APIRouter, HTTPException
from fastapi.routing import APIRoute
from typing import Optional, Dict, Any, Callable
from typing import Optional, Dict, Any, Callable, Union
from clearml import Task
from clearml_serving.version import __version__
@ -87,7 +87,7 @@ router = APIRouter(
@router.post("/{model_id}/{version}")
@router.post("/{model_id}/")
@router.post("/{model_id}")
async def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None):
try:
return_value = processor.process_request(
base_url=model_id,

View File

@ -9,6 +9,7 @@ from typing import Optional, Union, Dict, List
import itertools
import threading
from multiprocessing import Lock
from numpy import isin
from numpy.random import choice
from clearml import Task, Model
@ -123,7 +124,7 @@ class ModelRequestProcessor(object):
self._serving_base_url = None
self._metric_log_freq = None
def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict:
"""
Process request coming in,
Raise Value error if url does not match existing endpoints
@ -1132,7 +1133,7 @@ class ModelRequestProcessor(object):
# update preprocessing classes
BasePreprocessRequest.set_server_config(self._configuration)
def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict:
# collect statistics for this request
stats_collect_fn = None
collect_stats = False
@ -1167,7 +1168,7 @@ class ModelRequestProcessor(object):
if metric_endpoint:
metric_keys = set(metric_endpoint.metrics.keys())
# collect inputs
if body:
if body and isinstance(body, dict):
keys = set(body.keys()) & metric_keys
stats.update({k: body[k] for k in keys})
# collect outputs

View File

@ -1,4 +1,5 @@
from typing import Any
import io
from typing import Any, Union
import numpy as np
from PIL import Image, ImageOps
@ -13,16 +14,22 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request)
pass
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get two valid on the dict x0, and x1
url = body.get("url")
if not url:
raise ValueError("'url' entry not provided, expected http/s link to image")
if isinstance(body, bytes):
# we expect to get a stream of encoded image bytes
try:
image = Image.open(io.BytesIO(body)).convert("RGB")
except Exception:
raise ValueError("Image could not be decoded")
local_file = StorageManager.get_local_copy(remote_url=url)
image = Image.open(local_file)
if isinstance(body, dict) and "url" in body.keys():
# image is given as url, and is fetched
url = body.get("url")
local_file = StorageManager.get_local_copy(remote_url=url)
image = Image.open(local_file)
image = ImageOps.grayscale(image).resize((28, 28))
return np.array([np.array(image).flatten()])
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:

View File

@ -1,4 +1,5 @@
from typing import Any
import io
from typing import Any, Union
import numpy as np
from PIL import Image, ImageOps
@ -13,16 +14,23 @@ class Preprocess(object):
# set internal state, this will be called only once. (i.e. not per request)
pass
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any:
# we expect to get two valid on the dict x0, and x1
url = body.get("url")
if not url:
raise ValueError("'url' entry not provided, expected http/s link to image")
if isinstance(body, bytes):
# we expect to get a stream of encoded image bytes
try:
image = Image.open(io.BytesIO(body)).convert("RGB")
except Exception:
raise ValueError("Image could not be decoded")
local_file = StorageManager.get_local_copy(remote_url=url)
image = Image.open(local_file)
if isinstance(body, dict) and "url" in body.keys():
# image is given as url, and is fetched
url = body.get("url")
local_file = StorageManager.get_local_copy(remote_url=url)
image = Image.open(local_file)
image = ImageOps.grayscale(image).resize((28, 28))
return np.array([np.array(image)])
return np.array([np.array(image).flatten()])
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
# post process the data returned from the model inference engine