diff --git a/clearml_serving/serving/entrypoint.sh b/clearml_serving/serving/entrypoint.sh index e1a5bbc..a5efea1 100755 --- a/clearml_serving/serving/entrypoint.sh +++ b/clearml_serving/serving/entrypoint.sh @@ -2,6 +2,7 @@ # print configuration echo CLEARML_SERVING_TASK_ID="$CLEARML_SERVING_TASK_ID" +echo CLEARML_INFERENCE_TASK_ID="$CLEARML_INFERENCE_TASK_ID" echo CLEARML_SERVING_PORT="$CLEARML_SERVING_PORT" echo CLEARML_USE_GUNICORN="$CLEARML_USE_GUNICORN" echo CLEARML_EXTRA_PYTHON_PACKAGES="$CLEARML_EXTRA_PYTHON_PACKAGES" diff --git a/clearml_serving/serving/init.py b/clearml_serving/serving/init.py index 2ae54a8..0c75712 100644 --- a/clearml_serving/serving/init.py +++ b/clearml_serving/serving/init.py @@ -6,6 +6,7 @@ from clearml_serving.serving.preprocess_service import BasePreprocessRequest def setup_task(force_threaded_logging=None): serving_service_task_id = os.environ.get("CLEARML_SERVING_TASK_ID", None) + inference_service_task_id = os.environ.get("CLEARML_INFERENCE_TASK_ID", False) # according Task.init() docs # always use background thread, it requires less memory if force_threaded_logging or os.environ.get("CLEARML_BKG_THREAD_REPORT") in ("1", "Y", "y", "true"): @@ -24,6 +25,7 @@ def setup_task(force_threaded_logging=None): project_name=serving_task.get_project_name(), task_name="{} - serve instance".format(serving_task.name), task_type="inference", # noqa + continue_last_task=inference_service_task_id, ) instance_task.set_system_tags(["service"]) # make sure we start logging thread/process diff --git a/clearml_serving/serving/main.py b/clearml_serving/serving/main.py index 6865c93..10ce9c9 100644 --- a/clearml_serving/serving/main.py +++ b/clearml_serving/serving/main.py @@ -1,9 +1,13 @@ import os import traceback import gzip +import asyncio from fastapi import FastAPI, Request, Response, APIRouter, HTTPException from fastapi.routing import APIRoute +from fastapi.responses import PlainTextResponse + +from starlette.background import BackgroundTask from typing import Optional, Dict, Any, Callable, Union @@ -48,6 +52,9 @@ try: except (ValueError, TypeError): pass +class CUDAException(Exception): + def __init__(self, exception: str): + self.exception = exception # start FastAPI app app = FastAPI(title="ClearML Serving Service", version=__version__, description="ClearML Service Service router") @@ -70,6 +77,20 @@ async def startup_event(): processor.launch(poll_frequency_sec=model_sync_frequency_secs*60) +@app.on_event('shutdown') +def shutdown_event(): + print('RESTARTING INFERENCE SERVICE!') + +async def exit_app(): + loop = asyncio.get_running_loop() + loop.stop() + +@app.exception_handler(CUDAException) +async def cuda_exception_handler(request, exc): + task = BackgroundTask(exit_app) + return PlainTextResponse("CUDA out of memory. Restarting service", status_code=500, background=task) + + router = APIRouter( prefix="/serve", tags=["models"], @@ -102,7 +123,10 @@ async def serve_model(model_id: str, version: Optional[str] = None, request: Uni except ValueError as ex: session_logger.report_text("[{}] Exception [{}] {} while processing request: {}\n{}".format( instance_id, type(ex), ex, request, "".join(traceback.format_exc()))) - raise HTTPException(status_code=422, detail="Error [{}] processing request: {}".format(type(ex), ex)) + if "CUDA out of memory. " in str(ex) or "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(ex): + raise CUDAException(exception=ex) + else: + raise HTTPException(status_code=422, detail="Error [{}] processing request: {}".format(type(ex), ex)) except Exception as ex: session_logger.report_text("[{}] Exception [{}] {} while processing request: {}\n{}".format( instance_id, type(ex), ex, request, "".join(traceback.format_exc()))) diff --git a/clearml_serving/serving/model_request_processor.py b/clearml_serving/serving/model_request_processor.py index ba9242d..35f5120 100644 --- a/clearml_serving/serving/model_request_processor.py +++ b/clearml_serving/serving/model_request_processor.py @@ -1,5 +1,7 @@ import json import os +import gc +import torch from collections import deque from pathlib import Path from random import random @@ -915,7 +917,12 @@ class ModelRequestProcessor(object): for k in list(self._engine_processor_lookup.keys()): if k not in self._endpoints: # atomic + self._engine_processor_lookup[k]._model = None + self._engine_processor_lookup[k]._preprocess = None + del self._engine_processor_lookup[k] self._engine_processor_lookup.pop(k, None) + gc.collect() + torch.cuda.empty_cache() cleanup = False model_monitor_update = False except Exception as ex: