enh: AsyncGenerator support

This commit is contained in:
Timothy J. Baek 2024-07-24 11:29:57 +01:00
parent edff071cd2
commit 23e69bcdb4

View File

@ -54,7 +54,7 @@ import uuid
import time
import json
from typing import Iterator, Generator, Optional
from typing import Iterator, Generator, AsyncGenerator, Optional
from pydantic import BaseModel
app = FastAPI()
@ -411,6 +411,25 @@ async def generate_function_chat_completion(form_data, user):
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
if isinstance(res, AsyncGenerator):
async for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
@ -434,9 +453,12 @@ async def generate_function_chat_completion(form_data, user):
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
elif isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
elif isinstance(res, AsyncGenerator):
async for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",