mirror of
https://github.com/open-webui/open-webui
synced 2025-04-10 15:45:45 +00:00
refac: tools
This commit is contained in:
parent
82ca88105c
commit
5a7efad59c
@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, user: UserModel, models, tools
|
||||
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
|
||||
) -> tuple[dict, dict]:
|
||||
async def get_content_from_response(response) -> Optional[str]:
|
||||
content = None
|
||||
@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
|
||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||
}
|
||||
|
||||
event_caller = extra_params["__event_call__"]
|
||||
metadata = extra_params["__metadata__"]
|
||||
|
||||
task_model_id = get_task_model_id(
|
||||
body["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
@ -189,17 +192,33 @@ async def chat_completion_tools_handler(
|
||||
tool_function_params = tool_call.get("parameters", {})
|
||||
|
||||
try:
|
||||
spec = tools[tool_function_name].get("spec", {})
|
||||
tool = tools[tool_function_name]
|
||||
|
||||
spec = tool.get("spec", {})
|
||||
allowed_params = (
|
||||
spec.get("parameters", {}).get("properties", {}).keys()
|
||||
)
|
||||
tool_function = tools[tool_function_name]["callable"]
|
||||
tool_function = tool["callable"]
|
||||
tool_function_params = {
|
||||
k: v
|
||||
for k, v in tool_function_params.items()
|
||||
if k in allowed_params
|
||||
}
|
||||
tool_output = await tool_function(**tool_function_params)
|
||||
|
||||
if tool.get("direct", False):
|
||||
tool_output = await tool_function(**tool_function_params)
|
||||
else:
|
||||
tool_output = await event_caller(
|
||||
{
|
||||
"type": "execute:tool",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"tool": tool,
|
||||
"params": tool_function_params,
|
||||
"session_id": metadata.get("session_id", None),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
tool_output = str(e)
|
||||
@ -764,12 +783,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
# Server side tools
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
# Client side tools
|
||||
tool_specs = form_data.get("tool_specs", None)
|
||||
|
||||
log.debug(f"{tool_ids=}")
|
||||
log.debug(f"{tool_specs=}")
|
||||
|
||||
tools_dict = {}
|
||||
|
||||
if tool_ids:
|
||||
# If tool_ids field is present, then get the tools
|
||||
tools = get_tools(
|
||||
tools_dict = get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
@ -780,20 +805,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
log.info(f"{tools=}")
|
||||
log.info(f"{tools_dict=}")
|
||||
|
||||
if tool_specs:
|
||||
for tool in tool_specs:
|
||||
callable = tool.pop("callable", None)
|
||||
tools_dict[tool["name"]] = {
|
||||
"direct": True,
|
||||
"callable": callable,
|
||||
"spec": tool,
|
||||
}
|
||||
|
||||
if tools_dict:
|
||||
if metadata.get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
metadata["tools"] = tools
|
||||
metadata["tools"] = tools_dict
|
||||
form_data["tools"] = [
|
||||
{"type": "function", "function": tool.get("spec", {})}
|
||||
for tool in tools.values()
|
||||
for tool in tools_dict.values()
|
||||
]
|
||||
else:
|
||||
# If the function calling is not native, then call the tools function calling handler
|
||||
try:
|
||||
form_data, flags = await chat_completion_tools_handler(
|
||||
request, form_data, user, models, tools
|
||||
request, form_data, extra_params, user, models, tools_dict
|
||||
)
|
||||
sources.extend(flags.get("sources", []))
|
||||
|
||||
@ -1774,9 +1809,25 @@ async def process_chat_response(
|
||||
for k, v in tool_function_params.items()
|
||||
if k in allowed_params
|
||||
}
|
||||
tool_result = await tool_function(
|
||||
**tool_function_params
|
||||
)
|
||||
|
||||
if tool.get("direct", False):
|
||||
tool_result = await tool_function(
|
||||
**tool_function_params
|
||||
)
|
||||
else:
|
||||
tool_result = await event_caller(
|
||||
{
|
||||
"type": "execute:tool",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"tool": tool,
|
||||
"params": tool_function_params,
|
||||
"session_id": metadata.get(
|
||||
"session_id", None
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = str(e)
|
||||
|
||||
|
@ -1,6 +1,9 @@
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import inspect
|
||||
import uuid
|
||||
|
||||
from typing import Any, Awaitable, Callable, get_type_hints
|
||||
from functools import update_wrapper, partial
|
||||
|
||||
|
@ -203,6 +203,22 @@
|
||||
};
|
||||
};
|
||||
|
||||
const executeTool = async (data, cb) => {
|
||||
console.log(data);
|
||||
// TODO: MCP (SSE) support
|
||||
// TODO: API Server support
|
||||
|
||||
if (cb) {
|
||||
cb(
|
||||
JSON.parse(
|
||||
JSON.stringify({
|
||||
result: null
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const chatEventHandler = async (event, cb) => {
|
||||
const chat = $page.url.pathname.includes(`/c/${event.chat_id}`);
|
||||
|
||||
@ -256,6 +272,9 @@
|
||||
if (type === 'execute:python') {
|
||||
console.log('execute:python', data);
|
||||
executePythonAsWorker(data.id, data.code, cb);
|
||||
} else if (type === 'execute:tool') {
|
||||
console.log('execute:tool', data);
|
||||
executeTool(data, cb);
|
||||
} else if (type === 'request:chat:completion') {
|
||||
console.log(data, $socket.id);
|
||||
const { session_id, channel, form_data, model } = data;
|
||||
|
Loading…
Reference in New Issue
Block a user