Merge pull request #4724 from michaelpoluektov/tools-refac-2.1

feat: Add `__tools__` optional param for function pipes
This commit is contained in:
Timothy Jaeryang Baek 2024-08-20 17:01:13 +02:00 committed by GitHub
commit ee526b4b07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 251 additions and 135 deletions

View File

@ -731,11 +731,8 @@ async def generate_chat_completion(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
payload = {
**form_data.model_dump(exclude_none=True, exclude=["metadata"]),
}
payload = {**form_data.model_dump(exclude_none=True)}
log.debug(f"{payload = }")
if "metadata" in payload:
del payload["metadata"]

View File

@ -26,6 +26,7 @@ from utils.misc import (
apply_model_system_prompt_to_body,
)
from utils.tools import get_tools
from config import (
SHOW_ADMIN_DETAILS,
@ -275,7 +276,9 @@ def get_function_params(function_module, form_data, user, extra_params={}):
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None)
metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
__event_emitter__ = None
__event_call__ = None
@ -287,6 +290,20 @@ async def generate_function_chat_completion(form_data, user):
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
}
tools_params = {
**extra_params,
"__model__": app.state.MODELS[form_data["model"]],
"__messages__": form_data["messages"],
"__files__": files,
}
configured_tools = get_tools(app, tool_ids, user, tools_params)
extra_params["__tools__"] = configured_tools
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
@ -299,16 +316,7 @@ async def generate_function_chat_completion(form_data, user):
function_module = get_function_module(pipe_id)
pipe = function_module.pipe
params = get_function_params(
function_module,
form_data,
user,
{
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
},
)
params = get_function_params(function_module, form_data, user, extra_params)
if form_data["stream"]:

View File

@ -176,7 +176,6 @@ for version in soup.find_all("h2"):
CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################

View File

