mirror of
https://github.com/open-webui/open-webui
synced 2025-02-21 21:01:09 +00:00
Merge pull request #4724 from michaelpoluektov/tools-refac-2.1
feat: Add `__tools__` optional param for function pipes
This commit is contained in:
commit
ee526b4b07
@ -731,11 +731,8 @@ async def generate_chat_completion(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
|
||||
|
||||
payload = {
|
||||
**form_data.model_dump(exclude_none=True, exclude=["metadata"]),
|
||||
}
|
||||
payload = {**form_data.model_dump(exclude_none=True)}
|
||||
log.debug(f"{payload = }")
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from utils.misc import (
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
|
||||
from utils.tools import get_tools
|
||||
|
||||
from config import (
|
||||
SHOW_ADMIN_DETAILS,
|
||||
@ -275,7 +276,9 @@ def get_function_params(function_module, form_data, user, extra_params={}):
|
||||
async def generate_function_chat_completion(form_data, user):
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
metadata = form_data.pop("metadata", None)
|
||||
metadata = form_data.pop("metadata", {})
|
||||
files = metadata.get("files", [])
|
||||
tool_ids = metadata.get("tool_ids", [])
|
||||
|
||||
__event_emitter__ = None
|
||||
__event_call__ = None
|
||||
@ -287,6 +290,20 @@ async def generate_function_chat_completion(form_data, user):
|
||||
__event_call__ = get_event_call(metadata)
|
||||
__task__ = metadata.get("task", None)
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__task__": __task__,
|
||||
}
|
||||
tools_params = {
|
||||
**extra_params,
|
||||
"__model__": app.state.MODELS[form_data["model"]],
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": files,
|
||||
}
|
||||
configured_tools = get_tools(app, tool_ids, user, tools_params)
|
||||
|
||||
extra_params["__tools__"] = configured_tools
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
form_data["model"] = model_info.base_model_id
|
||||
@ -299,16 +316,7 @@ async def generate_function_chat_completion(form_data, user):
|
||||
function_module = get_function_module(pipe_id)
|
||||
|
||||
pipe = function_module.pipe
|
||||
params = get_function_params(
|
||||
function_module,
|
||||
form_data,
|
||||
user,
|
||||
{
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__task__": __task__,
|
||||
},
|
||||
)
|
||||
params = get_function_params(function_module, form_data, user, extra_params)
|
||||
|
||||
if form_data["stream"]:
|
||||
|
||||
|
@ -176,7 +176,6 @@ for version in soup.find_all("h2"):
|
||||
|
||||
CHANGELOG = changelog_json
|
||||
|
||||
|
||||
####################################
|
||||
# SAFE_MODE
|
||||
####################################
|
||||
|
157
backend/main.py
157
backend/main.py
@ -14,6 +14,7 @@ import requests
|
||||
import mimetypes
|
||||
import shutil
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@ -51,15 +52,13 @@ from apps.webui.internal.db import Session
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Callable, Awaitable
|
||||
|
||||
from apps.webui.models.auths import Auths
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.tools import Tools
|
||||
from apps.webui.models.functions import Functions
|
||||
from apps.webui.models.users import Users, UserModel
|
||||
|
||||
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
||||
from apps.webui.utils import load_function_module_by_id
|
||||
|
||||
from utils.utils import (
|
||||
get_admin_user,
|
||||
@ -76,6 +75,8 @@ from utils.task import (
|
||||
tools_function_calling_generation_template,
|
||||
moa_response_generation_template,
|
||||
)
|
||||
|
||||
from utils.tools import get_tools
|
||||
from utils.misc import (
|
||||
get_last_user_message,
|
||||
add_or_update_system_message,
|
||||
@ -325,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
|
||||
print(f"Error: {e}")
|
||||
raise e
|
||||
|
||||
if skip_files and "files" in body:
|
||||
del body["files"]
|
||||
if skip_files and "files" in body.get("metadata", {}):
|
||||
del body["metadata"]["files"]
|
||||
|
||||
return body, {}
|
||||
|
||||
@ -351,80 +352,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content):
|
||||
}
|
||||
|
||||
|
||||
def apply_extra_params_to_tool_function(
|
||||
function: Callable, extra_params: dict
|
||||
) -> Callable[..., Awaitable]:
|
||||
sig = inspect.signature(function)
|
||||
extra_params = {
|
||||
key: value for key, value in extra_params.items() if key in sig.parameters
|
||||
}
|
||||
is_coroutine = inspect.iscoroutinefunction(function)
|
||||
|
||||
async def new_function(**kwargs):
|
||||
extra_kwargs = kwargs | extra_params
|
||||
if is_coroutine:
|
||||
return await function(**extra_kwargs)
|
||||
return function(**extra_kwargs)
|
||||
|
||||
return new_function
|
||||
|
||||
|
||||
# Mutation on extra_params
|
||||
def get_tools(
|
||||
tool_ids: list[str], user: UserModel, extra_params: dict
|
||||
) -> dict[str, dict]:
|
||||
tools = {}
|
||||
for tool_id in tool_ids:
|
||||
toolkit = Tools.get_tool_by_id(tool_id)
|
||||
if toolkit is None:
|
||||
continue
|
||||
|
||||
module = webui_app.state.TOOLS.get(tool_id, None)
|
||||
if module is None:
|
||||
module, _ = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = module
|
||||
|
||||
extra_params["__id__"] = tool_id
|
||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||
module.valves = module.Valves(**valves)
|
||||
|
||||
if hasattr(module, "UserValves"):
|
||||
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||
)
|
||||
|
||||
for spec in toolkit.specs:
|
||||
# TODO: Fix hack for OpenAI API
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val["type"] == "str":
|
||||
val["type"] = "string"
|
||||
function_name = spec["name"]
|
||||
|
||||
# convert to function that takes only model params and inserts custom params
|
||||
callable = apply_extra_params_to_tool_function(
|
||||
getattr(module, function_name), extra_params
|
||||
)
|
||||
|
||||
# TODO: This needs to be a pydantic model
|
||||
tool_dict = {
|
||||
"toolkit_id": tool_id,
|
||||
"callable": callable,
|
||||
"spec": spec,
|
||||
"file_handler": hasattr(module, "file_handler") and module.file_handler,
|
||||
"citation": hasattr(module, "citation") and module.citation,
|
||||
}
|
||||
|
||||
# TODO: if collision, prepend toolkit name
|
||||
if function_name in tools:
|
||||
log.warning(f"Tool {function_name} already exists in another toolkit!")
|
||||
log.warning(f"Collision between {toolkit} and {tool_id}.")
|
||||
log.warning(f"Discarding {toolkit}.{function_name}")
|
||||
else:
|
||||
tools[function_name] = tool_dict
|
||||
return tools
|
||||
|
||||
|
||||
async def get_content_from_response(response) -> Optional[str]:
|
||||
content = None
|
||||
if hasattr(response, "body_iterator"):
|
||||
@ -443,15 +370,17 @@ async def get_content_from_response(response) -> Optional[str]:
|
||||
async def chat_completion_tools_handler(
|
||||
body: dict, user: UserModel, extra_params: dict
|
||||
) -> tuple[dict, dict]:
|
||||
# If tool_ids field is present, call the functions
|
||||
metadata = body.get("metadata", {})
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
if not tool_ids:
|
||||
return body, {}
|
||||
|
||||
skip_files = False
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
task_model_id = get_task_model_id(body["model"])
|
||||
# If tool_ids field is present, call the functions
|
||||
tool_ids = body.pop("tool_ids", None)
|
||||
if not tool_ids:
|
||||
return body, {}
|
||||
|
||||
log.debug(f"{tool_ids=}")
|
||||
|
||||
@ -459,9 +388,9 @@ async def chat_completion_tools_handler(
|
||||
**extra_params,
|
||||
"__model__": app.state.MODELS[task_model_id],
|
||||
"__messages__": body["messages"],
|
||||
"__files__": body.get("files", []),
|
||||
"__files__": metadata.get("files", []),
|
||||
}
|
||||
tools = get_tools(tool_ids, user, custom_params)
|
||||
tools = get_tools(webui_app, tool_ids, user, custom_params)
|
||||
log.info(f"{tools=}")
|
||||
|
||||
specs = [tool["spec"] for tool in tools.values()]
|
||||
@ -486,7 +415,7 @@ async def chat_completion_tools_handler(
|
||||
content = await get_content_from_response(response)
|
||||
log.debug(f"{content=}")
|
||||
|
||||
if content is None:
|
||||
if not content:
|
||||
return body, {}
|
||||
|
||||
result = json.loads(content)
|
||||
@ -521,13 +450,13 @@ async def chat_completion_tools_handler(
|
||||
contexts.append(tool_output)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
log.exception(f"Error: {e}")
|
||||
content = None
|
||||
|
||||
log.debug(f"tool_contexts: {contexts}")
|
||||
|
||||
if skip_files and "files" in body:
|
||||
del body["files"]
|
||||
if skip_files and "files" in body.get("metadata", {}):
|
||||
del body["metadata"]["files"]
|
||||
|
||||
return body, {"contexts": contexts, "citations": citations}
|
||||
|
||||
@ -536,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
if files := body.pop("files", None):
|
||||
if files := body.get("metadata", {}).get("files", None):
|
||||
contexts, citations = get_rag_context(
|
||||
files=files,
|
||||
messages=body["messages"],
|
||||
@ -597,6 +526,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
"message_id": body.pop("id", None),
|
||||
"session_id": body.pop("session_id", None),
|
||||
"valves": body.pop("valves", None),
|
||||
"tool_ids": body.pop("tool_ids", None),
|
||||
"files": body.pop("files", None),
|
||||
}
|
||||
|
||||
__user__ = {
|
||||
@ -680,37 +611,30 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
if isinstance(response, StreamingResponse):
|
||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||
content_type = response.headers["Content-Type"]
|
||||
if "text/event-stream" in content_type:
|
||||
return StreamingResponse(
|
||||
self.openai_stream_wrapper(response.body_iterator, data_items),
|
||||
)
|
||||
if "application/x-ndjson" in content_type:
|
||||
return StreamingResponse(
|
||||
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
||||
)
|
||||
if not isinstance(response, StreamingResponse):
|
||||
return response
|
||||
|
||||
return response
|
||||
content_type = response.headers["Content-Type"]
|
||||
is_openai = "text/event-stream" in content_type
|
||||
is_ollama = "application/x-ndjson" in content_type
|
||||
if not is_openai and not is_ollama:
|
||||
return response
|
||||
|
||||
def wrap_item(item):
|
||||
return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
||||
|
||||
async def stream_wrapper(original_generator, data_items):
|
||||
for item in data_items:
|
||||
yield wrap_item(json.dumps(item))
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
|
||||
|
||||
async def _receive(self, body: bytes):
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def openai_stream_wrapper(self, original_generator, data_items):
|
||||
for item in data_items:
|
||||
yield f"data: {json.dumps(item)}\n\n"
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
async def ollama_stream_wrapper(self, original_generator, data_items):
|
||||
for item in data_items:
|
||||
yield f"{json.dumps(item)}\n"
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
|
||||
app.add_middleware(ChatCompletionMiddleware)
|
||||
|
||||
@ -1065,7 +989,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
detail="Model not found",
|
||||
)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
if model.get("pipe"):
|
||||
return await generate_function_chat_completion(form_data, user=user)
|
||||
if model["owned_by"] == "ollama":
|
||||
|
104
backend/utils/schemas.py
Normal file
104
backend/utils/schemas.py
Normal file
@ -0,0 +1,104 @@
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
|
||||
def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
|
||||
"""
|
||||
Converts a JSON schema to a Pydantic BaseModel class.
|
||||
|
||||
Args:
|
||||
json_schema: The JSON schema to convert.
|
||||
|
||||
Returns:
|
||||
A Pydantic BaseModel class.
|
||||
"""
|
||||
|
||||
# Extract the model name from the schema title.
|
||||
model_name = tool_dict["name"]
|
||||
schema = tool_dict["parameters"]
|
||||
|
||||
# Extract the field definitions from the schema properties.
|
||||
field_definitions = {
|
||||
name: json_schema_to_pydantic_field(name, prop, schema.get("required", []))
|
||||
for name, prop in schema.get("properties", {}).items()
|
||||
}
|
||||
|
||||
# Create the BaseModel class using create_model().
|
||||
return create_model(model_name, **field_definitions)
|
||||
|
||||
|
||||
def json_schema_to_pydantic_field(
|
||||
name: str, json_schema: dict[str, Any], required: list[str]
|
||||
) -> Any:
|
||||
"""
|
||||
Converts a JSON schema property to a Pydantic field definition.
|
||||
|
||||
Args:
|
||||
name: The field name.
|
||||
json_schema: The JSON schema property.
|
||||
|
||||
Returns:
|
||||
A Pydantic field definition.
|
||||
"""
|
||||
|
||||
# Get the field type.
|
||||
type_ = json_schema_to_pydantic_type(json_schema)
|
||||
|
||||
# Get the field description.
|
||||
description = json_schema.get("description")
|
||||
|
||||
# Get the field examples.
|
||||
examples = json_schema.get("examples")
|
||||
|
||||
# Create a Field object with the type, description, and examples.
|
||||
# The 'required' flag will be set later when creating the model.
|
||||
return (
|
||||
type_,
|
||||
Field(
|
||||
description=description,
|
||||
examples=examples,
|
||||
default=... if name in required else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Converts a JSON schema type to a Pydantic type.
|
||||
|
||||
Args:
|
||||
json_schema: The JSON schema to convert.
|
||||
|
||||
Returns:
|
||||
A Pydantic type.
|
||||
"""
|
||||
|
||||
type_ = json_schema.get("type")
|
||||
|
||||
if type_ == "string" or type_ == "str":
|
||||
return str
|
||||
elif type_ == "integer" or type_ == "int":
|
||||
return int
|
||||
elif type_ == "number" or type_ == "float":
|
||||
return float
|
||||
elif type_ == "boolean" or type_ == "bool":
|
||||
return bool
|
||||
elif type_ == "array":
|
||||
items_schema = json_schema.get("items")
|
||||
if items_schema:
|
||||
item_type = json_schema_to_pydantic_type(items_schema)
|
||||
return list[item_type]
|
||||
else:
|
||||
return list
|
||||
elif type_ == "object":
|
||||
# Handle nested models.
|
||||
properties = json_schema.get("properties")
|
||||
if properties:
|
||||
nested_model = json_schema_to_model(json_schema)
|
||||
return nested_model
|
||||
else:
|
||||
return dict
|
||||
elif type_ == "null":
|
||||
return Optional[Any] # Use Optional[Any] for nullable fields
|
||||
else:
|
||||
raise ValueError(f"Unsupported JSON schema type: {type_}")
|
@ -1,5 +1,90 @@
|
||||
import inspect
|
||||
from typing import get_type_hints
|
||||
import logging
|
||||
from typing import Awaitable, Callable, get_type_hints
|
||||
|
||||
from apps.webui.models.tools import Tools
|
||||
from apps.webui.models.users import UserModel
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
from utils.schemas import json_schema_to_model
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def apply_extra_params_to_tool_function(
|
||||
function: Callable, extra_params: dict
|
||||
) -> Callable[..., Awaitable]:
|
||||
sig = inspect.signature(function)
|
||||
extra_params = {
|
||||
key: value for key, value in extra_params.items() if key in sig.parameters
|
||||
}
|
||||
is_coroutine = inspect.iscoroutinefunction(function)
|
||||
|
||||
async def new_function(**kwargs):
|
||||
extra_kwargs = kwargs | extra_params
|
||||
if is_coroutine:
|
||||
return await function(**extra_kwargs)
|
||||
return function(**extra_kwargs)
|
||||
|
||||
return new_function
|
||||
|
||||
|
||||
# Mutation on extra_params
|
||||
def get_tools(
|
||||
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||
) -> dict[str, dict]:
|
||||
tools = {}
|
||||
for tool_id in tool_ids:
|
||||
toolkit = Tools.get_tool_by_id(tool_id)
|
||||
if toolkit is None:
|
||||
continue
|
||||
|
||||
module = webui_app.state.TOOLS.get(tool_id, None)
|
||||
if module is None:
|
||||
module, _ = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = module
|
||||
|
||||
extra_params["__id__"] = tool_id
|
||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||
module.valves = module.Valves(**valves)
|
||||
|
||||
if hasattr(module, "UserValves"):
|
||||
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||
)
|
||||
|
||||
for spec in toolkit.specs:
|
||||
# TODO: Fix hack for OpenAI API
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val["type"] == "str":
|
||||
val["type"] = "string"
|
||||
function_name = spec["name"]
|
||||
|
||||
# convert to function that takes only model params and inserts custom params
|
||||
original_func = getattr(module, function_name)
|
||||
callable = apply_extra_params_to_tool_function(original_func, extra_params)
|
||||
if hasattr(original_func, "__doc__"):
|
||||
callable.__doc__ = original_func.__doc__
|
||||
|
||||
# TODO: This needs to be a pydantic model
|
||||
tool_dict = {
|
||||
"toolkit_id": tool_id,
|
||||
"callable": callable,
|
||||
"spec": spec,
|
||||
"pydantic_model": json_schema_to_model(spec),
|
||||
"file_handler": hasattr(module, "file_handler") and module.file_handler,
|
||||
"citation": hasattr(module, "citation") and module.citation,
|
||||
}
|
||||
|
||||
# TODO: if collision, prepend toolkit name
|
||||
if function_name in tools:
|
||||
log.warning(f"Tool {function_name} already exists in another toolkit!")
|
||||
log.warning(f"Collision between {toolkit} and {tool_id}.")
|
||||
log.warning(f"Discarding {toolkit}.{function_name}")
|
||||
else:
|
||||
tools[function_name] = tool_dict
|
||||
return tools
|
||||
|
||||
|
||||
def doc_to_dict(docstring):
|
||||
|
Loading…
Reference in New Issue
Block a user