add get_configured_tools

This commit is contained in:
Michael Poluektov 2024-08-12 14:48:57 +01:00
parent d598d4bb93
commit 6df6170c44

View File

@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional, Callable, Awaitable
from apps.webui.models.auths import Auths from apps.webui.models.auths import Auths
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.models.users import Users, User from apps.webui.models.users import Users, UserModel
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
@ -356,6 +356,7 @@ async def get_tool_call_response(
return None, None, False return None, None, False
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
log.debug(f"{tool.specs=}")
content = tool_calling_generation_template(template, tools_specs) content = tool_calling_generation_template(template, tools_specs)
payload = get_tool_call_payload(messages, task_model_id, content) payload = get_tool_call_payload(messages, task_model_id, content)
@ -492,14 +493,81 @@ async def chat_completion_inlets_handler(body, model, extra_params):
return body, {} return body, {}
def get_tool_with_custom_params(
tool: Callable, custom_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(tool)
extra_params = {
key: value for key, value in custom_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(tool)
async def new_tool(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await tool(**extra_kwargs)
return tool(**extra_kwargs)
return new_tool
def get_configured_tools(
tool_ids: list[str], extra_params: dict, user: UserModel
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
more_params = {"__id__": tool_id}
custom_params = more_params | extra_params
has_citation = hasattr(module, "citation") and module.citation
handles_files = hasattr(module, "file_handler") and module.file_handler
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
custom_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
name = spec["name"]
callable = getattr(module, name)
# convert to function that takes only model params and inserts custom params
custom_callable = get_tool_with_custom_params(callable, custom_params)
tool_dict = {
"spec": spec,
"citation": has_citation,
"file_handler": handles_files,
"toolkit_module": module,
"callable": custom_callable,
}
if name in tools:
log.warning(f"Tool {name} already exists in another toolkit!")
mod_name = tools[name]["toolkit_module"].__name__
log.warning(f"Collision between {toolkit} and {mod_name}.")
log.warning(f"Discarding {toolkit}.{name}")
else:
tools[name] = tool_dict
return tools
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
body: dict, user: User, extra_params: dict body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
skip_files = None skip_files = False
contexts = [] contexts = []
citations = None citations = []
task_model_id = get_task_model_id(body["model"]) task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
@ -507,6 +575,7 @@ async def chat_completion_tools_handler(
return body, {} return body, {}
log.debug(f"tool_ids: {body['tool_ids']}") log.debug(f"tool_ids: {body['tool_ids']}")
log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}")
kwargs = { kwargs = {
"messages": body["messages"], "messages": body["messages"],
"files": body.get("files", []), "files": body.get("files", []),
@ -515,6 +584,7 @@ async def chat_completion_tools_handler(
"user": user, "user": user,
"extra_params": extra_params, "extra_params": extra_params,
} }
for tool_id in body["tool_ids"]: for tool_id in body["tool_ids"]:
log.debug(f"{tool_id=}") log.debug(f"{tool_id=}")
try: try:
@ -526,10 +596,7 @@ async def chat_completion_tools_handler(
contexts.append(response) contexts.append(response)
if citation: if citation:
if citations is None: citations.append(citation)
citations = [citation]
else:
citations.append(citation)
if file_handler: if file_handler:
skip_files = True skip_files = True
@ -540,14 +607,10 @@ async def chat_completion_tools_handler(
del body["tool_ids"] del body["tool_ids"]
log.debug(f"tool_contexts: {contexts}") log.debug(f"tool_contexts: {contexts}")
if skip_files: if skip_files and "files" in body:
if "files" in body: del body["files"]
del body["files"]
return body, { return body, {"contexts": contexts, "citations": citations}
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
async def chat_completion_files_handler(body): async def chat_completion_files_handler(body):