diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 21b9e58a7..f7936398e 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -1,11 +1,23 @@ from litellm.proxy.proxy_server import ProxyConfig, initialize from litellm.proxy.proxy_server import app -from fastapi import FastAPI, Request, Depends, status +from fastapi import FastAPI, Request, Depends, status, Response from fastapi.responses import JSONResponse + +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.responses import StreamingResponse +import json + from utils.utils import get_http_authorization_cred, get_current_user from config import ENV + +from config import ( + MODEL_FILTER_ENABLED, + MODEL_FILTER_LIST, +) + + proxy_config = ProxyConfig() @@ -26,16 +38,67 @@ async def on_startup(): await startup() +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST + + @app.middleware("http") async def auth_middleware(request: Request, call_next): auth_header = request.headers.get("Authorization", "") + request.state.user = None if ENV != "dev": try: user = get_current_user(get_http_authorization_cred(auth_header)) print(user) + request.state.user = user except Exception as e: return JSONResponse(status_code=400, content={"detail": str(e)}) response = await call_next(request) return response + + +class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + + response = await call_next(request) + user = request.state.user + + # Check if the request is for the `/models` route + if "/models" in request.url.path: + # Ensure the response is a StreamingResponse + if isinstance(response, StreamingResponse): + # Read the content of the streaming response + body = b"" + async for chunk in response.body_iterator: + body += chunk + + # Modify the content as needed + data = json.loads(body.decode("utf-8")) + + print(data) + + if app.state.MODEL_FILTER_ENABLED: + if user and user.role == "user": + data["data"] = list( + filter( + lambda model: model["id"] + in app.state.MODEL_FILTER_LIST, + data["data"], + ) + ) + + # Example modification: Add a new key-value pair + data["modified"] = True + + # Return a new JSON response with the modified content + return JSONResponse(content=data) + + return response + + +# Add the middleware to the app +app.add_middleware(ModifyModelsResponseMiddleware) diff --git a/backend/config.py b/backend/config.py index e99e248c5..9236e8a86 100644 --- a/backend/config.py +++ b/backend/config.py @@ -298,7 +298,7 @@ USER_PERMISSIONS_CHAT_DELETION = ( USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} -MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False) +MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true" MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]