Merge pull request #7182 from michaelpoluektov/fix/tools-metadata

fix: Fix tools metadata
This commit is contained in:
Timothy Jaeryang Baek 2024-11-21 16:32:08 -08:00 committed by GitHub
commit 5be7cbfdf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 110 additions and 235 deletions

View File

@ -598,7 +598,9 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
app.state.config.BRAVE_SEARCH_API_KEY = ( app.state.config.BRAVE_SEARCH_API_KEY = (
form_data.web.search.brave_search_api_key form_data.web.search.brave_search_api_key
) )
app.state.config.MOJEEK_SEARCH_API_KEY = form_data.web.search.mojeek_search_api_key app.state.config.MOJEEK_SEARCH_API_KEY = (
form_data.web.search.mojeek_search_api_key
)
app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key

View File

@ -22,7 +22,7 @@ def search_mojeek(
headers = { headers = {
"Accept": "application/json", "Accept": "application/json",
} }
params = {"q": query, "api_key": api_key, 'fmt': 'json', 't': count} params = {"q": query, "api_key": api_key, "fmt": "json", "t": count}
response = requests.get(url, headers=headers, params=params) response = requests.get(url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
@ -32,10 +32,9 @@ def search_mojeek(
if filter_list: if filter_list:
results = get_filtered_results(results, filter_list) results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("desc") link=result["url"], title=result.get("title"), snippet=result.get("desc")
) )
for result in results for result in results
] ]

View File

@ -1,4 +1,3 @@
import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -10,7 +9,7 @@ from open_webui.apps.webui.models.tools import (
Tools, Tools,
) )
from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports
from open_webui.config import CACHE_DIR, DATA_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tools_specs from open_webui.utils.tools import get_tools_specs
@ -300,38 +299,35 @@ async def update_tools_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
): ):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id)
if tools: if not tools:
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "Valves"):
Valves = tools_module.Valves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if not hasattr(tools_module, "Valves"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
Valves = tools_module.Valves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
############################ ############################

View File

@ -1313,7 +1313,6 @@ async def generate_chat_completions(
@app.post("/api/chat/completed") @app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)): async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
model_list = await get_all_models() model_list = await get_all_models()
models = {model["id"]: model for model in model_list} models = {model["id"]: model for model in model_list}

View File

@ -1,112 +0,0 @@
from ast import literal_eval
from typing import Any, Literal, Optional, Type
from pydantic import BaseModel, Field, create_model
def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
"""
Converts a JSON schema to a Pydantic BaseModel class.
Args:
json_schema: The JSON schema to convert.
Returns:
A Pydantic BaseModel class.
"""
# Extract the model name from the schema title.
model_name = tool_dict["name"]
schema = tool_dict["parameters"]
# Extract the field definitions from the schema properties.
field_definitions = {
name: json_schema_to_pydantic_field(name, prop, schema.get("required", []))
for name, prop in schema.get("properties", {}).items()
}
# Create the BaseModel class using create_model().
return create_model(model_name, **field_definitions)
def json_schema_to_pydantic_field(
name: str, json_schema: dict[str, Any], required: list[str]
) -> Any:
"""
Converts a JSON schema property to a Pydantic field definition.
Args:
name: The field name.
json_schema: The JSON schema property.
Returns:
A Pydantic field definition.
"""
# Get the field type.
type_ = json_schema_to_pydantic_type(json_schema)
# Get the field description.
description = json_schema.get("description")
# Get the field examples.
examples = json_schema.get("examples")
# Create a Field object with the type, description, and examples.
# The 'required' flag will be set later when creating the model.
return (
type_,
Field(
description=description,
examples=examples,
default=... if name in required else None,
),
)
def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
"""
Converts a JSON schema type to a Pydantic type.
Args:
json_schema: The JSON schema to convert.
Returns:
A Pydantic type.
"""
type_ = json_schema.get("type")
if type_ == "string" or type_ == "str":
return str
elif type_ == "integer" or type_ == "int":
return int
elif type_ == "number" or type_ == "float":
return float
elif type_ == "boolean" or type_ == "bool":
return bool
elif type_ == "array" or type_ == "list":
items_schema = json_schema.get("items")
if items_schema:
item_type = json_schema_to_pydantic_type(items_schema)
return list[item_type]
else:
return list
elif type_ == "object":
# Handle nested models.
properties = json_schema.get("properties")
if properties:
nested_model = json_schema_to_model(json_schema)
return nested_model
else:
return dict
elif type_ == "null":
return Optional[Any] # Use Optional[Any] for nullable fields
elif type_ == "literal":
return Literal[literal_eval(json_schema.get("enum"))]
elif type_ == "optional":
inner_schema = json_schema.get("items", {"type": "string"})
inner_type = json_schema_to_pydantic_type(inner_schema)
return Optional[inner_type]
else:
raise ValueError(f"Unsupported JSON schema type: {type_}")

View File

