From 3faf9d206765caa034d8cfbd5c303abc0ce03e09 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 17 Nov 2024 19:15:09 -0800 Subject: [PATCH] refac: model access control behaviour --- backend/open_webui/apps/ollama/main.py | 16 ++++++++++++---- backend/open_webui/apps/openai/main.py | 8 ++++++-- backend/open_webui/main.py | 16 ++++++++++++---- .../workspace/Models/ModelEditor.svelte | 8 ++++++-- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 09862c111..3a94b721d 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -362,8 +362,6 @@ async def get_ollama_tags( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - else: - filtered_models.append(model) models["models"] = filtered_models return models @@ -960,6 +958,12 @@ async def generate_chat_completion( status_code=403, detail="Model not found", ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" @@ -1048,6 +1052,12 @@ async def generate_openai_chat_completion( status_code=403, detail="Model not found", ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" @@ -1130,8 +1140,6 @@ async def get_openai_models( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - else: - filtered_models.append(model) models = filtered_models return { diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index ff842a374..4c710f840 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -424,8 +424,6 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - else: - filtered_models.append(model) models["data"] = filtered_models return models @@ -512,6 +510,12 @@ async def generate_chat_completion( status_code=403, detail="Model not found", ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) # Attemp to get urlIdx from the model models = await get_all_models() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 37c94247a..cd1c73fbf 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -557,7 +557,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): model_info = Models.get_model_by_id(model["id"]) if user.role == "user": - if model_info and not ( + if not model_info: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"detail": "Model not found"}, + ) + elif not ( user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control @@ -1113,8 +1118,6 @@ async def get_models(user=Depends(get_verified_user)): user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) - else: - filtered_models.append(model) models = filtered_models return {"data": models} @@ -1147,7 +1150,12 @@ async def generate_chat_completions( # Check if user has access to the model if user.role == "user": model_info = Models.get_model_by_id(model_id) - if model_info and not ( + if not model_info: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + elif not ( user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control diff --git a/src/lib/components/workspace/Models/ModelEditor.svelte b/src/lib/components/workspace/Models/ModelEditor.svelte index c2fb9b675..7a7b376f5 100644 --- a/src/lib/components/workspace/Models/ModelEditor.svelte +++ b/src/lib/components/workspace/Models/ModelEditor.svelte @@ -79,7 +79,7 @@ let filterIds = []; let actionIds = []; - let accessControl = null; + let accessControl = {}; const addUsage = (base_model_id) => { const baseModel = $models.find((m) => m.id === base_model_id); @@ -213,7 +213,11 @@ capabilities.usage = false; } - accessControl = model?.access_control ?? null; + if ('access_control' in model) { + accessControl = model.access_control; + } else { + accessControl = {}; + } console.log(model?.access_control); console.log(accessControl);