diff --git a/main.py b/main.py index 4c2e3c4..1116422 100644 --- a/main.py +++ b/main.py @@ -65,33 +65,65 @@ async def get_models(): async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): user_message = get_last_user_message(form_data.messages) - def stream_content(): + if form_data.stream: + + def stream_content(): + res = get_response(user_message, messages=form_data.messages) + + if isinstance(res, str): + message = stream_message_template(res) + yield f"data: {json.dumps(message)}\n\n" + + elif isinstance(res, Generator): + for message in res: + message = stream_message_template(message) + yield f"data: {json.dumps(message)}\n\n" + + finish_message = { + "id": f"rag-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": MODEL_ID, + "choices": [ + {"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"} + ], + } + + yield f"data: {json.dumps(finish_message)}\n\n" + yield f"data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + res = get_response(user_message, messages=form_data.messages) + message = "" + if isinstance(res, str): message = stream_message_template(res) - yield f"data: {json.dumps(message)}\n\n" elif isinstance(res, Generator): - for message in res: - message = stream_message_template(message) - yield f"data: {json.dumps(message)}\n\n" + for stream in res: + message = f"{message}{stream}" - finish_message = { + return { "id": f"rag-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", + "object": "chat.completion", "created": int(time.time()), "model": MODEL_ID, "choices": [ - {"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"} + { + "index": 0, + "message": { + "role": "assistant", + "content": message, + }, + "logprobs": None, + "finish_reason": "stop", + } ], } - yield f"data: {json.dumps(finish_message)}\n\n" - yield f"data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - @app.get("/") async def get_status(): diff --git a/schemas.py b/schemas.py index 1c166ad..6f90d68 100644 --- a/schemas.py +++ b/schemas.py @@ -10,6 +10,7 @@ class OpenAIChatMessage(BaseModel): class OpenAIChatCompletionForm(BaseModel): + stream: bool = True model: str messages: List[OpenAIChatMessage]