enh: arena model send selected model id

This commit is contained in:
Timothy J. Baek 2024-10-22 17:43:39 -07:00
parent 2d99f275a3
commit 6d52f913d2

View File

@ -1102,9 +1102,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
model_id = None
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
model_id = random.choice(model_ids)
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
@ -1112,10 +1112,26 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
if model.get("owned_by") != "arena"
and not model.get("info", {}).get("meta", {}).get("hidden", False)
]
model_id = random.choice(model_ids)
selected_model_id = random.choice(model_ids)
form_data["model"] = model_id
return await generate_chat_completions(form_data, user)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completions(form_data, user)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
)
else:
return {
**(await generate_chat_completions(form_data, user)),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":