@ -1,11 +1,14 @@
import inspect import inspect
import logging import logging
from typing import Awaitable, Callable, get_type_hints import re
from typing import Any, Awaitable, Callable, get_type_hints
from functools import update_wrapper, partial
from langchain_core.utils.function_calling import convert_to_openai_function
from open_webui.apps.webui.models.tools import Tools from open_webui.apps.webui.models.tools import Tools
from open_webui.apps.webui.models.users import UserModel from open_webui.apps.webui.models.users import UserModel
from open_webui.apps.webui.utils import load_tools_module_by_id from open_webui.apps.webui.utils import load_tools_module_by_id
from open_webui.utils.schemas import json_schema_to_model from pydantic import BaseModel, Field, create_model
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -14,17 +17,16 @@ def apply_extra_params_to_tool_function(
function: Callable, extra_params: dict function: Callable, extra_params: dict
) -> Callable[..., Awaitable]: ) -> Callable[..., Awaitable]:
sig = inspect.signature(function) sig = inspect.signature(function)
extra_params = { extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
key: value for key, value in extra_params.items() if key in sig.parameters partial_func = partial(function, **extra_params)
} if inspect.iscoroutinefunction(function):
is_coroutine = inspect.iscoroutinefunction(function) update_wrapper(partial_func, function)
return partial_func
async def new_function(**kwargs): async def new_function(*args, **kwargs):
extra_kwargs = kwargs | extra_params return partial_func(*args, **kwargs)
if is_coroutine:
return await function(**extra_kwargs)
return function(**extra_kwargs)
update_wrapper(new_function, function)
return new_function return new_function
@ -55,11 +57,6 @@ def get_tools(
) )
for spec in tools.specs: for spec in tools.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
# Remove internal parameters # Remove internal parameters
spec["parameters"]["properties"] = { spec["parameters"]["properties"] = {
key: val key: val
@ -72,15 +69,12 @@ def get_tools(
# convert to function that takes only model params and inserts custom params # convert to function that takes only model params and inserts custom params
original_func = getattr(module, function_name) original_func = getattr(module, function_name)
callable = apply_extra_params_to_tool_function(original_func, extra_params) callable = apply_extra_params_to_tool_function(original_func, extra_params)
if hasattr(original_func, "__doc__"):
callable.__doc__ = original_func.__doc__
# TODO: This needs to be a pydantic model # TODO: This needs to be a pydantic model
tool_dict = { tool_dict = {
"toolkit_id": tool_id, "toolkit_id": tool_id,
"callable": callable, "callable": callable,
"spec": spec, "spec": spec,
"pydantic_model": json_schema_to_model(spec), "pydantic_model": function_to_pydantic_model(callable),
"file_handler": hasattr(module, "file_handler") and module.file_handler, "file_handler": hasattr(module, "file_handler") and module.file_handler,
"citation": hasattr(module, "citation") and module.citation, "citation": hasattr(module, "citation") and module.citation,
} }
@ -96,78 +90,75 @@ def get_tools(
return tools_dict return tools_dict
def doc_to_dict(docstring): def parse_docstring(docstring):
lines = docstring.split("\n") """
description = lines[1].strip() Parse a function's docstring to extract parameter descriptions in reST format.
param_dict = {}
for line in lines: Args:
if ":param" in line: docstring (str): The docstring to parse.
line = line.replace(":param", "").strip()
param, desc = line.split(":", 1) Returns:
param_dict[param.strip()] = desc.strip() dict: A dictionary where keys are parameter names and values are descriptions.
ret_dict = {"description": description, "params": param_dict} """
return ret_dict if not docstring:
return {}
# Regex to match `:param name: description` format
param_pattern = re.compile(r":param (\w+):\s*(.+)")
param_descriptions = {}
for line in docstring.splitlines():
match = param_pattern.match(line.strip())
if match:
param_name, param_description = match.groups()
param_descriptions[param_name] = param_description
return param_descriptions
def get_tools_specs(tools) -> list[dict]: def function_to_pydantic_model(func: Callable) -> type[BaseModel]:
function_list = [ """
{"name": func, "function": getattr(tools, func)} Converts a Python function's type hints and docstring to a Pydantic model,
for func in dir(tools) including support for nested types, default values, and descriptions.
if callable(getattr(tools, func))
Args:
func: The function whose type hints and docstring should be converted.
model_name: The name of the generated Pydantic model.
Returns:
A Pydantic model class.
"""
type_hints = get_type_hints(func)
signature = inspect.signature(func)
parameters = signature.parameters
docstring = func.__doc__
descriptions = parse_docstring(docstring)
field_defs = {}
for name, param in parameters.items():
type_hint = type_hints.get(name, Any)
default_value = param.default if param.default is not param.empty else ...
description = descriptions.get(name, None)
if not description:
field_defs[name] = type_hint, default_value
continue
field_defs[name] = type_hint, Field(default_value, description=description)
return create_model(func.__name__, **field_defs)
def get_callable_attributes(tool: object) -> list[Callable]:
return [
getattr(tool, func)
for func in dir(tool)
if callable(getattr(tool, func))
and not func.startswith("__") and not func.startswith("__")
and not inspect.isclass(getattr(tools, func)) and not inspect.isclass(getattr(tool, func))
] ]
specs = []
for function_item in function_list:
function_name = function_item["name"]
function = function_item["function"]
function_doc = doc_to_dict(function.__doc__ or function_name) def get_tools_specs(tool_class: object) -> list[dict]:
specs.append( function_list = get_callable_attributes(tool_class)
{ models = map(function_to_pydantic_model, function_list)
"name": function_name, return [convert_to_openai_function(tool) for tool in models]
# TODO: multi-line desc?
"description": function_doc.get("description", function_name),
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_annotation.__name__.lower(),
**(
{
"enum": (
str(param_annotation.__args__)
if hasattr(param_annotation, "__args__")
else None
)
}
if hasattr(param_annotation, "__args__")
else {}
),
"description": function_doc.get("params", {}).get(
param_name, param_name
),
}
for param_name, param_annotation in get_type_hints(
function
).items()
if param_name != "return"
and not (
param_name.startswith("__") and param_name.endswith("__")
)
},
"required": [
name
for name, param in inspect.signature(
function
).parameters.items()
if param.default is param.empty
and not (name.startswith("__") and name.endswith("__"))
],
},
}
)
return specs