diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 486ae4736..531e96494 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -11,7 +11,7 @@ from starlette.responses import StreamingResponse import json import requests -from utils.utils import get_verified_user, get_current_user +from utils.utils import get_verified_user, get_current_user, get_admin_user from config import SRC_LOG_LEVELS, ENV from constants import ERROR_MESSAGES @@ -112,6 +112,32 @@ async def get_status(): return {"status": True} +@app.get("/restart") +async def restart_litellm(user=Depends(get_admin_user)): + """ + Endpoint to restart the litellm background service. + """ + log.info("Requested restart of litellm service.") + try: + # Shut down the existing process if it is running + await shutdown_litellm_background() + log.info("litellm service shutdown complete.") + + # Restart the background service + await start_litellm_background() + log.info("litellm service restart complete.") + + return { + "status": "success", + "message": "litellm service restarted successfully.", + } + except Exception as e: + log.error(f"Error restarting litellm service: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) + + @app.get("/models") @app.get("/v1/models") async def get_models(user=Depends(get_current_user)): @@ -199,40 +225,3 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail ) - - -# class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): -# async def dispatch( -# self, request: Request, call_next: RequestResponseEndpoint -# ) -> Response: - -# response = await call_next(request) -# user = request.state.user - -# if "/models" in request.url.path: -# if isinstance(response, StreamingResponse): -# # Read the content of the streaming response -# body = b"" -# async for chunk in response.body_iterator: -# body += chunk - -# data = json.loads(body.decode("utf-8")) - -# if app.state.MODEL_FILTER_ENABLED: -# if user and user.role == "user": -# data["data"] = list( -# filter( -# lambda model: model["id"] -# in app.state.MODEL_FILTER_LIST, -# data["data"], -# ) -# ) - -# # Modified Flag -# data["modified"] = True -# return JSONResponse(content=data) - -# return response - - -# app.add_middleware(ModifyModelsResponseMiddleware)