mirror of
https://github.com/open-webui/open-webui
synced 2025-05-22 13:54:20 +00:00
is_chat_completion_request helper, remove nesting
This commit is contained in:
parent
3befadb29f
commit
589efcdc5f
292
backend/main.py
292
backend/main.py
@ -605,129 +605,126 @@ async def chat_completion_files_handler(body):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_chat_completion_request(request):
|
||||||
|
return request.method == "POST" and any(
|
||||||
|
endpoint in request.url.path
|
||||||
|
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if request.method == "POST" and any(
|
if not is_chat_completion_request(request):
|
||||||
endpoint in request.url.path
|
return await call_next(request)
|
||||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
log.debug(f"request.url.path: {request.url.path}")
|
||||||
):
|
|
||||||
log.debug(f"request.url.path: {request.url.path}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, model, user = await get_body_and_model_and_user(request)
|
body, model, user = await get_body_and_model_and_user(request)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
content={"detail": str(e)},
|
content={"detail": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"chat_id": body.pop("chat_id", None),
|
||||||
|
"message_id": body.pop("id", None),
|
||||||
|
"session_id": body.pop("session_id", None),
|
||||||
|
"valves": body.pop("valves", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
__event_emitter__ = get_event_emitter(metadata)
|
||||||
|
__event_call__ = get_event_call(metadata)
|
||||||
|
|
||||||
|
# Initialize data_items to store additional data to be sent to the client
|
||||||
|
data_items = []
|
||||||
|
|
||||||
|
# Initialize context, and citations
|
||||||
|
contexts = []
|
||||||
|
citations = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
body, flags = await chat_completion_functions_handler(
|
||||||
|
body, model, user, __event_emitter__, __event_call__
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content={"detail": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
body, flags = await chat_completion_tools_handler(
|
||||||
|
body, user, __event_emitter__, __event_call__
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.extend(flags.get("contexts", []))
|
||||||
|
citations.extend(flags.get("citations", []))
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
body, flags = await chat_completion_files_handler(body)
|
||||||
|
|
||||||
|
contexts.extend(flags.get("contexts", []))
|
||||||
|
citations.extend(flags.get("citations", []))
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If context is not empty, insert it into the messages
|
||||||
|
if len(contexts) > 0:
|
||||||
|
context_string = "/n".join(contexts).strip()
|
||||||
|
prompt = get_last_user_message(body["messages"])
|
||||||
|
|
||||||
|
# Workaround for Ollama 2.0+ system prompt issue
|
||||||
|
# TODO: replace with add_or_update_system_message
|
||||||
|
if model["owned_by"] == "ollama":
|
||||||
|
body["messages"] = prepend_to_first_user_message_content(
|
||||||
|
rag_template(
|
||||||
|
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||||
|
),
|
||||||
|
body["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"chat_id": body.pop("chat_id", None),
|
|
||||||
"message_id": body.pop("id", None),
|
|
||||||
"session_id": body.pop("session_id", None),
|
|
||||||
"valves": body.pop("valves", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
__event_emitter__ = get_event_emitter(metadata)
|
|
||||||
__event_call__ = get_event_call(metadata)
|
|
||||||
|
|
||||||
# Initialize data_items to store additional data to be sent to the client
|
|
||||||
data_items = []
|
|
||||||
|
|
||||||
# Initialize context, and citations
|
|
||||||
contexts = []
|
|
||||||
citations = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
body, flags = await chat_completion_functions_handler(
|
|
||||||
body, model, user, __event_emitter__, __event_call__
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
content={"detail": str(e)},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
body, flags = await chat_completion_tools_handler(
|
|
||||||
body, user, __event_emitter__, __event_call__
|
|
||||||
)
|
|
||||||
|
|
||||||
contexts.extend(flags.get("contexts", []))
|
|
||||||
citations.extend(flags.get("citations", []))
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
body, flags = await chat_completion_files_handler(body)
|
|
||||||
|
|
||||||
contexts.extend(flags.get("contexts", []))
|
|
||||||
citations.extend(flags.get("citations", []))
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
|
|
||||||
# If context is not empty, insert it into the messages
|
|
||||||
if len(contexts) > 0:
|
|
||||||
context_string = "/n".join(contexts).strip()
|
|
||||||
prompt = get_last_user_message(body["messages"])
|
|
||||||
|
|
||||||
# Workaround for Ollama 2.0+ system prompt issue
|
|
||||||
# TODO: replace with add_or_update_system_message
|
|
||||||
if model["owned_by"] == "ollama":
|
|
||||||
body["messages"] = prepend_to_first_user_message_content(
|
|
||||||
rag_template(
|
|
||||||
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
||||||
),
|
|
||||||
body["messages"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
body["messages"] = add_or_update_system_message(
|
|
||||||
rag_template(
|
|
||||||
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
||||||
),
|
|
||||||
body["messages"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# If there are citations, add them to the data_items
|
|
||||||
if len(citations) > 0:
|
|
||||||
data_items.append({"citations": citations})
|
|
||||||
|
|
||||||
body["metadata"] = metadata
|
|
||||||
modified_body_bytes = json.dumps(body).encode("utf-8")
|
|
||||||
# Replace the request body with the modified one
|
|
||||||
request._body = modified_body_bytes
|
|
||||||
# Set custom header to ensure content-length matches new body length
|
|
||||||
request.headers.__dict__["_list"] = [
|
|
||||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
||||||
*[
|
|
||||||
(k, v)
|
|
||||||
for k, v in request.headers.raw
|
|
||||||
if k.lower() != b"content-length"
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|
||||||
response = await call_next(request)
|
|
||||||
if isinstance(response, StreamingResponse):
|
|
||||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
|
||||||
content_type = response.headers.get("Content-Type")
|
|
||||||
if "text/event-stream" in content_type:
|
|
||||||
return StreamingResponse(
|
|
||||||
self.openai_stream_wrapper(response.body_iterator, data_items),
|
|
||||||
)
|
|
||||||
if "application/x-ndjson" in content_type:
|
|
||||||
return StreamingResponse(
|
|
||||||
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
else:
|
else:
|
||||||
return response
|
body["messages"] = add_or_update_system_message(
|
||||||
|
rag_template(
|
||||||
|
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||||
|
),
|
||||||
|
body["messages"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# If there are citations, add them to the data_items
|
||||||
|
if len(citations) > 0:
|
||||||
|
data_items.append({"citations": citations})
|
||||||
|
|
||||||
|
body["metadata"] = metadata
|
||||||
|
modified_body_bytes = json.dumps(body).encode("utf-8")
|
||||||
|
# Replace the request body with the modified one
|
||||||
|
request._body = modified_body_bytes
|
||||||
|
# Set custom header to ensure content-length matches new body length
|
||||||
|
request.headers.__dict__["_list"] = [
|
||||||
|
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||||
|
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
||||||
|
]
|
||||||
|
|
||||||
# If it's not a chat completion request, just pass it through
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
if isinstance(response, StreamingResponse):
|
||||||
|
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||||
|
content_type = response.headers.get("Content-Type")
|
||||||
|
if "text/event-stream" in content_type:
|
||||||
|
return StreamingResponse(
|
||||||
|
self.openai_stream_wrapper(response.body_iterator, data_items),
|
||||||
|
)
|
||||||
|
if "application/x-ndjson" in content_type:
|
||||||
|
return StreamingResponse(
|
||||||
|
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
async def _receive(self, body: bytes):
|
async def _receive(self, body: bytes):
|
||||||
return {"type": "http.request", "body": body, "more_body": False}
|
return {"type": "http.request", "body": body, "more_body": False}
|
||||||
@ -820,44 +817,39 @@ def filter_pipeline(payload, user):
|
|||||||
|
|
||||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if request.method == "POST" and (
|
if not is_chat_completion_request(request):
|
||||||
"/ollama/api/chat" in request.url.path
|
return await call_next(request)
|
||||||
or "/chat/completions" in request.url.path
|
|
||||||
):
|
|
||||||
log.debug(f"request.url.path: {request.url.path}")
|
|
||||||
|
|
||||||
# Read the original request body
|
log.debug(f"request.url.path: {request.url.path}")
|
||||||
body = await request.body()
|
|
||||||
# Decode body to string
|
|
||||||
body_str = body.decode("utf-8")
|
|
||||||
# Parse string to JSON
|
|
||||||
data = json.loads(body_str) if body_str else {}
|
|
||||||
|
|
||||||
user = get_current_user(
|
# Read the original request body
|
||||||
request,
|
body = await request.body()
|
||||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
# Decode body to string
|
||||||
|
body_str = body.decode("utf-8")
|
||||||
|
# Parse string to JSON
|
||||||
|
data = json.loads(body_str) if body_str else {}
|
||||||
|
|
||||||
|
user = get_current_user(
|
||||||
|
request,
|
||||||
|
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = filter_pipeline(data, user)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=e.args[0],
|
||||||
|
content={"detail": e.args[1]},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||||
data = filter_pipeline(data, user)
|
# Replace the request body with the modified one
|
||||||
except Exception as e:
|
request._body = modified_body_bytes
|
||||||
return JSONResponse(
|
# Set custom header to ensure content-length matches new body length
|
||||||
status_code=e.args[0],
|
request.headers.__dict__["_list"] = [
|
||||||
content={"detail": e.args[1]},
|
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||||
)
|
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
||||||
|
]
|
||||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
||||||
# Replace the request body with the modified one
|
|
||||||
request._body = modified_body_bytes
|
|
||||||
# Set custom header to ensure content-length matches new body length
|
|
||||||
request.headers.__dict__["_list"] = [
|
|
||||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
||||||
*[
|
|
||||||
(k, v)
|
|
||||||
for k, v in request.headers.raw
|
|
||||||
if k.lower() != b"content-length"
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
Loading…
Reference in New Issue
Block a user