mirror of
https://github.com/open-webui/pipelines
synced 2025-05-10 15:40:55 +00:00
refac
This commit is contained in:
parent
fe49f4af7e
commit
db1262def9
51
main.py
51
main.py
@ -10,9 +10,11 @@ import time
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from utils import stream_message_template
|
||||
from utils import get_last_user_message, stream_message_template
|
||||
from schemas import OpenAIChatCompletionForm
|
||||
from config import MODEL_ID, MODEL_NAME
|
||||
|
||||
from rag_pipeline import get_response
|
||||
|
||||
app = FastAPI(docs_url="/docs", redoc_url=None)
|
||||
|
||||
@ -58,48 +60,25 @@ async def get_models():
|
||||
}
|
||||
|
||||
|
||||
class OpenAIChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIChatCompletionForm(BaseModel):
|
||||
model: str
|
||||
messages: List[OpenAIChatMessage]
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
def get_response(user_message):
|
||||
return f"rag response to: {user_message}"
|
||||
|
||||
|
||||
@app.post("/chat/completions")
|
||||
@app.post("/v1/chat/completions")
|
||||
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
|
||||
|
||||
# get last user message (role == 'user') from the form_data
|
||||
# last message might be role == 'assistant' or 'system' or 'user'
|
||||
user_message = form_data.messages[-1].content
|
||||
|
||||
res = get_response(user_message)
|
||||
|
||||
finish_message = {
|
||||
"id": f"rag-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": MODEL_ID,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
|
||||
],
|
||||
}
|
||||
user_message = get_last_user_message(form_data.messages)
|
||||
res = get_response(user_message, messages=form_data.messages)
|
||||
|
||||
def stream_content():
|
||||
message = stream_message_template(res)
|
||||
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
|
||||
finish_message = {
|
||||
"id": f"rag-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": MODEL_ID,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(finish_message)}\n\n"
|
||||
yield f"data: [DONE]"
|
||||
|
||||
|
8
rag_pipeline.py
Normal file
8
rag_pipeline.py
Normal file
@ -0,0 +1,8 @@
|
||||
from typing import List
|
||||
from schemas import OpenAIChatMessage
|
||||
|
||||
|
||||
def get_response(user_message: str, messages: List[OpenAIChatMessage]):
|
||||
print(messages)
|
||||
print(user_message)
|
||||
return f"rag response to: {user_message}"
|
16
schemas.py
Normal file
16
schemas.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import List
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class OpenAIChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIChatCompletionForm(BaseModel):
|
||||
model: str
|
||||
messages: List[OpenAIChatMessage]
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
10
utils.py
10
utils.py
@ -1,5 +1,8 @@
|
||||
import uuid
|
||||
import time
|
||||
|
||||
from typing import List
|
||||
from schemas import OpenAIChatMessage
|
||||
from config import MODEL_ID
|
||||
|
||||
|
||||
@ -18,3 +21,10 @@ def stream_message_template(message: str):
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_last_user_message(messages: List[OpenAIChatMessage]) -> str:
|
||||
for message in reversed(messages):
|
||||
if message.role == "user":
|
||||
return message.content
|
||||
return None
|
||||
|
Loading…
Reference in New Issue
Block a user