refac: model access control behaviour

This commit is contained in:
Timothy Jaeryang Baek 2024-11-17 19:15:09 -08:00
parent 85731f400c
commit 3faf9d2067
4 changed files with 36 additions and 12 deletions

View File

@ -362,8 +362,6 @@ async def get_ollama_tags(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
): ):
filtered_models.append(model) filtered_models.append(model)
else:
filtered_models.append(model)
models["models"] = filtered_models models["models"] = filtered_models
return models return models
@ -960,6 +958,12 @@ async def generate_chat_completion(
status_code=403, status_code=403,
detail="Model not found", detail="Model not found",
) )
else:
if user.role != "admin":
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
@ -1048,6 +1052,12 @@ async def generate_openai_chat_completion(
status_code=403, status_code=403,
detail="Model not found", detail="Model not found",
) )
else:
if user.role != "admin":
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
@ -1130,8 +1140,6 @@ async def get_openai_models(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
): ):
filtered_models.append(model) filtered_models.append(model)
else:
filtered_models.append(model)
models = filtered_models models = filtered_models
return { return {

View File

@ -424,8 +424,6 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
): ):
filtered_models.append(model) filtered_models.append(model)
else:
filtered_models.append(model)
models["data"] = filtered_models models["data"] = filtered_models
return models return models
@ -512,6 +510,12 @@ async def generate_chat_completion(
status_code=403, status_code=403,
detail="Model not found", detail="Model not found",
) )
else:
if user.role != "admin":
raise HTTPException(
status_code=403,
detail="Model not found",
)
# Attemp to get urlIdx from the model # Attemp to get urlIdx from the model
models = await get_all_models() models = await get_all_models()

View File

@ -557,7 +557,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
model_info = Models.get_model_by_id(model["id"]) model_info = Models.get_model_by_id(model["id"])
if user.role == "user": if user.role == "user":
if model_info and not ( if not model_info:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={"detail": "Model not found"},
)
elif not (
user.id == model_info.user_id user.id == model_info.user_id
or has_access( or has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
@ -1113,8 +1118,6 @@ async def get_models(user=Depends(get_verified_user)):
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
): ):
filtered_models.append(model) filtered_models.append(model)
else:
filtered_models.append(model)
models = filtered_models models = filtered_models
return {"data": models} return {"data": models}
@ -1147,7 +1150,12 @@ async def generate_chat_completions(
# Check if user has access to the model # Check if user has access to the model
if user.role == "user": if user.role == "user":
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info and not ( if not model_info:
raise HTTPException(
status_code=404,
detail="Model not found",
)
elif not (
user.id == model_info.user_id user.id == model_info.user_id
or has_access( or has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control

View File

@ -79,7 +79,7 @@
let filterIds = []; let filterIds = [];
let actionIds = []; let actionIds = [];
let accessControl = null; let accessControl = {};
const addUsage = (base_model_id) => { const addUsage = (base_model_id) => {
const baseModel = $models.find((m) => m.id === base_model_id); const baseModel = $models.find((m) => m.id === base_model_id);
@ -213,7 +213,11 @@
capabilities.usage = false; capabilities.usage = false;
} }
accessControl = model?.access_control ?? null; if ('access_control' in model) {
accessControl = model.access_control;
} else {
accessControl = {};
}
console.log(model?.access_control); console.log(model?.access_control);
console.log(accessControl); console.log(accessControl);