From 866c3dff116bd96707f8ed3f11f6331298d6857f Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 20:39:55 -0800 Subject: [PATCH] fix --- backend/open_webui/main.py | 38 +++++++++++++++++++--------- backend/open_webui/routers/ollama.py | 18 ++++++------- backend/open_webui/routers/openai.py | 2 +- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index c1fab6b9c..a49c225b3 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -70,6 +70,15 @@ from open_webui.routers import ( users, utils, ) + +from open_webui.routers.openai import ( + generate_chat_completion as generate_openai_chat_completion, +) + +from open_webui.routers.ollama import ( + generate_chat_completion as generate_ollama_chat_completion, +) + from open_webui.routers.retrieval import ( get_embedding_function, get_ef, @@ -1019,8 +1028,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): return await call_next(request) log.debug(f"request.url.path: {request.url.path}") - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + await get_all_models(request) + models = app.state.MODELS try: body, model, user = await get_body_and_model_and_user(request, models) @@ -1257,7 +1266,7 @@ class PipelineMiddleware(BaseHTTPMiddleware): content={"detail": e.detail}, ) - await get_all_models() + await get_all_models(request) models = app.state.MODELS try: @@ -1924,6 +1933,7 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)): @app.post("/api/chat/completions") async def generate_chat_completions( + request: Request, form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False, @@ -1931,8 +1941,7 @@ async def generate_chat_completions( if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - model_list = app.state.MODELS - models = {model["id"]: model for model in model_list} + models = app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -1981,7 +1990,7 @@ async def generate_chat_completions( if model_ids and filter_mode == "exclude": model_ids = [ model["id"] - for model in await get_all_models() + for model in await get_all_models(request) if model.get("owned_by") != "arena" and model["id"] not in model_ids ] @@ -1991,7 +2000,7 @@ async def generate_chat_completions( else: model_ids = [ model["id"] - for model in await get_all_models() + for model in await get_all_models(request) if model.get("owned_by") != "arena" ] selected_model_id = random.choice(model_ids) @@ -2028,6 +2037,7 @@ async def generate_chat_completions( # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) response = await generate_ollama_chat_completion( + request=request, form_data=form_data, user=user, bypass_filter=bypass_filter ) if form_data.stream: @@ -2040,6 +2050,8 @@ async def generate_chat_completions( return convert_response_ollama_to_openai(response) else: return await generate_openai_chat_completion( + request=request, + form_data, user=user, bypass_filter=bypass_filter ) @@ -2048,8 +2060,8 @@ async def generate_chat_completions( async def chat_completed( request: Request, form_data: dict, user=Depends(get_verified_user) ): - model_list = await get_all_models(request) - models = {model["id"]: model for model in model_list} + await get_all_models(request) + models = app.state.MODELS data = form_data model_id = data["model"] @@ -2183,7 +2195,9 @@ async def chat_completed( @app.post("/api/chat/actions/{action_id}") -async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): +async def chat_action( + request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) +): if "." in action_id: action_id, sub_action_id = action_id.split(".") else: @@ -2196,8 +2210,8 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified detail="Action not found", ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + await get_all_models(request) + models = app.state.MODELS data = form_data model_id = data["model"] diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index c36c2d730..233e30ce5 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -344,7 +344,7 @@ async def get_ollama_tags( models = [] if url_idx is None: - models = await get_all_models() + models = await get_all_models(request) else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) @@ -565,7 +565,7 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS if form_data.source in models: @@ -620,7 +620,7 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -670,7 +670,7 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS if form_data.name not in models: @@ -734,7 +734,7 @@ async def embed( log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -803,7 +803,7 @@ async def embeddings( log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: - await get_all_models() + await get_all_models(request) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -878,8 +878,8 @@ async def generate_completion( user=Depends(get_verified_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -1200,7 +1200,7 @@ async def get_openai_models( models = [] if url_idx is None: - model_list = await get_all_models() + model_list = await get_all_models(request) models = [ { "id": model["model"], diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 657f3662a..f7f78be85 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -404,7 +404,7 @@ async def get_models( } if url_idx is None: - models = await get_all_models() + models = await get_all_models(request) else: url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx]