mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
Merge 61be8733c8
into 724c99c605
This commit is contained in:
commit
b67eba1093
@ -41,6 +41,17 @@ class Preprocess(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def unload(self) -> None:
|
||||
"""
|
||||
OPTIONAL: provide unloading method for the model
|
||||
For example:
|
||||
```py
|
||||
import torch
|
||||
torch.cuda.empty_cache()
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
body: Union[bytes, dict],
|
||||
|
@ -9,6 +9,7 @@ echo CLEARML_EXTRA_PYTHON_PACKAGES="$CLEARML_EXTRA_PYTHON_PACKAGES"
|
||||
echo CLEARML_SERVING_NUM_PROCESS="$CLEARML_SERVING_NUM_PROCESS"
|
||||
echo CLEARML_SERVING_POLL_FREQ="$CLEARML_SERVING_POLL_FREQ"
|
||||
echo CLEARML_DEFAULT_KAFKA_SERVE_URL="$CLEARML_DEFAULT_KAFKA_SERVE_URL"
|
||||
echo CLEARML_SERVING_RESTART_ON_FAILURE="$CLEARML_SERVING_RESTART_ON_FAILURE"
|
||||
|
||||
SERVING_PORT="${CLEARML_SERVING_PORT:-8080}"
|
||||
GUNICORN_NUM_PROCESS="${CLEARML_SERVING_NUM_PROCESS:-4}"
|
||||
@ -40,29 +41,36 @@ then
|
||||
python3 -m pip install $CLEARML_EXTRA_PYTHON_PACKAGES
|
||||
fi
|
||||
|
||||
if [ -z "$CLEARML_USE_GUNICORN" ]
|
||||
then
|
||||
if [ -z "$CLEARML_SERVING_NUM_PROCESS" ]
|
||||
while : ; do
|
||||
if [ -z "$CLEARML_USE_GUNICORN" ]
|
||||
then
|
||||
echo "Starting Uvicorn server - single worker"
|
||||
PYTHONPATH=$(pwd) python3 -m uvicorn \
|
||||
clearml_serving.serving.main:app --log-level $UVICORN_LOG_LEVEL --host 0.0.0.0 --port $SERVING_PORT --loop $UVICORN_SERVE_LOOP \
|
||||
$UVICORN_EXTRA_ARGS
|
||||
if [ -z "$CLEARML_SERVING_NUM_PROCESS" ]
|
||||
then
|
||||
echo "Starting Uvicorn server - single worker"
|
||||
PYTHONPATH=$(pwd) python3 -m uvicorn \
|
||||
clearml_serving.serving.main:app --log-level $UVICORN_LOG_LEVEL --host 0.0.0.0 --port $SERVING_PORT --loop $UVICORN_SERVE_LOOP \
|
||||
$UVICORN_EXTRA_ARGS
|
||||
else
|
||||
echo "Starting Uvicorn server - multi worker"
|
||||
PYTHONPATH=$(pwd) python3 clearml_serving/serving/uvicorn_mp_entrypoint.py \
|
||||
clearml_serving.serving.main:app --log-level $UVICORN_LOG_LEVEL --host 0.0.0.0 --port $SERVING_PORT --loop $UVICORN_SERVE_LOOP \
|
||||
--workers $CLEARML_SERVING_NUM_PROCESS $UVICORN_EXTRA_ARGS
|
||||
fi
|
||||
else
|
||||
echo "Starting Uvicorn server - multi worker"
|
||||
PYTHONPATH=$(pwd) python3 clearml_serving/serving/uvicorn_mp_entrypoint.py \
|
||||
clearml_serving.serving.main:app --log-level $UVICORN_LOG_LEVEL --host 0.0.0.0 --port $SERVING_PORT --loop $UVICORN_SERVE_LOOP \
|
||||
--workers $CLEARML_SERVING_NUM_PROCESS $UVICORN_EXTRA_ARGS
|
||||
echo "Starting Gunicorn server"
|
||||
# start service
|
||||
PYTHONPATH=$(pwd) python3 -m gunicorn \
|
||||
--preload clearml_serving.serving.main:app \
|
||||
--workers $GUNICORN_NUM_PROCESS \
|
||||
--worker-class uvicorn.workers.UvicornWorker \
|
||||
--max-requests $GUNICORN_MAX_REQUESTS \
|
||||
--timeout $GUNICORN_SERVING_TIMEOUT \
|
||||
--bind 0.0.0.0:$SERVING_PORT \
|
||||
$GUNICORN_EXTRA_ARGS
|
||||
fi
|
||||
else
|
||||
echo "Starting Gunicorn server"
|
||||
# start service
|
||||
PYTHONPATH=$(pwd) python3 -m gunicorn \
|
||||
--preload clearml_serving.serving.main:app \
|
||||
--workers $GUNICORN_NUM_PROCESS \
|
||||
--worker-class uvicorn.workers.UvicornWorker \
|
||||
--max-requests $GUNICORN_MAX_REQUESTS \
|
||||
--timeout $GUNICORN_SERVING_TIMEOUT \
|
||||
--bind 0.0.0.0:$SERVING_PORT \
|
||||
$GUNICORN_EXTRA_ARGS
|
||||
fi
|
||||
|
||||
if [ -z "$CLEARML_SERVING_RESTART_ON_FAILURE" ]
|
||||
then
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
@ -1,13 +1,9 @@
|
||||
import os
|
||||
import traceback
|
||||
import gzip
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI, Request, Response, APIRouter, HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from typing import Optional, Dict, Any, Callable, Union
|
||||
|
||||
@ -52,9 +48,6 @@ try:
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
class CUDAException(Exception):
|
||||
def __init__(self, exception: str):
|
||||
self.exception = exception
|
||||
|
||||
# start FastAPI app
|
||||
app = FastAPI(title="ClearML Serving Service", version=__version__, description="ClearML Service Service router")
|
||||
@ -77,20 +70,6 @@ async def startup_event():
|
||||
processor.launch(poll_frequency_sec=model_sync_frequency_secs*60)
|
||||
|
||||
|
||||
@app.on_event('shutdown')
|
||||
def shutdown_event():
|
||||
print('RESTARTING INFERENCE SERVICE!')
|
||||
|
||||
async def exit_app():
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.stop()
|
||||
|
||||
@app.exception_handler(CUDAException)
|
||||
async def cuda_exception_handler(request, exc):
|
||||
task = BackgroundTask(exit_app)
|
||||
return PlainTextResponse("CUDA out of memory. Restarting service", status_code=500, background=task)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/serve",
|
||||
tags=["models"],
|
||||
@ -124,9 +103,9 @@ async def serve_model(model_id: str, version: Optional[str] = None, request: Uni
|
||||
session_logger.report_text("[{}] Exception [{}] {} while processing request: {}\n{}".format(
|
||||
instance_id, type(ex), ex, request, "".join(traceback.format_exc())))
|
||||
if "CUDA out of memory. " in str(ex) or "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(ex):
|
||||
raise CUDAException(exception=ex)
|
||||
else:
|
||||
raise HTTPException(status_code=422, detail="Error [{}] processing request: {}".format(type(ex), ex))
|
||||
# can't always recover from this - prefer to exit the program such that it can be restarted
|
||||
os._exit(1)
|
||||
raise HTTPException(status_code=422, detail="Error [{}] processing request: {}".format(type(ex), ex))
|
||||
except Exception as ex:
|
||||
session_logger.report_text("[{}] Exception [{}] {} while processing request: {}\n{}".format(
|
||||
instance_id, type(ex), ex, request, "".join(traceback.format_exc())))
|
||||
|
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import gc
|
||||
import torch
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from random import random
|
||||
@ -169,6 +168,8 @@ class ModelRequestProcessor(object):
|
||||
# retry to process
|
||||
return await self.process_request(base_url=base_url, version=version, request_body=request_body)
|
||||
|
||||
processor = None
|
||||
url = None
|
||||
try:
|
||||
# normalize url and version
|
||||
url = self._normalize_endpoint_url(base_url, version)
|
||||
@ -190,6 +191,8 @@ class ModelRequestProcessor(object):
|
||||
|
||||
return_value = await self._process_request(processor=processor, url=url, body=request_body)
|
||||
finally:
|
||||
if url and processor is not None and processor is not self._engine_processor_lookup.get(url):
|
||||
gc.collect()
|
||||
self._request_processing_state.dec()
|
||||
|
||||
return return_value
|
||||
@ -907,22 +910,22 @@ class ModelRequestProcessor(object):
|
||||
if cleanup or model_monitor_update:
|
||||
self._update_serving_plot()
|
||||
if cleanup:
|
||||
gc.collect()
|
||||
self._engine_processor_lookup = dict()
|
||||
except Exception as ex:
|
||||
print("Exception occurred in monitoring thread: {}".format(ex))
|
||||
sleep(poll_frequency_sec)
|
||||
try:
|
||||
# we assume that by now all old deleted endpoints requests already returned
|
||||
call_gc_collect = False
|
||||
if model_monitor_update and not cleanup:
|
||||
for k in list(self._engine_processor_lookup.keys()):
|
||||
if k not in self._endpoints:
|
||||
# atomic
|
||||
self._engine_processor_lookup[k]._model = None
|
||||
self._engine_processor_lookup[k]._preprocess = None
|
||||
del self._engine_processor_lookup[k]
|
||||
self._engine_processor_lookup.pop(k, None)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
call_gc_collect = True
|
||||
if call_gc_collect:
|
||||
gc.collect()
|
||||
cleanup = False
|
||||
model_monitor_update = False
|
||||
except Exception as ex:
|
||||
|
@ -90,7 +90,18 @@ class BasePreprocessRequest(object):
|
||||
sys.modules[spec.name] = _preprocess
|
||||
spec.loader.exec_module(_preprocess)
|
||||
|
||||
Preprocess = _preprocess.Preprocess # noqa
|
||||
class PreprocessDelWrapper(_preprocess.Preprocess):
|
||||
def __del__(self):
|
||||
super_ = super(PreprocessDelWrapper, self)
|
||||
if callable(getattr(super_, "unload", None)):
|
||||
try:
|
||||
super_.unload()
|
||||
except Exception as ex:
|
||||
print("Failed unloading model: {}".format(ex))
|
||||
if callable(getattr(super_, "__del__", None)):
|
||||
super_.__del__()
|
||||
|
||||
Preprocess = PreprocessDelWrapper # noqa
|
||||
# override `send_request` method
|
||||
Preprocess.send_request = BasePreprocessRequest._preprocess_send_request
|
||||
# create preprocess class
|
||||
|
@ -96,6 +96,7 @@ services:
|
||||
CLEARML_DEFAULT_KAFKA_SERVE_URL: ${CLEARML_DEFAULT_KAFKA_SERVE_URL:-clearml-serving-kafka:9092}
|
||||
CLEARML_DEFAULT_TRITON_GRPC_ADDR: ${CLEARML_DEFAULT_TRITON_GRPC_ADDR:-clearml-serving-triton:8001}
|
||||
CLEARML_USE_GUNICORN: ${CLEARML_USE_GUNICORN:-}
|
||||
CLEARML_SERVING_RESTART_ON_FAILURE: ${CLEARML_SERVING_RESTART_ON_FAILURE:-}
|
||||
CLEARML_SERVING_NUM_PROCESS: ${CLEARML_SERVING_NUM_PROCESS:-}
|
||||
CLEARML_EXTRA_PYTHON_PACKAGES: ${CLEARML_EXTRA_PYTHON_PACKAGES:-}
|
||||
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
|
||||
|
@ -96,6 +96,7 @@ services:
|
||||
CLEARML_DEFAULT_KAFKA_SERVE_URL: ${CLEARML_DEFAULT_KAFKA_SERVE_URL:-clearml-serving-kafka:9092}
|
||||
CLEARML_DEFAULT_TRITON_GRPC_ADDR: ${CLEARML_DEFAULT_TRITON_GRPC_ADDR:-clearml-serving-triton:8001}
|
||||
CLEARML_USE_GUNICORN: ${CLEARML_USE_GUNICORN:-}
|
||||
CLEARML_SERVING_RESTART_ON_FAILURE: ${CLEARML_SERVING_RESTART_ON_FAILURE:-}
|
||||
CLEARML_SERVING_NUM_PROCESS: ${CLEARML_SERVING_NUM_PROCESS:-}
|
||||
CLEARML_EXTRA_PYTHON_PACKAGES: ${CLEARML_EXTRA_PYTHON_PACKAGES:-}
|
||||
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
|
||||
|
@ -96,6 +96,7 @@ services:
|
||||
CLEARML_DEFAULT_KAFKA_SERVE_URL: ${CLEARML_DEFAULT_KAFKA_SERVE_URL:-clearml-serving-kafka:9092}
|
||||
CLEARML_DEFAULT_TRITON_GRPC_ADDR: ${CLEARML_DEFAULT_TRITON_GRPC_ADDR:-}
|
||||
CLEARML_USE_GUNICORN: ${CLEARML_USE_GUNICORN:-}
|
||||
CLEARML_SERVING_RESTART_ON_FAILURE: ${CLEARML_SERVING_RESTART_ON_FAILURE:-}
|
||||
CLEARML_SERVING_NUM_PROCESS: ${CLEARML_SERVING_NUM_PROCESS:-}
|
||||
CLEARML_EXTRA_PYTHON_PACKAGES: ${CLEARML_EXTRA_PYTHON_PACKAGES:-}
|
||||
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
|
||||
|
@ -73,6 +73,8 @@ CLEARML_EXTRA_PYTHON_PACKAGES=transformers
|
||||
# Change this depending on your machine and performance needs
|
||||
CLEARML_USE_GUNICORN=1
|
||||
CLEARML_SERVING_NUM_PROCESS=8
|
||||
# Restarts if the serving process crashes
|
||||
CLEARML_SERVING_RESTART_ON_FAILURE=1
|
||||
```
|
||||
|
||||
Huggingface models require Triton engine support, please use `docker-compose-triton.yml` / `docker-compose-triton-gpu.yml` or if running on Kubernetes, the matching helm chart to set things up. Check the repository main readme documentation if you need help.
|
||||
|
Loading…
Reference in New Issue
Block a user