Add model ensemble and model pipelines support

This commit is contained in:
allegroai 2022-03-09 04:02:03 +02:00
parent 34e5a0b2c8
commit d684169367
16 changed files with 346 additions and 142 deletions

View File

@ -251,10 +251,12 @@ Example:
### :fire: Model Serving Examples
- Scikit-Learn [example](examples/sklearn/readme.md) - random data
- Scikit-Learn Model Ensemble [example](examples/ensemble/readme.md) - random data
- XGBoost [example](examples/xgboost/readme.md) - iris dataset
- LightGBM [example](examples/lightgbm/readme.md) - iris dataset
- PyTorch [example](examples/pytorch/readme.md) - mnist dataset
- TensorFlow/Keras [example](examples/keras/readme.md) - mnist dataset
- Model Pipeline [example](examples/pipeline/readme.md) - random data
### :pray: Status
@ -279,8 +281,8 @@ Example:
- [x] LightGBM example
- [x] PyTorch example
- [x] TensorFlow/Keras example
- [ ] Model ensemble example
- [ ] Model pipeline example
- [x] Model ensemble example
- [x] Model pipeline example
- [ ] Statistics Service
- [ ] Kafka install instructions
- [ ] Prometheus install instructions

View File

@ -4,7 +4,7 @@ from argparse import ArgumentParser
from pathlib import Path
from clearml_serving.serving.model_request_processor import ModelRequestProcessor, CanaryEP
from clearml_serving.serving.preprocess_service import ModelMonitoring, ModelEndpoint
from clearml_serving.serving.endpoints import ModelMonitoring, ModelEndpoint
verbosity = False
@ -92,8 +92,8 @@ def func_model_remove(args):
elif request_processor.remove_canary_endpoint(endpoint_url=args.endpoint):
print("Removing model canary endpoint: {}".format(args.endpoint))
else:
print("Error: Could not find base endpoint URL: {}".format(args.endpoint))
return
raise ValueError("Could not find base endpoint URL: {}".format(args.endpoint))
print("Updating serving service")
request_processor.serialize()
@ -111,8 +111,7 @@ def func_canary_add(args):
load_endpoint_prefix=args.input_endpoint_prefix,
)
):
print("Error: Could not add canary endpoint URL: {}".format(args.endpoint))
return
raise ValueError("Could not add canary endpoint URL: {}".format(args.endpoint))
print("Updating serving service")
request_processor.serialize()
@ -152,7 +151,8 @@ def func_model_auto_update_add(args):
),
preprocess_code=args.preprocess
):
print("Error: Could not find base endpoint URL: {}".format(args.endpoint))
raise ValueError("Could not find base endpoint URL: {}".format(args.endpoint))
print("Updating serving service")
request_processor.serialize()
@ -192,7 +192,8 @@ def func_model_endpoint_add(args):
model_tags=args.tags or None,
model_published=args.published,
):
print("Error: Could not find base endpoint URL: {}".format(args.endpoint))
raise ValueError("Could not find base endpoint URL: {}".format(args.endpoint))
print("Updating serving service")
request_processor.serialize()

View File

