fix: bypass_filter for arena models

This commit is contained in:
Timothy J. Baek 2024-10-23 15:05:43 -07:00
parent ee2f8d3552
commit d85480b4d6

View File

@ -1082,7 +1082,9 @@ async def get_models(user=Depends(get_verified_user)):
@app.post("/api/chat/completions") @app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): async def generate_chat_completions(
form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
):
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in app.state.MODELS: if model_id not in app.state.MODELS:
@ -1091,7 +1093,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
detail="Model not found", detail="Model not found",
) )
if app.state.config.ENABLE_MODEL_FILTER: if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@ -1103,7 +1105,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
if model["owned_by"] == "arena": if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids") model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == 'exclude': if model_ids and filter_mode == "exclude":
model_ids = [ model_ids = [
model["id"] model["id"]
for model in await get_all_models() for model in await get_all_models()
@ -1133,13 +1135,17 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
async for chunk in stream: async for chunk in stream:
yield chunk yield chunk
response = await generate_chat_completions(form_data, user) response = await generate_chat_completions(
form_data, user, bypass_filter=True
)
return StreamingResponse( return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream" stream_wrapper(response.body_iterator), media_type="text/event-stream"
) )
else: else:
return { return {
**(await generate_chat_completions(form_data, user)), **(
await generate_chat_completions(form_data, user, bypass_filter=True)
),
"selected_model_id": selected_model_id, "selected_model_id": selected_model_id,
} }
if model.get("pipe"): if model.get("pipe"):