From 31124d9deb08c8283247b7b95313be59646fa7e0 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 16:10:01 -0500 Subject: [PATCH] feat: litellm config update --- backend/apps/litellm/main.py | 75 ++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 8d1132bb4..5696b6945 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -11,6 +11,9 @@ from starlette.responses import StreamingResponse import json import requests +from pydantic import BaseModel +from typing import Optional, List + from utils.utils import get_verified_user, get_current_user, get_admin_user from config import SRC_LOG_LEVELS, ENV from constants import ERROR_MESSAGES @@ -19,15 +22,12 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["LITELLM"]) -from config import ( - MODEL_FILTER_ENABLED, - MODEL_FILTER_LIST, -) +from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR import asyncio import subprocess - +import yaml app = FastAPI() @@ -42,44 +42,51 @@ app.add_middleware( ) +LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml" + +with open(LITELLM_CONFIG_DIR, "r") as file: + litellm_config = yaml.safe_load(file) + +app.state.CONFIG = litellm_config + # Global variable to store the subprocess reference background_process = None async def run_background_process(command): global background_process - print("run_background_process") + log.info("run_background_process") try: # Log the command to be executed - print(f"Executing command: {command}") + log.info(f"Executing command: {command}") # Execute the command and create a subprocess process = await asyncio.create_subprocess_exec( *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE ) background_process = process - print("Subprocess started successfully.") + log.info("Subprocess started successfully.") # Capture STDERR for debugging purposes stderr_output = await process.stderr.read() stderr_text = stderr_output.decode().strip() if stderr_text: - print(f"Subprocess STDERR: {stderr_text}") + log.info(f"Subprocess STDERR: {stderr_text}") - # Print output line by line + # log.info output line by line async for line in process.stdout: - print(line.decode().strip()) + log.info(line.decode().strip()) # Wait for the process to finish returncode = await process.wait() - print(f"Subprocess exited with return code {returncode}") + log.info(f"Subprocess exited with return code {returncode}") except Exception as e: log.error(f"Failed to start subprocess: {e}") raise # Optionally re-raise the exception if you want it to propagate async def start_litellm_background(): - print("start_litellm_background") + log.info("start_litellm_background") # Command to run in the background command = ( "litellm --port 14365 --telemetry False --config ./data/litellm/config.yaml" @@ -89,18 +96,18 @@ async def start_litellm_background(): async def shutdown_litellm_background(): - print("shutdown_litellm_background") + log.info("shutdown_litellm_background") global background_process if background_process: background_process.terminate() await background_process.wait() # Ensure the process has terminated - print("Subprocess terminated") + log.info("Subprocess terminated") @app.on_event("startup") async def startup_event(): - print("startup_event") + log.info("startup_event") # TODO: Check config.yaml file and create one asyncio.create_task(start_litellm_background()) @@ -114,8 +121,7 @@ async def get_status(): return {"status": True} -@app.get("/restart") -async def restart_litellm(user=Depends(get_admin_user)): +async def restart_litellm(): """ Endpoint to restart the litellm background service. """ @@ -126,7 +132,8 @@ async def restart_litellm(user=Depends(get_admin_user)): log.info("litellm service shutdown complete.") # Restart the background service - start_litellm_background() + + asyncio.create_task(start_litellm_background()) log.info("litellm service restart complete.") return { @@ -134,12 +141,40 @@ async def restart_litellm(user=Depends(get_admin_user)): "message": "litellm service restarted successfully.", } except Exception as e: - log.error(f"Error restarting litellm service: {e}") + log.info(f"Error restarting litellm service: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) +@app.get("/restart") +async def restart_litellm_handler(user=Depends(get_admin_user)): + return await restart_litellm() + + +@app.get("/config") +async def get_config(user=Depends(get_admin_user)): + return app.state.CONFIG + + +class LiteLLMConfigForm(BaseModel): + general_settings: Optional[dict] = None + litellm_settings: Optional[dict] = None + model_list: Optional[List[dict]] = None + router_settings: Optional[dict] = None + + +@app.post("/config/update") +async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)): + app.state.CONFIG = form_data.model_dump(exclude_none=True) + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + return app.state.CONFIG + + @app.get("/models") @app.get("/v1/models") async def get_models(user=Depends(get_current_user)):