diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 260d305f0..f50d1dc36 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -233,6 +233,16 @@ async def generate_function_chat_completion(form_data, user): res = await pipe(**params) else: res = pipe(**params) + + # Directly return if the response is a StreamingResponse + if isinstance(res, StreamingResponse): + async for data in res.body_iterator: + yield data + return + if isinstance(res, dict): + yield f"data: {json.dumps(res)}\n\n" + return + except Exception as e: print(f"Error: {e}") yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" @@ -285,15 +295,13 @@ async def generate_function_chat_completion(form_data, user): res = await pipe(**params) else: res = pipe(**params) + + if isinstance(res, StreamingResponse): + return res except Exception as e: print(f"Error: {e}") return {"error": {"detail": str(e)}} - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) - if isinstance(res, dict): return res elif isinstance(res, BaseModel):