From 5e458d490acf8c57f5a09d50310a58fc1ffe57c9 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 00:52:27 -0500 Subject: [PATCH 01/13] fix: run litellm as subprocess --- backend/apps/litellm/main.py | 71 +++++++++++++++++++++++++++++------- backend/main.py | 7 +--- 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index a9922aad7..39f348141 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -1,8 +1,8 @@ +from fastapi import FastAPI, Depends +from fastapi.routing import APIRoute +from fastapi.middleware.cors import CORSMiddleware + import logging - -from litellm.proxy.proxy_server import ProxyConfig, initialize -from litellm.proxy.proxy_server import app - from fastapi import FastAPI, Request, Depends, status, Response from fastapi.responses import JSONResponse @@ -23,24 +23,39 @@ from config import ( ) -proxy_config = ProxyConfig() +import asyncio +import subprocess -async def config(): - router, model_list, general_settings = await proxy_config.load_config( - router=None, config_file_path="./data/litellm/config.yaml" +app = FastAPI() + +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +async def run_background_process(command): + process = await asyncio.create_subprocess_exec( + *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - - await initialize(config="./data/litellm/config.yaml", telemetry=False) + return process -async def startup(): - await config() +async def start_litellm_background(): + # Command to run in the background + command = "litellm --config ./data/litellm/config.yaml" + await run_background_process(command) @app.on_event("startup") -async def on_startup(): - await startup() +async def startup_event(): + asyncio.create_task(start_litellm_background()) app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED @@ -63,6 +78,11 @@ async def auth_middleware(request: Request, call_next): return response +@app.get("/") +async def get_status(): + return {"status": True} + + class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint @@ -98,3 +118,26 @@ class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): app.add_middleware(ModifyModelsResponseMiddleware) + + +# from litellm.proxy.proxy_server import ProxyConfig, initialize +# from litellm.proxy.proxy_server import app + +# proxy_config = ProxyConfig() + + +# async def config(): +# router, model_list, general_settings = await proxy_config.load_config( +# router=None, config_file_path="./data/litellm/config.yaml" +# ) + +# await initialize(config="./data/litellm/config.yaml", telemetry=False) + + +# async def startup(): +# await config() + + +# @app.on_event("startup") +# async def on_startup(): +# await startup() diff --git a/backend/main.py b/backend/main.py index 8b5fd76bc..b5aa7e7d0 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,7 +20,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app -from apps.litellm.main import app as litellm_app, startup as litellm_app_startup +from apps.litellm.main import app as litellm_app from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app @@ -168,11 +168,6 @@ async def check_url(request: Request, call_next): return response -@app.on_event("startup") -async def on_startup(): - await litellm_app_startup() - - app.mount("/api/v1", webui_app) app.mount("/litellm/api", litellm_app) From a41b195f466d7c62eae700186ccc7cc30453c7be Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 01:13:24 -0500 Subject: [PATCH 02/13] DO NOT TRACK ME >:( --- backend/apps/litellm/main.py | 185 ++++++++++++++++++++++------------- 1 file changed, 119 insertions(+), 66 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 39f348141..947456881 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, Depends +from fastapi import FastAPI, Depends, HTTPException from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware @@ -9,9 +9,11 @@ from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import StreamingResponse import json +import requests -from utils.utils import get_http_authorization_cred, get_current_user +from utils.utils import get_verified_user, get_current_user from config import SRC_LOG_LEVELS, ENV +from constants import ERROR_MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["LITELLM"]) @@ -49,12 +51,13 @@ async def run_background_process(command): async def start_litellm_background(): # Command to run in the background - command = "litellm --config ./data/litellm/config.yaml" + command = "litellm --telemetry False --config ./data/litellm/config.yaml" await run_background_process(command) @app.on_event("startup") async def startup_event(): + # TODO: Check config.yaml file and create one asyncio.create_task(start_litellm_background()) @@ -62,82 +65,132 @@ app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -@app.middleware("http") -async def auth_middleware(request: Request, call_next): - auth_header = request.headers.get("Authorization", "") - request.state.user = None - - try: - user = get_current_user(get_http_authorization_cred(auth_header)) - log.debug(f"user: {user}") - request.state.user = user - except Exception as e: - return JSONResponse(status_code=400, content={"detail": str(e)}) - - response = await call_next(request) - return response - - @app.get("/") async def get_status(): return {"status": True} -class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: +@app.get("/models") +@app.get("/v1/models") +async def get_models(user=Depends(get_current_user)): + url = "http://localhost:4000/v1" + r = None + try: + r = requests.request(method="GET", url=f"{url}/models") + r.raise_for_status() - response = await call_next(request) - user = request.state.user + data = r.json() - if "/models" in request.url.path: - if isinstance(response, StreamingResponse): - # Read the content of the streaming response - body = b"" - async for chunk in response.body_iterator: - body += chunk + if app.state.MODEL_FILTER_ENABLED: + if user and user.role == "user": + data["data"] = list( + filter( + lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + data["data"], + ) + ) - data = json.loads(body.decode("utf-8")) + return data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" - if app.state.MODEL_FILTER_ENABLED: - if user and user.role == "user": - data["data"] = list( - filter( - lambda model: model["id"] - in app.state.MODEL_FILTER_LIST, - data["data"], - ) - ) - - # Modified Flag - data["modified"] = True - return JSONResponse(content=data) - - return response + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) -app.add_middleware(ModifyModelsResponseMiddleware) +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_verified_user)): + body = await request.body() + + url = "http://localhost:4000/v1" + + target_url = f"{url}/{path}" + + headers = {} + # headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + r = None + + try: + r = requests.request( + method=request.method, + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + return StreamingResponse( + r.iter_content(chunk_size=8192), + status_code=r.status_code, + headers=dict(r.headers), + ) + else: + response_data = r.json() + return response_data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except: + error_detail = f"External: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, detail=error_detail + ) -# from litellm.proxy.proxy_server import ProxyConfig, initialize -# from litellm.proxy.proxy_server import app +# class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): +# async def dispatch( +# self, request: Request, call_next: RequestResponseEndpoint +# ) -> Response: -# proxy_config = ProxyConfig() +# response = await call_next(request) +# user = request.state.user + +# if "/models" in request.url.path: +# if isinstance(response, StreamingResponse): +# # Read the content of the streaming response +# body = b"" +# async for chunk in response.body_iterator: +# body += chunk + +# data = json.loads(body.decode("utf-8")) + +# if app.state.MODEL_FILTER_ENABLED: +# if user and user.role == "user": +# data["data"] = list( +# filter( +# lambda model: model["id"] +# in app.state.MODEL_FILTER_LIST, +# data["data"], +# ) +# ) + +# # Modified Flag +# data["modified"] = True +# return JSONResponse(content=data) + +# return response -# async def config(): -# router, model_list, general_settings = await proxy_config.load_config( -# router=None, config_file_path="./data/litellm/config.yaml" -# ) - -# await initialize(config="./data/litellm/config.yaml", telemetry=False) - - -# async def startup(): -# await config() - - -# @app.on_event("startup") -# async def on_startup(): -# await startup() +# app.add_middleware(ModifyModelsResponseMiddleware) From 8651bec915ae23f26f02f07b34d52f9099097148 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 01:22:02 -0500 Subject: [PATCH 03/13] pwned :) --- backend/apps/litellm/main.py | 11 ++++++++++- backend/main.py | 8 +++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 947456881..5a8b37f47 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -43,20 +43,29 @@ app.add_middleware( async def run_background_process(command): + # Start the process process = await asyncio.create_subprocess_exec( *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - return process + # Read output asynchronously + async for line in process.stdout: + print(line.decode().strip()) # Print stdout line by line + + await process.wait() # Wait for the subprocess to finish async def start_litellm_background(): + print("start_litellm_background") # Command to run in the background command = "litellm --telemetry False --config ./data/litellm/config.yaml" + await run_background_process(command) @app.on_event("startup") async def startup_event(): + + print("startup_event") # TODO: Check config.yaml file and create one asyncio.create_task(start_litellm_background()) diff --git a/backend/main.py b/backend/main.py index b5aa7e7d0..48e14f1dd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,12 +20,13 @@ from starlette.middleware.base import BaseHTTPMiddleware from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app -from apps.litellm.main import app as litellm_app +from apps.litellm.main import app as litellm_app, start_litellm_background from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app from apps.web.main import app as webui_app +import asyncio from pydantic import BaseModel from typing import List @@ -168,6 +169,11 @@ async def check_url(request: Request, call_next): return response +@app.on_event("startup") +async def on_startup(): + asyncio.create_task(start_litellm_background()) + + app.mount("/api/v1", webui_app) app.mount("/litellm/api", litellm_app) From 3c382d4c6cbea0352a4ad4bc3a90ed8f339a148b Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 01:46:09 -0500 Subject: [PATCH 04/13] refac: close subprocess gracefully --- backend/apps/litellm/main.py | 51 +++++++++++++++++++++++++++++------- backend/main.py | 11 +++++++- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 5a8b37f47..68e48858b 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -42,16 +42,40 @@ app.add_middleware( ) -async def run_background_process(command): - # Start the process - process = await asyncio.create_subprocess_exec( - *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - # Read output asynchronously - async for line in process.stdout: - print(line.decode().strip()) # Print stdout line by line +# Global variable to store the subprocess reference +background_process = None - await process.wait() # Wait for the subprocess to finish + +async def run_background_process(command): + global background_process + print("run_background_process") + + try: + # Log the command to be executed + print(f"Executing command: {command}") + # Execute the command and create a subprocess + process = await asyncio.create_subprocess_exec( + *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + background_process = process + print("Subprocess started successfully.") + + # Capture STDERR for debugging purposes + stderr_output = await process.stderr.read() + stderr_text = stderr_output.decode().strip() + if stderr_text: + print(f"Subprocess STDERR: {stderr_text}") + + # Print output line by line + async for line in process.stdout: + print(line.decode().strip()) + + # Wait for the process to finish + returncode = await process.wait() + print(f"Subprocess exited with return code {returncode}") + except Exception as e: + log.error(f"Failed to start subprocess: {e}") + raise # Optionally re-raise the exception if you want it to propagate async def start_litellm_background(): @@ -62,6 +86,15 @@ async def start_litellm_background(): await run_background_process(command) +async def shutdown_litellm_background(): + print("shutdown_litellm_background") + global background_process + if background_process: + background_process.terminate() + await background_process.wait() # Ensure the process has terminated + print("Subprocess terminated") + + @app.on_event("startup") async def startup_event(): diff --git a/backend/main.py b/backend/main.py index 48e14f1dd..579ff2ee0 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,7 +20,11 @@ from starlette.middleware.base import BaseHTTPMiddleware from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app -from apps.litellm.main import app as litellm_app, start_litellm_background +from apps.litellm.main import ( + app as litellm_app, + start_litellm_background, + shutdown_litellm_background, +) from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app @@ -316,3 +320,8 @@ app.mount( SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), name="spa-static-files", ) + + +@app.on_event("shutdown") +async def shutdown_event(): + await shutdown_litellm_background() From a59fb6b9eb6bcbe438d15e2020b31d2ef6cdf580 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 01:47:35 -0500 Subject: [PATCH 05/13] fix --- backend/apps/litellm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 68e48858b..486ae4736 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -154,7 +154,7 @@ async def get_models(user=Depends(get_current_user)): async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body = await request.body() - url = "http://localhost:4000/v1" + url = "http://localhost:4000" target_url = f"{url}/{path}" From 51191168bc77b50165e5d937cbb54f592d71d1e2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 01:51:38 -0500 Subject: [PATCH 06/13] feat: restart subprocess route --- backend/apps/litellm/main.py | 65 +++++++++++++++--------------------- 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 486ae4736..531e96494 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -11,7 +11,7 @@ from starlette.responses import StreamingResponse import json import requests -from utils.utils import get_verified_user, get_current_user +from utils.utils import get_verified_user, get_current_user, get_admin_user from config import SRC_LOG_LEVELS, ENV from constants import ERROR_MESSAGES @@ -112,6 +112,32 @@ async def get_status(): return {"status": True} +@app.get("/restart") +async def restart_litellm(user=Depends(get_admin_user)): + """ + Endpoint to restart the litellm background service. + """ + log.info("Requested restart of litellm service.") + try: + # Shut down the existing process if it is running + await shutdown_litellm_background() + log.info("litellm service shutdown complete.") + + # Restart the background service + await start_litellm_background() + log.info("litellm service restart complete.") + + return { + "status": "success", + "message": "litellm service restarted successfully.", + } + except Exception as e: + log.error(f"Error restarting litellm service: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) + + @app.get("/models") @app.get("/v1/models") async def get_models(user=Depends(get_current_user)): @@ -199,40 +225,3 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail ) - - -# class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): -# async def dispatch( -# self, request: Request, call_next: RequestResponseEndpoint -# ) -> Response: - -# response = await call_next(request) -# user = request.state.user - -# if "/models" in request.url.path: -# if isinstance(response, StreamingResponse): -# # Read the content of the streaming response -# body = b"" -# async for chunk in response.body_iterator: -# body += chunk - -# data = json.loads(body.decode("utf-8")) - -# if app.state.MODEL_FILTER_ENABLED: -# if user and user.role == "user": -# data["data"] = list( -# filter( -# lambda model: model["id"] -# in app.state.MODEL_FILTER_LIST, -# data["data"], -# ) -# ) - -# # Modified Flag -# data["modified"] = True -# return JSONResponse(content=data) - -# return response - - -# app.add_middleware(ModifyModelsResponseMiddleware) From 2717fe7c207b3a0e19e23113e647ec8b6e78e4bc Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 02:00:03 -0500 Subject: [PATCH 07/13] fix --- backend/apps/litellm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 531e96494..68ae54fbc 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -124,7 +124,7 @@ async def restart_litellm(user=Depends(get_admin_user)): log.info("litellm service shutdown complete.") # Restart the background service - await start_litellm_background() + start_litellm_background() log.info("litellm service restart complete.") return { From 77426266d24464d51334909ca77474f566ca1c6b Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 14:32:45 -0500 Subject: [PATCH 08/13] refac: port number update --- backend/apps/litellm/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 68ae54fbc..8d1132bb4 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -81,7 +81,9 @@ async def run_background_process(command): async def start_litellm_background(): print("start_litellm_background") # Command to run in the background - command = "litellm --telemetry False --config ./data/litellm/config.yaml" + command = ( + "litellm --port 14365 --telemetry False --config ./data/litellm/config.yaml" + ) await run_background_process(command) @@ -141,7 +143,7 @@ async def restart_litellm(user=Depends(get_admin_user)): @app.get("/models") @app.get("/v1/models") async def get_models(user=Depends(get_current_user)): - url = "http://localhost:4000/v1" + url = "http://localhost:14365/v1" r = None try: r = requests.request(method="GET", url=f"{url}/models") @@ -180,7 +182,7 @@ async def get_models(user=Depends(get_current_user)): async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body = await request.body() - url = "http://localhost:4000" + url = "http://localhost:14365" target_url = f"{url}/{path}" From 8422d3ea79c134ff12e9120c3f27220a7ac2bd57 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 14:43:51 -0500 Subject: [PATCH 09/13] Update requirements.txt --- backend/requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/requirements.txt b/backend/requirements.txt index 5f41137c9..0b5e90433 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -18,6 +18,8 @@ peewee-migrate bcrypt litellm==1.35.17 +litellm['proxy']==1.35.17 + boto3 argon2-cffi From f83eb7326f7b4fcaf54493c61bc0344855429617 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 14:44:28 -0500 Subject: [PATCH 10/13] Update requirements.txt --- backend/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/requirements.txt b/backend/requirements.txt index 0b5e90433..e04551567 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -18,7 +18,7 @@ peewee-migrate bcrypt litellm==1.35.17 -litellm['proxy']==1.35.17 +litellm[proxy]==1.35.17 boto3 From 31124d9deb08c8283247b7b95313be59646fa7e0 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 16:10:01 -0500 Subject: [PATCH 11/13] feat: litellm config update --- backend/apps/litellm/main.py | 75 ++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 8d1132bb4..5696b6945 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -11,6 +11,9 @@ from starlette.responses import StreamingResponse import json import requests +from pydantic import BaseModel +from typing import Optional, List + from utils.utils import get_verified_user, get_current_user, get_admin_user from config import SRC_LOG_LEVELS, ENV from constants import ERROR_MESSAGES @@ -19,15 +22,12 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["LITELLM"]) -from config import ( - MODEL_FILTER_ENABLED, - MODEL_FILTER_LIST, -) +from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR import asyncio import subprocess - +import yaml app = FastAPI() @@ -42,44 +42,51 @@ app.add_middleware( ) +LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml" + +with open(LITELLM_CONFIG_DIR, "r") as file: + litellm_config = yaml.safe_load(file) + +app.state.CONFIG = litellm_config + # Global variable to store the subprocess reference background_process = None async def run_background_process(command): global background_process - print("run_background_process") + log.info("run_background_process") try: # Log the command to be executed - print(f"Executing command: {command}") + log.info(f"Executing command: {command}") # Execute the command and create a subprocess process = await asyncio.create_subprocess_exec( *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE ) background_process = process - print("Subprocess started successfully.") + log.info("Subprocess started successfully.") # Capture STDERR for debugging purposes stderr_output = await process.stderr.read() stderr_text = stderr_output.decode().strip() if stderr_text: - print(f"Subprocess STDERR: {stderr_text}") + log.info(f"Subprocess STDERR: {stderr_text}") - # Print output line by line + # log.info output line by line async for line in process.stdout: - print(line.decode().strip()) + log.info(line.decode().strip()) # Wait for the process to finish returncode = await process.wait() - print(f"Subprocess exited with return code {returncode}") + log.info(f"Subprocess exited with return code {returncode}") except Exception as e: log.error(f"Failed to start subprocess: {e}") raise # Optionally re-raise the exception if you want it to propagate async def start_litellm_background(): - print("start_litellm_background") + log.info("start_litellm_background") # Command to run in the background command = ( "litellm --port 14365 --telemetry False --config ./data/litellm/config.yaml" @@ -89,18 +96,18 @@ async def start_litellm_background(): async def shutdown_litellm_background(): - print("shutdown_litellm_background") + log.info("shutdown_litellm_background") global background_process if background_process: background_process.terminate() await background_process.wait() # Ensure the process has terminated - print("Subprocess terminated") + log.info("Subprocess terminated") @app.on_event("startup") async def startup_event(): - print("startup_event") + log.info("startup_event") # TODO: Check config.yaml file and create one asyncio.create_task(start_litellm_background()) @@ -114,8 +121,7 @@ async def get_status(): return {"status": True} -@app.get("/restart") -async def restart_litellm(user=Depends(get_admin_user)): +async def restart_litellm(): """ Endpoint to restart the litellm background service. """ @@ -126,7 +132,8 @@ async def restart_litellm(user=Depends(get_admin_user)): log.info("litellm service shutdown complete.") # Restart the background service - start_litellm_background() + + asyncio.create_task(start_litellm_background()) log.info("litellm service restart complete.") return { @@ -134,12 +141,40 @@ async def restart_litellm(user=Depends(get_admin_user)): "message": "litellm service restarted successfully.", } except Exception as e: - log.error(f"Error restarting litellm service: {e}") + log.info(f"Error restarting litellm service: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) +@app.get("/restart") +async def restart_litellm_handler(user=Depends(get_admin_user)): + return await restart_litellm() + + +@app.get("/config") +async def get_config(user=Depends(get_admin_user)): + return app.state.CONFIG + + +class LiteLLMConfigForm(BaseModel): + general_settings: Optional[dict] = None + litellm_settings: Optional[dict] = None + model_list: Optional[List[dict]] = None + router_settings: Optional[dict] = None + + +@app.post("/config/update") +async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)): + app.state.CONFIG = form_data.model_dump(exclude_none=True) + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + return app.state.CONFIG + + @app.get("/models") @app.get("/v1/models") async def get_models(user=Depends(get_current_user)): From e627b8bf21d2eb5f78f753ed6896ea9255d9e2eb Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 17:26:22 -0500 Subject: [PATCH 12/13] feat: litellm model add/delete --- backend/apps/litellm/main.py | 50 +++++++++++++++++++ .../components/chat/Settings/Models.svelte | 12 ++--- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 5696b6945..9bc08598f 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -102,6 +102,7 @@ async def shutdown_litellm_background(): background_process.terminate() await background_process.wait() # Ensure the process has terminated log.info("Subprocess terminated") + background_process = None @app.on_event("startup") @@ -178,6 +179,9 @@ async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_use @app.get("/models") @app.get("/v1/models") async def get_models(user=Depends(get_current_user)): + while not background_process: + await asyncio.sleep(0.1) + url = "http://localhost:14365/v1" r = None try: @@ -213,6 +217,52 @@ async def get_models(user=Depends(get_current_user)): ) +@app.get("/model/info") +async def get_model_list(user=Depends(get_admin_user)): + return {"data": app.state.CONFIG["model_list"]} + + +class AddLiteLLMModelForm(BaseModel): + model_name: str + litellm_params: dict + + +@app.post("/model/new") +async def add_model_to_config( + form_data: AddLiteLLMModelForm, user=Depends(get_admin_user) +): + app.state.CONFIG["model_list"].append(form_data.model_dump()) + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + + return {"message": "model added"} + + +class DeleteLiteLLMModelForm(BaseModel): + id: str + + +@app.post("/model/delete") +async def delete_model_from_config( + form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user) +): + app.state.CONFIG["model_list"] = [ + model + for model in app.state.CONFIG["model_list"] + if model["model_name"] != form_data.id + ] + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + + return {"message": "model deleted"} + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body = await request.body() diff --git a/src/lib/components/chat/Settings/Models.svelte b/src/lib/components/chat/Settings/Models.svelte index 15b054024..688774d78 100644 --- a/src/lib/components/chat/Settings/Models.svelte +++ b/src/lib/components/chat/Settings/Models.svelte @@ -35,7 +35,7 @@ let liteLLMRPM = ''; let liteLLMMaxTokens = ''; - let deleteLiteLLMModelId = ''; + let deleteLiteLLMModelName = ''; $: liteLLMModelName = liteLLMModel; @@ -472,7 +472,7 @@ }; const deleteLiteLLMModelHandler = async () => { - const res = await deleteLiteLLMModel(localStorage.token, deleteLiteLLMModelId).catch( + const res = await deleteLiteLLMModel(localStorage.token, deleteLiteLLMModelName).catch( (error) => { toast.error(error); return null; @@ -485,7 +485,7 @@ } } - deleteLiteLLMModelId = ''; + deleteLiteLLMModelName = ''; liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token); models.set(await getModels()); }; @@ -1099,14 +1099,14 @@