diff --git a/main.py b/main.py index 595b31d..506d235 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,7 @@ from fastapi import FastAPI, Request, Depends, status, HTTPException from fastapi.middleware.cors import CORSMiddleware +from fastapi.concurrency import run_in_threadpool + from starlette.responses import StreamingResponse, Response from pydantic import BaseModel, ConfigDict @@ -116,65 +118,81 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): detail=f"Model {form_data.model} not found", ) - get_response = PIPELINES[form_data.model]["module"].get_response + def job(): - if form_data.stream: + get_response = PIPELINES[form_data.model]["module"].get_response - def stream_content(): + if form_data.stream: - res = get_response(user_message, messages=form_data.messages) + def stream_content(): + res = get_response(user_message, messages=form_data.messages) - if isinstance(res, str): - message = stream_message_template(res) - yield f"data: {json.dumps(message)}\n\n" - - elif isinstance(res, Generator): - for message in res: - message = stream_message_template(message) + if isinstance(res, str): + message = stream_message_template(res) yield f"data: {json.dumps(message)}\n\n" - finish_message = { + elif isinstance(res, Generator): + for message in res: + print(message) + message = stream_message_template(message) + yield f"data: {json.dumps(message)}\n\n" + + finish_message = { + "id": f"rag-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": MODEL_ID, + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + yield f"data: {json.dumps(finish_message)}\n\n" + yield f"data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + res = get_response(user_message, messages=form_data.messages) + message = "" + + if isinstance(res, str): + message = res + + elif isinstance(res, Generator): + for stream in res: + message = f"{message}{stream}" + + return { "id": f"rag-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", + "object": "chat.completion", "created": int(time.time()), "model": MODEL_ID, "choices": [ - {"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"} + { + "index": 0, + "message": { + "role": "assistant", + "content": message, + }, + "logprobs": None, + "finish_reason": "stop", + } ], } - yield f"data: {json.dumps(finish_message)}\n\n" - yield f"data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - else: - res = get_response(user_message, messages=form_data.messages) - message = "" - - if isinstance(res, str): - message = res - - elif isinstance(res, Generator): - for stream in res: - message = f"{message}{stream}" - - return { - "id": f"rag-{str(uuid.uuid4())}", - "object": "chat.completion", - "created": int(time.time()), - "model": MODEL_ID, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": message, - }, - "logprobs": None, - "finish_reason": "stop", - } - ], - } + try: + return await run_in_threadpool(job) + except Exception as e: + print(e) + raise HTTPException( + status_code=500, + detail="{e}", + ) @app.get("/")