@ -14,6 +14,7 @@ import requests
import mimetypes
import shutil
import inspect
from typing import Optional
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
@ -51,15 +52,13 @@ from apps.webui.internal.db import Session
from pydantic import BaseModel
from typing import Optional, Callable, Awaitable
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users, UserModel
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from apps.webui.utils import load_function_module_by_id
from utils.utils import (
get_admin_user,
@ -76,6 +75,8 @@ from utils.task import (
tools_function_calling_generation_template,
moa_response_generation_template,
)
from utils.tools import get_tools
from utils.misc import (
get_last_user_message,
add_or_update_system_message,
@ -325,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
print(f"Error: {e}")
raise e
if skip_files and "files" in body:
del body["files"]
if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
return body, {}
@ -351,80 +352,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content):
}
def apply_extra_params_to_tool_function(
function: Callable, extra_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(function)
extra_params = {
key: value for key, value in extra_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(function)
async def new_function(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await function(**extra_kwargs)
return function(**extra_kwargs)
return new_function
# Mutation on extra_params
def get_tools(
tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
function_name = spec["name"]
# convert to function that takes only model params and inserts custom params
callable = apply_extra_params_to_tool_function(
getattr(module, function_name), extra_params
)
# TODO: This needs to be a pydantic model
tool_dict = {
"toolkit_id": tool_id,
"callable": callable,
"spec": spec,
"file_handler": hasattr(module, "file_handler") and module.file_handler,
"citation": hasattr(module, "citation") and module.citation,
}
# TODO: if collision, prepend toolkit name
if function_name in tools:
log.warning(f"Tool {function_name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{function_name}")
else:
tools[function_name] = tool_dict
return tools
async def get_content_from_response(response) -> Optional[str]:
content = None
if hasattr(response, "body_iterator"):
@ -443,15 +370,17 @@ async def get_content_from_response(response) -> Optional[str]:
async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions
metadata = body.get("metadata", {})
tool_ids = metadata.get("tool_ids", None)
if not tool_ids:
return body, {}
skip_files = False
contexts = []
citations = []
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions
tool_ids = body.pop("tool_ids", None)
if not tool_ids:
return body, {}
log.debug(f"{tool_ids=}")
@ -459,9 +388,9 @@ async def chat_completion_tools_handler(
**extra_params,
"__model__": app.state.MODELS[task_model_id],
"__messages__": body["messages"],
"__files__": body.get("files", []),
"__files__": metadata.get("files", []),
}
tools = get_tools(tool_ids, user, custom_params)
tools = get_tools(webui_app, tool_ids, user, custom_params)
log.info(f"{tools=}")
specs = [tool["spec"] for tool in tools.values()]
@ -486,7 +415,7 @@ async def chat_completion_tools_handler(
content = await get_content_from_response(response)
log.debug(f"{content=}")
if content is None:
if not content:
return body, {}
result = json.loads(content)
@ -521,13 +450,13 @@ async def chat_completion_tools_handler(
contexts.append(tool_output)
except Exception as e:
print(f"Error: {e}")
log.exception(f"Error: {e}")
content = None
log.debug(f"tool_contexts: {contexts}")
if skip_files and "files" in body:
del body["files"]
if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
return body, {"contexts": contexts, "citations": citations}
@ -536,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts = []
citations = []
if files := body.pop("files", None):
if files := body.get("metadata", {}).get("files", None):
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
@ -597,6 +526,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
"message_id": body.pop("id", None),
"session_id": body.pop("session_id", None),
"valves": body.pop("valves", None),
"tool_ids": body.pop("tool_ids", None),
"files": body.pop("files", None),
}
__user__ = {
@ -680,37 +611,30 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
]
response = await call_next(request)
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers["Content-Type"]
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, data_items),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, data_items),
)
if not isinstance(response, StreamingResponse):
return response
return response
content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
if not is_openai and not is_ollama:
return response
def wrap_item(item):
return f"data: {item}\n\n" if is_openai else f"{item}\n"
async def stream_wrapper(original_generator, data_items):
for item in data_items:
yield wrap_item(json.dumps(item))
async for data in original_generator:
yield data
return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"data: {json.dumps(item)}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"{json.dumps(item)}\n"
async for data in original_generator:
yield data
app.add_middleware(ChatCompletionMiddleware)
@ -1065,7 +989,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
detail="Model not found",
)
model = app.state.MODELS[model_id]
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":

104
backend/utils/schemas.py Normal file
View File

@ -0,0 +1,104 @@
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional, Type
def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
"""
Converts a JSON schema to a Pydantic BaseModel class.
Args:
json_schema: The JSON schema to convert.
Returns:
A Pydantic BaseModel class.
"""
# Extract the model name from the schema title.
model_name = tool_dict["name"]
schema = tool_dict["parameters"]
# Extract the field definitions from the schema properties.
field_definitions = {
name: json_schema_to_pydantic_field(name, prop, schema.get("required", []))
for name, prop in schema.get("properties", {}).items()
}
# Create the BaseModel class using create_model().
return create_model(model_name, **field_definitions)
def json_schema_to_pydantic_field(
name: str, json_schema: dict[str, Any], required: list[str]
) -> Any:
"""
Converts a JSON schema property to a Pydantic field definition.
Args:
name: The field name.
json_schema: The JSON schema property.
Returns:
A Pydantic field definition.
"""
# Get the field type.
type_ = json_schema_to_pydantic_type(json_schema)
# Get the field description.
description = json_schema.get("description")
# Get the field examples.
examples = json_schema.get("examples")
# Create a Field object with the type, description, and examples.
# The 'required' flag will be set later when creating the model.
return (
type_,
Field(
description=description,
examples=examples,
default=... if name in required else None,
),
)
def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
"""
Converts a JSON schema type to a Pydantic type.
Args:
json_schema: The JSON schema to convert.
Returns:
A Pydantic type.
"""
type_ = json_schema.get("type")
if type_ == "string" or type_ == "str":
return str
elif type_ == "integer" or type_ == "int":
return int
elif type_ == "number" or type_ == "float":
return float
elif type_ == "boolean" or type_ == "bool":
return bool
elif type_ == "array":
items_schema = json_schema.get("items")
if items_schema:
item_type = json_schema_to_pydantic_type(items_schema)
return list[item_type]
else:
return list
elif type_ == "object":
# Handle nested models.
properties = json_schema.get("properties")
if properties:
nested_model = json_schema_to_model(json_schema)
return nested_model
else:
return dict
elif type_ == "null":
return Optional[Any] # Use Optional[Any] for nullable fields
else:
raise ValueError(f"Unsupported JSON schema type: {type_}")

View File

@ -1,5 +1,90 @@
import inspect
from typing import get_type_hints
import logging
from typing import Awaitable, Callable, get_type_hints
from apps.webui.models.tools import Tools
from apps.webui.models.users import UserModel
from apps.webui.utils import load_toolkit_module_by_id
from utils.schemas import json_schema_to_model
log = logging.getLogger(__name__)
def apply_extra_params_to_tool_function(
function: Callable, extra_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(function)
extra_params = {
key: value for key, value in extra_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(function)
async def new_function(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await function(**extra_kwargs)
return function(**extra_kwargs)
return new_function
# Mutation on extra_params
def get_tools(
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
function_name = spec["name"]
# convert to function that takes only model params and inserts custom params
original_func = getattr(module, function_name)
callable = apply_extra_params_to_tool_function(original_func, extra_params)
if hasattr(original_func, "__doc__"):
callable.__doc__ = original_func.__doc__
# TODO: This needs to be a pydantic model
tool_dict = {
"toolkit_id": tool_id,
"callable": callable,
"spec": spec,
"pydantic_model": json_schema_to_model(spec),
"file_handler": hasattr(module, "file_handler") and module.file_handler,
"citation": hasattr(module, "citation") and module.citation,
}
# TODO: if collision, prepend toolkit name
if function_name in tools:
log.warning(f"Tool {function_name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{function_name}")
else:
tools[function_name] = tool_dict
return tools
def doc_to_dict(docstring):