mirror of
https://github.com/open-webui/open-webui
synced 2024-11-22 08:07:55 +00:00
refac: access control
This commit is contained in:
parent
a2dcbc41e5
commit
4a34ca35f0
@ -43,6 +43,7 @@ from open_webui.utils.payload import (
|
|||||||
apply_model_system_prompt_to_body,
|
apply_model_system_prompt_to_body,
|
||||||
)
|
)
|
||||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||||
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||||
@ -316,22 +317,9 @@ async def get_all_models():
|
|||||||
async def get_ollama_tags(
|
async def get_ollama_tags(
|
||||||
url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
models = []
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
models = await get_all_models()
|
models = await get_all_models()
|
||||||
|
|
||||||
# TODO: Check User Group and Filter Models
|
|
||||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
|
||||||
# if user.role == "user":
|
|
||||||
# models["models"] = list(
|
|
||||||
# filter(
|
|
||||||
# lambda model: model["name"]
|
|
||||||
# in app.state.config.MODEL_FILTER_LIST,
|
|
||||||
# models["models"],
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# return models
|
|
||||||
|
|
||||||
return models
|
|
||||||
else:
|
else:
|
||||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||||
|
|
||||||
@ -347,7 +335,7 @@ async def get_ollama_tags(
|
|||||||
r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
|
r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
return r.json()
|
models = r.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
error_detail = "Open WebUI: Server Connection Error"
|
error_detail = "Open WebUI: Server Connection Error"
|
||||||
@ -363,6 +351,23 @@ async def get_ollama_tags(
|
|||||||
status_code=r.status_code if r else 500,
|
status_code=r.status_code if r else 500,
|
||||||
detail=error_detail,
|
detail=error_detail,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if user.role == "user":
|
||||||
|
# Filter models based on user access control
|
||||||
|
filtered_models = []
|
||||||
|
for model in models.get("models", []):
|
||||||
|
model_info = Models.get_model_by_id(model["model"])
|
||||||
|
if model_info:
|
||||||
|
if has_access(
|
||||||
|
user.id, type="read", access_control=model_info.access_control
|
||||||
|
):
|
||||||
|
filtered_models.append(model)
|
||||||
|
else:
|
||||||
|
filtered_models.append(model)
|
||||||
|
models["models"] = filtered_models
|
||||||
|
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/version")
|
@app.get("/api/version")
|
||||||
@ -926,16 +931,9 @@ async def generate_chat_completion(
|
|||||||
if "metadata" in payload:
|
if "metadata" in payload:
|
||||||
del payload["metadata"]
|
del payload["metadata"]
|
||||||
|
|
||||||
model_id = form_data.model
|
model_id = payload["model"]
|
||||||
|
if ":" not in model_id:
|
||||||
# TODO: Check User Group and Filter Models
|
model_id = f"{model_id}:latest"
|
||||||
# if not bypass_filter:
|
|
||||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
|
||||||
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=403,
|
|
||||||
# detail="Model not found",
|
|
||||||
# )
|
|
||||||
|
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
@ -954,9 +952,19 @@ async def generate_chat_completion(
|
|||||||
)
|
)
|
||||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||||
|
|
||||||
|
# Check if user has access to the model
|
||||||
|
if not bypass_filter and user.role == "user" and not has_access(
|
||||||
|
user.id, type="read", access_control=model_info.access_control
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
|
||||||
if ":" not in payload["model"]:
|
if ":" not in payload["model"]:
|
||||||
payload["model"] = f"{payload['model']}:latest"
|
payload["model"] = f"{payload['model']}:latest"
|
||||||
|
|
||||||
|
|
||||||
url = await get_ollama_url(url_idx, payload["model"])
|
url = await get_ollama_url(url_idx, payload["model"])
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
||||||
@ -1015,17 +1023,11 @@ async def generate_openai_chat_completion(
|
|||||||
del payload["metadata"]
|
del payload["metadata"]
|
||||||
|
|
||||||
model_id = completion_form.model
|
model_id = completion_form.model
|
||||||
|
if ":" not in model_id:
|
||||||
|
model_id = f"{model_id}:latest"
|
||||||
|
|
||||||
# TODO: Check User Group and Filter Models
|
|
||||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
|
||||||
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=403,
|
|
||||||
# detail="Model not found",
|
|
||||||
# )
|
|
||||||
|
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
if model_info.base_model_id:
|
if model_info.base_model_id:
|
||||||
payload["model"] = model_info.base_model_id
|
payload["model"] = model_info.base_model_id
|
||||||
@ -1036,6 +1038,15 @@ async def generate_openai_chat_completion(
|
|||||||
payload = apply_model_params_to_body_openai(params, payload)
|
payload = apply_model_params_to_body_openai(params, payload)
|
||||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||||
|
|
||||||
|
# Check if user has access to the model
|
||||||
|
if user.role == "user" and not has_access(
|
||||||
|
user.id, type="read", access_control=model_info.access_control
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
|
||||||
if ":" not in payload["model"]:
|
if ":" not in payload["model"]:
|
||||||
payload["model"] = f"{payload['model']}:latest"
|
payload["model"] = f"{payload['model']}:latest"
|
||||||
|
|
||||||
@ -1060,32 +1071,19 @@ async def get_openai_models(
|
|||||||
url_idx: Optional[int] = None,
|
url_idx: Optional[int] = None,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
|
|
||||||
|
models = []
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
models = await get_all_models()
|
model_list = await get_all_models()
|
||||||
|
models = [
|
||||||
# TODO: Check User Group and Filter Models
|
|
||||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
|
||||||
# if user.role == "user":
|
|
||||||
# models["models"] = list(
|
|
||||||
# filter(
|
|
||||||
# lambda model: model["name"]
|
|
||||||
# in app.state.config.MODEL_FILTER_LIST,
|
|
||||||
# models["models"],
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
|
|
||||||
return {
|
|
||||||
"data": [
|
|
||||||
{
|
{
|
||||||
"id": model["model"],
|
"id": model["model"],
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"owned_by": "openai",
|
"owned_by": "openai",
|
||||||
}
|
}
|
||||||
for model in models["models"]
|
for model in model_list["models"]
|
||||||
],
|
]
|
||||||
"object": "list",
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||||
@ -1093,10 +1091,9 @@ async def get_openai_models(
|
|||||||
r = requests.request(method="GET", url=f"{url}/api/tags")
|
r = requests.request(method="GET", url=f"{url}/api/tags")
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
models = r.json()
|
model_list = r.json()
|
||||||
|
|
||||||
return {
|
models = [
|
||||||
"data": [
|
|
||||||
{
|
{
|
||||||
"id": model["model"],
|
"id": model["model"],
|
||||||
"object": "model",
|
"object": "model",
|
||||||
@ -1104,10 +1101,7 @@ async def get_openai_models(
|
|||||||
"owned_by": "openai",
|
"owned_by": "openai",
|
||||||
}
|
}
|
||||||
for model in models["models"]
|
for model in models["models"]
|
||||||
],
|
]
|
||||||
"object": "list",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
error_detail = "Open WebUI: Server Connection Error"
|
error_detail = "Open WebUI: Server Connection Error"
|
||||||
@ -1123,6 +1117,27 @@ async def get_openai_models(
|
|||||||
status_code=r.status_code if r else 500,
|
status_code=r.status_code if r else 500,
|
||||||
detail=error_detail,
|
detail=error_detail,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if user.role == "user":
|
||||||
|
# Filter models based on user access control
|
||||||
|
filtered_models = []
|
||||||
|
for model in models:
|
||||||
|
model_info = Models.get_model_by_id(model["id"])
|
||||||
|
if model_info:
|
||||||
|
if has_access(
|
||||||
|
user.id, type="read", access_control=model_info.access_control
|
||||||
|
):
|
||||||
|
filtered_models.append(model)
|
||||||
|
else:
|
||||||
|
filtered_models.append(model)
|
||||||
|
models = filtered_models
|
||||||
|
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data": models,
|
||||||
|
"object": "list",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class UrlForm(BaseModel):
|
class UrlForm(BaseModel):
|
||||||
|
@ -501,7 +501,7 @@ async def generate_chat_completion(
|
|||||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||||
|
|
||||||
# Check if user has access to the model
|
# Check if user has access to the model
|
||||||
if user.role == "user" and not has_access(
|
if not bypass_filter and user.role == "user" and not has_access(
|
||||||
user.id, type="read", access_control=model_info.access_control
|
user.id, type="read", access_control=model_info.access_control
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
Loading…
Reference in New Issue
Block a user