diff --git a/examples/keras/preprocess.py b/examples/keras/preprocess.py index b10b693..1738ea8 100644 --- a/examples/keras/preprocess.py +++ b/examples/keras/preprocess.py @@ -1,4 +1,5 @@ -from typing import Any +import io +from typing import Any, Union import numpy as np from PIL import Image, ImageOps @@ -13,16 +14,22 @@ class Preprocess(object): # set internal state, this will be called only once. (i.e. not per request) pass - def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: + def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any: # we expect to get two valid on the dict x0, and x1 - url = body.get("url") - if not url: - raise ValueError("'url' entry not provided, expected http/s link to image") + if isinstance(body, bytes): + # we expect to get a stream of encoded image bytes + try: + image = Image.open(io.BytesIO(body)).convert("RGB") + except Exception: + raise ValueError("Image could not be decoded") - local_file = StorageManager.get_local_copy(remote_url=url) - image = Image.open(local_file) + if isinstance(body, dict) and "url" in body.keys(): + # image is given as url, and is fetched + url = body.get("url") + local_file = StorageManager.get_local_copy(remote_url=url) + image = Image.open(local_file) + image = ImageOps.grayscale(image).resize((28, 28)) - return np.array([np.array(image).flatten()]) def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: diff --git a/examples/pytorch/preprocess.py b/examples/pytorch/preprocess.py index 2fd1581..1738ea8 100644 --- a/examples/pytorch/preprocess.py +++ b/examples/pytorch/preprocess.py @@ -1,4 +1,5 @@ -from typing import Any +import io +from typing import Any, Union import numpy as np from PIL import Image, ImageOps @@ -13,16 +14,23 @@ class Preprocess(object): # set internal state, this will be called only once. (i.e. not per request) pass - def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any: + def preprocess(self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn=None) -> Any: # we expect to get two valid on the dict x0, and x1 - url = body.get("url") - if not url: - raise ValueError("'url' entry not provided, expected http/s link to image") + if isinstance(body, bytes): + # we expect to get a stream of encoded image bytes + try: + image = Image.open(io.BytesIO(body)).convert("RGB") + except Exception: + raise ValueError("Image could not be decoded") - local_file = StorageManager.get_local_copy(remote_url=url) - image = Image.open(local_file) + if isinstance(body, dict) and "url" in body.keys(): + # image is given as url, and is fetched + url = body.get("url") + local_file = StorageManager.get_local_copy(remote_url=url) + image = Image.open(local_file) + image = ImageOps.grayscale(image).resize((28, 28)) - return np.array([np.array(image)]) + return np.array([np.array(image).flatten()]) def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict: # post process the data returned from the model inference engine