clearml-serving/clearml_serving/serving/main.py

175 lines
6.6 KiB
Python

import os
import shlex
import traceback
import gzip
import asyncio
from fastapi import FastAPI, Request, Response, APIRouter, HTTPException
from fastapi.routing import APIRoute
from fastapi.responses import PlainTextResponse
from grpc.aio import AioRpcError
from starlette.background import BackgroundTask
from typing import Optional, Dict, Any, Callable, Union
from clearml_serving.version import __version__
from clearml_serving.serving.init import setup_task
from clearml_serving.serving.model_request_processor import (
ModelRequestProcessor,
EndpointNotFoundException,
EndpointBackendEngineException,
EndpointModelLoadException,
ServingInitializationException,
)
from clearml_serving.serving.utils import parse_grpc_errors
class GzipRequest(Request):
async def body(self) -> bytes:
if not hasattr(self, "_body"):
body = await super().body()
if "gzip" in self.headers.getlist("Content-Encoding"):
body = gzip.decompress(body)
self._body = body # noqa
return self._body
class GzipRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
request = GzipRequest(request.scope, request.receive)
return await original_route_handler(request)
return custom_route_handler
# process Lock, so that we can have only a single process doing the model reloading at a time
singleton_sync_lock = None # Lock()
# shared Model processor object
processor = None # type: Optional[ModelRequestProcessor]
# create clearml Task and load models
serving_service_task_id, session_logger, instance_id = setup_task()
# polling frequency
model_sync_frequency_secs = 5
try:
model_sync_frequency_secs = float(os.environ.get("CLEARML_SERVING_POLL_FREQ", model_sync_frequency_secs))
except (ValueError, TypeError):
pass
grpc_aio_ignore_errors = parse_grpc_errors(shlex.split(os.environ.get("CLEARML_SERVING_AIO_RPC_IGNORE_ERRORS", "")))
grpc_aio_verbose_errors = parse_grpc_errors(shlex.split(os.environ.get("CLEARML_SERVING_AIO_RPC_VERBOSE_ERRORS", "")))
class CUDAException(Exception):
def __init__(self, exception: str):
self.exception = exception
# start FastAPI app
app = FastAPI(title="ClearML Serving Service", version=__version__, description="ClearML Service Service router")
@app.on_event("startup")
async def startup_event():
global processor
if processor:
print(
"ModelRequestProcessor already initialized [pid={}] [service_id={}]".format(
os.getpid(), serving_service_task_id
)
)
else:
print("Starting up ModelRequestProcessor [pid={}] [service_id={}]".format(os.getpid(), serving_service_task_id))
processor = ModelRequestProcessor(
task_id=serving_service_task_id,
update_lock_guard=singleton_sync_lock,
)
print("ModelRequestProcessor [id={}] loaded".format(processor.get_id()))
processor.launch(poll_frequency_sec=model_sync_frequency_secs * 60)
@app.on_event("shutdown")
def shutdown_event():
print("RESTARTING INFERENCE SERVICE!")
async def exit_app():
loop = asyncio.get_running_loop()
loop.stop()
@app.exception_handler(CUDAException)
async def cuda_exception_handler(request, exc):
task = BackgroundTask(exit_app)
return PlainTextResponse("CUDA out of memory. Restarting service", status_code=500, background=task)
router = APIRouter(
prefix="/serve",
tags=["models"],
responses={404: {"description": "Model Serving Endpoint Not found"}},
route_class=GzipRoute, # mark-out to remove support for GZip content encoding
)
# cover all routing options for model version `/{model_id}`, `/{model_id}/123`, `/{model_id}?version=123`
@router.post("/{model_id}/{version}")
@router.post("/{model_id}/")
@router.post("/{model_id}")
async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None):
try:
return_value = await processor.process_request(base_url=model_id, version=version, request_body=request)
except EndpointNotFoundException as ex:
raise HTTPException(status_code=404, detail="Error processing request, endpoint was not found: {}".format(ex))
except (EndpointModelLoadException, EndpointBackendEngineException) as ex:
session_logger.report_text(
"[{}] Exception [{}] {} while processing request: {}\n{}".format(
instance_id, type(ex), ex, request, "".join(traceback.format_exc())
)
)
raise HTTPException(status_code=422, detail="Error [{}] processing request: {}".format(type(ex), ex))
except ServingInitializationException as ex:
session_logger.report_text(
"[{}] Exception [{}] {} while loading serving inference: {}\n{}".format(
instance_id, type(ex), ex, request, "".join(traceback.format_exc())
)
)
raise HTTPException(status_code=500, detail="Error [{}] processing request: {}".format(type(ex), ex))
except ValueError as ex:
session_logger.report_text(
"[{}] Exception [{}] {} while processing request: {}\n{}".format(
instance_id, type(ex), ex, request, "".join(traceback.format_exc())
)
)
if "CUDA out of memory. " in str(ex) or "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(ex):
raise CUDAException(exception=ex)
else:
raise HTTPException(status_code=422, detail="Error [{}] processing request: {}".format(type(ex), ex))
except AioRpcError as ex:
if grpc_aio_verbose_errors and ex.code() in grpc_aio_verbose_errors:
session_logger.report_text(
"[{}] Exception [AioRpcError] {} while processing request: {}".format(instance_id, ex, request)
)
elif not grpc_aio_ignore_errors or ex.code() not in grpc_aio_ignore_errors:
session_logger.report_text("[{}] Exception [AioRpcError] status={} ".format(instance_id, ex.code()))
raise HTTPException(
status_code=500, detail="Error [AioRpcError] processing request: status={}".format(ex.code())
)
except Exception as ex:
session_logger.report_text(
"[{}] Exception [{}] {} while processing request: {}\n{}".format(
instance_id, type(ex), ex, request, "".join(traceback.format_exc())
)
)
raise HTTPException(status_code=500, detail="Error [{}] processing request: {}".format(type(ex), ex))
return return_value
app.include_router(router)