mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
refac: tools
This commit is contained in:
parent
82ca88105c
commit
5a7efad59c
@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(
|
async def chat_completion_tools_handler(
|
||||||
request: Request, body: dict, user: UserModel, models, tools
|
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
async def get_content_from_response(response) -> Optional[str]:
|
async def get_content_from_response(response) -> Optional[str]:
|
||||||
content = None
|
content = None
|
||||||
@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
|
|||||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
event_caller = extra_params["__event_call__"]
|
||||||
|
metadata = extra_params["__metadata__"]
|
||||||
|
|
||||||
task_model_id = get_task_model_id(
|
task_model_id = get_task_model_id(
|
||||||
body["model"],
|
body["model"],
|
||||||
request.app.state.config.TASK_MODEL,
|
request.app.state.config.TASK_MODEL,
|
||||||
@ -189,17 +192,33 @@ async def chat_completion_tools_handler(
|
|||||||
tool_function_params = tool_call.get("parameters", {})
|
tool_function_params = tool_call.get("parameters", {})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
spec = tools[tool_function_name].get("spec", {})
|
tool = tools[tool_function_name]
|
||||||
|
|
||||||
|
spec = tool.get("spec", {})
|
||||||
allowed_params = (
|
allowed_params = (
|
||||||
spec.get("parameters", {}).get("properties", {}).keys()
|
spec.get("parameters", {}).get("properties", {}).keys()
|
||||||
)
|
)
|
||||||
tool_function = tools[tool_function_name]["callable"]
|
tool_function = tool["callable"]
|
||||||
tool_function_params = {
|
tool_function_params = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in tool_function_params.items()
|
for k, v in tool_function_params.items()
|
||||||
if k in allowed_params
|
if k in allowed_params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tool.get("direct", False):
|
||||||
tool_output = await tool_function(**tool_function_params)
|
tool_output = await tool_function(**tool_function_params)
|
||||||
|
else:
|
||||||
|
tool_output = await event_caller(
|
||||||
|
{
|
||||||
|
"type": "execute:tool",
|
||||||
|
"data": {
|
||||||
|
"id": str(uuid4()),
|
||||||
|
"tool": tool,
|
||||||
|
"params": tool_function_params,
|
||||||
|
"session_id": metadata.get("session_id", None),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tool_output = str(e)
|
tool_output = str(e)
|
||||||
@ -764,12 +783,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||||||
}
|
}
|
||||||
form_data["metadata"] = metadata
|
form_data["metadata"] = metadata
|
||||||
|
|
||||||
|
# Server side tools
|
||||||
tool_ids = metadata.get("tool_ids", None)
|
tool_ids = metadata.get("tool_ids", None)
|
||||||
|
# Client side tools
|
||||||
|
tool_specs = form_data.get("tool_specs", None)
|
||||||
|
|
||||||
log.debug(f"{tool_ids=}")
|
log.debug(f"{tool_ids=}")
|
||||||
|
log.debug(f"{tool_specs=}")
|
||||||
|
|
||||||
|
tools_dict = {}
|
||||||
|
|
||||||
if tool_ids:
|
if tool_ids:
|
||||||
# If tool_ids field is present, then get the tools
|
tools_dict = get_tools(
|
||||||
tools = get_tools(
|
|
||||||
request,
|
request,
|
||||||
tool_ids,
|
tool_ids,
|
||||||
user,
|
user,
|
||||||
@ -780,20 +805,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||||||
"__files__": metadata.get("files", []),
|
"__files__": metadata.get("files", []),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log.info(f"{tools=}")
|
log.info(f"{tools_dict=}")
|
||||||
|
|
||||||
|
if tool_specs:
|
||||||
|
for tool in tool_specs:
|
||||||
|
callable = tool.pop("callable", None)
|
||||||
|
tools_dict[tool["name"]] = {
|
||||||
|
"direct": True,
|
||||||
|
"callable": callable,
|
||||||
|
"spec": tool,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools_dict:
|
||||||
if metadata.get("function_calling") == "native":
|
if metadata.get("function_calling") == "native":
|
||||||
# If the function calling is native, then call the tools function calling handler
|
# If the function calling is native, then call the tools function calling handler
|
||||||
metadata["tools"] = tools
|
metadata["tools"] = tools_dict
|
||||||
form_data["tools"] = [
|
form_data["tools"] = [
|
||||||
{"type": "function", "function": tool.get("spec", {})}
|
{"type": "function", "function": tool.get("spec", {})}
|
||||||
for tool in tools.values()
|
for tool in tools_dict.values()
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# If the function calling is not native, then call the tools function calling handler
|
# If the function calling is not native, then call the tools function calling handler
|
||||||
try:
|
try:
|
||||||
form_data, flags = await chat_completion_tools_handler(
|
form_data, flags = await chat_completion_tools_handler(
|
||||||
request, form_data, user, models, tools
|
request, form_data, extra_params, user, models, tools_dict
|
||||||
)
|
)
|
||||||
sources.extend(flags.get("sources", []))
|
sources.extend(flags.get("sources", []))
|
||||||
|
|
||||||
@ -1774,9 +1809,25 @@ async def process_chat_response(
|
|||||||
for k, v in tool_function_params.items()
|
for k, v in tool_function_params.items()
|
||||||
if k in allowed_params
|
if k in allowed_params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tool.get("direct", False):
|
||||||
tool_result = await tool_function(
|
tool_result = await tool_function(
|
||||||
**tool_function_params
|
**tool_function_params
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
tool_result = await event_caller(
|
||||||
|
{
|
||||||
|
"type": "execute:tool",
|
||||||
|
"data": {
|
||||||
|
"id": str(uuid4()),
|
||||||
|
"tool": tool,
|
||||||
|
"params": tool_function_params,
|
||||||
|
"session_id": metadata.get(
|
||||||
|
"session_id", None
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tool_result = str(e)
|
tool_result = str(e)
|
||||||
|
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import inspect
|
||||||
|
import uuid
|
||||||
|
|
||||||
from typing import Any, Awaitable, Callable, get_type_hints
|
from typing import Any, Awaitable, Callable, get_type_hints
|
||||||
from functools import update_wrapper, partial
|
from functools import update_wrapper, partial
|
||||||
|
|
||||||
|
@ -203,6 +203,22 @@
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const executeTool = async (data, cb) => {
|
||||||
|
console.log(data);
|
||||||
|
// TODO: MCP (SSE) support
|
||||||
|
// TODO: API Server support
|
||||||
|
|
||||||
|
if (cb) {
|
||||||
|
cb(
|
||||||
|
JSON.parse(
|
||||||
|
JSON.stringify({
|
||||||
|
result: null
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const chatEventHandler = async (event, cb) => {
|
const chatEventHandler = async (event, cb) => {
|
||||||
const chat = $page.url.pathname.includes(`/c/${event.chat_id}`);
|
const chat = $page.url.pathname.includes(`/c/${event.chat_id}`);
|
||||||
|
|
||||||
@ -256,6 +272,9 @@
|
|||||||
if (type === 'execute:python') {
|
if (type === 'execute:python') {
|
||||||
console.log('execute:python', data);
|
console.log('execute:python', data);
|
||||||
executePythonAsWorker(data.id, data.code, cb);
|
executePythonAsWorker(data.id, data.code, cb);
|
||||||
|
} else if (type === 'execute:tool') {
|
||||||
|
console.log('execute:tool', data);
|
||||||
|
executeTool(data, cb);
|
||||||
} else if (type === 'request:chat:completion') {
|
} else if (type === 'request:chat:completion') {
|
||||||
console.log(data, $socket.id);
|
console.log(data, $socket.id);
|
||||||
const { session_id, channel, form_data, model } = data;
|
const { session_id, channel, form_data, model } = data;
|
||||||
|
Loading…
Reference in New Issue
Block a user