Add custom model serving example

This commit is contained in:
allegroai 2022-07-15 22:27:55 +03:00
parent 4f31c4c178
commit 6005e238ca
6 changed files with 196 additions and 0 deletions

View File

@ -291,6 +291,7 @@ Grafana model performance example:
- PyTorch [example](examples/pytorch/readme.md) - mnist dataset
- TensorFlow/Keras [example](examples/keras/readme.md) - mnist dataset
- Model Pipeline [example](examples/pipeline/readme.md) - random data
- Custom Model [example](examples/custom/readme.md) - custom data
### :pray: Status

View File

@ -8,6 +8,15 @@ class Preprocess(object):
Preprocess class Must be named "Preprocess"
Otherwise there are No limitations, No need to inherit or to implement all methods
Notice! This is not thread safe! the same instance may be accessed from multiple threads simultaneously
to store date in a safe way push it into the `state` dict argument of preprocessing/postprocessing functions
Notice the execution flows is synchronous as follows:
1. RestAPI(...) -> body: dict
2. preprocess(body: dict, ...) -> data: Any
3. process(data: Any, ...) -> data: Any
4. postprocess(data: Any, ...) -> result: dict
5. RestAPI(result: dict) -> returned request
"""
def __init__(self):

View File

@ -0,0 +1,132 @@
from typing import Any, Callable, Optional
import joblib
import numpy as np
# Notice Preprocess class Must be named "Preprocess"
class Preprocess(object):
"""
Notice the execution flows is synchronous as follows:
1. RestAPI(...) -> body: dict
2. preprocess(body: dict, ...) -> data: Any
3. process(data: Any, ...) -> data: Any
4. postprocess(data: Any, ...) -> result: dict
5. RestAPI(result: dict) -> returned request
"""
def __init__(self):
"""
Set any initial property on the Task (usually model object)
Notice these properties will be accessed from multiple threads.
If you need a stateful (per request) data, use the `state` dict argument passed to pre/post/process functions
"""
# set internal state, this will be called only once. (i.e. not per request)
self._model = None
def load(self, local_file_name: str) -> Optional[Any]: # noqa
"""
Optional: provide loading method for the model
useful if we need to load a model in a specific way for the prediction engine to work
:param local_file_name: file name / path to read load the model from
:return: Object that will be called with .predict() method for inference
"""
# Example now lets load the actual model
self._model = joblib.load(local_file_name)
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
"""
Optional: do something with the request data, return any type of object.
The returned object will be passed as is to the inference engine
:param body: dictionary as recieved from the RestAPI
:param state: Use state dict to store data passed to the post-processing function call.
This is a per-request state dict (meaning a new empty dict will be passed per request)
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values
to the statictics collector servicd.
None is passed if statiscs collector is not configured, or if the current request should not be
collected
Usage example:
>>> print(body)
{"x0": 1, "x1": 2}
>>> if collect_custom_statistics_fn:
>>> collect_custom_statistics_fn({"x0": 1, "x1": 2})
:return: Object to be passed directly to the model inference
"""
# we expect to get a feature vector on the `feature` entry if the dict
return np.array(body.get("features", []), dtype=np.float)
def process(
self,
data: Any,
state: dict,
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
) -> Any: # noqa
"""
Optional: do something with the actual data, return any type of object.
The returned object will be passed as is to the postprocess function engine
:param data: object as recieved from the preprocessing function
:param state: Use state dict to store data passed to the post-processing function call.
This is a per-request state dict (meaning a dict instance per request)
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values
to the statictics collector servicd.
None is passed if statiscs collector is not configured, or if the current request should not be collected
Usage example:
>>> if collect_custom_statistics_fn:
>>> collect_custom_statistics_fn({"type": "classification"})
:return: Object to be passed tp the post-processing function
"""
# this is where we do the heavy lifting, i.e. run our model.
# notice we know data is a numpy array of type float, because this is what we prepared in preprocessing function
data = self._model.predict(np.atleast_2d(data))
# data is also a numpy array, as returned from our fit function
return data
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
"""
Optional: post process the data returned from the model inference engine
returned dict will be passed back as the request result as is.
:param data: object as recieved from the inference model function
:param state: Use state dict to store data passed to the post-processing function call.
This is a per-request state dict (meaning a dict instance per request)
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values
to the statictics collector servicd.
None is passed if statiscs collector is not configured, or if the current request should not be
collected
Usage example:
>>> if collect_custom_statistics_fn:
>>> collect_custom_statistics_fn({"y": 1})
:return: Dictionary passed directly as the returned result of the RestAPI
"""
# Now we take the result numpy (predicted) and create a list of values to
# send back as the restapi return value
# data is the return value from model.predict we will put is inside a return value as Y
return dict(predict=data.tolist())

37
examples/custom/readme.md Normal file
View File

@ -0,0 +1,37 @@
# Train and Deploy custom model
## training mock custom model
Run the mock python training code
```bash
pip install -r examples/custom/requirements.txt
python examples/custom/train_model.py
```
The output will be a model created on the project "serving examples", by the name "custom train model"
## setting up the serving service
1. Create serving Service: `clearml-serving create --name "serving example"` (write down the service ID)
2. Make sure to add any required additional packages (for your custom model) to the [docker-compose.yml](https://github.com/allegroai/clearml-serving/blob/826f503cf4a9b069b89eb053696d218d1ce26f47/docker/docker-compose.yml#L97) (or as environment variable to the `clearml-serving-inference` container), by defining for example: `CLEARML_EXTRA_PYTHON_PACKAGES="scikit-learn numpy"`
3. Create model endpoint:
`clearml-serving --id <service_id> model add --engine custom --endpoint "test_model_custom" --preprocess "examples/custom/preprocess.py" --name "custom train model" --project "serving examples"`
Or auto update
`clearml-serving --id <service_id> model auto-update --engine custom --endpoint "test_model_custom_auto" --preprocess "examples/custom/preprocess.py" --name "custom train model" --project "serving examples" --max-versions 2`
Or add Canary endpoint
`clearml-serving --id <service_id> model canary --endpoint "test_model_custom_auto" --weights 0.1 0.9 --input-endpoint-prefix test_model_custom_auto`
4. If you already have the `clearml-serving` docker-compose running, it might take it a minute or two to sync with the new endpoint.
Or you can run the clearml-serving container independently `docker run -v ~/clearml.conf:/root/clearml.conf -p 8080:8080 -e CLEARML_SERVING_TASK_ID=<service_id> clearml-serving:latest`
5. Test new endpoint (do notice the first call will trigger the model pulling, so it might take longer, from here on, it's all in memory): `curl -X POST "http://127.0.0.1:8080/serve/test_model_custom" -H "accept: application/json" -H "Content-Type: application/json" -d '{"features": [1, 2, 3]}'`
> **_Notice:_** You can also change the serving service while it is already running!
This includes adding/removing endpoints, adding canary model routing etc.
by default new endpoints/models will be automatically updated after 1 minute

View File

@ -0,0 +1,2 @@
clearml >= 1.1.6
scikit-learn

View File

@ -0,0 +1,15 @@
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_blobs
from joblib import dump
from clearml import Task
task = Task.init(project_name="serving examples", task_name="custom train model", output_uri=True)
# generate 2d classification dataset
X, y = make_blobs(n_samples=100, centers=2, n_features=3, random_state=1)
# fit final model
model = LogisticRegression()
model.fit(X, y)
dump(model, filename="custom-model.pkl", compress=9)