diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 570cad9f1..997a05974 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -54,7 +54,7 @@ import uuid import time import json -from typing import Iterator, Generator, Optional +from typing import Iterator, Generator, AsyncGenerator, Optional from pydantic import BaseModel app = FastAPI() @@ -411,6 +411,25 @@ async def generate_function_chat_completion(form_data, user): yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: [DONE]" + if isinstance(res, AsyncGenerator): + async for line in res: + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + if isinstance(line, dict): + line = f"data: {json.dumps(line)}" + + try: + line = line.decode("utf-8") + except: + pass + + if line.startswith("data:"): + yield f"{line}\n\n" + else: + line = stream_message_template(form_data["model"], line) + yield f"data: {json.dumps(line)}\n\n" + return StreamingResponse(stream_content(), media_type="text/event-stream") else: @@ -434,9 +453,12 @@ async def generate_function_chat_completion(form_data, user): message = "" if isinstance(res, str): message = res - if isinstance(res, Generator): + elif isinstance(res, Generator): for stream in res: message = f"{message}{stream}" + elif isinstance(res, AsyncGenerator): + async for stream in res: + message = f"{message}{stream}" return { "id": f"{form_data['model']}-{str(uuid.uuid4())}",