From a07ff56c5010b127fc8b64f82d1f1e14d293e27a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 20:15:23 -0800 Subject: [PATCH] wip --- backend/open_webui/main.py | 34 +++++++------ backend/open_webui/routers/openai.py | 71 ++++++++++++++-------------- 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index dbb9518af..a632a3874 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1009,9 +1009,12 @@ async def get_body_and_model_and_user(request, models): class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if not request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] + if not ( + request.method == "POST" + and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) ): return await call_next(request) log.debug(f"request.url.path: {request.url.path}") @@ -1214,9 +1217,12 @@ app.add_middleware(ChatCompletionMiddleware) class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if not request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] + if not ( + request.method == "POST" + and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) ): return await call_next(request) @@ -1664,17 +1670,17 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}): return openai_chat_completion_message_template(form_data["model"], message) -async def get_all_base_models(): +async def get_all_base_models(request): function_models = [] openai_models = [] ollama_models = [] if app.state.config.ENABLE_OPENAI_API: - openai_models = await openai.get_all_models() + openai_models = await openai.get_all_models(request) openai_models = openai_models["data"] if app.state.config.ENABLE_OLLAMA_API: - ollama_models = await ollama.get_all_models() + ollama_models = await ollama.get_all_models(request) ollama_models = [ { "id": model["model"], @@ -1729,8 +1735,8 @@ async def get_all_base_models(): @cached(ttl=3) -async def get_all_models(): - models = await get_all_base_models() +async def get_all_models(request): + models = await get_all_base_models(request) # If there are no models, return an empty list if len([model for model in models if not model.get("arena", False)]) == 0: @@ -1859,8 +1865,8 @@ async def get_all_models(): @app.get("/api/models") -async def get_models(user=Depends(get_verified_user)): - models = await get_all_models() +async def get_models(request: Request, user=Depends(get_verified_user)): + models = await get_all_models(request) # Filter out filter pipelines models = [ @@ -2042,7 +2048,7 @@ async def generate_chat_completions( async def chat_completed( request: Request, form_data: dict, user=Depends(get_verified_user) ): - model_list = await get_all_models() + model_list = await get_all_models(request) models = {model["id"]: model for model in model_list} data = form_data diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 34c5683a8..657f3662a 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -245,41 +245,6 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -def merge_models_lists(model_lists): - log.debug(f"merge_models_lists {model_lists}") - merged_list = [] - - for idx, models in enumerate(model_lists): - if models is not None and "error" not in models: - merged_list.extend( - [ - { - **model, - "name": model.get("name", model["id"]), - "owned_by": "openai", - "openai": model, - "urlIdx": idx, - } - for model in models - if "api.openai.com" - not in request.app.state.config.OPENAI_API_BASE_URLS[idx] - or not any( - name in model["id"] - for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", - ] - ) - ] - ) - - return merged_list - - async def get_all_models_responses(request: Request) -> list: if not request.app.state.config.ENABLE_OPENAI_API: return [] @@ -379,7 +344,7 @@ async def get_all_models(request: Request) -> dict[str, list]: if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} - responses = await get_all_models_responses() + responses = await get_all_models_responses(request) def extract_data(response): if response and "data" in response: @@ -388,6 +353,40 @@ async def get_all_models(request: Request) -> dict[str, list]: return response return None + def merge_models_lists(model_lists): + log.debug(f"merge_models_lists {model_lists}") + merged_list = [] + + for idx, models in enumerate(model_lists): + if models is not None and "error" not in models: + merged_list.extend( + [ + { + **model, + "name": model.get("name", model["id"]), + "owned_by": "openai", + "openai": model, + "urlIdx": idx, + } + for model in models + if "api.openai.com" + not in request.app.state.config.OPENAI_API_BASE_URLS[idx] + or not any( + name in model["id"] + for name in [ + "babbage", + "dall-e", + "davinci", + "embedding", + "tts", + "whisper", + ] + ) + ] + ) + + return merged_list + models = {"data": merge_models_lists(map(extract_data, responses))} log.debug(f"models: {models}")