feat: support stream=false

This commit is contained in:
Timothy J. Baek 2024-05-21 14:46:09 -07:00
parent 3fee0347a1
commit 32928c754e
2 changed files with 46 additions and 13 deletions

58
main.py
View File

@ -65,33 +65,65 @@ async def get_models():
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages) user_message = get_last_user_message(form_data.messages)
def stream_content(): if form_data.stream:
def stream_content():
res = get_response(user_message, messages=form_data.messages)
if isinstance(res, str):
message = stream_message_template(res)
yield f"data: {json.dumps(message)}\n\n"
elif isinstance(res, Generator):
for message in res:
message = stream_message_template(message)
yield f"data: {json.dumps(message)}\n\n"
finish_message = {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(user_message, messages=form_data.messages) res = get_response(user_message, messages=form_data.messages)
message = ""
if isinstance(res, str): if isinstance(res, str):
message = stream_message_template(res) message = stream_message_template(res)
yield f"data: {json.dumps(message)}\n\n"
elif isinstance(res, Generator): elif isinstance(res, Generator):
for message in res: for stream in res:
message = stream_message_template(message) message = f"{message}{stream}"
yield f"data: {json.dumps(message)}\n\n"
finish_message = { return {
"id": f"rag-{str(uuid.uuid4())}", "id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion.chunk", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": MODEL_ID, "model": MODEL_ID,
"choices": [ "choices": [
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"} {
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
], ],
} }
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
@app.get("/") @app.get("/")
async def get_status(): async def get_status():

View File

@ -10,6 +10,7 @@ class OpenAIChatMessage(BaseModel):
class OpenAIChatCompletionForm(BaseModel): class OpenAIChatCompletionForm(BaseModel):
stream: bool = True
model: str model: str
messages: List[OpenAIChatMessage] messages: List[OpenAIChatMessage]