mirror of
https://github.com/open-webui/open-webui
synced 2024-11-26 22:09:51 +00:00
feat: the tools are handled using a prompt or the native API mechanism depending on the native_tool_call parameter
The older code path is in fact unneeded (once you remove the tools, the new one ) but it's simpler an more tested so I leave it how it was.
This commit is contained in:
parent
c4fd39fe83
commit
e825ebbcb9
@ -445,7 +445,7 @@ async def handle_streaming_response(request: Request, response: Response,
|
|||||||
async def stream_wrapper(original_generator, data_items):
|
async def stream_wrapper(original_generator, data_items):
|
||||||
for item in data_items:
|
for item in data_items:
|
||||||
yield wrap_item(json.dumps(item))
|
yield wrap_item(json.dumps(item))
|
||||||
|
|
||||||
citations = []
|
citations = []
|
||||||
body = json.loads(request._body)
|
body = json.loads(request._body)
|
||||||
generator = original_generator
|
generator = original_generator
|
||||||
@ -498,7 +498,7 @@ async def handle_streaming_response(request: Request, response: Response,
|
|||||||
|
|
||||||
if not tool_call["function"]["arguments"]:
|
if not tool_call["function"]["arguments"]:
|
||||||
tool_function_params = {}
|
tool_function_params = {}
|
||||||
else:
|
else:
|
||||||
tool_function_params = json.loads(tool_call["function"]["arguments"])
|
tool_function_params = json.loads(tool_call["function"]["arguments"])
|
||||||
|
|
||||||
log.debug(f"calling {tool_function_name} with params {tool_function_params}")
|
log.debug(f"calling {tool_function_name} with params {tool_function_params}")
|
||||||
@ -559,7 +559,6 @@ async def handle_nonstreaming_response(request: Request, response: Response,
|
|||||||
is_ollama = True
|
is_ollama = True
|
||||||
is_openai = not is_ollama
|
is_openai = not is_ollama
|
||||||
|
|
||||||
|
|
||||||
while (is_ollama and "tool_calls" in response_dict.get("message", {})) or \
|
while (is_ollama and "tool_calls" in response_dict.get("message", {})) or \
|
||||||
(is_openai and "tool_calls" in response_dict.get("choices", [{}])[0].get("message",{}) ):
|
(is_openai and "tool_calls" in response_dict.get("choices", [{}])[0].get("message",{}) ):
|
||||||
if is_ollama:
|
if is_ollama:
|
||||||
@ -801,6 +800,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
"session_id": body.pop("session_id", None),
|
"session_id": body.pop("session_id", None),
|
||||||
"tool_ids": body.get("tool_ids", None),
|
"tool_ids": body.get("tool_ids", None),
|
||||||
"files": body.get("files", None),
|
"files": body.get("files", None),
|
||||||
|
"native_tool_call": body.pop("native_tool_call", False),
|
||||||
}
|
}
|
||||||
body["metadata"] = metadata
|
body["metadata"] = metadata
|
||||||
|
|
||||||
@ -840,6 +840,21 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
body, tools = get_tools_body(body, user, extra_params)
|
body, tools = get_tools_body(body, user, extra_params)
|
||||||
|
|
||||||
|
if model["owned_by"] == "ollama" and \
|
||||||
|
body["metadata"]["native_tool_call"] and \
|
||||||
|
body.get("stream", False):
|
||||||
|
log.info("Ollama models don't support function calling in streaming yet. forcing native_tool_call to False")
|
||||||
|
body["metadata"]["native_tool_call"] = False
|
||||||
|
|
||||||
|
if not body["metadata"]["native_tool_call"]:
|
||||||
|
del body["tools"] # we won't use those
|
||||||
|
try:
|
||||||
|
body, flags = await chat_completion_tools_handler(body, user, extra_params)
|
||||||
|
contexts.extend(flags.get("contexts", []))
|
||||||
|
citations.extend(flags.get("citations", []))
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, flags = await chat_completion_files_handler(body)
|
body, flags = await chat_completion_files_handler(body)
|
||||||
contexts.extend(flags.get("contexts", []))
|
contexts.extend(flags.get("contexts", []))
|
||||||
@ -884,12 +899,36 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
data_items.append({"citations": citations})
|
data_items.append({"citations": citations})
|
||||||
|
|
||||||
update_body_request(request, body)
|
update_body_request(request, body)
|
||||||
first_response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
||||||
if body.get("stream", False) is False:
|
if not body["metadata"]["native_tool_call"]:
|
||||||
return await handle_nonstreaming_response(request, first_response, tools, user, data_items)
|
|
||||||
|
|
||||||
return await handle_streaming_response(request, first_response, tools, data_items, call_next, user)
|
if not isinstance(response, StreamingResponse):
|
||||||
|
return response
|
||||||
|
content_type = response.headers["Content-Type"]
|
||||||
|
is_openai = "text/event-stream" in content_type
|
||||||
|
is_ollama = "application/x-ndjson" in content_type
|
||||||
|
if not is_openai and not is_ollama:
|
||||||
|
return response
|
||||||
|
|
||||||
|
def wrap_item(item):
|
||||||
|
return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
||||||
|
|
||||||
|
async def stream_wrapper(original_generator, data_items):
|
||||||
|
for item in data_items:
|
||||||
|
yield wrap_item(json.dumps(item))
|
||||||
|
|
||||||
|
async for data in original_generator:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_wrapper(response.body_iterator, data_items),
|
||||||
|
headers=dict(response.headers),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not body.get("stream", False):
|
||||||
|
return await handle_nonstreaming_response(request, response, tools, user, data_items)
|
||||||
|
return await handle_streaming_response(request, response, tools, data_items, call_next, user)
|
||||||
|
|
||||||
async def _receive(self, body: bytes):
|
async def _receive(self, body: bytes):
|
||||||
return {"type": "http.request", "body": body, "more_body": False}
|
return {"type": "http.request", "body": body, "more_body": False}
|
||||||
|
Loading…
Reference in New Issue
Block a user