Merge pull request #4724 from michaelpoluektov/tools-refac-2.1

feat: Add `__tools__` optional param for function pipes
This commit is contained in:
Timothy Jaeryang Baek 2024-08-20 17:01:13 +02:00 committed by GitHub
commit ee526b4b07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 251 additions and 135 deletions

View File

@ -731,11 +731,8 @@ async def generate_chat_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") payload = {**form_data.model_dump(exclude_none=True)}
log.debug(f"{payload = }")
payload = {
**form_data.model_dump(exclude_none=True, exclude=["metadata"]),
}
if "metadata" in payload: if "metadata" in payload:
del payload["metadata"] del payload["metadata"]

View File

@ -26,6 +26,7 @@ from utils.misc import (
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
from utils.tools import get_tools
from config import ( from config import (
SHOW_ADMIN_DETAILS, SHOW_ADMIN_DETAILS,
@ -275,7 +276,9 @@ def get_function_params(function_module, form_data, user, extra_params={}):
async def generate_function_chat_completion(form_data, user): async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None) metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
__event_emitter__ = None __event_emitter__ = None
__event_call__ = None __event_call__ = None
@ -287,6 +290,20 @@ async def generate_function_chat_completion(form_data, user):
__event_call__ = get_event_call(metadata) __event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None) __task__ = metadata.get("task", None)
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
}
tools_params = {
**extra_params,
"__model__": app.state.MODELS[form_data["model"]],
"__messages__": form_data["messages"],
"__files__": files,
}
configured_tools = get_tools(app, tool_ids, user, tools_params)
extra_params["__tools__"] = configured_tools
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
form_data["model"] = model_info.base_model_id form_data["model"] = model_info.base_model_id
@ -299,16 +316,7 @@ async def generate_function_chat_completion(form_data, user):
function_module = get_function_module(pipe_id) function_module = get_function_module(pipe_id)
pipe = function_module.pipe pipe = function_module.pipe
params = get_function_params( params = get_function_params(function_module, form_data, user, extra_params)
function_module,
form_data,
user,
{
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
},
)
if form_data["stream"]: if form_data["stream"]:

View File

@ -176,7 +176,6 @@ for version in soup.find_all("h2"):
CHANGELOG = changelog_json CHANGELOG = changelog_json
#################################### ####################################
# SAFE_MODE # SAFE_MODE
#################################### ####################################

View File

