diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 39f348141..947456881 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, Depends +from fastapi import FastAPI, Depends, HTTPException from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware @@ -9,9 +9,11 @@ from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import StreamingResponse import json +import requests -from utils.utils import get_http_authorization_cred, get_current_user +from utils.utils import get_verified_user, get_current_user from config import SRC_LOG_LEVELS, ENV +from constants import ERROR_MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["LITELLM"]) @@ -49,12 +51,13 @@ async def run_background_process(command): async def start_litellm_background(): # Command to run in the background - command = "litellm --config ./data/litellm/config.yaml" + command = "litellm --telemetry False --config ./data/litellm/config.yaml" await run_background_process(command) @app.on_event("startup") async def startup_event(): + # TODO: Check config.yaml file and create one asyncio.create_task(start_litellm_background()) @@ -62,82 +65,132 @@ app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -@app.middleware("http") -async def auth_middleware(request: Request, call_next): - auth_header = request.headers.get("Authorization", "") - request.state.user = None - - try: - user = get_current_user(get_http_authorization_cred(auth_header)) - log.debug(f"user: {user}") - request.state.user = user - except Exception as e: - return JSONResponse(status_code=400, content={"detail": str(e)}) - - response = await call_next(request) - return response - - @app.get("/") async def get_status(): return {"status": True} -class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: +@app.get("/models") +@app.get("/v1/models") +async def get_models(user=Depends(get_current_user)): + url = "http://localhost:4000/v1" + r = None + try: + r = requests.request(method="GET", url=f"{url}/models") + r.raise_for_status() - response = await call_next(request) - user = request.state.user + data = r.json() - if "/models" in request.url.path: - if isinstance(response, StreamingResponse): - # Read the content of the streaming response - body = b"" - async for chunk in response.body_iterator: - body += chunk + if app.state.MODEL_FILTER_ENABLED: + if user and user.role == "user": + data["data"] = list( + filter( + lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + data["data"], + ) + ) - data = json.loads(body.decode("utf-8")) + return data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" - if app.state.MODEL_FILTER_ENABLED: - if user and user.role == "user": - data["data"] = list( - filter( - lambda model: model["id"] - in app.state.MODEL_FILTER_LIST, - data["data"], - ) - ) - - # Modified Flag - data["modified"] = True - return JSONResponse(content=data) - - return response + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) -app.add_middleware(ModifyModelsResponseMiddleware) +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_verified_user)): + body = await request.body() + + url = "http://localhost:4000/v1" + + target_url = f"{url}/{path}" + + headers = {} + # headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + r = None + + try: + r = requests.request( + method=request.method, + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + return StreamingResponse( + r.iter_content(chunk_size=8192), + status_code=r.status_code, + headers=dict(r.headers), + ) + else: + response_data = r.json() + return response_data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except: + error_detail = f"External: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, detail=error_detail + ) -# from litellm.proxy.proxy_server import ProxyConfig, initialize -# from litellm.proxy.proxy_server import app +# class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): +# async def dispatch( +# self, request: Request, call_next: RequestResponseEndpoint +# ) -> Response: -# proxy_config = ProxyConfig() +# response = await call_next(request) +# user = request.state.user + +# if "/models" in request.url.path: +# if isinstance(response, StreamingResponse): +# # Read the content of the streaming response +# body = b"" +# async for chunk in response.body_iterator: +# body += chunk + +# data = json.loads(body.decode("utf-8")) + +# if app.state.MODEL_FILTER_ENABLED: +# if user and user.role == "user": +# data["data"] = list( +# filter( +# lambda model: model["id"] +# in app.state.MODEL_FILTER_LIST, +# data["data"], +# ) +# ) + +# # Modified Flag +# data["modified"] = True +# return JSONResponse(content=data) + +# return response -# async def config(): -# router, model_list, general_settings = await proxy_config.load_config( -# router=None, config_file_path="./data/litellm/config.yaml" -# ) - -# await initialize(config="./data/litellm/config.yaml", telemetry=False) - - -# async def startup(): -# await config() - - -# @app.on_event("startup") -# async def on_startup(): -# await startup() +# app.add_middleware(ModifyModelsResponseMiddleware)