move tools utils to utils.tools

This commit is contained in:
Michael Poluektov 2024-08-19 10:53:12 +01:00
parent fd422d2e3c
commit a4a7d678f9
2 changed files with 86 additions and 79 deletions

View File

@ -51,15 +51,13 @@ from apps.webui.internal.db import Session
from pydantic import BaseModel
from typing import Optional, Callable, Awaitable
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users, UserModel
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from apps.webui.utils import load_function_module_by_id
from utils.utils import (
get_admin_user,
@ -75,6 +73,8 @@ from utils.task import (
tools_function_calling_generation_template,
moa_response_generation_template,
)
from utils.tools import get_tools
from utils.misc import (
get_last_user_message,
add_or_update_system_message,
@ -353,80 +353,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content):
}
def apply_extra_params_to_tool_function(
function: Callable, extra_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(function)
extra_params = {
key: value for key, value in extra_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(function)
async def new_function(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await function(**extra_kwargs)
return function(**extra_kwargs)
return new_function
# Mutation on extra_params
def get_tools(
tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
function_name = spec["name"]
# convert to function that takes only model params and inserts custom params
callable = apply_extra_params_to_tool_function(
getattr(module, function_name), extra_params
)
# TODO: This needs to be a pydantic model
tool_dict = {
"toolkit_id": tool_id,
"callable": callable,
"spec": spec,
"file_handler": hasattr(module, "file_handler") and module.file_handler,
"citation": hasattr(module, "citation") and module.citation,
}
# TODO: if collision, prepend toolkit name
if function_name in tools:
log.warning(f"Tool {function_name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{function_name}")
else:
tools[function_name] = tool_dict
return tools
async def get_content_from_response(response) -> Optional[str]:
content = None
if hasattr(response, "body_iterator"):
@ -467,7 +393,7 @@ async def chat_completion_tools_handler(
"__messages__": body["messages"],
"__files__": body.get("files", []),
}
tools = get_tools(tool_ids, user, custom_params)
tools = get_tools(webui_app, tool_ids, user, custom_params)
log.info(f"{tools=}")
specs = [tool["spec"] for tool in tools.values()]

View File

@ -1,5 +1,86 @@
import inspect
from typing import get_type_hints
import logging
from typing import Awaitable, Callable, get_type_hints
from apps.webui.models.tools import Tools
from apps.webui.models.users import UserModel
from apps.webui.utils import load_toolkit_module_by_id
log = logging.getLogger(__name__)
def apply_extra_params_to_tool_function(
function: Callable, extra_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(function)
extra_params = {
key: value for key, value in extra_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(function)
async def new_function(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await function(**extra_kwargs)
return function(**extra_kwargs)
return new_function
# Mutation on extra_params
def get_tools(
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
function_name = spec["name"]
# convert to function that takes only model params and inserts custom params
callable = apply_extra_params_to_tool_function(
getattr(module, function_name), extra_params
)
# TODO: This needs to be a pydantic model
tool_dict = {
"toolkit_id": tool_id,
"callable": callable,
"spec": spec,
"file_handler": hasattr(module, "file_handler") and module.file_handler,
"citation": hasattr(module, "citation") and module.citation,
}
# TODO: if collision, prepend toolkit name
if function_name in tools:
log.warning(f"Tool {function_name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{function_name}")
else:
tools[function_name] = tool_dict
return tools
def doc_to_dict(docstring):