@ -14,6 +14,7 @@ import requests
import mimetypes import mimetypes
import shutil import shutil
import inspect import inspect
from typing import Optional
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
@ -51,15 +52,13 @@ from apps.webui.internal.db import Session
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, Callable, Awaitable
from apps.webui.models.auths import Auths from apps.webui.models.auths import Auths
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.models.users import Users, UserModel from apps.webui.models.users import Users, UserModel
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id from apps.webui.utils import load_function_module_by_id
from utils.utils import ( from utils.utils import (
get_admin_user, get_admin_user,
@ -76,6 +75,8 @@ from utils.task import (
tools_function_calling_generation_template, tools_function_calling_generation_template,
moa_response_generation_template, moa_response_generation_template,
) )
from utils.tools import get_tools
from utils.misc import ( from utils.misc import (
get_last_user_message, get_last_user_message,
add_or_update_system_message, add_or_update_system_message,
@ -325,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
print(f"Error: {e}") print(f"Error: {e}")
raise e raise e
if skip_files and "files" in body: if skip_files and "files" in body.get("metadata", {}):
del body["files"] del body["metadata"]["files"]
return body, {} return body, {}
@ -351,80 +352,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content):
} }
def apply_extra_params_to_tool_function(
function: Callable, extra_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(function)
extra_params = {
key: value for key, value in extra_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(function)
async def new_function(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await function(**extra_kwargs)
return function(**extra_kwargs)
return new_function
# Mutation on extra_params
def get_tools(
tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
function_name = spec["name"]
# convert to function that takes only model params and inserts custom params
callable = apply_extra_params_to_tool_function(
getattr(module, function_name), extra_params
)
# TODO: This needs to be a pydantic model
tool_dict = {
"toolkit_id": tool_id,
"callable": callable,
"spec": spec,
"file_handler": hasattr(module, "file_handler") and module.file_handler,
"citation": hasattr(module, "citation") and module.citation,
}
# TODO: if collision, prepend toolkit name
if function_name in tools:
log.warning(f"Tool {function_name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{function_name}")
else:
tools[function_name] = tool_dict
return tools
async def get_content_from_response(response) -> Optional[str]: async def get_content_from_response(response) -> Optional[str]:
content = None content = None
if hasattr(response, "body_iterator"): if hasattr(response, "body_iterator"):
@ -443,15 +370,17 @@ async def get_content_from_response(response) -> Optional[str]:
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions
metadata = body.get("metadata", {})
tool_ids = metadata.get("tool_ids", None)
if not tool_ids:
return body, {}
skip_files = False skip_files = False
contexts = [] contexts = []
citations = [] citations = []
task_model_id = get_task_model_id(body["model"]) task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions
tool_ids = body.pop("tool_ids", None)
if not tool_ids:
return body, {}
log.debug(f"{tool_ids=}") log.debug(f"{tool_ids=}")
@ -459,9 +388,9 @@ async def chat_completion_tools_handler(
**extra_params, **extra_params,
"__model__": app.state.MODELS[task_model_id], "__model__": app.state.MODELS[task_model_id],
"__messages__": body["messages"], "__messages__": body["messages"],
"__files__": body.get("files", []), "__files__": metadata.get("files", []),
} }
tools = get_tools(tool_ids, user, custom_params) tools = get_tools(webui_app, tool_ids, user, custom_params)
log.info(f"{tools=}") log.info(f"{tools=}")
specs = [tool["spec"] for tool in tools.values()] specs = [tool["spec"] for tool in tools.values()]
@ -486,7 +415,7 @@ async def chat_completion_tools_handler(
content = await get_content_from_response(response) content = await get_content_from_response(response)
log.debug(f"{content=}") log.debug(f"{content=}")
if content is None: if not content:
return body, {} return body, {}
result = json.loads(content) result = json.loads(content)
@ -521,13 +450,13 @@ async def chat_completion_tools_handler(
contexts.append(tool_output) contexts.append(tool_output)
except Exception as e: except Exception as e:
print(f"Error: {e}") log.exception(f"Error: {e}")
content = None content = None
log.debug(f"tool_contexts: {contexts}") log.debug(f"tool_contexts: {contexts}")
if skip_files and "files" in body: if skip_files and "files" in body.get("metadata", {}):
del body["files"] del body["metadata"]["files"]
return body, {"contexts": contexts, "citations": citations} return body, {"contexts": contexts, "citations": citations}
@ -536,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts = [] contexts = []
citations = [] citations = []
if files := body.pop("files", None): if files := body.get("metadata", {}).get("files", None):
contexts, citations = get_rag_context( contexts, citations = get_rag_context(
files=files, files=files,
messages=body["messages"], messages=body["messages"],
@ -597,6 +526,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
"message_id": body.pop("id", None), "message_id": body.pop("id", None),
"session_id": body.pop("session_id", None), "session_id": body.pop("session_id", None),
"valves": body.pop("valves", None), "valves": body.pop("valves", None),
"tool_ids": body.pop("tool_ids", None),
"files": body.pop("files", None),
} }
__user__ = { __user__ = {
@ -680,37 +611,30 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
] ]
response = await call_next(request) response = await call_next(request)
if isinstance(response, StreamingResponse): if not isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers["Content-Type"]
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, data_items),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, data_items),
)
return response return response
content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
if not is_openai and not is_ollama:
return response
def wrap_item(item):
return f"data: {item}\n\n" if is_openai else f"{item}\n"
async def stream_wrapper(original_generator, data_items):
for item in data_items:
yield wrap_item(json.dumps(item))
async for data in original_generator:
yield data
return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"data: {json.dumps(item)}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"{json.dumps(item)}\n"
async for data in original_generator:
yield data
app.add_middleware(ChatCompletionMiddleware) app.add_middleware(ChatCompletionMiddleware)
@ -1065,7 +989,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
if model.get("pipe"): if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":

104
backend/utils/schemas.py Normal file
View File

@ -0,0 +1,104 @@
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional, Type
def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
"""
Converts a JSON schema to a Pydantic BaseModel class.
Args:
json_schema: The JSON schema to convert.
Returns:
A Pydantic BaseModel class.
"""
# Extract the model name from the schema title.
model_name = tool_dict["name"]
schema = tool_dict["parameters"]
# Extract the field definitions from the schema properties.
field_definitions = {
name: json_schema_to_pydantic_field(name, prop, schema.get("required", []))
for name, prop in schema.get("properties", {}).items()
}
# Create the BaseModel class using create_model().
return create_model(model_name, **field_definitions)
def json_schema_to_pydantic_field(
name: str, json_schema: dict[str, Any], required: list[str]
) -> Any:
"""
Converts a JSON schema property to a Pydantic field definition.
Args:
name: The field name.
json_schema: The JSON schema property.
Returns:
A Pydantic field definition.
"""
# Get the field type.
type_ = json_schema_to_pydantic_type(json_schema)
# Get the field description.
description = json_schema.get("description")
# Get the field examples.
examples = json_schema.get("examples")
# Create a Field object with the type, description, and examples.
# The 'required' flag will be set later when creating the model.
return (
type_,
Field(
description=description,
examples=examples,
default=... if name in required else None,
),
)
def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
"""
Converts a JSON schema type to a Pydantic type.
Args:
json_schema: The JSON schema to convert.
Returns:
A Pydantic type.
"""
type_ = json_schema.get("type")
if type_ == "string" or type_ == "str":
return str
elif type_ == "integer" or type_ == "int":
return int
elif type_ == "number" or type_ == "float":
return float
elif type_ == "boolean" or type_ == "bool":
return bool
elif type_ == "array":
items_schema = json_schema.get("items")
if items_schema:
item_type = json_schema_to_pydantic_type(items_schema)
return list[item_type]
else:
return list
elif type_ == "object":
# Handle nested models.
properties = json_schema.get("properties")
if properties:
nested_model = json_schema_to_model(json_schema)
return nested_model
else:
return dict
elif type_ == "null":
return Optional[Any] # Use Optional[Any] for nullable fields
else:
raise ValueError(f"Unsupported JSON schema type: {type_}")

View File

@ -1,5 +1,90 @@
import inspect import inspect
from typing import get_type_hints import logging
from typing import Awaitable, Callable, get_type_hints
from apps.webui.models.tools import Tools
from apps.webui.models.users import UserModel
from apps.webui.utils import load_toolkit_module_by_id
from utils.schemas import json_schema_to_model
log = logging.getLogger(__name__)
def apply_extra_params_to_tool_function(
function: Callable, extra_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(function)
extra_params = {
key: value for key, value in extra_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(function)
async def new_function(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await function(**extra_kwargs)
return function(**extra_kwargs)
return new_function
# Mutation on extra_params
def get_tools(
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
function_name = spec["name"]
# convert to function that takes only model params and inserts custom params
original_func = getattr(module, function_name)
callable = apply_extra_params_to_tool_function(original_func, extra_params)
if hasattr(original_func, "__doc__"):
callable.__doc__ = original_func.__doc__
# TODO: This needs to be a pydantic model
tool_dict = {
"toolkit_id": tool_id,
"callable": callable,
"spec": spec,
"pydantic_model": json_schema_to_model(spec),
"file_handler": hasattr(module, "file_handler") and module.file_handler,
"citation": hasattr(module, "citation") and module.citation,
}
# TODO: if collision, prepend toolkit name
if function_name in tools:
log.warning(f"Tool {function_name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{function_name}")
else:
tools[function_name] = tool_dict
return tools
def doc_to_dict(docstring): def doc_to_dict(docstring):