From 1d4c3a8c588285163e45ec473f344f251b88034d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 17 Nov 2024 02:51:57 -0800 Subject: [PATCH] fix: access control behaviour --- backend/open_webui/apps/ollama/main.py | 91 +++++++++++++------------- backend/open_webui/apps/openai/main.py | 20 +++--- backend/open_webui/main.py | 16 +++-- 3 files changed, 70 insertions(+), 57 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 449059160..463cc86cc 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -351,22 +351,21 @@ async def get_ollama_tags( status_code=r.status_code if r else 500, detail=error_detail, ) - + if user.role == "user": # Filter models based on user access control filtered_models = [] for model in models.get("models", []): model_info = Models.get_model_by_id(model["model"]) if model_info: - if has_access( + if user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) else: filtered_models.append(model) models["models"] = filtered_models - - + return models @@ -953,18 +952,21 @@ async def generate_chat_completion( payload = apply_model_system_prompt_to_body(params, payload, user) # Check if user has access to the model - if not bypass_filter and user.role == "user" and not has_access( - user.id, type="read", access_control=model_info.access_control - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - + if not bypass_filter and user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") @@ -1026,7 +1028,6 @@ async def generate_openai_chat_completion( if ":" not in model_id: model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: @@ -1039,13 +1040,17 @@ async def generate_openai_chat_completion( payload = apply_model_system_prompt_to_body(params, payload, user) # Check if user has access to the model - if user.role == "user" and not has_access( - user.id, type="read", access_control=model_info.access_control - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" @@ -1071,19 +1076,19 @@ async def get_openai_models( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - + models = [] if url_idx is None: model_list = await get_all_models() models = [ - { - "id": model["model"], - "object": "model", - "created": int(time.time()), - "owned_by": "openai", - } - for model in model_list["models"] - ] + { + "id": model["model"], + "object": "model", + "created": int(time.time()), + "owned_by": "openai", + } + for model in model_list["models"] + ] else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -1094,14 +1099,14 @@ async def get_openai_models( model_list = r.json() models = [ - { - "id": model["model"], - "object": "model", - "created": int(time.time()), - "owned_by": "openai", - } - for model in models["models"] - ] + { + "id": model["model"], + "object": "model", + "created": int(time.time()), + "owned_by": "openai", + } + for model in models["models"] + ] except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" @@ -1117,7 +1122,6 @@ async def get_openai_models( status_code=r.status_code if r else 500, detail=error_detail, ) - if user.role == "user": # Filter models based on user access control @@ -1125,19 +1129,18 @@ async def get_openai_models( for model in models: model_info = Models.get_model_by_id(model["id"]) if model_info: - if has_access( + if user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) else: filtered_models.append(model) models = filtered_models - return { - "data": models, - "object": "list", - } + "data": models, + "object": "list", + } class UrlForm(BaseModel): diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 6174cad1c..ff842a374 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -420,7 +420,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us for model in models.get("data", []): model_info = Models.get_model_by_id(model["id"]) if model_info: - if has_access( + if user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) @@ -501,13 +501,17 @@ async def generate_chat_completion( payload = apply_model_system_prompt_to_body(params, payload, user) # Check if user has access to the model - if not bypass_filter and user.role == "user" and not has_access( - user.id, type="read", access_control=model_info.access_control - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) + if not bypass_filter and user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) # Attemp to get urlIdx from the model models = await get_all_models() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7fdd45c97..f639b932c 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -557,8 +557,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): model_info = Models.get_model_by_id(model["id"]) if user.role == "user": - if model_info and not has_access( - user.id, type="read", access_control=model_info.access_control + if model_info and not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) ): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, @@ -1106,7 +1109,7 @@ async def get_models(user=Depends(get_verified_user)): for model in models: model_info = Models.get_model_by_id(model["id"]) if model_info: - if has_access( + if user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) @@ -1144,8 +1147,11 @@ async def generate_chat_completions( # Check if user has access to the model if user.role == "user": model_info = Models.get_model_by_id(model_id) - if not has_access( - user.id, type="read", access_control=model_info.access_control + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) ): raise HTTPException( status_code=403,