From e825ebbcb9e70e0a43cfcee57bf04472f340f1f6 Mon Sep 17 00:00:00 2001 From: Samuel Date: Sun, 10 Nov 2024 07:32:33 +0000 Subject: [PATCH] feat: the tools are handled using a prompt or the native API mechanism depending on the native_tool_call parameter The older code path is in fact unneeded (once you remove the tools, the new one ) but it's simpler an more tested so I leave it how it was. --- backend/open_webui/main.py | 53 +++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 351f67e58..885a9828d 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -445,7 +445,7 @@ async def handle_streaming_response(request: Request, response: Response, async def stream_wrapper(original_generator, data_items): for item in data_items: yield wrap_item(json.dumps(item)) - + citations = [] body = json.loads(request._body) generator = original_generator @@ -498,7 +498,7 @@ async def handle_streaming_response(request: Request, response: Response, if not tool_call["function"]["arguments"]: tool_function_params = {} - else: + else: tool_function_params = json.loads(tool_call["function"]["arguments"]) log.debug(f"calling {tool_function_name} with params {tool_function_params}") @@ -559,7 +559,6 @@ async def handle_nonstreaming_response(request: Request, response: Response, is_ollama = True is_openai = not is_ollama - while (is_ollama and "tool_calls" in response_dict.get("message", {})) or \ (is_openai and "tool_calls" in response_dict.get("choices", [{}])[0].get("message",{}) ): if is_ollama: @@ -801,6 +800,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): "session_id": body.pop("session_id", None), "tool_ids": body.get("tool_ids", None), "files": body.get("files", None), + "native_tool_call": body.pop("native_tool_call", False), } body["metadata"] = metadata @@ -840,6 +840,21 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): body, tools = get_tools_body(body, user, extra_params) + if model["owned_by"] == "ollama" and \ + body["metadata"]["native_tool_call"] and \ + body.get("stream", False): + log.info("Ollama models don't support function calling in streaming yet. forcing native_tool_call to False") + body["metadata"]["native_tool_call"] = False + + if not body["metadata"]["native_tool_call"]: + del body["tools"] # we won't use those + try: + body, flags = await chat_completion_tools_handler(body, user, extra_params) + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + log.exception(e) + try: body, flags = await chat_completion_files_handler(body) contexts.extend(flags.get("contexts", [])) @@ -884,12 +899,36 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): data_items.append({"citations": citations}) update_body_request(request, body) - first_response = await call_next(request) + response = await call_next(request) - if body.get("stream", False) is False: - return await handle_nonstreaming_response(request, first_response, tools, user, data_items) + if not body["metadata"]["native_tool_call"]: - return await handle_streaming_response(request, first_response, tools, data_items, call_next, user) + if not isinstance(response, StreamingResponse): + return response + content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + if not is_openai and not is_ollama: + return response + + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" + + async def stream_wrapper(original_generator, data_items): + for item in data_items: + yield wrap_item(json.dumps(item)) + + async for data in original_generator: + yield data + + return StreamingResponse( + stream_wrapper(response.body_iterator, data_items), + headers=dict(response.headers), + ) + else: + if not body.get("stream", False): + return await handle_nonstreaming_response(request, response, tools, user, data_items) + return await handle_streaming_response(request, response, tools, data_items, call_next, user) async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False}