diff --git a/main.py b/main.py index df65dcf..b1715ee 100644 --- a/main.py +++ b/main.py @@ -10,9 +10,11 @@ import time import json import uuid -from utils import stream_message_template +from utils import get_last_user_message, stream_message_template +from schemas import OpenAIChatCompletionForm from config import MODEL_ID, MODEL_NAME +from rag_pipeline import get_response app = FastAPI(docs_url="/docs", redoc_url=None) @@ -58,48 +60,25 @@ async def get_models(): } -class OpenAIChatMessage(BaseModel): - role: str - content: str - - model_config = ConfigDict(extra="allow") - - -class OpenAIChatCompletionForm(BaseModel): - model: str - messages: List[OpenAIChatMessage] - - model_config = ConfigDict(extra="allow") - - -def get_response(user_message): - return f"rag response to: {user_message}" - - @app.post("/chat/completions") @app.post("/v1/chat/completions") async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): - - # get last user message (role == 'user') from the form_data - # last message might be role == 'assistant' or 'system' or 'user' - user_message = form_data.messages[-1].content - - res = get_response(user_message) - - finish_message = { - "id": f"rag-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": MODEL_ID, - "choices": [ - {"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"} - ], - } + user_message = get_last_user_message(form_data.messages) + res = get_response(user_message, messages=form_data.messages) def stream_content(): message = stream_message_template(res) - yield f"data: {json.dumps(message)}\n\n" + + finish_message = { + "id": f"rag-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": MODEL_ID, + "choices": [ + {"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"} + ], + } yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: [DONE]" diff --git a/rag_pipeline.py b/rag_pipeline.py new file mode 100644 index 0000000..1f74754 --- /dev/null +++ b/rag_pipeline.py @@ -0,0 +1,8 @@ +from typing import List +from schemas import OpenAIChatMessage + + +def get_response(user_message: str, messages: List[OpenAIChatMessage]): + print(messages) + print(user_message) + return f"rag response to: {user_message}" diff --git a/schemas.py b/schemas.py new file mode 100644 index 0000000..1c166ad --- /dev/null +++ b/schemas.py @@ -0,0 +1,16 @@ +from typing import List +from pydantic import BaseModel, ConfigDict + + +class OpenAIChatMessage(BaseModel): + role: str + content: str + + model_config = ConfigDict(extra="allow") + + +class OpenAIChatCompletionForm(BaseModel): + model: str + messages: List[OpenAIChatMessage] + + model_config = ConfigDict(extra="allow") diff --git a/utils.py b/utils.py index 3360a65..67af85c 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,8 @@ import uuid import time + +from typing import List +from schemas import OpenAIChatMessage from config import MODEL_ID @@ -18,3 +21,10 @@ def stream_message_template(message: str): } ], } + + +def get_last_user_message(messages: List[OpenAIChatMessage]) -> str: + for message in reversed(messages): + if message.role == "user": + return message.content + return None