From 70838148e77f77b1577899e416c5871ba36260b5 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 21 Nov 2024 17:19:56 +0000 Subject: [PATCH 1/4] fix tools metadata --- .../open_webui/apps/webui/routers/tools.py | 56 +++--- backend/open_webui/utils/schemas.py | 5 +- backend/open_webui/utils/tools.py | 167 ++++++++---------- 3 files changed, 108 insertions(+), 120 deletions(-) diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index 883c34405..d0523ddac 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import Optional @@ -10,7 +9,7 @@ from open_webui.apps.webui.models.tools import ( Tools, ) from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports -from open_webui.config import CACHE_DIR, DATA_DIR +from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs @@ -300,38 +299,35 @@ async def update_tools_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): tools = Tools.get_tool_by_id(id) - if tools: - if id in request.app.state.TOOLS: - tools_module = request.app.state.TOOLS[id] - else: - tools_module, _ = load_tools_module_by_id(id) - request.app.state.TOOLS[id] = tools_module - - if hasattr(tools_module, "Valves"): - Valves = tools_module.Valves - - try: - form_data = {k: v for k, v in form_data.items() if v is not None} - valves = Valves(**form_data) - Tools.update_tool_valves_by_id(id, valves.model_dump()) - return valves.model_dump() - except Exception as e: - print(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(str(e)), - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - else: + if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND, ) + if id in request.app.state.TOOLS: + tools_module = request.app.state.TOOLS[id] + else: + tools_module, _ = load_tools_module_by_id(id) + request.app.state.TOOLS[id] = tools_module + + if not hasattr(tools_module, "Valves"): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + Valves = tools_module.Valves + + try: + form_data = {k: v for k, v in form_data.items() if v is not None} + valves = Valves(**form_data) + Tools.update_tool_valves_by_id(id, valves.model_dump()) + return valves.model_dump() + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(str(e)), + ) ############################ diff --git a/backend/open_webui/utils/schemas.py b/backend/open_webui/utils/schemas.py index 4d1d448cd..21c9824f0 100644 --- a/backend/open_webui/utils/schemas.py +++ b/backend/open_webui/utils/schemas.py @@ -103,7 +103,10 @@ def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any: elif type_ == "null": return Optional[Any] # Use Optional[Any] for nullable fields elif type_ == "literal": - return Literal[literal_eval(json_schema.get("enum"))] + enum = json_schema.get("enum") + if enum is None: + raise ValueError("Enum values must be provided for 'literal' type.") + return Literal[literal_eval(enum)] elif type_ == "optional": inner_schema = json_schema.get("items", {"type": "string"}) inner_type = json_schema_to_pydantic_type(inner_schema) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index e77386ac4..bbce8341f 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -1,11 +1,14 @@ import inspect import logging -from typing import Awaitable, Callable, get_type_hints +import re +from typing import Any, Awaitable, Callable, get_type_hints +from functools import update_wrapper, partial +from langchain_core.utils.function_calling import convert_to_openai_function from open_webui.apps.webui.models.tools import Tools from open_webui.apps.webui.models.users import UserModel from open_webui.apps.webui.utils import load_tools_module_by_id -from open_webui.utils.schemas import json_schema_to_model +from pydantic import BaseModel, Field, create_model log = logging.getLogger(__name__) @@ -13,18 +16,15 @@ log = logging.getLogger(__name__) def apply_extra_params_to_tool_function( function: Callable, extra_params: dict ) -> Callable[..., Awaitable]: - sig = inspect.signature(function) - extra_params = { - key: value for key, value in extra_params.items() if key in sig.parameters - } - is_coroutine = inspect.iscoroutinefunction(function) + partial_func = partial(function, **extra_params) + if inspect.iscoroutinefunction(function): + update_wrapper(partial_func, function) + return partial_func - async def new_function(**kwargs): - extra_kwargs = kwargs | extra_params - if is_coroutine: - return await function(**extra_kwargs) - return function(**extra_kwargs) + async def new_function(*args, **kwargs): + return partial_func(*args, **kwargs) + update_wrapper(new_function, function) return new_function @@ -55,11 +55,6 @@ def get_tools( ) for spec in tools.specs: - # TODO: Fix hack for OpenAI API - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val["type"] == "str": - val["type"] = "string" - # Remove internal parameters spec["parameters"]["properties"] = { key: val @@ -72,15 +67,12 @@ def get_tools( # convert to function that takes only model params and inserts custom params original_func = getattr(module, function_name) callable = apply_extra_params_to_tool_function(original_func, extra_params) - if hasattr(original_func, "__doc__"): - callable.__doc__ = original_func.__doc__ - # TODO: This needs to be a pydantic model tool_dict = { "toolkit_id": tool_id, "callable": callable, "spec": spec, - "pydantic_model": json_schema_to_model(spec), + "pydantic_model": function_to_pydantic_model(callable), "file_handler": hasattr(module, "file_handler") and module.file_handler, "citation": hasattr(module, "citation") and module.citation, } @@ -96,78 +88,75 @@ def get_tools( return tools_dict -def doc_to_dict(docstring): - lines = docstring.split("\n") - description = lines[1].strip() - param_dict = {} +def parse_docstring(docstring): + """ + Parse a function's docstring to extract parameter descriptions in reST format. - for line in lines: - if ":param" in line: - line = line.replace(":param", "").strip() - param, desc = line.split(":", 1) - param_dict[param.strip()] = desc.strip() - ret_dict = {"description": description, "params": param_dict} - return ret_dict + Args: + docstring (str): The docstring to parse. + + Returns: + dict: A dictionary where keys are parameter names and values are descriptions. + """ + if not docstring: + return {} + + # Regex to match `:param name: description` format + param_pattern = re.compile(r":param (\w+):\s*(.+)") + param_descriptions = {} + + for line in docstring.splitlines(): + match = param_pattern.match(line.strip()) + if match: + param_name, param_description = match.groups() + param_descriptions[param_name] = param_description + + return param_descriptions -def get_tools_specs(tools) -> list[dict]: - function_list = [ - {"name": func, "function": getattr(tools, func)} - for func in dir(tools) - if callable(getattr(tools, func)) +def function_to_pydantic_model(func: Callable) -> type[BaseModel]: + """ + Converts a Python function's type hints and docstring to a Pydantic model, + including support for nested types, default values, and descriptions. + + Args: + func: The function whose type hints and docstring should be converted. + model_name: The name of the generated Pydantic model. + + Returns: + A Pydantic model class. + """ + type_hints = get_type_hints(func) + signature = inspect.signature(func) + parameters = signature.parameters + + docstring = func.__doc__ + descriptions = parse_docstring(docstring) + + field_defs = {} + for name, param in parameters.items(): + type_hint = type_hints.get(name, Any) + default_value = param.default if param.default is not param.empty else ... + description = descriptions.get(name, None) + if not description: + field_defs[name] = type_hint, default_value + continue + field_defs[name] = type_hint, Field(default_value, description=description) + + return create_model(func.__name__, **field_defs) + + +def get_callable_attributes(tool: object) -> list[Callable]: + return [ + getattr(tool, func) + for func in dir(tool) + if callable(getattr(tool, func)) and not func.startswith("__") - and not inspect.isclass(getattr(tools, func)) + and not inspect.isclass(getattr(tool, func)) ] - specs = [] - for function_item in function_list: - function_name = function_item["name"] - function = function_item["function"] - function_doc = doc_to_dict(function.__doc__ or function_name) - specs.append( - { - "name": function_name, - # TODO: multi-line desc? - "description": function_doc.get("description", function_name), - "parameters": { - "type": "object", - "properties": { - param_name: { - "type": param_annotation.__name__.lower(), - **( - { - "enum": ( - str(param_annotation.__args__) - if hasattr(param_annotation, "__args__") - else None - ) - } - if hasattr(param_annotation, "__args__") - else {} - ), - "description": function_doc.get("params", {}).get( - param_name, param_name - ), - } - for param_name, param_annotation in get_type_hints( - function - ).items() - if param_name != "return" - and not ( - param_name.startswith("__") and param_name.endswith("__") - ) - }, - "required": [ - name - for name, param in inspect.signature( - function - ).parameters.items() - if param.default is param.empty - and not (name.startswith("__") and name.endswith("__")) - ], - }, - } - ) - - return specs +def get_tools_specs(tool_class: object) -> list[dict]: + function_list = get_callable_attributes(tool_class) + models = map(function_to_pydantic_model, function_list) + return [convert_to_openai_function(tool) for tool in models] From c03bfd141e3a34e636f834f8ef2f8aa6c2f5f591 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 21 Nov 2024 17:41:35 +0000 Subject: [PATCH 2/4] fix optional args not present --- backend/open_webui/main.py | 1 - backend/open_webui/utils/tools.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 6a7cbb7eb..6c6b82289 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1313,7 +1313,6 @@ async def generate_chat_completions( @app.post("/api/chat/completed") async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - model_list = await get_all_models() models = {model["id"]: model for model in model_list} diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index bbce8341f..2c6d53e3b 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -16,6 +16,8 @@ log = logging.getLogger(__name__) def apply_extra_params_to_tool_function( function: Callable, extra_params: dict ) -> Callable[..., Awaitable]: + sig = inspect.signature(function) + extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters} partial_func = partial(function, **extra_params) if inspect.iscoroutinefunction(function): update_wrapper(partial_func, function) From e1a85c99ab04191954537bd4149765c4bb632389 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 21 Nov 2024 17:52:19 +0000 Subject: [PATCH 3/4] format --- backend/open_webui/apps/retrieval/main.py | 4 +++- backend/open_webui/apps/retrieval/web/mojeek.py | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index a2c1250fd..10d6ff9a7 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -598,7 +598,9 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ app.state.config.BRAVE_SEARCH_API_KEY = ( form_data.web.search.brave_search_api_key ) - app.state.config.MOJEEK_SEARCH_API_KEY = form_data.web.search.mojeek_search_api_key + app.state.config.MOJEEK_SEARCH_API_KEY = ( + form_data.web.search.mojeek_search_api_key + ) app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key diff --git a/backend/open_webui/apps/retrieval/web/mojeek.py b/backend/open_webui/apps/retrieval/web/mojeek.py index 46a6526c7..f257c92aa 100644 --- a/backend/open_webui/apps/retrieval/web/mojeek.py +++ b/backend/open_webui/apps/retrieval/web/mojeek.py @@ -22,7 +22,7 @@ def search_mojeek( headers = { "Accept": "application/json", } - params = {"q": query, "api_key": api_key, 'fmt': 'json', 't': count} + params = {"q": query, "api_key": api_key, "fmt": "json", "t": count} response = requests.get(url, headers=headers, params=params) response.raise_for_status() @@ -32,10 +32,9 @@ def search_mojeek( if filter_list: results = get_filtered_results(results, filter_list) - return [ SearchResult( link=result["url"], title=result.get("title"), snippet=result.get("desc") ) for result in results - ] \ No newline at end of file + ] From 8abf5d57c1b34040762770fefa509760c8449c19 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 21 Nov 2024 23:49:58 +0000 Subject: [PATCH 4/4] remove unused schemas file --- backend/open_webui/utils/schemas.py | 115 ---------------------------- 1 file changed, 115 deletions(-) delete mode 100644 backend/open_webui/utils/schemas.py diff --git a/backend/open_webui/utils/schemas.py b/backend/open_webui/utils/schemas.py deleted file mode 100644 index 21c9824f0..000000000 --- a/backend/open_webui/utils/schemas.py +++ /dev/null @@ -1,115 +0,0 @@ -from ast import literal_eval -from typing import Any, Literal, Optional, Type - -from pydantic import BaseModel, Field, create_model - - -def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]: - """ - Converts a JSON schema to a Pydantic BaseModel class. - - Args: - json_schema: The JSON schema to convert. - - Returns: - A Pydantic BaseModel class. - """ - - # Extract the model name from the schema title. - model_name = tool_dict["name"] - schema = tool_dict["parameters"] - - # Extract the field definitions from the schema properties. - field_definitions = { - name: json_schema_to_pydantic_field(name, prop, schema.get("required", [])) - for name, prop in schema.get("properties", {}).items() - } - - # Create the BaseModel class using create_model(). - return create_model(model_name, **field_definitions) - - -def json_schema_to_pydantic_field( - name: str, json_schema: dict[str, Any], required: list[str] -) -> Any: - """ - Converts a JSON schema property to a Pydantic field definition. - - Args: - name: The field name. - json_schema: The JSON schema property. - - Returns: - A Pydantic field definition. - """ - - # Get the field type. - type_ = json_schema_to_pydantic_type(json_schema) - - # Get the field description. - description = json_schema.get("description") - - # Get the field examples. - examples = json_schema.get("examples") - - # Create a Field object with the type, description, and examples. - # The 'required' flag will be set later when creating the model. - return ( - type_, - Field( - description=description, - examples=examples, - default=... if name in required else None, - ), - ) - - -def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any: - """ - Converts a JSON schema type to a Pydantic type. - - Args: - json_schema: The JSON schema to convert. - - Returns: - A Pydantic type. - """ - - type_ = json_schema.get("type") - - if type_ == "string" or type_ == "str": - return str - elif type_ == "integer" or type_ == "int": - return int - elif type_ == "number" or type_ == "float": - return float - elif type_ == "boolean" or type_ == "bool": - return bool - elif type_ == "array" or type_ == "list": - items_schema = json_schema.get("items") - if items_schema: - item_type = json_schema_to_pydantic_type(items_schema) - return list[item_type] - else: - return list - elif type_ == "object": - # Handle nested models. - properties = json_schema.get("properties") - if properties: - nested_model = json_schema_to_model(json_schema) - return nested_model - else: - return dict - elif type_ == "null": - return Optional[Any] # Use Optional[Any] for nullable fields - elif type_ == "literal": - enum = json_schema.get("enum") - if enum is None: - raise ValueError("Enum values must be provided for 'literal' type.") - return Literal[literal_eval(enum)] - elif type_ == "optional": - inner_schema = json_schema.get("items", {"type": "string"}) - inner_type = json_schema_to_pydantic_type(inner_schema) - return Optional[inner_type] - else: - raise ValueError(f"Unsupported JSON schema type: {type_}")