mirror of
https://github.com/clearml/clearml-serving
synced 2025-03-13 07:18:47 +00:00
Add custom model serving example
This commit is contained in:
parent
4f31c4c178
commit
6005e238ca
@ -291,6 +291,7 @@ Grafana model performance example:
|
|||||||
- PyTorch [example](examples/pytorch/readme.md) - mnist dataset
|
- PyTorch [example](examples/pytorch/readme.md) - mnist dataset
|
||||||
- TensorFlow/Keras [example](examples/keras/readme.md) - mnist dataset
|
- TensorFlow/Keras [example](examples/keras/readme.md) - mnist dataset
|
||||||
- Model Pipeline [example](examples/pipeline/readme.md) - random data
|
- Model Pipeline [example](examples/pipeline/readme.md) - random data
|
||||||
|
- Custom Model [example](examples/custom/readme.md) - custom data
|
||||||
|
|
||||||
### :pray: Status
|
### :pray: Status
|
||||||
|
|
||||||
|
@ -8,6 +8,15 @@ class Preprocess(object):
|
|||||||
Preprocess class Must be named "Preprocess"
|
Preprocess class Must be named "Preprocess"
|
||||||
Otherwise there are No limitations, No need to inherit or to implement all methods
|
Otherwise there are No limitations, No need to inherit or to implement all methods
|
||||||
Notice! This is not thread safe! the same instance may be accessed from multiple threads simultaneously
|
Notice! This is not thread safe! the same instance may be accessed from multiple threads simultaneously
|
||||||
|
to store date in a safe way push it into the `state` dict argument of preprocessing/postprocessing functions
|
||||||
|
|
||||||
|
Notice the execution flows is synchronous as follows:
|
||||||
|
|
||||||
|
1. RestAPI(...) -> body: dict
|
||||||
|
2. preprocess(body: dict, ...) -> data: Any
|
||||||
|
3. process(data: Any, ...) -> data: Any
|
||||||
|
4. postprocess(data: Any, ...) -> result: dict
|
||||||
|
5. RestAPI(result: dict) -> returned request
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
132
examples/custom/preprocess.py
Normal file
132
examples/custom/preprocess.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# Notice Preprocess class Must be named "Preprocess"
|
||||||
|
class Preprocess(object):
|
||||||
|
"""
|
||||||
|
Notice the execution flows is synchronous as follows:
|
||||||
|
|
||||||
|
1. RestAPI(...) -> body: dict
|
||||||
|
2. preprocess(body: dict, ...) -> data: Any
|
||||||
|
3. process(data: Any, ...) -> data: Any
|
||||||
|
4. postprocess(data: Any, ...) -> result: dict
|
||||||
|
5. RestAPI(result: dict) -> returned request
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Set any initial property on the Task (usually model object)
|
||||||
|
Notice these properties will be accessed from multiple threads.
|
||||||
|
If you need a stateful (per request) data, use the `state` dict argument passed to pre/post/process functions
|
||||||
|
"""
|
||||||
|
# set internal state, this will be called only once. (i.e. not per request)
|
||||||
|
self._model = None
|
||||||
|
|
||||||
|
def load(self, local_file_name: str) -> Optional[Any]: # noqa
|
||||||
|
"""
|
||||||
|
Optional: provide loading method for the model
|
||||||
|
useful if we need to load a model in a specific way for the prediction engine to work
|
||||||
|
:param local_file_name: file name / path to read load the model from
|
||||||
|
:return: Object that will be called with .predict() method for inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Example now lets load the actual model
|
||||||
|
|
||||||
|
self._model = joblib.load(local_file_name)
|
||||||
|
|
||||||
|
def preprocess(self, body: dict, state: dict, collect_custom_statistics_fn=None) -> Any:
|
||||||
|
"""
|
||||||
|
Optional: do something with the request data, return any type of object.
|
||||||
|
The returned object will be passed as is to the inference engine
|
||||||
|
|
||||||
|
:param body: dictionary as recieved from the RestAPI
|
||||||
|
:param state: Use state dict to store data passed to the post-processing function call.
|
||||||
|
This is a per-request state dict (meaning a new empty dict will be passed per request)
|
||||||
|
Usage example:
|
||||||
|
>>> def preprocess(..., state):
|
||||||
|
state['preprocess_aux_data'] = [1,2,3]
|
||||||
|
>>> def postprocess(..., state):
|
||||||
|
print(state['preprocess_aux_data'])
|
||||||
|
:param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values
|
||||||
|
to the statictics collector servicd.
|
||||||
|
None is passed if statiscs collector is not configured, or if the current request should not be
|
||||||
|
collected
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
>>> print(body)
|
||||||
|
{"x0": 1, "x1": 2}
|
||||||
|
>>> if collect_custom_statistics_fn:
|
||||||
|
>>> collect_custom_statistics_fn({"x0": 1, "x1": 2})
|
||||||
|
|
||||||
|
:return: Object to be passed directly to the model inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
# we expect to get a feature vector on the `feature` entry if the dict
|
||||||
|
return np.array(body.get("features", []), dtype=np.float)
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
data: Any,
|
||||||
|
state: dict,
|
||||||
|
collect_custom_statistics_fn: Optional[Callable[[dict], None]],
|
||||||
|
) -> Any: # noqa
|
||||||
|
"""
|
||||||
|
Optional: do something with the actual data, return any type of object.
|
||||||
|
The returned object will be passed as is to the postprocess function engine
|
||||||
|
|
||||||
|
:param data: object as recieved from the preprocessing function
|
||||||
|
:param state: Use state dict to store data passed to the post-processing function call.
|
||||||
|
This is a per-request state dict (meaning a dict instance per request)
|
||||||
|
Usage example:
|
||||||
|
>>> def preprocess(..., state):
|
||||||
|
state['preprocess_aux_data'] = [1,2,3]
|
||||||
|
>>> def postprocess(..., state):
|
||||||
|
print(state['preprocess_aux_data'])
|
||||||
|
:param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values
|
||||||
|
to the statictics collector servicd.
|
||||||
|
None is passed if statiscs collector is not configured, or if the current request should not be collected
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
>>> if collect_custom_statistics_fn:
|
||||||
|
>>> collect_custom_statistics_fn({"type": "classification"})
|
||||||
|
|
||||||
|
:return: Object to be passed tp the post-processing function
|
||||||
|
"""
|
||||||
|
|
||||||
|
# this is where we do the heavy lifting, i.e. run our model.
|
||||||
|
# notice we know data is a numpy array of type float, because this is what we prepared in preprocessing function
|
||||||
|
data = self._model.predict(np.atleast_2d(data))
|
||||||
|
# data is also a numpy array, as returned from our fit function
|
||||||
|
return data
|
||||||
|
|
||||||
|
def postprocess(self, data: Any, state: dict, collect_custom_statistics_fn=None) -> dict:
|
||||||
|
"""
|
||||||
|
Optional: post process the data returned from the model inference engine
|
||||||
|
returned dict will be passed back as the request result as is.
|
||||||
|
|
||||||
|
:param data: object as recieved from the inference model function
|
||||||
|
:param state: Use state dict to store data passed to the post-processing function call.
|
||||||
|
This is a per-request state dict (meaning a dict instance per request)
|
||||||
|
Usage example:
|
||||||
|
>>> def preprocess(..., state):
|
||||||
|
state['preprocess_aux_data'] = [1,2,3]
|
||||||
|
>>> def postprocess(..., state):
|
||||||
|
print(state['preprocess_aux_data'])
|
||||||
|
:param collect_custom_statistics_fn: Optional, if provided allows to send a custom set of key/values
|
||||||
|
to the statictics collector servicd.
|
||||||
|
None is passed if statiscs collector is not configured, or if the current request should not be
|
||||||
|
collected
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
>>> if collect_custom_statistics_fn:
|
||||||
|
>>> collect_custom_statistics_fn({"y": 1})
|
||||||
|
|
||||||
|
:return: Dictionary passed directly as the returned result of the RestAPI
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Now we take the result numpy (predicted) and create a list of values to
|
||||||
|
# send back as the restapi return value
|
||||||
|
# data is the return value from model.predict we will put is inside a return value as Y
|
||||||
|
return dict(predict=data.tolist())
|
37
examples/custom/readme.md
Normal file
37
examples/custom/readme.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# Train and Deploy custom model
|
||||||
|
|
||||||
|
## training mock custom model
|
||||||
|
|
||||||
|
Run the mock python training code
|
||||||
|
```bash
|
||||||
|
pip install -r examples/custom/requirements.txt
|
||||||
|
python examples/custom/train_model.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The output will be a model created on the project "serving examples", by the name "custom train model"
|
||||||
|
|
||||||
|
## setting up the serving service
|
||||||
|
|
||||||
|
1. Create serving Service: `clearml-serving create --name "serving example"` (write down the service ID)
|
||||||
|
2. Make sure to add any required additional packages (for your custom model) to the [docker-compose.yml](https://github.com/allegroai/clearml-serving/blob/826f503cf4a9b069b89eb053696d218d1ce26f47/docker/docker-compose.yml#L97) (or as environment variable to the `clearml-serving-inference` container), by defining for example: `CLEARML_EXTRA_PYTHON_PACKAGES="scikit-learn numpy"`
|
||||||
|
3. Create model endpoint:
|
||||||
|
`clearml-serving --id <service_id> model add --engine custom --endpoint "test_model_custom" --preprocess "examples/custom/preprocess.py" --name "custom train model" --project "serving examples"`
|
||||||
|
|
||||||
|
Or auto update
|
||||||
|
|
||||||
|
`clearml-serving --id <service_id> model auto-update --engine custom --endpoint "test_model_custom_auto" --preprocess "examples/custom/preprocess.py" --name "custom train model" --project "serving examples" --max-versions 2`
|
||||||
|
|
||||||
|
Or add Canary endpoint
|
||||||
|
|
||||||
|
`clearml-serving --id <service_id> model canary --endpoint "test_model_custom_auto" --weights 0.1 0.9 --input-endpoint-prefix test_model_custom_auto`
|
||||||
|
|
||||||
|
4. If you already have the `clearml-serving` docker-compose running, it might take it a minute or two to sync with the new endpoint.
|
||||||
|
|
||||||
|
Or you can run the clearml-serving container independently `docker run -v ~/clearml.conf:/root/clearml.conf -p 8080:8080 -e CLEARML_SERVING_TASK_ID=<service_id> clearml-serving:latest`
|
||||||
|
|
||||||
|
5. Test new endpoint (do notice the first call will trigger the model pulling, so it might take longer, from here on, it's all in memory): `curl -X POST "http://127.0.0.1:8080/serve/test_model_custom" -H "accept: application/json" -H "Content-Type: application/json" -d '{"features": [1, 2, 3]}'`
|
||||||
|
|
||||||
|
|
||||||
|
> **_Notice:_** You can also change the serving service while it is already running!
|
||||||
|
This includes adding/removing endpoints, adding canary model routing etc.
|
||||||
|
by default new endpoints/models will be automatically updated after 1 minute
|
2
examples/custom/requirements.txt
Normal file
2
examples/custom/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
clearml >= 1.1.6
|
||||||
|
scikit-learn
|
15
examples/custom/train_model.py
Normal file
15
examples/custom/train_model.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.datasets import make_blobs
|
||||||
|
from joblib import dump
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
task = Task.init(project_name="serving examples", task_name="custom train model", output_uri=True)
|
||||||
|
|
||||||
|
# generate 2d classification dataset
|
||||||
|
X, y = make_blobs(n_samples=100, centers=2, n_features=3, random_state=1)
|
||||||
|
# fit final model
|
||||||
|
model = LogisticRegression()
|
||||||
|
model.fit(X, y)
|
||||||
|
|
||||||
|
dump(model, filename="custom-model.pkl", compress=9)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user