mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
Add model ensemble and model pipelines support
This commit is contained in:
parent
34e5a0b2c8
commit
d684169367
@ -251,10 +251,12 @@ Example:
|
|||||||
### :fire: Model Serving Examples
|
### :fire: Model Serving Examples
|
||||||
|
|
||||||
- Scikit-Learn [example](examples/sklearn/readme.md) - random data
|
- Scikit-Learn [example](examples/sklearn/readme.md) - random data
|
||||||
|
- Scikit-Learn Model Ensemble [example](examples/ensemble/readme.md) - random data
|
||||||
- XGBoost [example](examples/xgboost/readme.md) - iris dataset
|
- XGBoost [example](examples/xgboost/readme.md) - iris dataset
|
||||||
- LightGBM [example](examples/lightgbm/readme.md) - iris dataset
|
- LightGBM [example](examples/lightgbm/readme.md) - iris dataset
|
||||||
- PyTorch [example](examples/pytorch/readme.md) - mnist dataset
|
- PyTorch [example](examples/pytorch/readme.md) - mnist dataset
|
||||||
- TensorFlow/Keras [example](examples/keras/readme.md) - mnist dataset
|
- TensorFlow/Keras [example](examples/keras/readme.md) - mnist dataset
|
||||||
|
- Model Pipeline [example](examples/pipeline/readme.md) - random data
|
||||||
|
|
||||||
### :pray: Status
|
### :pray: Status
|
||||||
|
|
||||||
@ -279,8 +281,8 @@ Example:
|
|||||||
- [x] LightGBM example
|
- [x] LightGBM example
|
||||||
- [x] PyTorch example
|
- [x] PyTorch example
|
||||||
- [x] TensorFlow/Keras example
|
- [x] TensorFlow/Keras example
|
||||||
- [ ] Model ensemble example
|
- [x] Model ensemble example
|
||||||
- [ ] Model pipeline example
|
- [x] Model pipeline example
|
||||||
- [ ] Statistics Service
|
- [ ] Statistics Service
|
||||||
- [ ] Kafka install instructions
|
- [ ] Kafka install instructions
|
||||||
- [ ] Prometheus install instructions
|
- [ ] Prometheus install instructions
|
||||||
|
@ -4,7 +4,7 @@ from argparse import ArgumentParser
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from clearml_serving.serving.model_request_processor import ModelRequestProcessor, CanaryEP
|
from clearml_serving.serving.model_request_processor import ModelRequestProcessor, CanaryEP
|
||||||
from clearml_serving.serving.preprocess_service import ModelMonitoring, ModelEndpoint
|
from clearml_serving.serving.endpoints import ModelMonitoring, ModelEndpoint
|
||||||
|
|
||||||
verbosity = False
|
verbosity = False
|
||||||
|
|
||||||
@ -92,8 +92,8 @@ def func_model_remove(args):
|
|||||||
elif request_processor.remove_canary_endpoint(endpoint_url=args.endpoint):
|
elif request_processor.remove_canary_endpoint(endpoint_url=args.endpoint):
|
||||||
print("Removing model canary endpoint: {}".format(args.endpoint))
|
print("Removing model canary endpoint: {}".format(args.endpoint))
|
||||||
else:
|
else:
|
||||||
print("Error: Could not find base endpoint URL: {}".format(args.endpoint))
|
raise ValueError("Could not find base endpoint URL: {}".format(args.endpoint))
|
||||||
return
|
|
||||||
print("Updating serving service")
|
print("Updating serving service")
|
||||||
request_processor.serialize()
|
request_processor.serialize()
|
||||||
|
|
||||||
@ -111,8 +111,7 @@ def func_canary_add(args):
|
|||||||
load_endpoint_prefix=args.input_endpoint_prefix,
|
load_endpoint_prefix=args.input_endpoint_prefix,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
print("Error: Could not add canary endpoint URL: {}".format(args.endpoint))
|
raise ValueError("Could not add canary endpoint URL: {}".format(args.endpoint))
|
||||||
return
|
|
||||||
|
|
||||||
print("Updating serving service")
|
print("Updating serving service")
|
||||||
request_processor.serialize()
|
request_processor.serialize()
|
||||||
@ -152,7 +151,8 @@ def func_model_auto_update_add(args):
|
|||||||
),
|
),
|
||||||
preprocess_code=args.preprocess
|
preprocess_code=args.preprocess
|
||||||
):
|
):
|
||||||
print("Error: Could not find base endpoint URL: {}".format(args.endpoint))
|
raise ValueError("Could not find base endpoint URL: {}".format(args.endpoint))
|
||||||
|
|
||||||
print("Updating serving service")
|
print("Updating serving service")
|
||||||
request_processor.serialize()
|
request_processor.serialize()
|
||||||
|
|
||||||
@ -192,7 +192,8 @@ def func_model_endpoint_add(args):
|
|||||||
model_tags=args.tags or None,
|
model_tags=args.tags or None,
|
||||||
model_published=args.published,
|
model_published=args.published,
|
||||||
):
|
):
|
||||||
print("Error: Could not find base endpoint URL: {}".format(args.endpoint))
|
raise ValueError("Could not find base endpoint URL: {}".format(args.endpoint))
|
||||||
|
|
||||||
print("Updating serving service")
|
print("Updating serving service")
|
||||||
request_processor.serialize()
|
request_processor.serialize()
|
||||||
|
|
||||||
|
@ -2,17 +2,18 @@ import os
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import numpy as np
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pathlib2 import Path
|
import numpy as np
|
||||||
|
|
||||||
from clearml import Task, Logger, InputModel
|
from clearml import Task, Logger, InputModel
|
||||||
from clearml.backend_api.utils import get_http_session_with_retry
|
from clearml.backend_api.utils import get_http_session_with_retry
|
||||||
from clearml_serving.serving.model_request_processor import ModelRequestProcessor, ModelEndpoint
|
|
||||||
from clearml.utilities.pyhocon import ConfigFactory, ConfigTree, HOCONConverter
|
from clearml.utilities.pyhocon import ConfigFactory, ConfigTree, HOCONConverter
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from clearml_serving.serving.endpoints import ModelEndpoint
|
||||||
|
from clearml_serving.serving.model_request_processor import ModelRequestProcessor
|
||||||
|
|
||||||
|
|
||||||
class TritonHelper(object):
|
class TritonHelper(object):
|
||||||
@ -268,6 +269,7 @@ class TritonHelper(object):
|
|||||||
Full spec available here:
|
Full spec available here:
|
||||||
https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md
|
https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _convert_lists(config):
|
def _convert_lists(config):
|
||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
return [_convert_lists(i) for i in config]
|
return [_convert_lists(i) for i in config]
|
||||||
@ -346,7 +348,7 @@ class TritonHelper(object):
|
|||||||
if config_dict:
|
if config_dict:
|
||||||
config_dict = _convert_lists(config_dict)
|
config_dict = _convert_lists(config_dict)
|
||||||
# Convert HOCON standard to predefined message format
|
# Convert HOCON standard to predefined message format
|
||||||
config_pbtxt = "\n" + HOCONConverter.to_hocon(config_dict).\
|
config_pbtxt = "\n" + HOCONConverter.to_hocon(config_dict). \
|
||||||
replace("=", ":").replace(" : ", ": ")
|
replace("=", ":").replace(" : ", ": ")
|
||||||
# conform types (remove string quotes)
|
# conform types (remove string quotes)
|
||||||
if input_type:
|
if input_type:
|
||||||
|
@ -1,40 +0,0 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
# Notice Preprocess class Must be named "Preprocess"
|
|
||||||
class Preprocess(object):
|
|
||||||
serving_config = None
|
|
||||||
# example: {
|
|
||||||
# 'base_serving_url': 'http://127.0.0.1:8080/serve/',
|
|
||||||
# 'triton_grpc_server': '127.0.0.1:9001',
|
|
||||||
# }"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# set internal state, this will be called only once. (i.e. not per request)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load(self, local_file_name: str) -> Optional[Any]:
|
|
||||||
"""
|
|
||||||
Optional, provide loading method for the model
|
|
||||||
useful if we need to load a model in a specific way for the prediction engine to work
|
|
||||||
:param local_file_name: file name / path to read load the model from
|
|
||||||
:return: Object that will be called with .predict() method for inference
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def preprocess(self, body: dict) -> Any:
|
|
||||||
# do something with the request data, return any type of object.
|
|
||||||
# The returned object will be passed as is to the inference engine
|
|
||||||
return body
|
|
||||||
|
|
||||||
def postprocess(self, data: Any) -> dict:
|
|
||||||
# post process the data returned from the model inference engine
|
|
||||||
# returned dict will be passed back as the request result as is.
|
|
||||||
return data
|
|
||||||
|
|
||||||
def process(self, data: Any) -> Any:
|
|
||||||
# do something with the actual data, return any type of object.
|
|
||||||
# The returned object will be passed as is to the postprocess function engine
|
|
||||||
return data
|
|
66
clearml_serving/preprocess/preprocess_template.py
Normal file
66
clearml_serving/preprocess/preprocess_template.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
# Notice Preprocess class Must be named "Preprocess"
|
||||||
|
# Otherwise there are No limitations, No need to inherit or to implement all methods
|
||||||
|
class Preprocess(object):
|
||||||
|
serving_config = None
|
||||||
|
# example: {
|
||||||
|
# 'base_serving_url': 'http://127.0.0.1:8080/serve/',
|
||||||
|
# 'triton_grpc_server': '127.0.0.1:9001',
|
||||||
|
# }"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# set internal state, this will be called only once. (i.e. not per request)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load(self, local_file_name: str) -> Optional[Any]: # noqa
|
||||||
|
"""
|
||||||
|
Optional, provide loading method for the model
|
||||||
|
useful if we need to load a model in a specific way for the prediction engine to work
|
||||||
|
:param local_file_name: file name / path to read load the model from
|
||||||
|
:return: Object that will be called with .predict() method for inference
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def preprocess(self, body: dict) -> Any: # noqa
|
||||||
|
"""
|
||||||
|
do something with the request data, return any type of object.
|
||||||
|
The returned object will be passed as is to the inference engine
|
||||||
|
"""
|
||||||
|
return body
|
||||||
|
|
||||||
|
def postprocess(self, data: Any) -> dict: # noqa
|
||||||
|
"""
|
||||||
|
post process the data returned from the model inference engine
|
||||||
|
returned dict will be passed back as the request result as is.
|
||||||
|
"""
|
||||||
|
return data
|
||||||
|
|
||||||
|
def process(self, data: Any) -> Any: # noqa
|
||||||
|
"""
|
||||||
|
do something with the actual data, return any type of object.
|
||||||
|
The returned object will be passed as is to the postprocess function engine
|
||||||
|
"""
|
||||||
|
return data
|
||||||
|
|
||||||
|
def send_request( # noqa
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
version: Optional[str] = None,
|
||||||
|
data: Optional[dict] = None
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
NOTICE: This method will be replaced in runtime, by the inference service
|
||||||
|
|
||||||
|
Helper method to send model inference requests to the inference service itself.
|
||||||
|
This is designed to help with model ensemble, model pipelines, etc.
|
||||||
|
On request error return None, otherwise the request result data dictionary
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
|
||||||
|
>>> x0, x1 = 1, 2
|
||||||
|
>>> result = self.send_request(endpoint="test_model_sklearn", version="1", data={"x0": x0, "x1": x1})
|
||||||
|
>>> y = result["y"]
|
||||||
|
"""
|
||||||
|
return None
|
75
clearml_serving/serving/endpoints.py
Normal file
75
clearml_serving/serving/endpoints.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import numpy as np
|
||||||
|
from attr import attrib, attrs, asdict
|
||||||
|
|
||||||
|
|
||||||
|
def _engine_validator(inst, attr, value): # noqa
|
||||||
|
from .preprocess_service import BasePreprocessRequest
|
||||||
|
if not BasePreprocessRequest.validate_engine_type(value):
|
||||||
|
raise TypeError("{} not supported engine type".format(value))
|
||||||
|
|
||||||
|
|
||||||
|
def _matrix_type_validator(inst, attr, value): # noqa
|
||||||
|
if value and not np.dtype(value):
|
||||||
|
raise TypeError("{} not supported matrix type".format(value))
|
||||||
|
|
||||||
|
|
||||||
|
@attrs
|
||||||
|
class ModelMonitoring(object):
|
||||||
|
base_serving_url = attrib(type=str) # serving point url prefix (example: "detect_cat")
|
||||||
|
engine_type = attrib(type=str, validator=_engine_validator) # engine type
|
||||||
|
monitor_project = attrib(type=str, default=None) # monitor model project (for model auto update)
|
||||||
|
monitor_name = attrib(type=str, default=None) # monitor model name (for model auto update, regexp selection)
|
||||||
|
monitor_tags = attrib(type=list, default=[]) # monitor model tag (for model auto update)
|
||||||
|
only_published = attrib(type=bool, default=False) # only select published models
|
||||||
|
max_versions = attrib(type=int, default=None) # Maximum number of models to keep serving (latest X models)
|
||||||
|
input_size = attrib(type=list, default=None) # optional, model matrix size
|
||||||
|
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
|
||||||
|
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
|
||||||
|
output_size = attrib(type=list, default=None) # optional, model matrix size
|
||||||
|
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
|
||||||
|
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
|
||||||
|
preprocess_artifact = attrib(
|
||||||
|
type=str, default=None) # optional artifact name storing the model preprocessing code
|
||||||
|
auxiliary_cfg = attrib(type=dict, default=None) # Auxiliary configuration (e.g. triton conf), Union[str, dict]
|
||||||
|
|
||||||
|
def as_dict(self, remove_null_entries=False):
|
||||||
|
if not remove_null_entries:
|
||||||
|
return asdict(self)
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
@attrs
|
||||||
|
class ModelEndpoint(object):
|
||||||
|
engine_type = attrib(type=str, validator=_engine_validator) # engine type
|
||||||
|
serving_url = attrib(type=str) # full serving point url (including version) example: "detect_cat/v1"
|
||||||
|
model_id = attrib(type=str, default=None) # model ID to serve (and download)
|
||||||
|
version = attrib(type=str, default="") # key (version string), default no version
|
||||||
|
preprocess_artifact = attrib(
|
||||||
|
type=str, default=None) # optional artifact name storing the model preprocessing code
|
||||||
|
input_size = attrib(type=list, default=None) # optional, model matrix size
|
||||||
|
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
|
||||||
|
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
|
||||||
|
output_size = attrib(type=list, default=None) # optional, model matrix size
|
||||||
|
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
|
||||||
|
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
|
||||||
|
auxiliary_cfg = attrib(type=dict, default=None) # Optional: Auxiliary configuration (e.g. triton conf), [str, dict]
|
||||||
|
|
||||||
|
def as_dict(self, remove_null_entries=False):
|
||||||
|
if not remove_null_entries:
|
||||||
|
return asdict(self)
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
@attrs
|
||||||
|
class CanaryEP(object):
|
||||||
|
endpoint = attrib(type=str) # load balancer endpoint
|
||||||
|
weights = attrib(type=list) # list of weights (order should be matching fixed_endpoints or prefix)
|
||||||
|
load_endpoints = attrib(type=list, default=[]) # list of endpoints to balance and route
|
||||||
|
load_endpoint_prefix = attrib(
|
||||||
|
type=str, default=None) # endpoint prefix to list
|
||||||
|
# (any endpoint starting with this prefix will be listed, sorted lexicographically, or broken into /<int>)
|
||||||
|
|
||||||
|
def as_dict(self, remove_null_entries=False):
|
||||||
|
if not remove_null_entries:
|
||||||
|
return asdict(self)
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
@ -19,7 +19,7 @@ class GzipRequest(Request):
|
|||||||
body = await super().body()
|
body = await super().body()
|
||||||
if "gzip" in self.headers.getlist("Content-Encoding"):
|
if "gzip" in self.headers.getlist("Content-Encoding"):
|
||||||
body = gzip.decompress(body)
|
body = gzip.decompress(body)
|
||||||
self._body = body
|
self._body = body # noqa
|
||||||
return self._body
|
return self._body
|
||||||
|
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ router = APIRouter(
|
|||||||
@router.post("/{model_id}/{version}")
|
@router.post("/{model_id}/{version}")
|
||||||
@router.post("/{model_id}/")
|
@router.post("/{model_id}/")
|
||||||
@router.post("/{model_id}")
|
@router.post("/{model_id}")
|
||||||
async def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
|
def serve_model(model_id: str, version: Optional[str] = None, request: Dict[Any, Any] = None):
|
||||||
try:
|
try:
|
||||||
return_value = processor.process_request(
|
return_value = processor.process_request(
|
||||||
base_url=model_id,
|
base_url=model_id,
|
||||||
|
@ -8,25 +8,10 @@ import threading
|
|||||||
from multiprocessing import Lock
|
from multiprocessing import Lock
|
||||||
from numpy.random import choice
|
from numpy.random import choice
|
||||||
|
|
||||||
from attr import attrib, attrs, asdict
|
|
||||||
from clearml import Task, Model
|
from clearml import Task, Model
|
||||||
from clearml.storage.util import hash_dict
|
from clearml.storage.util import hash_dict
|
||||||
from .preprocess_service import ModelEndpoint, ModelMonitoring, BasePreprocessRequest
|
from .preprocess_service import BasePreprocessRequest
|
||||||
|
from .endpoints import ModelEndpoint, ModelMonitoring, CanaryEP
|
||||||
|
|
||||||
@attrs
|
|
||||||
class CanaryEP(object):
|
|
||||||
endpoint = attrib(type=str) # load balancer endpoint
|
|
||||||
weights = attrib(type=list) # list of weights (order should be matching fixed_endpoints or prefix)
|
|
||||||
load_endpoints = attrib(type=list, default=[]) # list of endpoints to balance and route
|
|
||||||
load_endpoint_prefix = attrib(
|
|
||||||
type=str, default=None) # endpoint prefix to list
|
|
||||||
# (any endpoint starting with this prefix will be listed, sorted lexicographically, or broken into /<int>)
|
|
||||||
|
|
||||||
def as_dict(self, remove_null_entries=False):
|
|
||||||
if not remove_null_entries:
|
|
||||||
return asdict(self)
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
|
|
||||||
class FastWriteCounter(object):
|
class FastWriteCounter(object):
|
||||||
@ -98,6 +83,7 @@ class ModelRequestProcessor(object):
|
|||||||
sleep(1)
|
sleep(1)
|
||||||
# retry to process
|
# retry to process
|
||||||
return self.process_request(base_url=base_url, version=version, request_body=request_body)
|
return self.process_request(base_url=base_url, version=version, request_body=request_body)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# normalize url and version
|
# normalize url and version
|
||||||
url = self._normalize_endpoint_url(base_url, version)
|
url = self._normalize_endpoint_url(base_url, version)
|
||||||
@ -120,9 +106,8 @@ class ModelRequestProcessor(object):
|
|||||||
self._engine_processor_lookup[url] = processor
|
self._engine_processor_lookup[url] = processor
|
||||||
|
|
||||||
return_value = self._process_request(processor=processor, url=url, body=request_body)
|
return_value = self._process_request(processor=processor, url=url, body=request_body)
|
||||||
except Exception:
|
finally:
|
||||||
self._request_processing_state.dec()
|
self._request_processing_state.dec()
|
||||||
raise
|
|
||||||
|
|
||||||
return return_value
|
return return_value
|
||||||
|
|
||||||
@ -194,7 +179,7 @@ class ModelRequestProcessor(object):
|
|||||||
if url in self._endpoints:
|
if url in self._endpoints:
|
||||||
print("Warning: Model endpoint \'{}\' overwritten".format(url))
|
print("Warning: Model endpoint \'{}\' overwritten".format(url))
|
||||||
|
|
||||||
if not endpoint.model_id:
|
if not endpoint.model_id and any([model_project, model_name, model_tags]):
|
||||||
model_query = dict(
|
model_query = dict(
|
||||||
project_name=model_project,
|
project_name=model_project,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -208,6 +193,8 @@ class ModelRequestProcessor(object):
|
|||||||
if len(models) > 1:
|
if len(models) > 1:
|
||||||
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
|
print("Warning: Found multiple Models for \'{}\', selecting id={}".format(model_query, models[0].id))
|
||||||
endpoint.model_id = models[0].id
|
endpoint.model_id = models[0].id
|
||||||
|
elif not endpoint.model_id:
|
||||||
|
print("Warning: No Model provided for \'{}\'".format(url))
|
||||||
|
|
||||||
# upload as new artifact
|
# upload as new artifact
|
||||||
if preprocess_code:
|
if preprocess_code:
|
||||||
@ -237,6 +224,11 @@ class ModelRequestProcessor(object):
|
|||||||
if not isinstance(monitoring, ModelMonitoring):
|
if not isinstance(monitoring, ModelMonitoring):
|
||||||
monitoring = ModelMonitoring(**monitoring)
|
monitoring = ModelMonitoring(**monitoring)
|
||||||
|
|
||||||
|
# make sure we actually have something to monitor
|
||||||
|
if not any([monitoring.monitor_project, monitoring.monitor_name, monitoring.monitor_tags]):
|
||||||
|
raise ValueError("Model monitoring requires at least a "
|
||||||
|
"project / name / tag to monitor, none were provided.")
|
||||||
|
|
||||||
# make sure we have everything configured
|
# make sure we have everything configured
|
||||||
self._validate_model(monitoring)
|
self._validate_model(monitoring)
|
||||||
|
|
||||||
@ -384,6 +376,10 @@ class ModelRequestProcessor(object):
|
|||||||
# release stall lock
|
# release stall lock
|
||||||
self._update_lock_flag = False
|
self._update_lock_flag = False
|
||||||
|
|
||||||
|
# update the state on the inference task
|
||||||
|
if Task.current_task() and Task.current_task().id != self._task.id:
|
||||||
|
self.serialize(task=Task.current_task())
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def serialize(self, task: Optional[Task] = None) -> None:
|
def serialize(self, task: Optional[Task] = None) -> None:
|
||||||
@ -878,7 +874,7 @@ class ModelRequestProcessor(object):
|
|||||||
return task
|
return task
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _normalize_endpoint_url(cls, endpoint: str, version : Optional[str] = None) -> str:
|
def _normalize_endpoint_url(cls, endpoint: str, version: Optional[str] = None) -> str:
|
||||||
return "{}/{}".format(endpoint.rstrip("/"), version or "").rstrip("/")
|
return "{}/{}".format(endpoint.rstrip("/"), version or "").rstrip("/")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,73 +1,20 @@
|
|||||||
import numpy as np
|
import os
|
||||||
from typing import Optional, Any, Callable, List
|
from typing import Optional, Any, Callable, List
|
||||||
|
|
||||||
from attr import attrib, attrs, asdict
|
import numpy as np
|
||||||
|
|
||||||
from clearml import Task, Model
|
from clearml import Task, Model
|
||||||
from clearml.binding.artifacts import Artifacts
|
from clearml.binding.artifacts import Artifacts
|
||||||
from clearml.storage.util import sha256sum
|
from clearml.storage.util import sha256sum
|
||||||
|
from requests import post as request_post
|
||||||
|
|
||||||
|
from .endpoints import ModelEndpoint
|
||||||
def _engine_validator(inst, attr, value): # noqa
|
|
||||||
if not BasePreprocessRequest.validate_engine_type(value):
|
|
||||||
raise TypeError("{} not supported engine type".format(value))
|
|
||||||
|
|
||||||
|
|
||||||
def _matrix_type_validator(inst, attr, value): # noqa
|
|
||||||
if value and not np.dtype(value):
|
|
||||||
raise TypeError("{} not supported matrix type".format(value))
|
|
||||||
|
|
||||||
|
|
||||||
@attrs
|
|
||||||
class ModelMonitoring(object):
|
|
||||||
base_serving_url = attrib(type=str) # serving point url prefix (example: "detect_cat")
|
|
||||||
monitor_project = attrib(type=str) # monitor model project (for model auto update)
|
|
||||||
monitor_name = attrib(type=str) # monitor model name (for model auto update, regexp selection)
|
|
||||||
monitor_tags = attrib(type=list) # monitor model tag (for model auto update)
|
|
||||||
engine_type = attrib(type=str, validator=_engine_validator) # engine type
|
|
||||||
only_published = attrib(type=bool, default=False) # only select published models
|
|
||||||
max_versions = attrib(type=int, default=None) # Maximum number of models to keep serving (latest X models)
|
|
||||||
input_size = attrib(type=list, default=None) # optional, model matrix size
|
|
||||||
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
|
|
||||||
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
|
|
||||||
output_size = attrib(type=list, default=None) # optional, model matrix size
|
|
||||||
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
|
|
||||||
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
|
|
||||||
preprocess_artifact = attrib(
|
|
||||||
type=str, default=None) # optional artifact name storing the model preprocessing code
|
|
||||||
auxiliary_cfg = attrib(type=dict, default=None) # Auxiliary configuration (e.g. triton conf), Union[str, dict]
|
|
||||||
|
|
||||||
def as_dict(self, remove_null_entries=False):
|
|
||||||
if not remove_null_entries:
|
|
||||||
return asdict(self)
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
|
|
||||||
@attrs
|
|
||||||
class ModelEndpoint(object):
|
|
||||||
engine_type = attrib(type=str, validator=_engine_validator) # engine type
|
|
||||||
serving_url = attrib(type=str) # full serving point url (including version) example: "detect_cat/v1"
|
|
||||||
model_id = attrib(type=str) # list of model IDs to serve (order implies versions first is v1)
|
|
||||||
version = attrib(type=str, default="") # key (version string), default no version
|
|
||||||
preprocess_artifact = attrib(
|
|
||||||
type=str, default=None) # optional artifact name storing the model preprocessing code
|
|
||||||
input_size = attrib(type=list, default=None) # optional, model matrix size
|
|
||||||
input_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
|
|
||||||
input_name = attrib(type=str, default=None) # optional, layer name to push the input to
|
|
||||||
output_size = attrib(type=list, default=None) # optional, model matrix size
|
|
||||||
output_type = attrib(type=str, default=None, validator=_matrix_type_validator) # optional, model matrix type
|
|
||||||
output_name = attrib(type=str, default=None) # optional, layer name to pull the results from
|
|
||||||
auxiliary_cfg = attrib(type=dict, default=None) # Optional: Auxiliary configuration (e.g. triton conf), [str, dict]
|
|
||||||
|
|
||||||
def as_dict(self, remove_null_entries=False):
|
|
||||||
if not remove_null_entries:
|
|
||||||
return asdict(self)
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
|
|
||||||
class BasePreprocessRequest(object):
|
class BasePreprocessRequest(object):
|
||||||
__preprocessing_lookup = {}
|
__preprocessing_lookup = {}
|
||||||
__preprocessing_modules = set()
|
__preprocessing_modules = set()
|
||||||
|
_default_serving_base_url = "http://127.0.0.1:8080/serve/"
|
||||||
|
_timeout = None # timeout in seconds for the entire request, set in __init__
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -83,6 +30,8 @@ class BasePreprocessRequest(object):
|
|||||||
self._preprocess = None
|
self._preprocess = None
|
||||||
self._model = None
|
self._model = None
|
||||||
self._server_config = server_config or {}
|
self._server_config = server_config or {}
|
||||||
|
if self._timeout is None:
|
||||||
|
self._timeout = int(float(os.environ.get('GUNICORN_SERVING_TIMEOUT', 600)) * 0.8)
|
||||||
# load preprocessing code here
|
# load preprocessing code here
|
||||||
if self.model_endpoint.preprocess_artifact:
|
if self.model_endpoint.preprocess_artifact:
|
||||||
if not task or self.model_endpoint.preprocess_artifact not in task.artifacts:
|
if not task or self.model_endpoint.preprocess_artifact not in task.artifacts:
|
||||||
@ -111,7 +60,10 @@ class BasePreprocessRequest(object):
|
|||||||
spec = importlib.util.spec_from_file_location("Preprocess", path)
|
spec = importlib.util.spec_from_file_location("Preprocess", path)
|
||||||
_preprocess = importlib.util.module_from_spec(spec)
|
_preprocess = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(_preprocess)
|
spec.loader.exec_module(_preprocess)
|
||||||
self._preprocess = _preprocess.Preprocess() # noqa
|
Preprocess = _preprocess.Preprocess # noqa
|
||||||
|
# override `send_request` method
|
||||||
|
Preprocess.send_request = BasePreprocessRequest._preprocess_send_request
|
||||||
|
self._preprocess = Preprocess()
|
||||||
self._preprocess.serving_config = server_config or {}
|
self._preprocess.serving_config = server_config or {}
|
||||||
if callable(getattr(self._preprocess, 'load', None)):
|
if callable(getattr(self._preprocess, 'load', None)):
|
||||||
self._model = self._preprocess.load(self._get_local_model_file())
|
self._model = self._preprocess.load(self._get_local_model_file())
|
||||||
@ -125,7 +77,7 @@ class BasePreprocessRequest(object):
|
|||||||
Raise exception to report an error
|
Raise exception to report an error
|
||||||
Return value will be passed to serving engine
|
Return value will be passed to serving engine
|
||||||
"""
|
"""
|
||||||
if self._preprocess is not None:
|
if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'):
|
||||||
return self._preprocess.preprocess(request)
|
return self._preprocess.preprocess(request)
|
||||||
return request
|
return request
|
||||||
|
|
||||||
@ -135,7 +87,7 @@ class BasePreprocessRequest(object):
|
|||||||
Raise exception to report an error
|
Raise exception to report an error
|
||||||
Return value will be passed to serving engine
|
Return value will be passed to serving engine
|
||||||
"""
|
"""
|
||||||
if self._preprocess is not None:
|
if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'):
|
||||||
return self._preprocess.postprocess(data)
|
return self._preprocess.postprocess(data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@ -162,6 +114,7 @@ class BasePreprocessRequest(object):
|
|||||||
"""
|
"""
|
||||||
A decorator to register an annotation type name for classes deriving from Annotation
|
A decorator to register an annotation type name for classes deriving from Annotation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(cls):
|
def wrapper(cls):
|
||||||
cls.__preprocessing_lookup[engine_name] = cls
|
cls.__preprocessing_lookup[engine_name] = cls
|
||||||
return cls
|
return cls
|
||||||
@ -181,6 +134,17 @@ class BasePreprocessRequest(object):
|
|||||||
except (ImportError, TypeError):
|
except (ImportError, TypeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _preprocess_send_request(self, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]:
|
||||||
|
endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/")
|
||||||
|
base_url = self.serving_config.get("base_serving_url") if self.serving_config else None
|
||||||
|
base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/")
|
||||||
|
url = "{}/{}".format(base_url, endpoint.strip("/"))
|
||||||
|
return_value = request_post(url, json=data, timeout=BasePreprocessRequest._timeout)
|
||||||
|
if not return_value.ok:
|
||||||
|
return None
|
||||||
|
return return_value.json()
|
||||||
|
|
||||||
|
|
||||||
@BasePreprocessRequest.register_engine("triton", modules=["grpc", "tritonclient"])
|
@BasePreprocessRequest.register_engine("triton", modules=["grpc", "tritonclient"])
|
||||||
class TritonPreprocessRequest(BasePreprocessRequest):
|
class TritonPreprocessRequest(BasePreprocessRequest):
|
||||||
@ -224,7 +188,7 @@ class TritonPreprocessRequest(BasePreprocessRequest):
|
|||||||
Detect gRPC server and send the request to it
|
Detect gRPC server and send the request to it
|
||||||
"""
|
"""
|
||||||
# allow to override bt preprocessing class
|
# allow to override bt preprocessing class
|
||||||
if self._preprocess is not None and getattr(self._preprocess, "process", None):
|
if self._preprocess is not None and hasattr(self._preprocess, "process"):
|
||||||
return self._preprocess.process(data)
|
return self._preprocess.process(data)
|
||||||
|
|
||||||
# Create gRPC stub for communicating with the server
|
# Create gRPC stub for communicating with the server
|
||||||
@ -268,7 +232,11 @@ class TritonPreprocessRequest(BasePreprocessRequest):
|
|||||||
output0.name = self.model_endpoint.output_name
|
output0.name = self.model_endpoint.output_name
|
||||||
|
|
||||||
request.outputs.extend([output0])
|
request.outputs.extend([output0])
|
||||||
response = grpc_stub.ModelInfer(request, compression=self._ext_grpc.Compression.Gzip)
|
response = grpc_stub.ModelInfer(
|
||||||
|
request,
|
||||||
|
compression=self._ext_grpc.Compression.Gzip,
|
||||||
|
timeout=self._timeout
|
||||||
|
)
|
||||||
|
|
||||||
output_results = []
|
output_results = []
|
||||||
index = 0
|
index = 0
|
||||||
@ -351,6 +319,6 @@ class CustomPreprocessRequest(BasePreprocessRequest):
|
|||||||
The actual processing function.
|
The actual processing function.
|
||||||
We run the process in this context
|
We run the process in this context
|
||||||
"""
|
"""
|
||||||
if self._preprocess is not None:
|
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
|
||||||
return self._preprocess.process(data)
|
return self._preprocess.process(data)
|
||||||
return None
|
return None
|
||||||
|
@ -14,3 +14,4 @@ grpcio
|
|||||||
Pillow
|
Pillow
|
||||||
xgboost
|
xgboost
|
||||||
lightgbm
|
lightgbm
|
||||||
|
requests
|
||||||
|
19
examples/ensemble/preprocess.py
Normal file
19
examples/ensemble/preprocess.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# Notice Preprocess class Must be named "Preprocess"
|
||||||
|
class Preprocess(object):
|
||||||
|
def __init__(self):
|
||||||
|
# set internal state, this will be called only once. (i.e. not per request)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def preprocess(self, body: dict) -> Any:
|
||||||
|
# we expect to get two valid on the dict x0, and x1
|
||||||
|
return [[body.get("x0", None), body.get("x1", None)], ]
|
||||||
|
|
||||||
|
def postprocess(self, data: Any) -> dict:
|
||||||
|
# post process the data returned from the model inference engine
|
||||||
|
# data is the return value from model.predict we will put is inside a return value as Y
|
||||||
|
return dict(y=data.tolist() if isinstance(data, np.ndarray) else data)
|
31
examples/ensemble/readme.md
Normal file
31
examples/ensemble/readme.md
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# Train and Deploy Scikit-Learn model ensemble
|
||||||
|
|
||||||
|
## training mock voting regression model
|
||||||
|
|
||||||
|
Run the mock python training code
|
||||||
|
```bash
|
||||||
|
pip install -r examples/ensemble/requirements.txt
|
||||||
|
python examples/ensemble/train_ensemble.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The output will be a model created on the project "serving examples", by the name "train model ensemble"
|
||||||
|
|
||||||
|
## setting up the serving service
|
||||||
|
|
||||||
|
1. Create serving Service: `clearml-serving create --name "serving example"` (write down the service ID)
|
||||||
|
2. Create model endpoint:
|
||||||
|
`clearml-serving --id <service_id> model add --engine sklearn --endpoint "test_model_ensemble" --preprocess "examples/ensemble/preprocess.py" --name "train model ensemble" --project "serving examples"`
|
||||||
|
|
||||||
|
Or auto update
|
||||||
|
|
||||||
|
`clearml-serving --id <service_id> model auto-update --engine sklearn --endpoint "test_model_ensemble_auto" --preprocess "examples/ensemble/preprocess.py" --name "train model ensemble" --project "serving examples" --max-versions 2`
|
||||||
|
|
||||||
|
Or add Canary endpoint
|
||||||
|
|
||||||
|
`clearml-serving --id <service_id> model canary --endpoint "test_model_ensemble_auto" --weights 0.1 0.9 --input-endpoint-prefix test_model_ensemble_auto`
|
||||||
|
|
||||||
|
3. Run the clearml-serving container `docker run -v ~/clearml.conf:/root/clearml.conf -p 8080:8080 -e CLEARML_SERVING_TASK_ID=<service_id> clearml-serving:latest`
|
||||||
|
4. Test new endpoint: `curl -X POST "http://127.0.0.1:8080/serve/test_model_ensemble" -H "accept: application/json" -H "Content-Type: application/json" -d '{"x0": 1, "x1": 2}'`
|
||||||
|
|
||||||
|
> **_Notice:_** You can also change the serving service while it is already running!
|
||||||
|
This includes adding/removing endpoints, adding canary model routing etc.
|
2
examples/ensemble/requirements.txt
Normal file
2
examples/ensemble/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
clearml >= 1.1.6
|
||||||
|
scikit-learn
|
23
examples/ensemble/train_ensemble.py
Normal file
23
examples/ensemble/train_ensemble.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from sklearn.neighbors import KNeighborsRegressor
|
||||||
|
from sklearn.ensemble import RandomForestRegressor
|
||||||
|
from sklearn.ensemble import VotingRegressor
|
||||||
|
from sklearn.datasets import make_blobs
|
||||||
|
from joblib import dump
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
task = Task.init(project_name="serving examples", task_name="train model ensemble", output_uri=True)
|
||||||
|
|
||||||
|
# generate 2d classification dataset
|
||||||
|
X, y = make_blobs(n_samples=100, centers=2, n_features=2, random_state=1)
|
||||||
|
|
||||||
|
knn = KNeighborsRegressor(n_neighbors=5)
|
||||||
|
knn.fit(X, y)
|
||||||
|
|
||||||
|
rf = RandomForestRegressor(n_estimators=50)
|
||||||
|
rf.fit(X, y)
|
||||||
|
|
||||||
|
estimators = [("knn", knn), ("rf", rf), ]
|
||||||
|
ensemble = VotingRegressor(estimators)
|
||||||
|
ensemble.fit(X, y)
|
||||||
|
|
||||||
|
dump(ensemble, filename="ensemble-vr.pkl", compress=9)
|
32
examples/pipeline/preprocess.py
Normal file
32
examples/pipeline/preprocess.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
|
||||||
|
# Notice Preprocess class Must be named "Preprocess"
|
||||||
|
class Preprocess(object):
|
||||||
|
def __init__(self):
|
||||||
|
# set internal state, this will be called only once. (i.e. not per request)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def postprocess(self, data: List[dict]) -> dict:
|
||||||
|
# we will here average the results and return the new value
|
||||||
|
# assume data is a list of dicts greater than 1
|
||||||
|
|
||||||
|
# average result
|
||||||
|
return dict(y=0.5 * data[0]['y'][0] + 0.5 * data[1]['y'][0])
|
||||||
|
|
||||||
|
def process(self, data: Any) -> Any:
|
||||||
|
"""
|
||||||
|
do something with the actual data, return any type of object.
|
||||||
|
The returned object will be passed as is to the postprocess function engine
|
||||||
|
"""
|
||||||
|
predict_a = self.send_request(endpoint="/test_model_sklearn_a/", version=None, data=data)
|
||||||
|
predict_b = self.send_request(endpoint="/test_model_sklearn_b/", version=None, data=data)
|
||||||
|
if not predict_b or not predict_a:
|
||||||
|
raise ValueError("Error requesting inference endpoint test_model_sklearn a/b")
|
||||||
|
|
||||||
|
return [predict_a, predict_b]
|
||||||
|
|
||||||
|
def send_request(self, endpoint, version, data) -> List[dict]:
|
||||||
|
# Mock Function!
|
||||||
|
# replaced by real send request function when constructed by the inference service
|
||||||
|
pass
|
26
examples/pipeline/readme.md
Normal file
26
examples/pipeline/readme.md
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# Deploy a model inference pipeline
|
||||||
|
|
||||||
|
## prerequisites
|
||||||
|
|
||||||
|
Training a scikit-learn model (see example/sklearn)
|
||||||
|
|
||||||
|
## setting up the serving service
|
||||||
|
|
||||||
|
1. Create serving Service (if not already running):
|
||||||
|
`clearml-serving create --name "serving example"` (write down the service ID)
|
||||||
|
|
||||||
|
2. Create model base two endpoints:
|
||||||
|
`clearml-serving --id <service_id> model add --engine sklearn --endpoint "test_model_sklearn_a" --preprocess "examples/sklearn/preprocess.py" --name "train sklearn model" --project "serving examples"`
|
||||||
|
|
||||||
|
`clearml-serving --id <service_id> model add --engine sklearn --endpoint "test_model_sklearn_b" --preprocess "examples/sklearn/preprocess.py" --name "train sklearn model" --project "serving examples"`
|
||||||
|
|
||||||
|
3. Create pipeline model endpoint:
|
||||||
|
`clearml-serving --id <service_id> model add --engine custom --endpoint "test_model_pipeline" --preprocess "examples/pipeline/preprocess.py"`
|
||||||
|
|
||||||
|
4. Run the clearml-serving container `docker run -v ~/clearml.conf:/root/clearml.conf -p 8080:8080 -e CLEARML_SERVING_TASK_ID=<service_id> clearml-serving:latest`
|
||||||
|
|
||||||
|
5. Test new endpoint: `curl -X POST "http://127.0.0.1:8080/serve/test_model_pipeline" -H "accept: application/json" -H "Content-Type: application/json" -d '{"x0": 1, "x1": 2}'`
|
||||||
|
|
||||||
|
|
||||||
|
> **_Notice:_** You can also change the serving service while it is already running!
|
||||||
|
This includes adding/removing endpoints, adding canary model routing etc.
|
Loading…
Reference in New Issue
Block a user