mirror of
https://github.com/open-webui/open-webui
synced 2025-03-24 22:49:22 +00:00
refac: chat completion middleware
This commit is contained in:
parent
448ca9d836
commit
6b8a7b9939
@ -316,7 +316,7 @@ async def get_function_call_response(
|
|||||||
|
|
||||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
return_citations = False
|
data_items = []
|
||||||
|
|
||||||
if request.method == "POST" and (
|
if request.method == "POST" and (
|
||||||
"/ollama/api/chat" in request.url.path
|
"/ollama/api/chat" in request.url.path
|
||||||
@ -326,23 +326,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# Read the original request body
|
# Read the original request body
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
# Decode body to string
|
|
||||||
body_str = body.decode("utf-8")
|
body_str = body.decode("utf-8")
|
||||||
# Parse string to JSON
|
|
||||||
data = json.loads(body_str) if body_str else {}
|
data = json.loads(body_str) if body_str else {}
|
||||||
|
|
||||||
|
model_id = data["model"]
|
||||||
user = get_current_user(
|
user = get_current_user(
|
||||||
request,
|
request,
|
||||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove the citations from the body
|
|
||||||
return_citations = data.get("citations", False)
|
|
||||||
if "citations" in data:
|
|
||||||
del data["citations"]
|
|
||||||
|
|
||||||
# Set the task model
|
# Set the task model
|
||||||
task_model_id = data["model"]
|
task_model_id = model_id
|
||||||
if task_model_id not in app.state.MODELS:
|
if task_model_id not in app.state.MODELS:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
@ -364,12 +358,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
):
|
):
|
||||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
|
||||||
|
skip_files = False
|
||||||
prompt = get_last_user_message(data["messages"])
|
prompt = get_last_user_message(data["messages"])
|
||||||
context = ""
|
context = ""
|
||||||
|
|
||||||
# If tool_ids field is present, call the functions
|
# If tool_ids field is present, call the functions
|
||||||
|
|
||||||
skip_files = False
|
|
||||||
if "tool_ids" in data:
|
if "tool_ids" in data:
|
||||||
print(data["tool_ids"])
|
print(data["tool_ids"])
|
||||||
for tool_id in data["tool_ids"]:
|
for tool_id in data["tool_ids"]:
|
||||||
@ -415,8 +408,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
context += ("\n" if context != "" else "") + rag_context
|
context += ("\n" if context != "" else "") + rag_context
|
||||||
|
|
||||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||||
else:
|
|
||||||
return_citations = False
|
if citations:
|
||||||
|
data_items.append({"citations": citations})
|
||||||
|
|
||||||
del data["files"]
|
del data["files"]
|
||||||
|
|
||||||
@ -426,7 +420,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
print(system_prompt)
|
print(system_prompt)
|
||||||
data["messages"] = add_or_update_system_message(
|
data["messages"] = add_or_update_system_message(
|
||||||
f"\n{system_prompt}", data["messages"]
|
system_prompt, data["messages"]
|
||||||
)
|
)
|
||||||
|
|
||||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||||
@ -444,18 +438,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
||||||
if return_citations:
|
# If there are data_items to inject into the response
|
||||||
# Inject the citations into the response
|
if len(data_items) > 0:
|
||||||
if isinstance(response, StreamingResponse):
|
if isinstance(response, StreamingResponse):
|
||||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||||
content_type = response.headers.get("Content-Type")
|
content_type = response.headers.get("Content-Type")
|
||||||
if "text/event-stream" in content_type:
|
if "text/event-stream" in content_type:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
self.openai_stream_wrapper(response.body_iterator, citations),
|
self.openai_stream_wrapper(response.body_iterator, data_items),
|
||||||
)
|
)
|
||||||
if "application/x-ndjson" in content_type:
|
if "application/x-ndjson" in content_type:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
self.ollama_stream_wrapper(response.body_iterator, citations),
|
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@ -463,13 +457,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
async def _receive(self, body: bytes):
|
async def _receive(self, body: bytes):
|
||||||
return {"type": "http.request", "body": body, "more_body": False}
|
return {"type": "http.request", "body": body, "more_body": False}
|
||||||
|
|
||||||
async def openai_stream_wrapper(self, original_generator, citations):
|
async def openai_stream_wrapper(self, original_generator, data_items):
|
||||||
yield f"data: {json.dumps({'citations': citations})}\n\n"
|
for item in data_items:
|
||||||
|
yield f"data: {json.dumps(item)}\n\n"
|
||||||
|
|
||||||
async for data in original_generator:
|
async for data in original_generator:
|
||||||
yield data
|
yield data
|
||||||
|
|
||||||
async def ollama_stream_wrapper(self, original_generator, citations):
|
async def ollama_stream_wrapper(self, original_generator, data_items):
|
||||||
yield f"{json.dumps({'citations': citations})}\n"
|
for item in data_items:
|
||||||
|
yield f"{json.dumps(item)}\n"
|
||||||
|
|
||||||
async for data in original_generator:
|
async for data in original_generator:
|
||||||
yield data
|
yield data
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user