mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
Merge pull request #37 from codechem/feature/bytes-payload
Bytes request payload
This commit is contained in:
commit
1569f08d1d
@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Callable
|
||||
from typing import Any, Optional, Callable, Union
|
||||
|
||||
|
||||
# Preprocess class Must be named "Preprocess"
|
||||
@ -12,8 +12,8 @@ class Preprocess(object):
|
||||
|
||||
Notice the execution flows is synchronous as follows:
|
||||
|
||||
1. RestAPI(...) -> body: dict
|
||||
2. preprocess(body: dict, ...) -> data: Any
|
||||
1. RestAPI(...) -> body: Union[bytes, dict]
|
||||
2. preprocess(body: Union[bytes, dict], ...) -> data: Any
|
||||
3. process(data: Any, ...) -> data: Any
|
||||
4. postprocess(data: Any, ...) -> result: dict
|
||||
5. RestAPI(result: dict) -> returned request
|
||||
@ -35,7 +35,7 @@ class Preprocess(object):
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
body: dict,
|
||||
body: Union[bytes, dict],
|
||||
state: dict,
|
||||
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
|
||||
) -> Any: # noqa
|
||||
@ -43,7 +43,7 @@ class Preprocess(object):
|
||||
Optional: do something with the request data, return any type of object.
|
||||
The returned object will be passed as is to the inference engine
|
||||
|
||||
:param body: dictionary as recieved from the RestAPI
|
||||
:param body: dictionary or bytes as recieved from the RestAPI
|
||||
:param state: Use state dict to store data passed to the post-processing function call.
|
||||
This is a per-request state dict (meaning a new empty dict will be passed per request)
|
||||
Usage example:
|
||||
|
@ -5,7 +5,7 @@ import gzip
|
||||
from fastapi import FastAPI, Request, Response, APIRouter, HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from typing import Optional, Dict, Any, Callable
|
||||
from typing import Optional, Dict, Any, Callable, Union
|
||||
|
||||
from clearml import Task
|
||||
from clearml_serving.version import __version__
|
||||
@ -87,7 +87,7 @@ router = APIRouter(
|
||||
@router.post("/{model_id}/{version}")
|
||||
@router.post("/{model_id}/")
|
||||
@router.post("/{model_id}")
|
||||
async def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
|
||||
async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None):
|
||||
try:
|
||||
return_value = processor.process_request(
|
||||
base_url=model_id,
|
||||
|
@ -9,6 +9,7 @@ from typing import Optional, Union, Dict, List
|
||||
import itertools
|
||||
import threading
|
||||
from multiprocessing import Lock
|
||||
from numpy import isin
|
||||
from numpy.random import choice
|
||||
|
||||
from clearml import Task, Model
|
||||
@ -123,7 +124,7 @@ class ModelRequestProcessor(object):
|
||||
self._serving_base_url = None
|
||||
self._metric_log_freq = None
|
||||
|
||||
def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
|
||||
def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict:
|
||||
"""
|
||||
Process request coming in,
|
||||
Raise Value error if url does not match existing endpoints
|
||||
@ -1132,7 +1133,7 @@ class ModelRequestProcessor(object):
|
||||
# update preprocessing classes
|
||||
BasePreprocessRequest.set_server_config(self._configuration)
|
||||
|
||||
def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
|
||||
def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict:
|
||||
# collect statistics for this request
|
||||
stats_collect_fn = None
|
||||
collect_stats = False
|
||||
@ -1167,7 +1168,7 @@ class ModelRequestProcessor(object):
|
||||
if metric_endpoint:
|
||||
metric_keys = set(metric_endpoint.metrics.keys())
|
||||
# collect inputs
|
||||
if body:
|
||||
if body and isinstance(body, dict):
|
||||
keys = set(body.keys()) & metric_keys
|
||||
stats.update({k: body[k] for k in keys})
|
||||
# collect outputs
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
import io
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
@ -13,16 +14,22 @@ class Preprocess(object):
|
||||
# set internal state, this will be called only once. (i.e. not per request)
|
||||
pass
|
||||
|
||||
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
|
||||
def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any:
|
||||
# we expect to get two valid on the dict x0, and x1
|
||||
url = body.get("url")
|
||||
if not url:
|
||||
raise ValueError("'url' entry not provided, expected http/s link to image")
|
||||
if isinstance(body, bytes):
|
||||
# we expect to get a stream of encoded image bytes
|
||||
try:
|
||||
image = Image.open(io.BytesIO(body)).convert("RGB")
|
||||
except Exception:
|
||||
raise ValueError("Image could not be decoded")
|
||||
|
||||
local_file = StorageManager.get_local_copy(remote_url=url)
|
||||
image = Image.open(local_file)
|
||||
if isinstance(body, dict) and "url" in body.keys():
|
||||
# image is given as url, and is fetched
|
||||
url = body.get("url")
|
||||
local_file = StorageManager.get_local_copy(remote_url=url)
|
||||
image = Image.open(local_file)
|
||||
|
||||
image = ImageOps.grayscale(image).resize((28, 28))
|
||||
|
||||
return np.array([np.array(image).flatten()])
|
||||
|
||||
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
import io
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
@ -13,16 +14,23 @@ class Preprocess(object):
|
||||
# set internal state, this will be called only once. (i.e. not per request)
|
||||
pass
|
||||
|
||||
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
|
||||
def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any:
|
||||
# we expect to get two valid on the dict x0, and x1
|
||||
url = body.get("url")
|
||||
if not url:
|
||||
raise ValueError("'url' entry not provided, expected http/s link to image")
|
||||
if isinstance(body, bytes):
|
||||
# we expect to get a stream of encoded image bytes
|
||||
try:
|
||||
image = Image.open(io.BytesIO(body)).convert("RGB")
|
||||
except Exception:
|
||||
raise ValueError("Image could not be decoded")
|
||||
|
||||
local_file = StorageManager.get_local_copy(remote_url=url)
|
||||
image = Image.open(local_file)
|
||||
if isinstance(body, dict) and "url" in body.keys():
|
||||
# image is given as url, and is fetched
|
||||
url = body.get("url")
|
||||
local_file = StorageManager.get_local_copy(remote_url=url)
|
||||
image = Image.open(local_file)
|
||||
|
||||
image = ImageOps.grayscale(image).resize((28, 28))
|
||||
return np.array([np.array(image)])
|
||||
return np.array([np.array(image).flatten()])
|
||||
|
||||
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
|
||||
# post process the data returned from the model inference engine
|
||||
|
Loading…
Reference in New Issue
Block a user