From 6df6170c4493020baf937f18b609281523dec582 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 12 Aug 2024 14:48:57 +0100 Subject: [PATCH] add get_configured_tools --- backend/main.py | 99 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 81 insertions(+), 18 deletions(-) diff --git a/backend/main.py b/backend/main.py index 0a1768fd6..50b83f437 100644 --- a/backend/main.py +++ b/backend/main.py @@ -51,13 +51,13 @@ from apps.webui.internal.db import Session from pydantic import BaseModel -from typing import Optional +from typing import Optional, Callable, Awaitable from apps.webui.models.auths import Auths from apps.webui.models.models import Models from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions -from apps.webui.models.users import Users, User +from apps.webui.models.users import Users, UserModel from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id @@ -356,6 +356,7 @@ async def get_tool_call_response( return None, None, False tools_specs = json.dumps(tool.specs, indent=2) + log.debug(f"{tool.specs=}") content = tool_calling_generation_template(template, tools_specs) payload = get_tool_call_payload(messages, task_model_id, content) @@ -492,14 +493,81 @@ async def chat_completion_inlets_handler(body, model, extra_params): return body, {} +def get_tool_with_custom_params( + tool: Callable, custom_params: dict +) -> Callable[..., Awaitable]: + sig = inspect.signature(tool) + extra_params = { + key: value for key, value in custom_params.items() if key in sig.parameters + } + is_coroutine = inspect.iscoroutinefunction(tool) + + async def new_tool(**kwargs): + extra_kwargs = kwargs | extra_params + if is_coroutine: + return await tool(**extra_kwargs) + return tool(**extra_kwargs) + + return new_tool + + +def get_configured_tools( + tool_ids: list[str], extra_params: dict, user: UserModel +) -> dict[str, dict]: + tools = {} + for tool_id in tool_ids: + toolkit = Tools.get_tool_by_id(tool_id) + if toolkit is None: + continue + + module = webui_app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = module + + more_params = {"__id__": tool_id} + custom_params = more_params | extra_params + has_citation = hasattr(module, "citation") and module.citation + handles_files = hasattr(module, "file_handler") and module.file_handler + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + + if hasattr(module, "UserValves"): + custom_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in toolkit.specs: + name = spec["name"] + callable = getattr(module, name) + # convert to function that takes only model params and inserts custom params + custom_callable = get_tool_with_custom_params(callable, custom_params) + + tool_dict = { + "spec": spec, + "citation": has_citation, + "file_handler": handles_files, + "toolkit_module": module, + "callable": custom_callable, + } + if name in tools: + log.warning(f"Tool {name} already exists in another toolkit!") + mod_name = tools[name]["toolkit_module"].__name__ + log.warning(f"Collision between {toolkit} and {mod_name}.") + log.warning(f"Discarding {toolkit}.{name}") + else: + tools[name] = tool_dict + + return tools + + async def chat_completion_tools_handler( - body: dict, user: User, extra_params: dict + body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: - skip_files = None - + skip_files = False contexts = [] - citations = None - + citations = [] task_model_id = get_task_model_id(body["model"]) # If tool_ids field is present, call the functions @@ -507,6 +575,7 @@ async def chat_completion_tools_handler( return body, {} log.debug(f"tool_ids: {body['tool_ids']}") + log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}") kwargs = { "messages": body["messages"], "files": body.get("files", []), @@ -515,6 +584,7 @@ async def chat_completion_tools_handler( "user": user, "extra_params": extra_params, } + for tool_id in body["tool_ids"]: log.debug(f"{tool_id=}") try: @@ -526,10 +596,7 @@ async def chat_completion_tools_handler( contexts.append(response) if citation: - if citations is None: - citations = [citation] - else: - citations.append(citation) + citations.append(citation) if file_handler: skip_files = True @@ -540,14 +607,10 @@ async def chat_completion_tools_handler( del body["tool_ids"] log.debug(f"tool_contexts: {contexts}") - if skip_files: - if "files" in body: - del body["files"] + if skip_files and "files" in body: + del body["files"] - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), - } + return body, {"contexts": contexts, "citations": citations} async def chat_completion_files_handler(body):