refac: access control

This commit is contained in:
Timothy J. Baek 2024-11-17 01:46:51 -08:00
parent a2dcbc41e5
commit 4a34ca35f0
2 changed files with 75 additions and 60 deletions

View File

@ -43,6 +43,7 @@ from open_webui.utils.payload import (
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
@ -316,22 +317,9 @@ async def get_all_models():
async def get_ollama_tags( async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_verified_user) url_idx: Optional[int] = None, user=Depends(get_verified_user)
): ):
models = []
if url_idx is None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
# TODO: Check User Group and Filter Models
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user":
# models["models"] = list(
# filter(
# lambda model: model["name"]
# in app.state.config.MODEL_FILTER_LIST,
# models["models"],
# )
# )
# return models
return models
else: else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@ -347,7 +335,7 @@ async def get_ollama_tags(
r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers) r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
r.raise_for_status() r.raise_for_status()
return r.json() models = r.json()
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
@ -364,6 +352,23 @@ async def get_ollama_tags(
detail=error_detail, detail=error_detail,
) )
if user.role == "user":
# Filter models based on user access control
filtered_models = []
for model in models.get("models", []):
model_info = Models.get_model_by_id(model["model"])
if model_info:
if has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
else:
filtered_models.append(model)
models["models"] = filtered_models
return models
@app.get("/api/version") @app.get("/api/version")
@app.get("/api/version/{url_idx}") @app.get("/api/version/{url_idx}")
@ -926,16 +931,9 @@ async def generate_chat_completion(
if "metadata" in payload: if "metadata" in payload:
del payload["metadata"] del payload["metadata"]
model_id = form_data.model model_id = payload["model"]
if ":" not in model_id:
# TODO: Check User Group and Filter Models model_id = f"{model_id}:latest"
# if not bypass_filter:
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
# raise HTTPException(
# status_code=403,
# detail="Model not found",
# )
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
@ -954,9 +952,19 @@ async def generate_chat_completion(
) )
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user" and not has_access(
user.id, type="read", access_control=model_info.access_control
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = await get_ollama_url(url_idx, payload["model"]) url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}") log.info(f"url: {url}")
log.debug(f"generate_chat_completion() - 2.payload = {payload}") log.debug(f"generate_chat_completion() - 2.payload = {payload}")
@ -1015,17 +1023,11 @@ async def generate_openai_chat_completion(
del payload["metadata"] del payload["metadata"]
model_id = completion_form.model model_id = completion_form.model
if ":" not in model_id:
model_id = f"{model_id}:latest"
# TODO: Check User Group and Filter Models
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
# raise HTTPException(
# status_code=403,
# detail="Model not found",
# )
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
@ -1036,6 +1038,15 @@ async def generate_openai_chat_completion(
payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
# Check if user has access to the model
if user.role == "user" and not has_access(
user.id, type="read", access_control=model_info.access_control
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
@ -1060,32 +1071,19 @@ async def get_openai_models(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
models = []
if url_idx is None: if url_idx is None:
models = await get_all_models() model_list = await get_all_models()
models = [
# TODO: Check User Group and Filter Models
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user":
# models["models"] = list(
# filter(
# lambda model: model["name"]
# in app.state.config.MODEL_FILTER_LIST,
# models["models"],
# )
# )
return {
"data": [
{ {
"id": model["model"], "id": model["model"],
"object": "model", "object": "model",
"created": int(time.time()), "created": int(time.time()),
"owned_by": "openai", "owned_by": "openai",
} }
for model in models["models"] for model in model_list["models"]
], ]
"object": "list",
}
else: else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@ -1093,10 +1091,9 @@ async def get_openai_models(
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
models = r.json() model_list = r.json()
return { models = [
"data": [
{ {
"id": model["model"], "id": model["model"],
"object": "model", "object": "model",
@ -1104,10 +1101,7 @@ async def get_openai_models(
"owned_by": "openai", "owned_by": "openai",
} }
for model in models["models"] for model in models["models"]
], ]
"object": "list",
}
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
@ -1125,6 +1119,27 @@ async def get_openai_models(
) )
if user.role == "user":
# Filter models based on user access control
filtered_models = []
for model in models:
model_info = Models.get_model_by_id(model["id"])
if model_info:
if has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
else:
filtered_models.append(model)
models = filtered_models
return {
"data": models,
"object": "list",
}
class UrlForm(BaseModel): class UrlForm(BaseModel):
url: str url: str

View File

@ -501,7 +501,7 @@ async def generate_chat_completion(
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
# Check if user has access to the model # Check if user has access to the model
if user.role == "user" and not has_access( if not bypass_filter and user.role == "user" and not has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
): ):
raise HTTPException( raise HTTPException(