diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index 883c34405..d0523ddac 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import Optional @@ -10,7 +9,7 @@ from open_webui.apps.webui.models.tools import ( Tools, ) from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports -from open_webui.config import CACHE_DIR, DATA_DIR +from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs @@ -300,38 +299,35 @@ async def update_tools_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): tools = Tools.get_tool_by_id(id) - if tools: - if id in request.app.state.TOOLS: - tools_module = request.app.state.TOOLS[id] - else: - tools_module, _ = load_tools_module_by_id(id) - request.app.state.TOOLS[id] = tools_module - - if hasattr(tools_module, "Valves"): - Valves = tools_module.Valves - - try: - form_data = {k: v for k, v in form_data.items() if v is not None} - valves = Valves(**form_data) - Tools.update_tool_valves_by_id(id, valves.model_dump()) - return valves.model_dump() - except Exception as e: - print(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(str(e)), - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - else: + if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND, ) + if id in request.app.state.TOOLS: + tools_module = request.app.state.TOOLS[id] + else: + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module + + if not hasattr(tools_module, "Valves"): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + Valves = tools_module.Valves + + try: + form_data = {k: v for k, v in form_data.items() if v is not None} + valves = Valves(**form_data) + Tools.update_tool_valves_by_id(id, valves.model_dump()) + return valves.model_dump() + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(str(e)), + ) ############################ diff --git a/backend/open_webui/utils/schemas.py b/backend/open_webui/utils/schemas.py index 4d1d448cd..21c9824f0 100644 --- a/backend/open_webui/utils/schemas.py +++ b/backend/open_webui/utils/schemas.py @@ -103,7 +103,10 @@ def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any: elif type_ == "null": return Optional[Any] # Use Optional[Any] for nullable fields elif type_ == "literal": - return Literal[literal_eval(json_schema.get("enum"))] + enum = json_schema.get("enum") + if enum is None: + raise ValueError("Enum values must be provided for 'literal' type.") + return Literal[literal_eval(enum)] elif type_ == "optional": inner_schema = json_schema.get("items", {"type": "string"}) inner_type = json_schema_to_pydantic_type(inner_schema) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index e77386ac4..bbce8341f 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -1,11 +1,14 @@ import inspect import logging -from typing import Awaitable, Callable, get_type_hints +import re +from typing import Any, Awaitable, Callable, get_type_hints +from functools import update_wrapper, partial +from langchain_core.utils.function_calling import convert_to_openai_function from open_webui.apps.webui.models.tools import Tools from open_webui.apps.webui.models.users import UserModel from open_webui.apps.webui.utils import load_tools_module_by_id -from open_webui.utils.schemas import json_schema_to_model +from pydantic import BaseModel, Field, create_model log = logging.getLogger(__name__) @@ -13,18 +16,15 @@ log = logging.getLogger(__name__) def apply_extra_params_to_tool_function( function: Callable, extra_params: dict ) -> Callable[..., Awaitable]: - sig = inspect.signature(function) - extra_params = { - key: value for key, value in extra_params.items() if key in sig.parameters - } - is_coroutine = inspect.iscoroutinefunction(function) + partial_func = partial(function, **extra_params) + if inspect.iscoroutinefunction(function): + update_wrapper(partial_func, function) + return partial_func - async def new_function(**kwargs): - extra_kwargs = kwargs | extra_params - if is_coroutine: - return await function(**extra_kwargs) - return function(**extra_kwargs) + async def new_function(*args, **kwargs): + return partial_func(*args, **kwargs) + update_wrapper(new_function, function) return new_function @@ -55,11 +55,6 @@ def get_tools( ) for spec in tools.specs: - # TODO: Fix hack for OpenAI API - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val["type"] == "str": - val["type"] = "string" - # Remove internal parameters spec["parameters"]["properties"] = { key: val @@ -72,15 +67,12 @@ def get_tools( # convert to function that takes only model params and inserts custom params original_func = getattr(module, function_name) callable = apply_extra_params_to_tool_function(original_func, extra_params) - if hasattr(original_func, "__doc__"): - callable.__doc__ = original_func.__doc__ - # TODO: This needs to be a pydantic model tool_dict = { "toolkit_id": tool_id, "callable": callable, "spec": spec, - "pydantic_model": json_schema_to_model(spec), + "pydantic_model": function_to_pydantic_model(callable), "file_handler": hasattr(module, "file_handler") and module.file_handler, "citation": hasattr(module, "citation") and module.citation, } @@ -96,78 +88,75 @@ def get_tools( return tools_dict -def doc_to_dict(docstring): - lines = docstring.split("\n") - description = lines[1].strip() - param_dict = {} +def parse_docstring(docstring): + """ + Parse a function's docstring to extract parameter descriptions in reST format. - for line in lines: - if ":param" in line: - line = line.replace(":param", "").strip() - param, desc = line.split(":", 1) - param_dict[param.strip()] = desc.strip() - ret_dict = {"description": description, "params": param_dict} - return ret_dict + Args: + docstring (str): The docstring to parse. + + Returns: + dict: A dictionary where keys are parameter names and values are descriptions. + """ + if not docstring: + return {} + + # Regex to match `:param name: description` format + param_pattern = re.compile(r":param (\w+):\s*(.+)") + param_descriptions = {} + + for line in docstring.splitlines(): + match = param_pattern.match(line.strip()) + if match: + param_name, param_description = match.groups() + param_descriptions[param_name] = param_description + + return param_descriptions -def get_tools_specs(tools) -> list[dict]: - function_list = [ - {"name": func, "function": getattr(tools, func)} - for func in dir(tools) - if callable(getattr(tools, func)) +def function_to_pydantic_model(func: Callable) -> type[BaseModel]: + """ + Converts a Python function's type hints and docstring to a Pydantic model, + including support for nested types, default values, and descriptions. + + Args: + func: The function whose type hints and docstring should be converted. + model_name: The name of the generated Pydantic model. + + Returns: + A Pydantic model class. + """ + type_hints = get_type_hints(func) + signature = inspect.signature(func) + parameters = signature.parameters + + docstring = func.__doc__ + descriptions = parse_docstring(docstring) + + field_defs = {} + for name, param in parameters.items(): + type_hint = type_hints.get(name, Any) + default_value = param.default if param.default is not param.empty else ... + description = descriptions.get(name, None) + if not description: + field_defs[name] = type_hint, default_value + continue + field_defs[name] = type_hint, Field(default_value, description=description) + + return create_model(func.__name__, **field_defs) + + +def get_callable_attributes(tool: object) -> list[Callable]: + return [ + getattr(tool, func) + for func in dir(tool) + if callable(getattr(tool, func)) and not func.startswith("__") - and not inspect.isclass(getattr(tools, func)) + and not inspect.isclass(getattr(tool, func)) ] - specs = [] - for function_item in function_list: - function_name = function_item["name"] - function = function_item["function"] - function_doc = doc_to_dict(function.__doc__ or function_name) - specs.append( - { - "name": function_name, - # TODO: multi-line desc? - "description": function_doc.get("description", function_name), - "parameters": { - "type": "object", - "properties": { - param_name: { - "type": param_annotation.__name__.lower(), - **( - { - "enum": ( - str(param_annotation.__args__) - if hasattr(param_annotation, "__args__") - else None - ) - } - if hasattr(param_annotation, "__args__") - else {} - ), - "description": function_doc.get("params", {}).get( - param_name, param_name - ), - } - for param_name, param_annotation in get_type_hints( - function - ).items() - if param_name != "return" - and not ( - param_name.startswith("__") and param_name.endswith("__") - ) - }, - "required": [ - name - for name, param in inspect.signature( - function - ).parameters.items() - if param.default is param.empty - and not (name.startswith("__") and name.endswith("__")) - ], - }, - } - ) - - return specs +def get_tools_specs(tool_class: object) -> list[dict]: + function_list = get_callable_attributes(tool_class) + models = map(function_to_pydantic_model, function_list) + return [convert_to_openai_function(tool) for tool in models]