fix: arena access control

This commit is contained in:
Timothy Jaeryang Baek 2024-11-18 07:40:37 -08:00
parent 269151cd2c
commit f37d847521
3 changed files with 62 additions and 38 deletions

View File

@ -958,7 +958,7 @@ async def generate_chat_completion(
status_code=403, status_code=403,
detail="Model not found", detail="Model not found",
) )
else: elif not bypass_filter:
if user.role != "admin": if user.role != "admin":
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,

View File

@ -510,7 +510,7 @@ async def generate_chat_completion(
status_code=403, status_code=403,
detail="Model not found", detail="Model not found",
) )
else: elif not bypass_filter:
if user.role != "admin": if user.role != "admin":
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,

View File

@ -557,21 +557,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
model_info = Models.get_model_by_id(model["id"]) model_info = Models.get_model_by_id(model["id"])
if user.role == "user": if user.role == "user":
if not model_info: if model.get("arena"):
return JSONResponse( if not has_access(
status_code=status.HTTP_404_NOT_FOUND, user.id,
content={"detail": "Model not found"}, type="read",
) access_control=model.get("info", {})
elif not ( .get("meta", {})
user.id == model_info.user_id .get("access_control", {}),
or has_access( ):
user.id, type="read", access_control=model_info.access_control raise HTTPException(
) status_code=403,
): detail="Model not found",
return JSONResponse( )
status_code=status.HTTP_403_FORBIDDEN, else:
content={"detail": "User does not have access to the model"}, if not model_info:
) return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={"detail": "Model not found"},
)
elif not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "User does not have access to the model"},
)
metadata = { metadata = {
"chat_id": body.pop("chat_id", None), "chat_id": body.pop("chat_id", None),
@ -1160,24 +1173,38 @@ async def generate_chat_completions(
) )
model = models[model_id] model = models[model_id]
# Check if user has access to the model # Check if user has access to the model
if user.role == "user": if not bypass_filter and user.role == "user":
model_info = Models.get_model_by_id(model_id) if model.get("arena"):
if not model_info: if not has_access(
raise HTTPException( user.id,
status_code=404, type="read",
detail="Model not found", access_control=model.get("info", {})
) .get("meta", {})
elif not ( .get("access_control", {}),
user.id == model_info.user_id ):
or has_access( raise HTTPException(
user.id, type="read", access_control=model_info.access_control status_code=403,
) detail="Model not found",
): )
raise HTTPException( else:
status_code=403, model_info = Models.get_model_by_id(model_id)
detail="Model not found", if not model_info:
) raise HTTPException(
status_code=404,
detail="Model not found",
)
elif not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
if model["owned_by"] == "arena": if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids") model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
@ -1186,9 +1213,7 @@ async def generate_chat_completions(
model_ids = [ model_ids = [
model["id"] model["id"]
for model in await get_all_models() for model in await get_all_models()
if model.get("owned_by") != "arena" if model.get("owned_by") != "arena" and model["id"] not in model_ids
and not model.get("info", {}).get("meta", {}).get("hidden", False)
and model["id"] not in model_ids
] ]
selected_model_id = None selected_model_id = None
@ -1199,7 +1224,6 @@ async def generate_chat_completions(
model["id"] model["id"]
for model in await get_all_models() for model in await get_all_models()
if model.get("owned_by") != "arena" if model.get("owned_by") != "arena"
and not model.get("info", {}).get("meta", {}).get("hidden", False)
] ]
selected_model_id = random.choice(model_ids) selected_model_id = random.choice(model_ids)