mirror of
https://github.com/open-webui/open-webui
synced 2024-11-24 21:13:59 +00:00
add get_configured_tools
This commit is contained in:
parent
d598d4bb93
commit
6df6170c44
@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
|
|||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional, Callable, Awaitable
|
||||||
|
|
||||||
from apps.webui.models.auths import Auths
|
from apps.webui.models.auths import Auths
|
||||||
from apps.webui.models.models import Models
|
from apps.webui.models.models import Models
|
||||||
from apps.webui.models.tools import Tools
|
from apps.webui.models.tools import Tools
|
||||||
from apps.webui.models.functions import Functions
|
from apps.webui.models.functions import Functions
|
||||||
from apps.webui.models.users import Users, User
|
from apps.webui.models.users import Users, UserModel
|
||||||
|
|
||||||
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
||||||
|
|
||||||
@ -356,6 +356,7 @@ async def get_tool_call_response(
|
|||||||
return None, None, False
|
return None, None, False
|
||||||
|
|
||||||
tools_specs = json.dumps(tool.specs, indent=2)
|
tools_specs = json.dumps(tool.specs, indent=2)
|
||||||
|
log.debug(f"{tool.specs=}")
|
||||||
content = tool_calling_generation_template(template, tools_specs)
|
content = tool_calling_generation_template(template, tools_specs)
|
||||||
payload = get_tool_call_payload(messages, task_model_id, content)
|
payload = get_tool_call_payload(messages, task_model_id, content)
|
||||||
|
|
||||||
@ -492,14 +493,81 @@ async def chat_completion_inlets_handler(body, model, extra_params):
|
|||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_with_custom_params(
|
||||||
|
tool: Callable, custom_params: dict
|
||||||
|
) -> Callable[..., Awaitable]:
|
||||||
|
sig = inspect.signature(tool)
|
||||||
|
extra_params = {
|
||||||
|
key: value for key, value in custom_params.items() if key in sig.parameters
|
||||||
|
}
|
||||||
|
is_coroutine = inspect.iscoroutinefunction(tool)
|
||||||
|
|
||||||
|
async def new_tool(**kwargs):
|
||||||
|
extra_kwargs = kwargs | extra_params
|
||||||
|
if is_coroutine:
|
||||||
|
return await tool(**extra_kwargs)
|
||||||
|
return tool(**extra_kwargs)
|
||||||
|
|
||||||
|
return new_tool
|
||||||
|
|
||||||
|
|
||||||
|
def get_configured_tools(
|
||||||
|
tool_ids: list[str], extra_params: dict, user: UserModel
|
||||||
|
) -> dict[str, dict]:
|
||||||
|
tools = {}
|
||||||
|
for tool_id in tool_ids:
|
||||||
|
toolkit = Tools.get_tool_by_id(tool_id)
|
||||||
|
if toolkit is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module = webui_app.state.TOOLS.get(tool_id, None)
|
||||||
|
if module is None:
|
||||||
|
module, _ = load_toolkit_module_by_id(tool_id)
|
||||||
|
webui_app.state.TOOLS[tool_id] = module
|
||||||
|
|
||||||
|
more_params = {"__id__": tool_id}
|
||||||
|
custom_params = more_params | extra_params
|
||||||
|
has_citation = hasattr(module, "citation") and module.citation
|
||||||
|
handles_files = hasattr(module, "file_handler") and module.file_handler
|
||||||
|
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||||
|
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||||
|
module.valves = module.Valves(**valves)
|
||||||
|
|
||||||
|
if hasattr(module, "UserValves"):
|
||||||
|
custom_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
||||||
|
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
for spec in toolkit.specs:
|
||||||
|
name = spec["name"]
|
||||||
|
callable = getattr(module, name)
|
||||||
|
# convert to function that takes only model params and inserts custom params
|
||||||
|
custom_callable = get_tool_with_custom_params(callable, custom_params)
|
||||||
|
|
||||||
|
tool_dict = {
|
||||||
|
"spec": spec,
|
||||||
|
"citation": has_citation,
|
||||||
|
"file_handler": handles_files,
|
||||||
|
"toolkit_module": module,
|
||||||
|
"callable": custom_callable,
|
||||||
|
}
|
||||||
|
if name in tools:
|
||||||
|
log.warning(f"Tool {name} already exists in another toolkit!")
|
||||||
|
mod_name = tools[name]["toolkit_module"].__name__
|
||||||
|
log.warning(f"Collision between {toolkit} and {mod_name}.")
|
||||||
|
log.warning(f"Discarding {toolkit}.{name}")
|
||||||
|
else:
|
||||||
|
tools[name] = tool_dict
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(
|
async def chat_completion_tools_handler(
|
||||||
body: dict, user: User, extra_params: dict
|
body: dict, user: UserModel, extra_params: dict
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
skip_files = None
|
skip_files = False
|
||||||
|
|
||||||
contexts = []
|
contexts = []
|
||||||
citations = None
|
citations = []
|
||||||
|
|
||||||
task_model_id = get_task_model_id(body["model"])
|
task_model_id = get_task_model_id(body["model"])
|
||||||
|
|
||||||
# If tool_ids field is present, call the functions
|
# If tool_ids field is present, call the functions
|
||||||
@ -507,6 +575,7 @@ async def chat_completion_tools_handler(
|
|||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
log.debug(f"tool_ids: {body['tool_ids']}")
|
log.debug(f"tool_ids: {body['tool_ids']}")
|
||||||
|
log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}")
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"messages": body["messages"],
|
"messages": body["messages"],
|
||||||
"files": body.get("files", []),
|
"files": body.get("files", []),
|
||||||
@ -515,6 +584,7 @@ async def chat_completion_tools_handler(
|
|||||||
"user": user,
|
"user": user,
|
||||||
"extra_params": extra_params,
|
"extra_params": extra_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
for tool_id in body["tool_ids"]:
|
for tool_id in body["tool_ids"]:
|
||||||
log.debug(f"{tool_id=}")
|
log.debug(f"{tool_id=}")
|
||||||
try:
|
try:
|
||||||
@ -526,10 +596,7 @@ async def chat_completion_tools_handler(
|
|||||||
contexts.append(response)
|
contexts.append(response)
|
||||||
|
|
||||||
if citation:
|
if citation:
|
||||||
if citations is None:
|
citations.append(citation)
|
||||||
citations = [citation]
|
|
||||||
else:
|
|
||||||
citations.append(citation)
|
|
||||||
|
|
||||||
if file_handler:
|
if file_handler:
|
||||||
skip_files = True
|
skip_files = True
|
||||||
@ -540,14 +607,10 @@ async def chat_completion_tools_handler(
|
|||||||
del body["tool_ids"]
|
del body["tool_ids"]
|
||||||
log.debug(f"tool_contexts: {contexts}")
|
log.debug(f"tool_contexts: {contexts}")
|
||||||
|
|
||||||
if skip_files:
|
if skip_files and "files" in body:
|
||||||
if "files" in body:
|
del body["files"]
|
||||||
del body["files"]
|
|
||||||
|
|
||||||
return body, {
|
return body, {"contexts": contexts, "citations": citations}
|
||||||
**({"contexts": contexts} if contexts is not None else {}),
|
|
||||||
**({"citations": citations} if citations is not None else {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_files_handler(body):
|
async def chat_completion_files_handler(body):
|
||||||
|
Loading…
Reference in New Issue
Block a user