This commit is contained in:
Timothy J. Baek 2024-05-21 14:18:25 -07:00
parent fe49f4af7e
commit db1262def9
4 changed files with 49 additions and 36 deletions

51
main.py
View File

@ -10,9 +10,11 @@ import time
import json
import uuid
from utils import stream_message_template
from utils import get_last_user_message, stream_message_template
from schemas import OpenAIChatCompletionForm
from config import MODEL_ID, MODEL_NAME
from rag_pipeline import get_response
app = FastAPI(docs_url="/docs", redoc_url=None)
@ -58,48 +60,25 @@ async def get_models():
}
class OpenAIChatMessage(BaseModel):
role: str
content: str
model_config = ConfigDict(extra="allow")
class OpenAIChatCompletionForm(BaseModel):
model: str
messages: List[OpenAIChatMessage]
model_config = ConfigDict(extra="allow")
def get_response(user_message):
return f"rag response to: {user_message}"
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
# get last user message (role == 'user') from the form_data
# last message might be role == 'assistant' or 'system' or 'user'
user_message = form_data.messages[-1].content
res = get_response(user_message)
finish_message = {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
],
}
user_message = get_last_user_message(form_data.messages)
res = get_response(user_message, messages=form_data.messages)
def stream_content():
message = stream_message_template(res)
yield f"data: {json.dumps(message)}\n\n"
finish_message = {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"

8
rag_pipeline.py Normal file
View File

@ -0,0 +1,8 @@
from typing import List
from schemas import OpenAIChatMessage
def get_response(user_message: str, messages: List[OpenAIChatMessage]):
print(messages)
print(user_message)
return f"rag response to: {user_message}"

16
schemas.py Normal file
View File

@ -0,0 +1,16 @@
from typing import List
from pydantic import BaseModel, ConfigDict
class OpenAIChatMessage(BaseModel):
role: str
content: str
model_config = ConfigDict(extra="allow")
class OpenAIChatCompletionForm(BaseModel):
model: str
messages: List[OpenAIChatMessage]
model_config = ConfigDict(extra="allow")

View File

@ -1,5 +1,8 @@
import uuid
import time
from typing import List
from schemas import OpenAIChatMessage
from config import MODEL_ID
@ -18,3 +21,10 @@ def stream_message_template(message: str):
}
],
}
def get_last_user_message(messages: List[OpenAIChatMessage]) -> str:
for message in reversed(messages):
if message.role == "user":
return message.content
return None