diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index c7af4a16c..b5d916e1d 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -6,12 +6,28 @@ import aiohttp import asyncio import yaml -from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union, Optional +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from typing import ( + Any, + Awaitable, + Callable, + get_type_hints, + get_args, + get_origin, + Dict, + List, + Tuple, + Union, + Optional, + Type, +) from functools import update_wrapper, partial from fastapi import Request from pydantic import BaseModel, Field, create_model + from langchain_core.utils.function_calling import ( convert_to_openai_function as convert_pydantic_model_to_openai_function_spec, ) @@ -259,22 +275,25 @@ def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]: parameters = signature.parameters docstring = func.__doc__ - descriptions = parse_docstring(docstring) - tool_description = parse_description(docstring) + description = parse_description(docstring) + function_descriptions = parse_docstring(docstring) field_defs = {} for name, param in parameters.items(): + type_hint = type_hints.get(name, Any) default_value = param.default if param.default is not param.empty else ... - description = descriptions.get(name, None) - if not description: + + description = function_descriptions.get(name, None) + + if description: + field_defs[name] = type_hint, Field(default_value, description=description) + else: field_defs[name] = type_hint, default_value - continue - field_defs[name] = type_hint, Field(default_value, description=description) model = create_model(func.__name__, **field_defs) - model.__doc__ = tool_description + model.__doc__ = description return model @@ -300,11 +319,13 @@ def get_tool_specs(tool_module: object) -> list[dict]: convert_function_to_pydantic_model, get_functions_from_tool(tool_module) ) - return [ + specs = [ convert_pydantic_model_to_openai_function_spec(function_model) for function_model in function_models ] + return specs + def resolve_schema(schema, components): """ diff --git a/src/lib/components/chat/Messages/UserMessage.svelte b/src/lib/components/chat/Messages/UserMessage.svelte index 99f1351de..605ab6352 100644 --- a/src/lib/components/chat/Messages/UserMessage.svelte +++ b/src/lib/components/chat/Messages/UserMessage.svelte @@ -192,7 +192,7 @@