From 3aa6b0fea916bf13a12aa5757d4c2e41890c710d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 17 May 2024 19:11:14 -0700 Subject: [PATCH] fix: model filter issue --- backend/apps/litellm/main.py | 8 ++++---- backend/apps/ollama/main.py | 14 ++++++++------ backend/apps/openai/main.py | 9 +++++---- backend/main.py | 8 ++++---- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 6db426439..d70056ee6 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -75,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file: litellm_config = yaml.safe_load(file) +app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST + + app.state.ENABLE = ENABLE_LITELLM app.state.CONFIG = litellm_config @@ -151,10 +155,6 @@ async def shutdown_litellm_background(): background_process = None -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST - - @app.get("/") async def get_status(): return {"status": True} diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index b4451d35d..df268067f 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -64,8 +64,8 @@ app.add_middleware( app.state.config = AppConfig() -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER +app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -178,11 +178,12 @@ async def get_ollama_tags( if url_idx == None: models = await get_all_models() - if app.state.ENABLE_MODEL_FILTER: + if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user": models["models"] = list( filter( - lambda model: model["name"] in app.state.MODEL_FILTER_LIST, + lambda model: model["name"] + in app.state.config.MODEL_FILTER_LIST, models["models"], ) ) @@ -1046,11 +1047,12 @@ async def get_openai_models( if url_idx == None: models = await get_all_models() - if app.state.ENABLE_MODEL_FILTER: + if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user": models["models"] = list( filter( - lambda model: model["name"] in app.state.MODEL_FILTER_LIST, + lambda model: model["name"] + in app.state.config.MODEL_FILTER_LIST, models["models"], ) ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index a153bde0b..85ee531f1 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -47,10 +47,11 @@ app.add_middleware( allow_headers=["*"], ) + app.state.config = AppConfig() -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER +app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API @@ -259,11 +260,11 @@ async def get_all_models(): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): if url_idx == None: models = await get_all_models() - if app.state.ENABLE_MODEL_FILTER: + if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user": models["data"] = list( filter( - lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, models["data"], ) ) diff --git a/backend/main.py b/backend/main.py index 165ba24b6..209199591 100644 --- a/backend/main.py +++ b/backend/main.py @@ -292,11 +292,11 @@ async def update_model_filter_config( app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.MODEL_FILTER_LIST = form_data.models - ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST + ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST + openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST