From bfc5f6b2f642ce5d6c25dd26802677b921661bc3 Mon Sep 17 00:00:00 2001 From: smonux Date: Sun, 6 Oct 2024 11:09:32 +0200 Subject: [PATCH] feat: Add tool handling and response modification for non-streaming responses --- backend/open_webui/main.py | 85 +++++++++++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 20 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8da4c6a48..1af142e84 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -388,6 +388,63 @@ async def get_content_from_response(response) -> Optional[str]: content = response["choices"][0]["message"]["content"] return content +def get_tools_body( + body: dict, user: UserModel, extra_params: dict +) -> tuple[dict, dict]: + metadata = body.get("metadata", {}) + + tool_ids = metadata.get("tool_ids", None) + log.debug(f"{tool_ids=}") + if not tool_ids: + return body, {} + + task_model_id = get_task_model_id(body["model"]) + tools = get_tools( + webui_app, + tool_ids, + user, + { + **extra_params, + "__model__": app.state.MODELS[task_model_id], + "__messages__": body["messages"], + "__files__": metadata.get("files", []), + }, + ) + log.info(f"{tools=}") + + specs = [tool["spec"] for tool in tools.values()] + + tools_args = [ { "type" : "function", "function" : spec } \ + for spec in specs ] + body["tools"] = tools_args + + return body, tools + +def update_body_request(request: Request, + body: dict) -> None: + modified_body_bytes = json.dumps(body).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] + return None + +async def handle_nonstreaming_response(request : Request, + response : Response, + tools : dict) -> Response: + + response_dict = json.loads(response) + body = json.loads(request._body) + # body["messages"] + while response_dict["choices"][0]["finish_reason"] == "tool_calls": + for tool_call in response_dict["choices"][0]["tool_calls"]: + + + return response + async def chat_completion_tools_handler( body: dict, user: UserModel, extra_params: dict @@ -601,12 +658,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): } body["metadata"] = metadata - try: - body, flags = await chat_completion_tools_handler(body, user, extra_params) - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - log.exception(e) + body, tools = get_tools_body(body, user, extra_params) try: body, flags = await chat_completion_files_handler(body) @@ -651,20 +703,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if len(citations) > 0: data_items.append({"citations": citations}) - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], - ] - - response = await call_next(request) - if not isinstance(response, StreamingResponse): + update_body_request(request, body) + first_response = await call_next(request) + if not isinstance(first_response, StreamingResponse): + response = await handle_nonstreaming_response(request, response, tools) return response - content_type = response.headers["Content-Type"] + content_type = first_response.headers["Content-Type"] is_openai = "text/event-stream" in content_type is_ollama = "application/x-ndjson" in content_type if not is_openai and not is_ollama: @@ -681,8 +726,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): yield data return StreamingResponse( - stream_wrapper(response.body_iterator, data_items), - headers=dict(response.headers), + stream_wrapper(first_response.body_iterator, data_items), + headers=dict(first_response.headers), ) async def _receive(self, body: bytes):