From 4a34ca35f07b25e38a1f6286e564fa843e2e576a Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 17 Nov 2024 01:46:51 -0800 Subject: [PATCH] refac: access control --- backend/open_webui/apps/ollama/main.py | 133 ++++++++++++++----------- backend/open_webui/apps/openai/main.py | 2 +- 2 files changed, 75 insertions(+), 60 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 07bf43510..449059160 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -43,6 +43,7 @@ from open_webui.utils.payload import ( apply_model_system_prompt_to_body, ) from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) @@ -316,22 +317,9 @@ async def get_all_models(): async def get_ollama_tags( url_idx: Optional[int] = None, user=Depends(get_verified_user) ): + models = [] if url_idx is None: models = await get_all_models() - - # TODO: Check User Group and Filter Models - # if app.state.config.ENABLE_MODEL_FILTER: - # if user.role == "user": - # models["models"] = list( - # filter( - # lambda model: model["name"] - # in app.state.config.MODEL_FILTER_LIST, - # models["models"], - # ) - # ) - # return models - - return models else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -347,7 +335,7 @@ async def get_ollama_tags( r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers) r.raise_for_status() - return r.json() + models = r.json() except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" @@ -363,6 +351,23 @@ async def get_ollama_tags( status_code=r.status_code if r else 500, detail=error_detail, ) + + if user.role == "user": + # Filter models based on user access control + filtered_models = [] + for model in models.get("models", []): + model_info = Models.get_model_by_id(model["model"]) + if model_info: + if has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + else: + filtered_models.append(model) + models["models"] = filtered_models + + + return models @app.get("/api/version") @@ -926,16 +931,9 @@ async def generate_chat_completion( if "metadata" in payload: del payload["metadata"] - model_id = form_data.model - - # TODO: Check User Group and Filter Models - # if not bypass_filter: - # if app.state.config.ENABLE_MODEL_FILTER: - # if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - # raise HTTPException( - # status_code=403, - # detail="Model not found", - # ) + model_id = payload["model"] + if ":" not in model_id: + model_id = f"{model_id}:latest" model_info = Models.get_model_by_id(model_id) @@ -954,9 +952,19 @@ async def generate_chat_completion( ) payload = apply_model_system_prompt_to_body(params, payload, user) + # Check if user has access to the model + if not bypass_filter and user.role == "user" and not has_access( + user.id, type="read", access_control=model_info.access_control + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" + url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") @@ -1015,17 +1023,11 @@ async def generate_openai_chat_completion( del payload["metadata"] model_id = completion_form.model + if ":" not in model_id: + model_id = f"{model_id}:latest" - # TODO: Check User Group and Filter Models - # if app.state.config.ENABLE_MODEL_FILTER: - # if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - # raise HTTPException( - # status_code=403, - # detail="Model not found", - # ) model_info = Models.get_model_by_id(model_id) - if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -1036,6 +1038,15 @@ async def generate_openai_chat_completion( payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) + # Check if user has access to the model + if user.role == "user" and not has_access( + user.id, type="read", access_control=model_info.access_control + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" @@ -1060,32 +1071,19 @@ async def get_openai_models( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): + + models = [] if url_idx is None: - models = await get_all_models() - - # TODO: Check User Group and Filter Models - # if app.state.config.ENABLE_MODEL_FILTER: - # if user.role == "user": - # models["models"] = list( - # filter( - # lambda model: model["name"] - # in app.state.config.MODEL_FILTER_LIST, - # models["models"], - # ) - # ) - - return { - "data": [ + model_list = await get_all_models() + models = [ { "id": model["model"], "object": "model", "created": int(time.time()), "owned_by": "openai", } - for model in models["models"] - ], - "object": "list", - } + for model in model_list["models"] + ] else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -1093,10 +1091,9 @@ async def get_openai_models( r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() - models = r.json() + model_list = r.json() - return { - "data": [ + models = [ { "id": model["model"], "object": "model", @@ -1104,10 +1101,7 @@ async def get_openai_models( "owned_by": "openai", } for model in models["models"] - ], - "object": "list", - } - + ] except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" @@ -1123,6 +1117,27 @@ async def get_openai_models( status_code=r.status_code if r else 500, detail=error_detail, ) + + + if user.role == "user": + # Filter models based on user access control + filtered_models = [] + for model in models: + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + else: + filtered_models.append(model) + models = filtered_models + + + return { + "data": models, + "object": "list", + } class UrlForm(BaseModel): diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index cbea60467..6174cad1c 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -501,7 +501,7 @@ async def generate_chat_completion( payload = apply_model_system_prompt_to_body(params, payload, user) # Check if user has access to the model - if user.role == "user" and not has_access( + if not bypass_filter and user.role == "user" and not has_access( user.id, type="read", access_control=model_info.access_control ): raise HTTPException(