From 6d52f913d284aa4098e6eb27189d924b8b3d126d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 22 Oct 2024 17:43:39 -0700 Subject: [PATCH] enh: arena model send selected model id --- backend/open_webui/main.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 1d3239d25..be1cef611 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1102,9 +1102,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u if model["owned_by"] == "arena": model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - model_id = None + selected_model_id = None if isinstance(model_ids, list) and model_ids: - model_id = random.choice(model_ids) + selected_model_id = random.choice(model_ids) else: model_ids = [ model["id"] @@ -1112,10 +1112,26 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u if model.get("owned_by") != "arena" and not model.get("info", {}).get("meta", {}).get("hidden", False) ] - model_id = random.choice(model_ids) + selected_model_id = random.choice(model_ids) - form_data["model"] = model_id - return await generate_chat_completions(form_data, user) + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completions(form_data, user) + return StreamingResponse( + stream_wrapper(response.body_iterator), media_type="text/event-stream" + ) + else: + return { + **(await generate_chat_completions(form_data, user)), + "selected_model_id": selected_model_id, + } if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama":