mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
Merge pull request #37 from codechem/feature/bytes-payload
Bytes request payload
This commit is contained in:
commit
1569f08d1d
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Optional, Callable
|
from typing import Any, Optional, Callable, Union
|
||||||
|
|
||||||
|
|
||||||
# Preprocess class Must be named "Preprocess"
|
# Preprocess class Must be named "Preprocess"
|
||||||
@ -12,8 +12,8 @@ class Preprocess(object):
|
|||||||
|
|
||||||
Notice the execution flows is synchronous as follows:
|
Notice the execution flows is synchronous as follows:
|
||||||
|
|
||||||
1. RestAPI(...) -> body: dict
|
1. RestAPI(...) -> body: Union[bytes, dict]
|
||||||
2. preprocess(body: dict, ...) -> data: Any
|
2. preprocess(body: Union[bytes, dict], ...) -> data: Any
|
||||||
3. process(data: Any, ...) -> data: Any
|
3. process(data: Any, ...) -> data: Any
|
||||||
4. postprocess(data: Any, ...) -> result: dict
|
4. postprocess(data: Any, ...) -> result: dict
|
||||||
5. RestAPI(result: dict) -> returned request
|
5. RestAPI(result: dict) -> returned request
|
||||||
@ -35,7 +35,7 @@ class Preprocess(object):
|
|||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
body: dict,
|
body: Union[bytes, dict],
|
||||||
state: dict,
|
state: dict,
|
||||||
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
|
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
|
||||||
) -> Any: # noqa
|
) -> Any: # noqa
|
||||||
@ -43,7 +43,7 @@ class Preprocess(object):
|
|||||||
Optional: do something with the request data, return any type of object.
|
Optional: do something with the request data, return any type of object.
|
||||||
The returned object will be passed as is to the inference engine
|
The returned object will be passed as is to the inference engine
|
||||||
|
|
||||||
:param body: dictionary as recieved from the RestAPI
|
:param body: dictionary or bytes as recieved from the RestAPI
|
||||||
:param state: Use state dict to store data passed to the post-processing function call.
|
:param state: Use state dict to store data passed to the post-processing function call.
|
||||||
This is a per-request state dict (meaning a new empty dict will be passed per request)
|
This is a per-request state dict (meaning a new empty dict will be passed per request)
|
||||||
Usage example:
|
Usage example:
|
||||||
|
@ -5,7 +5,7 @@ import gzip
|
|||||||
from fastapi import FastAPI, Request, Response, APIRouter, HTTPException
|
from fastapi import FastAPI, Request, Response, APIRouter, HTTPException
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
|
|
||||||
from typing import Optional, Dict, Any, Callable
|
from typing import Optional, Dict, Any, Callable, Union
|
||||||
|
|
||||||
from clearml import Task
|
from clearml import Task
|
||||||
from clearml_serving.version import __version__
|
from clearml_serving.version import __version__
|
||||||
@ -87,7 +87,7 @@ router = APIRouter(
|
|||||||
@router.post("/{model_id}/{version}")
|
@router.post("/{model_id}/{version}")
|
||||||
@router.post("/{model_id}/")
|
@router.post("/{model_id}/")
|
||||||
@router.post("/{model_id}")
|
@router.post("/{model_id}")
|
||||||
async def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
|
async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None):
|
||||||
try:
|
try:
|
||||||
return_value = processor.process_request(
|
return_value = processor.process_request(
|
||||||
base_url=model_id,
|
base_url=model_id,
|
||||||
|
@ -9,6 +9,7 @@ from typing import Optional, Union, Dict, List
|
|||||||
import itertools
|
import itertools
|
||||||
import threading
|
import threading
|
||||||
from multiprocessing import Lock
|
from multiprocessing import Lock
|
||||||
|
from numpy import isin
|
||||||
from numpy.random import choice
|
from numpy.random import choice
|
||||||
|
|
||||||
from clearml import Task, Model
|
from clearml import Task, Model
|
||||||
@ -123,7 +124,7 @@ class ModelRequestProcessor(object):
|
|||||||
self._serving_base_url = None
|
self._serving_base_url = None
|
||||||
self._metric_log_freq = None
|
self._metric_log_freq = None
|
||||||
|
|
||||||
def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
|
def process_request(self, base_url: str, version: str, request_body: Union[dict, bytes]) -> dict:
|
||||||
"""
|
"""
|
||||||
Process request coming in,
|
Process request coming in,
|
||||||
Raise Value error if url does not match existing endpoints
|
Raise Value error if url does not match existing endpoints
|
||||||
@ -1132,7 +1133,7 @@ class ModelRequestProcessor(object):
|
|||||||
# update preprocessing classes
|
# update preprocessing classes
|
||||||
BasePreprocessRequest.set_server_config(self._configuration)
|
BasePreprocessRequest.set_server_config(self._configuration)
|
||||||
|
|
||||||
def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
|
def _process_request(self, processor: BasePreprocessRequest, url: str, body: Union[bytes, dict]) -> dict:
|
||||||
# collect statistics for this request
|
# collect statistics for this request
|
||||||
stats_collect_fn = None
|
stats_collect_fn = None
|
||||||
collect_stats = False
|
collect_stats = False
|
||||||
@ -1167,7 +1168,7 @@ class ModelRequestProcessor(object):
|
|||||||
if metric_endpoint:
|
if metric_endpoint:
|
||||||
metric_keys = set(metric_endpoint.metrics.keys())
|
metric_keys = set(metric_endpoint.metrics.keys())
|
||||||
# collect inputs
|
# collect inputs
|
||||||
if body:
|
if body and isinstance(body, dict):
|
||||||
keys = set(body.keys()) & metric_keys
|
keys = set(body.keys()) & metric_keys
|
||||||
stats.update({k: body[k] for k in keys})
|
stats.update({k: body[k] for k in keys})
|
||||||
# collect outputs
|
# collect outputs
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any
|
import io
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
@ -13,16 +14,22 @@ class Preprocess(object):
|
|||||||
# set internal state, this will be called only once. (i.e. not per request)
|
# set internal state, this will be called only once. (i.e. not per request)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
|
def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any:
|
||||||
# we expect to get two valid on the dict x0, and x1
|
# we expect to get two valid on the dict x0, and x1
|
||||||
url = body.get("url")
|
if isinstance(body, bytes):
|
||||||
if not url:
|
# we expect to get a stream of encoded image bytes
|
||||||
raise ValueError("'url' entry not provided, expected http/s link to image")
|
try:
|
||||||
|
image = Image.open(io.BytesIO(body)).convert("RGB")
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Image could not be decoded")
|
||||||
|
|
||||||
local_file = StorageManager.get_local_copy(remote_url=url)
|
if isinstance(body, dict) and "url" in body.keys():
|
||||||
image = Image.open(local_file)
|
# image is given as url, and is fetched
|
||||||
|
url = body.get("url")
|
||||||
|
local_file = StorageManager.get_local_copy(remote_url=url)
|
||||||
|
image = Image.open(local_file)
|
||||||
|
|
||||||
image = ImageOps.grayscale(image).resize((28, 28))
|
image = ImageOps.grayscale(image).resize((28, 28))
|
||||||
|
|
||||||
return np.array([np.array(image).flatten()])
|
return np.array([np.array(image).flatten()])
|
||||||
|
|
||||||
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
|
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any
|
import io
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
@ -13,16 +14,23 @@ class Preprocess(object):
|
|||||||
# set internal state, this will be called only once. (i.e. not per request)
|
# set internal state, this will be called only once. (i.e. not per request)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
|
def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any:
|
||||||
# we expect to get two valid on the dict x0, and x1
|
# we expect to get two valid on the dict x0, and x1
|
||||||
url = body.get("url")
|
if isinstance(body, bytes):
|
||||||
if not url:
|
# we expect to get a stream of encoded image bytes
|
||||||
raise ValueError("'url' entry not provided, expected http/s link to image")
|
try:
|
||||||
|
image = Image.open(io.BytesIO(body)).convert("RGB")
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Image could not be decoded")
|
||||||
|
|
||||||
local_file = StorageManager.get_local_copy(remote_url=url)
|
if isinstance(body, dict) and "url" in body.keys():
|
||||||
image = Image.open(local_file)
|
# image is given as url, and is fetched
|
||||||
|
url = body.get("url")
|
||||||
|
local_file = StorageManager.get_local_copy(remote_url=url)
|
||||||
|
image = Image.open(local_file)
|
||||||
|
|
||||||
image = ImageOps.grayscale(image).resize((28, 28))
|
image = ImageOps.grayscale(image).resize((28, 28))
|
||||||
return np.array([np.array(image)])
|
return np.array([np.array(image).flatten()])
|
||||||
|
|
||||||
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
|
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
|
||||||
# post process the data returned from the model inference engine
|
# post process the data returned from the model inference engine
|
||||||
|
Loading…
Reference in New Issue
Block a user