diff --git a/backend/main.py b/backend/main.py index f19e444c3..dbe9d30bf 100644 --- a/backend/main.py +++ b/backend/main.py @@ -51,15 +51,13 @@ from apps.webui.internal.db import Session from pydantic import BaseModel -from typing import Optional, Callable, Awaitable from apps.webui.models.auths import Auths from apps.webui.models.models import Models -from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions from apps.webui.models.users import Users, UserModel -from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id +from apps.webui.utils import load_function_module_by_id from utils.utils import ( get_admin_user, @@ -75,6 +73,8 @@ from utils.task import ( tools_function_calling_generation_template, moa_response_generation_template, ) + +from utils.tools import get_tools from utils.misc import ( get_last_user_message, add_or_update_system_message, @@ -353,80 +353,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content): } -def apply_extra_params_to_tool_function( - function: Callable, extra_params: dict -) -> Callable[..., Awaitable]: - sig = inspect.signature(function) - extra_params = { - key: value for key, value in extra_params.items() if key in sig.parameters - } - is_coroutine = inspect.iscoroutinefunction(function) - - async def new_function(**kwargs): - extra_kwargs = kwargs | extra_params - if is_coroutine: - return await function(**extra_kwargs) - return function(**extra_kwargs) - - return new_function - - -# Mutation on extra_params -def get_tools( - tool_ids: list[str], user: UserModel, extra_params: dict -) -> dict[str, dict]: - tools = {} - for tool_id in tool_ids: - toolkit = Tools.get_tool_by_id(tool_id) - if toolkit is None: - continue - - module = webui_app.state.TOOLS.get(tool_id, None) - if module is None: - module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = module - - extra_params["__id__"] = tool_id - if hasattr(module, "valves") and hasattr(module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) or {} - module.valves = module.Valves(**valves) - - if hasattr(module, "UserValves"): - extra_params["__user__"]["valves"] = module.UserValves( # type: ignore - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - - for spec in toolkit.specs: - # TODO: Fix hack for OpenAI API - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val["type"] == "str": - val["type"] = "string" - function_name = spec["name"] - - # convert to function that takes only model params and inserts custom params - callable = apply_extra_params_to_tool_function( - getattr(module, function_name), extra_params - ) - - # TODO: This needs to be a pydantic model - tool_dict = { - "toolkit_id": tool_id, - "callable": callable, - "spec": spec, - "file_handler": hasattr(module, "file_handler") and module.file_handler, - "citation": hasattr(module, "citation") and module.citation, - } - - # TODO: if collision, prepend toolkit name - if function_name in tools: - log.warning(f"Tool {function_name} already exists in another toolkit!") - log.warning(f"Collision between {toolkit} and {tool_id}.") - log.warning(f"Discarding {toolkit}.{function_name}") - else: - tools[function_name] = tool_dict - return tools - - async def get_content_from_response(response) -> Optional[str]: content = None if hasattr(response, "body_iterator"): @@ -467,7 +393,7 @@ async def chat_completion_tools_handler( "__messages__": body["messages"], "__files__": body.get("files", []), } - tools = get_tools(tool_ids, user, custom_params) + tools = get_tools(webui_app, tool_ids, user, custom_params) log.info(f"{tools=}") specs = [tool["spec"] for tool in tools.values()] diff --git a/backend/utils/tools.py b/backend/utils/tools.py index eac36b5d9..12642ccfd 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -1,5 +1,86 @@ import inspect -from typing import get_type_hints +import logging +from typing import Awaitable, Callable, get_type_hints + +from apps.webui.models.tools import Tools +from apps.webui.models.users import UserModel +from apps.webui.utils import load_toolkit_module_by_id + +log = logging.getLogger(__name__) + + +def apply_extra_params_to_tool_function( + function: Callable, extra_params: dict +) -> Callable[..., Awaitable]: + sig = inspect.signature(function) + extra_params = { + key: value for key, value in extra_params.items() if key in sig.parameters + } + is_coroutine = inspect.iscoroutinefunction(function) + + async def new_function(**kwargs): + extra_kwargs = kwargs | extra_params + if is_coroutine: + return await function(**extra_kwargs) + return function(**extra_kwargs) + + return new_function + + +# Mutation on extra_params +def get_tools( + webui_app, tool_ids: list[str], user: UserModel, extra_params: dict +) -> dict[str, dict]: + tools = {} + for tool_id in tool_ids: + toolkit = Tools.get_tool_by_id(tool_id) + if toolkit is None: + continue + + module = webui_app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = module + + extra_params["__id__"] = tool_id + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + + if hasattr(module, "UserValves"): + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in toolkit.specs: + # TODO: Fix hack for OpenAI API + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" + function_name = spec["name"] + + # convert to function that takes only model params and inserts custom params + callable = apply_extra_params_to_tool_function( + getattr(module, function_name), extra_params + ) + + # TODO: This needs to be a pydantic model + tool_dict = { + "toolkit_id": tool_id, + "callable": callable, + "spec": spec, + "file_handler": hasattr(module, "file_handler") and module.file_handler, + "citation": hasattr(module, "citation") and module.citation, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools: + log.warning(f"Tool {function_name} already exists in another toolkit!") + log.warning(f"Collision between {toolkit} and {tool_id}.") + log.warning(f"Discarding {toolkit}.{function_name}") + else: + tools[function_name] = tool_dict + return tools def doc_to_dict(docstring):