From d85480b4d6b364329e710189dad849e79a30f650 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 23 Oct 2024 15:05:43 -0700 Subject: [PATCH] fix: bypass_filter for arena models --- backend/open_webui/main.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 58c29e725..1c7e5dd21 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1082,7 +1082,9 @@ async def get_models(user=Depends(get_verified_user)): @app.post("/api/chat/completions") -async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): +async def generate_chat_completions( + form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False +): model_id = form_data["model"] if model_id not in app.state.MODELS: @@ -1091,7 +1093,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u detail="Model not found", ) - if app.state.config.ENABLE_MODEL_FILTER: + if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER: if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -1103,7 +1105,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u if model["owned_by"] == "arena": model_ids = model.get("info", {}).get("meta", {}).get("model_ids") filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == 'exclude': + if model_ids and filter_mode == "exclude": model_ids = [ model["id"] for model in await get_all_models() @@ -1133,13 +1135,17 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u async for chunk in stream: yield chunk - response = await generate_chat_completions(form_data, user) + response = await generate_chat_completions( + form_data, user, bypass_filter=True + ) return StreamingResponse( stream_wrapper(response.body_iterator), media_type="text/event-stream" ) else: return { - **(await generate_chat_completions(form_data, user)), + **( + await generate_chat_completions(form_data, user, bypass_filter=True) + ), "selected_model_id": selected_model_id, } if model.get("pipe"):