mirror of
https://github.com/open-webui/open-webui
synced 2024-11-26 22:09:51 +00:00
feat: Add tool handling and response modification for non-streaming responses
This commit is contained in:
parent
1f488a0072
commit
bfc5f6b2f6
@ -388,6 +388,63 @@ async def get_content_from_response(response) -> Optional[str]:
|
|||||||
content = response["choices"][0]["message"]["content"]
|
content = response["choices"][0]["message"]["content"]
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
def get_tools_body(
|
||||||
|
body: dict, user: UserModel, extra_params: dict
|
||||||
|
) -> tuple[dict, dict]:
|
||||||
|
metadata = body.get("metadata", {})
|
||||||
|
|
||||||
|
tool_ids = metadata.get("tool_ids", None)
|
||||||
|
log.debug(f"{tool_ids=}")
|
||||||
|
if not tool_ids:
|
||||||
|
return body, {}
|
||||||
|
|
||||||
|
task_model_id = get_task_model_id(body["model"])
|
||||||
|
tools = get_tools(
|
||||||
|
webui_app,
|
||||||
|
tool_ids,
|
||||||
|
user,
|
||||||
|
{
|
||||||
|
**extra_params,
|
||||||
|
"__model__": app.state.MODELS[task_model_id],
|
||||||
|
"__messages__": body["messages"],
|
||||||
|
"__files__": metadata.get("files", []),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
log.info(f"{tools=}")
|
||||||
|
|
||||||
|
specs = [tool["spec"] for tool in tools.values()]
|
||||||
|
|
||||||
|
tools_args = [ { "type" : "function", "function" : spec } \
|
||||||
|
for spec in specs ]
|
||||||
|
body["tools"] = tools_args
|
||||||
|
|
||||||
|
return body, tools
|
||||||
|
|
||||||
|
def update_body_request(request: Request,
|
||||||
|
body: dict) -> None:
|
||||||
|
modified_body_bytes = json.dumps(body).encode("utf-8")
|
||||||
|
# Replace the request body with the modified one
|
||||||
|
request._body = modified_body_bytes
|
||||||
|
# Set custom header to ensure content-length matches new body length
|
||||||
|
request.headers.__dict__["_list"] = [
|
||||||
|
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||||
|
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
||||||
|
]
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def handle_nonstreaming_response(request : Request,
|
||||||
|
response : Response,
|
||||||
|
tools : dict) -> Response:
|
||||||
|
|
||||||
|
response_dict = json.loads(response)
|
||||||
|
body = json.loads(request._body)
|
||||||
|
# body["messages"]
|
||||||
|
while response_dict["choices"][0]["finish_reason"] == "tool_calls":
|
||||||
|
for tool_call in response_dict["choices"][0]["tool_calls"]:
|
||||||
|
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(
|
async def chat_completion_tools_handler(
|
||||||
body: dict, user: UserModel, extra_params: dict
|
body: dict, user: UserModel, extra_params: dict
|
||||||
@ -601,12 +658,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
}
|
}
|
||||||
body["metadata"] = metadata
|
body["metadata"] = metadata
|
||||||
|
|
||||||
try:
|
body, tools = get_tools_body(body, user, extra_params)
|
||||||
body, flags = await chat_completion_tools_handler(body, user, extra_params)
|
|
||||||
contexts.extend(flags.get("contexts", []))
|
|
||||||
citations.extend(flags.get("citations", []))
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(e)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, flags = await chat_completion_files_handler(body)
|
body, flags = await chat_completion_files_handler(body)
|
||||||
@ -651,20 +703,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
if len(citations) > 0:
|
if len(citations) > 0:
|
||||||
data_items.append({"citations": citations})
|
data_items.append({"citations": citations})
|
||||||
|
|
||||||
modified_body_bytes = json.dumps(body).encode("utf-8")
|
update_body_request(request, body)
|
||||||
# Replace the request body with the modified one
|
first_response = await call_next(request)
|
||||||
request._body = modified_body_bytes
|
if not isinstance(first_response, StreamingResponse):
|
||||||
# Set custom header to ensure content-length matches new body length
|
response = await handle_nonstreaming_response(request, response, tools)
|
||||||
request.headers.__dict__["_list"] = [
|
|
||||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
||||||
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
|
||||||
]
|
|
||||||
|
|
||||||
response = await call_next(request)
|
|
||||||
if not isinstance(response, StreamingResponse):
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
content_type = response.headers["Content-Type"]
|
content_type = first_response.headers["Content-Type"]
|
||||||
is_openai = "text/event-stream" in content_type
|
is_openai = "text/event-stream" in content_type
|
||||||
is_ollama = "application/x-ndjson" in content_type
|
is_ollama = "application/x-ndjson" in content_type
|
||||||
if not is_openai and not is_ollama:
|
if not is_openai and not is_ollama:
|
||||||
@ -681,8 +726,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
yield data
|
yield data
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_wrapper(response.body_iterator, data_items),
|
stream_wrapper(first_response.body_iterator, data_items),
|
||||||
headers=dict(response.headers),
|
headers=dict(first_response.headers),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _receive(self, body: bytes):
|
async def _receive(self, body: bytes):
|
||||||
|
Loading…
Reference in New Issue
Block a user