@ -2,17 +2,18 @@ import os
import re
import shutil
import subprocess
import numpy as np
from argparse import ArgumentParser
from time import time
from typing import Optional
from pathlib2 import Path
import numpy as np
from clearml import Task, Logger, InputModel
from clearml.backend_api.utils import get_http_session_with_retry
from clearml_serving.serving.model_request_processor import ModelRequestProcessor, ModelEndpoint
from clearml.utilities.pyhocon import ConfigFactory, ConfigTree, HOCONConverter
from pathlib2 import Path
from clearml_serving.serving.endpoints import ModelEndpoint
from clearml_serving.serving.model_request_processor import ModelRequestProcessor
class TritonHelper(object):
@ -268,6 +269,7 @@ class TritonHelper(object):
Full spec available here:
https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md
"""
def _convert_lists(config):
if isinstance(config, list):
return [_convert_lists(i) for i in config]
@ -346,7 +348,7 @@ class TritonHelper(object):
if config_dict:
config_dict = _convert_lists(config_dict)
# Convert HOCON standard to predefined message format
config_pbtxt = "\n" + HOCONConverter.to_hocon(config_dict).\
config_pbtxt = "\n" + HOCONConverter.to_hocon(config_dict). \
replace("=", ":").replace(" : ", ": ")
# conform types (remove string quotes)
if input_type:

View File

@ -1,40 +0,0 @@
from typing import Any, Optional
import numpy as np
# Notice Preprocess class Must be named "Preprocess"
class Preprocess(object):
serving_config = None
# example: {
# 'base_serving_url': 'http://127.0.0.1:8080/serve/',
# 'triton_grpc_server': '127.0.0.1:9001',
# }"
def __init__(self):
# set internal state, this will be called only once. (i.e. not per request)
pass
def load(self, local_file_name: str) -> Optional[Any]:
"""
Optional, provide loading method for the model
useful if we need to load a model in a specific way for the prediction engine to work
:param local_file_name: file name / path to read load the model from
:return: Object that will be called with .predict() method for inference
"""
pass
def preprocess(self, body: dict) -> Any:
# do something with the request data, return any type of object.
# The returned object will be passed as is to the inference engine
return body
def postprocess(self, data: Any) -> dict:
# post process the data returned from the model inference engine
# returned dict will be passed back as the request result as is.
return data
def process(self, data: Any) -> Any:
# do something with the actual data, return any type of object.
# The returned object will be passed as is to the postprocess function engine
return data

View File

@ -0,0 +1,66 @@
from typing import Any, Optional
# Notice Preprocess class Must be named "Preprocess"
# Otherwise there are No limitations, No need to inherit or to implement all methods
class Preprocess(object):
serving_config = None
# example: {
# 'base_serving_url': 'http://127.0.0.1:8080/serve/',
# 'triton_grpc_server': '127.0.0.1:9001',
# }"
def __init__(self):
# set internal state, this will be called only once. (i.e. not per request)
pass
def load(self, local_file_name: str) -> Optional[Any]: # noqa
"""
Optional, provide loading method for the model
useful if we need to load a model in a specific way for the prediction engine to work
:param local_file_name: file name / path to read load the model from
:return: Object that will be called with .predict() method for inference
"""
pass
def preprocess(self, body: dict) -> Any: # noqa
"""
do something with the request data, return any type of object.
The returned object will be passed as is to the inference engine
"""
return body
def postprocess(self, data: Any) -> dict: # noqa
"""
post process the data returned from the model inference engine
returned dict will be passed back as the request result as is.
"""
return data
def process(self, data: Any) -> Any: # noqa
"""
do something with the actual data, return any type of object.
The returned object will be passed as is to the postprocess function engine
"""
return data
def send_request( # noqa
self,
endpoint: str,
version: Optional[str] = None,
data: Optional[dict] = None
) -> Optional[dict]:
"""
NOTICE: This method will be replaced in runtime, by the inference service
Helper method to send model inference requests to the inference service itself.
This is designed to help with model ensemble, model pipelines, etc.
On request error return None, otherwise the request result data dictionary
Usage example:
>>> x0, x1 = 1, 2
>>> result = self.send_request(endpoint="test_model_sklearn", version="1", data={"x0": x0, "x1": x1})
>>> y = result["y"]
"""
return None

View File

@ -0,0 +1,75 @@
import numpy as np
from attr import attrib, attrs, asdict
def _engine_validator(inst, attr, value): # noqa
from .preprocess_service import BasePreprocessRequest
if not BasePreprocessRequest.validate_engine_type(value):
raise TypeError("{} not supported engine type".format(value))
def _matrix_type_validator(inst, attr, value): # noqa
if value and not np.dtype(value):
raise TypeError("{} not supported matrix type".format(value))
@attrs
class ModelMonitoring(object):
base_serving_url = attrib(type=str) # serving point url prefix (example: "detect_cat")
engine_type = attrib(type=str, validator=_engine_validator) # engine type
monitor_project = attrib(type=str, default=None) # monitor model project (for model auto update)
monitor_name = attrib(type=str, default=None) # monitor model name (for model auto update, regexp selection)
monitor_tags = attrib(type=list, default=[]) # monitor model tag (for model auto update)
only_published = attrib(type=bool, default=False) # only select published models
max_versions = attrib(type=int, default=None) # Maximum number of models to keep serving (latest X models)
input_size = attrib(type=list, default=None) # optional, model matrix size
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
output_size = attrib(type=list, default=None) # optional, model matrix size
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
preprocess_artifact = attrib(
type=str, default=None) # optional artifact name storing the model preprocessing code
auxiliary_cfg = attrib(type=dict, default=None) # Auxiliary configuration (e.g. triton conf), Union[str, dict]
def as_dict(self, remove_null_entries=False):
if not remove_null_entries:
return asdict(self)
return {k: v for k, v in asdict(self).items() if v is not None}
@attrs
class ModelEndpoint(object):
engine_type = attrib(type=str, validator=_engine_validator) # engine type
serving_url = attrib(type=str) # full serving point url (including version) example: "detect_cat/v1"
model_id = attrib(type=str, default=None) # model ID to serve (and download)
version = attrib(type=str, default="") # key (version string), default no version
preprocess_artifact = attrib(
type=str, default=None) # optional artifact name storing the model preprocessing code
input_size = attrib(type=list, default=None) # optional, model matrix size
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
output_size = attrib(type=list, default=None) # optional, model matrix size
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
auxiliary_cfg = attrib(type=dict, default=None) # Optional: Auxiliary configuration (e.g. triton conf), [str, dict]
def as_dict(self, remove_null_entries=False):
if not remove_null_entries:
return asdict(self)
return {k: v for k, v in asdict(self).items() if v is not None}
@attrs
class CanaryEP(object):
endpoint = attrib(type=str) # load balancer endpoint
weights = attrib(type=list) # list of weights (order should be matching fixed_endpoints or prefix)
load_endpoints = attrib(type=list, default=[]) # list of endpoints to balance and route
load_endpoint_prefix = attrib(
type=str, default=None) # endpoint prefix to list
# (any endpoint starting with this prefix will be listed, sorted lexicographically, or broken into /<int>)
def as_dict(self, remove_null_entries=False):
if not remove_null_entries:
return asdict(self)
return {k: v for k, v in asdict(self).items() if v is not None}

View File

@ -19,7 +19,7 @@ class GzipRequest(Request):
body = await super().body()
if "gzip" in self.headers.getlist("Content-Encoding"):
body = gzip.decompress(body)
self._body = body
self._body = body # noqa
return self._body
@ -83,7 +83,7 @@ router = APIRouter(
@router.post("/{model_id}/{version}")
@router.post("/{model_id}/")
@router.post("/{model_id}")
async def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
try:
return_value = processor.process_request(
base_url=model_id,

View File

@ -8,25 +8,10 @@ import threading
from multiprocessing import Lock
from numpy.random import choice
from attr import attrib, attrs, asdict
from clearml import Task, Model
from clearml.storage.util import hash_dict
from .preprocess_service import ModelEndpoint, ModelMonitoring, BasePreprocessRequest
@attrs
class CanaryEP(object):
endpoint = attrib(type=str) # load balancer endpoint
weights = attrib(type=list) # list of weights (order should be matching fixed_endpoints or prefix)
load_endpoints = attrib(type=list, default=[]) # list of endpoints to balance and route
load_endpoint_prefix = attrib(
type=str, default=None) # endpoint prefix to list
# (any endpoint starting with this prefix will be listed, sorted lexicographically, or broken into /<int>)
def as_dict(self, remove_null_entries=False):
if not remove_null_entries:
return asdict(self)
return {k: v for k, v in asdict(self).items() if v is not None}
from .preprocess_service import BasePreprocessRequest
from .endpoints import ModelEndpoint, ModelMonitoring, CanaryEP
class FastWriteCounter(object):
@ -98,6 +83,7 @@ class ModelRequestProcessor(object):
sleep(1)
# retry to process
return self.process_request(base_url=base_url, version=version, request_body=request_body)
try:
# normalize url and version
url = self._normalize_endpoint_url(base_url, version)
@ -120,9 +106,8 @@ class ModelRequestProcessor(object):
self._engine_processor_lookup[url] = processor
return_value = self._process_request(processor=processor, url=url, body=request_body)
except Exception:
finally:
self._request_processing_state.dec()
raise
return return_value
@ -194,7 +179,7 @@ class ModelRequestProcessor(object):
if url in self._endpoints:
print("Warning: Model endpoint \'{}\' overwritten".format(url))
if not endpoint.model_id:
if not endpoint.model_id and any([model_project, model_name, model_tags]):
model_query = dict(
project_name=model_project,
model_name=model_name,
@ -208,6 +193,8 @@ class ModelRequestProcessor(object):
if len(models) > 1:
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
endpoint.model_id = models[0].id
elif not endpoint.model_id:
print("Warning: No Model provided for \'{}\'".format(url))
# upload as new artifact
if preprocess_code:
@ -237,6 +224,11 @@ class ModelRequestProcessor(object):
if not isinstance(monitoring, ModelMonitoring):
monitoring = ModelMonitoring(**monitoring)
# make sure we actually have something to monitor
if not any([monitoring.monitor_project, monitoring.monitor_name, monitoring.monitor_tags]):
raise ValueError("Model monitoring requires at least a "
"project / name / tag to monitor, none were provided.")
# make sure we have everything configured
self._validate_model(monitoring)
@ -384,6 +376,10 @@ class ModelRequestProcessor(object):
# release stall lock
self._update_lock_flag = False
# update the state on the inference task
if Task.current_task() and Task.current_task().id != self._task.id:
self.serialize(task=Task.current_task())
return True
def serialize(self, task: Optional[Task] = None) -> None:
@ -878,7 +874,7 @@ class ModelRequestProcessor(object):
return task
@classmethod
def _normalize_endpoint_url(cls, endpoint: str, version : Optional[str] = None) -> str:
def _normalize_endpoint_url(cls, endpoint: str, version: Optional[str] = None) -> str:
return "{}/{}".format(endpoint.rstrip("/"), version or "").rstrip("/")
@classmethod

View File

@ -1,73 +1,20 @@
import numpy as np
import os
from typing import Optional, Any, Callable, List
from attr import attrib, attrs, asdict
import numpy as np
from clearml import Task, Model
from clearml.binding.artifacts import Artifacts
from clearml.storage.util import sha256sum
from requests import post as request_post
def _engine_validator(inst, attr, value): # noqa
if not BasePreprocessRequest.validate_engine_type(value):
raise TypeError("{} not supported engine type".format(value))
def _matrix_type_validator(inst, attr, value): # noqa
if value and not np.dtype(value):
raise TypeError("{} not supported matrix type".format(value))
@attrs
class ModelMonitoring(object):
base_serving_url = attrib(type=str) # serving point url prefix (example: "detect_cat")
monitor_project = attrib(type=str) # monitor model project (for model auto update)
monitor_name = attrib(type=str) # monitor model name (for model auto update, regexp selection)
monitor_tags = attrib(type=list) # monitor model tag (for model auto update)
engine_type = attrib(type=str, validator=_engine_validator) # engine type
only_published = attrib(type=bool, default=False) # only select published models
max_versions = attrib(type=int, default=None) # Maximum number of models to keep serving (latest X models)
input_size = attrib(type=list, default=None) # optional, model matrix size
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
output_size = attrib(type=list, default=None) # optional, model matrix size
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
preprocess_artifact = attrib(
type=str, default=None) # optional artifact name storing the model preprocessing code
auxiliary_cfg = attrib(type=dict, default=None) # Auxiliary configuration (e.g. triton conf), Union[str, dict]
def as_dict(self, remove_null_entries=False):
if not remove_null_entries:
return asdict(self)
return {k: v for k, v in asdict(self).items() if v is not None}
@attrs
class ModelEndpoint(object):
engine_type = attrib(type=str, validator=_engine_validator) # engine type
serving_url = attrib(type=str) # full serving point url (including version) example: "detect_cat/v1"
model_id = attrib(type=str) # list of model IDs to serve (order implies versions first is v1)
version = attrib(type=str, default="") # key (version string), default no version
preprocess_artifact = attrib(
type=str, default=None) # optional artifact name storing the model preprocessing code
input_size = attrib(type=list, default=None) # optional, model matrix size
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
output_size = attrib(type=list, default=None) # optional, model matrix size
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
auxiliary_cfg = attrib(type=dict, default=None) # Optional: Auxiliary configuration (e.g. triton conf), [str, dict]
def as_dict(self, remove_null_entries=False):
if not remove_null_entries:
return asdict(self)
return {k: v for k, v in asdict(self).items() if v is not None}
from .endpoints import ModelEndpoint
class BasePreprocessRequest(object):
__preprocessing_lookup = {}
__preprocessing_modules = set()
_default_serving_base_url = "http://127.0.0.1:8080/serve/"
_timeout = None # timeout in seconds for the entire request, set in __init__
def __init__(
self,
@ -83,6 +30,8 @@ class BasePreprocessRequest(object):
self._preprocess = None
self._model = None
self._server_config = server_config or {}
if self._timeout is None:
self._timeout = int(float(os.environ.get('GUNICORN_SERVING_TIMEOUT', 600)) * 0.8)
# load preprocessing code here
if self.model_endpoint.preprocess_artifact:
if not task or self.model_endpoint.preprocess_artifact not in task.artifacts:
@ -111,7 +60,10 @@ class BasePreprocessRequest(object):
spec = importlib.util.spec_from_file_location("Preprocess", path)
_preprocess = importlib.util.module_from_spec(spec)
spec.loader.exec_module(_preprocess)
self._preprocess = _preprocess.Preprocess() # noqa
Preprocess = _preprocess.Preprocess # noqa
# override `send_request` method
Preprocess.send_request = BasePreprocessRequest._preprocess_send_request
self._preprocess = Preprocess()
self._preprocess.serving_config = server_config or {}
if callable(getattr(self._preprocess, 'load', None)):
self._model = self._preprocess.load(self._get_local_model_file())
@ -125,7 +77,7 @@ class BasePreprocessRequest(object):
Raise exception to report an error
Return value will be passed to serving engine
"""
if self._preprocess is not None:
if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'):
return self._preprocess.preprocess(request)
return request
@ -135,7 +87,7 @@ class BasePreprocessRequest(object):
Raise exception to report an error
Return value will be passed to serving engine
"""
if self._preprocess is not None:
if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'):
return self._preprocess.postprocess(data)
return data
@ -162,6 +114,7 @@ class BasePreprocessRequest(object):
"""
A decorator to register an annotation type name for classes deriving from Annotation
"""
def wrapper(cls):
cls.__preprocessing_lookup[engine_name] = cls
return cls
@ -181,6 +134,17 @@ class BasePreprocessRequest(object):
except (ImportError, TypeError):
pass
@staticmethod
def _preprocess_send_request(self, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]:
endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/")
base_url = self.serving_config.get("base_serving_url") if self.serving_config else None
base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/")
url = "{}/{}".format(base_url, endpoint.strip("/"))
return_value = request_post(url, json=data, timeout=BasePreprocessRequest._timeout)
if not return_value.ok:
return None
return return_value.json()
@BasePreprocessRequest.register_engine("triton", modules=["grpc", "tritonclient"])
class TritonPreprocessRequest(BasePreprocessRequest):
@ -224,7 +188,7 @@ class TritonPreprocessRequest(BasePreprocessRequest):
Detect gRPC server and send the request to it
"""
# allow to override bt preprocessing class
if self._preprocess is not None and getattr(self._preprocess, "process", None):
if self._preprocess is not None and hasattr(self._preprocess, "process"):
return self._preprocess.process(data)
# Create gRPC stub for communicating with the server
@ -268,7 +232,11 @@ class TritonPreprocessRequest(BasePreprocessRequest):
output0.name = self.model_endpoint.output_name
request.outputs.extend([output0])
response = grpc_stub.ModelInfer(request, compression=self._ext_grpc.Compression.Gzip)
response = grpc_stub.ModelInfer(
request,
compression=self._ext_grpc.Compression.Gzip,
timeout=self._timeout
)
output_results = []
index = 0
@ -351,6 +319,6 @@ class CustomPreprocessRequest(BasePreprocessRequest):
The actual processing function.
We run the process in this context
"""
if self._preprocess is not None:
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return self._preprocess.process(data)
return None

View File

@ -14,3 +14,4 @@ grpcio
Pillow
xgboost
lightgbm
requests

View File

@ -0,0 +1,19 @@
from typing import Any
import numpy as np
# Notice Preprocess class Must be named "Preprocess"
class Preprocess(object):
def __init__(self):
# set internal state, this will be called only once. (i.e. not per request)
pass
def preprocess(self, body: dict) -> Any:
# we expect to get two valid on the dict x0, and x1
return [[body.get("x0", None), body.get("x1", None)], ]
def postprocess(self, data: Any) -> dict:
# post process the data returned from the model inference engine
# data is the return value from model.predict we will put is inside a return value as Y
return dict(y=data.tolist() if isinstance(data, np.ndarray) else data)

View File

@ -0,0 +1,31 @@
# Train and Deploy Scikit-Learn model ensemble
## training mock voting regression model
Run the mock python training code
```bash
pip install -r examples/ensemble/requirements.txt
python examples/ensemble/train_ensemble.py
```
The output will be a model created on the project "serving examples", by the name "train model ensemble"
## setting up the serving service
1. Create serving Service: `clearml-serving create --name "serving example"` (write down the service ID)
2. Create model endpoint:
`clearml-serving --id <service_id> model add --engine sklearn --endpoint "test_model_ensemble" --preprocess "examples/ensemble/preprocess.py" --name "train model ensemble" --project "serving examples"`
Or auto update
`clearml-serving --id <service_id> model auto-update --engine sklearn --endpoint "test_model_ensemble_auto" --preprocess "examples/ensemble/preprocess.py" --name "train model ensemble" --project "serving examples" --max-versions 2`
Or add Canary endpoint
`clearml-serving --id <service_id> model canary --endpoint "test_model_ensemble_auto" --weights 0.1 0.9 --input-endpoint-prefix test_model_ensemble_auto`
3. Run the clearml-serving container `docker run -v ~/clearml.conf:/root/clearml.conf -p 8080:8080 -e CLEARML_SERVING_TASK_ID=<service_id> clearml-serving:latest`
4. Test new endpoint: `curl -X POST "http://127.0.0.1:8080/serve/test_model_ensemble" -H "accept: application/json" -H "Content-Type: application/json" -d '{"x0": 1, "x1": 2}'`
> **_Notice:_** You can also change the serving service while it is already running!
This includes adding/removing endpoints, adding canary model routing etc.

View File

@ -0,0 +1,2 @@
clearml >= 1.1.6
scikit-learn

View File

@ -0,0 +1,23 @@
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import VotingRegressor
from sklearn.datasets import make_blobs
from joblib import dump
from clearml import Task
task = Task.init(project_name="serving examples", task_name="train model ensemble", output_uri=True)
# generate 2d classification dataset
X, y = make_blobs(n_samples=100, centers=2, n_features=2, random_state=1)
knn = KNeighborsRegressor(n_neighbors=5)
knn.fit(X, y)
rf = RandomForestRegressor(n_estimators=50)
rf.fit(X, y)
estimators = [("knn", knn), ("rf", rf), ]
ensemble = VotingRegressor(estimators)
ensemble.fit(X, y)
dump(ensemble, filename="ensemble-vr.pkl", compress=9)

View File

@ -0,0 +1,32 @@
from typing import Any, List
# Notice Preprocess class Must be named "Preprocess"
class Preprocess(object):
def __init__(self):
# set internal state, this will be called only once. (i.e. not per request)
pass
def postprocess(self, data: List[dict]) -> dict:
# we will here average the results and return the new value
# assume data is a list of dicts greater than 1
# average result
return dict(y=0.5 * data[0]['y'][0] + 0.5 * data[1]['y'][0])
def process(self, data: Any) -> Any:
"""
do something with the actual data, return any type of object.
The returned object will be passed as is to the postprocess function engine
"""
predict_a = self.send_request(endpoint="/test_model_sklearn_a/", version=None, data=data)
predict_b = self.send_request(endpoint="/test_model_sklearn_b/", version=None, data=data)
if not predict_b or not predict_a:
raise ValueError("Error requesting inference endpoint test_model_sklearn a/b")
return [predict_a, predict_b]
def send_request(self, endpoint, version, data) -> List[dict]:
# Mock Function!
# replaced by real send request function when constructed by the inference service
pass

View File

@ -0,0 +1,26 @@
# Deploy a model inference pipeline
## prerequisites
Training a scikit-learn model (see example/sklearn)
## setting up the serving service
1. Create serving Service (if not already running):
`clearml-serving create --name "serving example"` (write down the service ID)
2. Create model base two endpoints:
`clearml-serving --id <service_id> model add --engine sklearn --endpoint "test_model_sklearn_a" --preprocess "examples/sklearn/preprocess.py" --name "train sklearn model" --project "serving examples"`
`clearml-serving --id <service_id> model add --engine sklearn --endpoint "test_model_sklearn_b" --preprocess "examples/sklearn/preprocess.py" --name "train sklearn model" --project "serving examples"`
3. Create pipeline model endpoint:
`clearml-serving --id <service_id> model add --engine custom --endpoint "test_model_pipeline" --preprocess "examples/pipeline/preprocess.py"`
4. Run the clearml-serving container `docker run -v ~/clearml.conf:/root/clearml.conf -p 8080:8080 -e CLEARML_SERVING_TASK_ID=<service_id> clearml-serving:latest`
5. Test new endpoint: `curl -X POST "http://127.0.0.1:8080/serve/test_model_pipeline" -H "accept: application/json" -H "Content-Type: application/json" -d '{"x0": 1, "x1": 2}'`
> **_Notice:_** You can also change the serving service while it is already running!
This includes adding/removing endpoints, adding canary model routing etc.