fix: access control behaviour

This commit is contained in:
Timothy Jaeryang Baek 2024-11-17 02:51:57 -08:00
parent 892f6ba42b
commit 1d4c3a8c58
3 changed files with 70 additions and 57 deletions

View File

@ -351,22 +351,21 @@ async def get_ollama_tags(
status_code=r.status_code if r else 500, status_code=r.status_code if r else 500,
detail=error_detail, detail=error_detail,
) )
if user.role == "user": if user.role == "user":
# Filter models based on user access control # Filter models based on user access control
filtered_models = [] filtered_models = []
for model in models.get("models", []): for model in models.get("models", []):
model_info = Models.get_model_by_id(model["model"]) model_info = Models.get_model_by_id(model["model"])
if model_info: if model_info:
if has_access( if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
): ):
filtered_models.append(model) filtered_models.append(model)
else: else:
filtered_models.append(model) filtered_models.append(model)
models["models"] = filtered_models models["models"] = filtered_models
return models return models
@ -953,18 +952,21 @@ async def generate_chat_completion(
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
# Check if user has access to the model # Check if user has access to the model
if not bypass_filter and user.role == "user" and not has_access( if not bypass_filter and user.role == "user":
user.id, type="read", access_control=model_info.access_control if not (
): user.id == model_info.user_id
raise HTTPException( or has_access(
status_code=403, user.id, type="read", access_control=model_info.access_control
detail="Model not found", )
) ):
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = await get_ollama_url(url_idx, payload["model"]) url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}") log.info(f"url: {url}")
log.debug(f"generate_chat_completion() - 2.payload = {payload}") log.debug(f"generate_chat_completion() - 2.payload = {payload}")
@ -1026,7 +1028,6 @@ async def generate_openai_chat_completion(
if ":" not in model_id: if ":" not in model_id:
model_id = f"{model_id}:latest" model_id = f"{model_id}:latest"
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
@ -1039,13 +1040,17 @@ async def generate_openai_chat_completion(
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
# Check if user has access to the model # Check if user has access to the model
if user.role == "user" and not has_access( if user.role == "user":
user.id, type="read", access_control=model_info.access_control if not (
): user.id == model_info.user_id
raise HTTPException( or has_access(
status_code=403, user.id, type="read", access_control=model_info.access_control
detail="Model not found", )
) ):
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
@ -1071,19 +1076,19 @@ async def get_openai_models(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
models = [] models = []
if url_idx is None: if url_idx is None:
model_list = await get_all_models() model_list = await get_all_models()
models = [ models = [
{ {
"id": model["model"], "id": model["model"],
"object": "model", "object": "model",
"created": int(time.time()), "created": int(time.time()),
"owned_by": "openai", "owned_by": "openai",
} }
for model in model_list["models"] for model in model_list["models"]
] ]
else: else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@ -1094,14 +1099,14 @@ async def get_openai_models(
model_list = r.json() model_list = r.json()
models = [ models = [
{ {
"id": model["model"], "id": model["model"],
"object": "model", "object": "model",
"created": int(time.time()), "created": int(time.time()),
"owned_by": "openai", "owned_by": "openai",
} }
for model in models["models"] for model in models["models"]
] ]
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
@ -1117,7 +1122,6 @@ async def get_openai_models(
status_code=r.status_code if r else 500, status_code=r.status_code if r else 500,
detail=error_detail, detail=error_detail,
) )
if user.role == "user": if user.role == "user":
# Filter models based on user access control # Filter models based on user access control
@ -1125,19 +1129,18 @@ async def get_openai_models(
for model in models: for model in models:
model_info = Models.get_model_by_id(model["id"]) model_info = Models.get_model_by_id(model["id"])
if model_info: if model_info:
if has_access( if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
): ):
filtered_models.append(model) filtered_models.append(model)
else: else:
filtered_models.append(model) filtered_models.append(model)
models = filtered_models models = filtered_models
return { return {
"data": models, "data": models,
"object": "list", "object": "list",
} }
class UrlForm(BaseModel): class UrlForm(BaseModel):

View File

@ -420,7 +420,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
for model in models.get("data", []): for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"]) model_info = Models.get_model_by_id(model["id"])
if model_info: if model_info:
if has_access( if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
): ):
filtered_models.append(model) filtered_models.append(model)
@ -501,13 +501,17 @@ async def generate_chat_completion(
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
# Check if user has access to the model # Check if user has access to the model
if not bypass_filter and user.role == "user" and not has_access( if not bypass_filter and user.role == "user":
user.id, type="read", access_control=model_info.access_control if not (
): user.id == model_info.user_id
raise HTTPException( or has_access(
status_code=403, user.id, type="read", access_control=model_info.access_control
detail="Model not found", )
) ):
raise HTTPException(
status_code=403,
detail="Model not found",
)
# Attemp to get urlIdx from the model # Attemp to get urlIdx from the model
models = await get_all_models() models = await get_all_models()

View File

@ -557,8 +557,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
model_info = Models.get_model_by_id(model["id"]) model_info = Models.get_model_by_id(model["id"])
if user.role == "user": if user.role == "user":
if model_info and not has_access( if model_info and not (
user.id, type="read", access_control=model_info.access_control user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
): ):
return JSONResponse( return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@ -1106,7 +1109,7 @@ async def get_models(user=Depends(get_verified_user)):
for model in models: for model in models:
model_info = Models.get_model_by_id(model["id"]) model_info = Models.get_model_by_id(model["id"])
if model_info: if model_info:
if has_access( if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
): ):
filtered_models.append(model) filtered_models.append(model)
@ -1144,8 +1147,11 @@ async def generate_chat_completions(
# Check if user has access to the model # Check if user has access to the model
if user.role == "user": if user.role == "user":
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if not has_access( if not (
user.id, type="read", access_control=model_info.access_control user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
): ):
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,