diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index cc82089b5..d3931b1ab 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -731,11 +731,8 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") - - payload = { - **form_data.model_dump(exclude_none=True, exclude=["metadata"]), - } + payload = {**form_data.model_dump(exclude_none=True)} + log.debug(f"{payload = }") if "metadata" in payload: del payload["metadata"] diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 82bdea165..375615180 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -26,6 +26,7 @@ from utils.misc import ( apply_model_system_prompt_to_body, ) +from utils.tools import get_tools from config import ( SHOW_ADMIN_DETAILS, @@ -275,7 +276,9 @@ def get_function_params(function_module, form_data, user, extra_params={}): async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) - metadata = form_data.pop("metadata", None) + metadata = form_data.pop("metadata", {}) + files = metadata.get("files", []) + tool_ids = metadata.get("tool_ids", []) __event_emitter__ = None __event_call__ = None @@ -287,6 +290,20 @@ async def generate_function_chat_completion(form_data, user): __event_call__ = get_event_call(metadata) __task__ = metadata.get("task", None) + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + } + tools_params = { + **extra_params, + "__model__": app.state.MODELS[form_data["model"]], + "__messages__": form_data["messages"], + "__files__": files, + } + configured_tools = get_tools(app, tool_ids, user, tools_params) + + extra_params["__tools__"] = configured_tools if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id @@ -299,16 +316,7 @@ async def generate_function_chat_completion(form_data, user): function_module = get_function_module(pipe_id) pipe = function_module.pipe - params = get_function_params( - function_module, - form_data, - user, - { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - }, - ) + params = get_function_params(function_module, form_data, user, extra_params) if form_data["stream"]: diff --git a/backend/config.py b/backend/config.py index 716bc6be2..72f3b5e5a 100644 --- a/backend/config.py +++ b/backend/config.py @@ -176,7 +176,6 @@ for version in soup.find_all("h2"): CHANGELOG = changelog_json - #################################### # SAFE_MODE #################################### diff --git a/backend/main.py b/backend/main.py index fbd6a9439..f00d6c382 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,6 +14,7 @@ import requests import mimetypes import shutil import inspect +from typing import Optional from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles @@ -51,15 +52,13 @@ from apps.webui.internal.db import Session from pydantic import BaseModel -from typing import Optional, Callable, Awaitable from apps.webui.models.auths import Auths from apps.webui.models.models import Models -from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions from apps.webui.models.users import Users, UserModel -from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id +from apps.webui.utils import load_function_module_by_id from utils.utils import ( get_admin_user, @@ -76,6 +75,8 @@ from utils.task import ( tools_function_calling_generation_template, moa_response_generation_template, ) + +from utils.tools import get_tools from utils.misc import ( get_last_user_message, add_or_update_system_message, @@ -325,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params): print(f"Error: {e}") raise e - if skip_files and "files" in body: - del body["files"] + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] return body, {} @@ -351,80 +352,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content): } -def apply_extra_params_to_tool_function( - function: Callable, extra_params: dict -) -> Callable[..., Awaitable]: - sig = inspect.signature(function) - extra_params = { - key: value for key, value in extra_params.items() if key in sig.parameters - } - is_coroutine = inspect.iscoroutinefunction(function) - - async def new_function(**kwargs): - extra_kwargs = kwargs | extra_params - if is_coroutine: - return await function(**extra_kwargs) - return function(**extra_kwargs) - - return new_function - - -# Mutation on extra_params -def get_tools( - tool_ids: list[str], user: UserModel, extra_params: dict -) -> dict[str, dict]: - tools = {} - for tool_id in tool_ids: - toolkit = Tools.get_tool_by_id(tool_id) - if toolkit is None: - continue - - module = webui_app.state.TOOLS.get(tool_id, None) - if module is None: - module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = module - - extra_params["__id__"] = tool_id - if hasattr(module, "valves") and hasattr(module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) or {} - module.valves = module.Valves(**valves) - - if hasattr(module, "UserValves"): - extra_params["__user__"]["valves"] = module.UserValves( # type: ignore - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - - for spec in toolkit.specs: - # TODO: Fix hack for OpenAI API - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val["type"] == "str": - val["type"] = "string" - function_name = spec["name"] - - # convert to function that takes only model params and inserts custom params - callable = apply_extra_params_to_tool_function( - getattr(module, function_name), extra_params - ) - - # TODO: This needs to be a pydantic model - tool_dict = { - "toolkit_id": tool_id, - "callable": callable, - "spec": spec, - "file_handler": hasattr(module, "file_handler") and module.file_handler, - "citation": hasattr(module, "citation") and module.citation, - } - - # TODO: if collision, prepend toolkit name - if function_name in tools: - log.warning(f"Tool {function_name} already exists in another toolkit!") - log.warning(f"Collision between {toolkit} and {tool_id}.") - log.warning(f"Discarding {toolkit}.{function_name}") - else: - tools[function_name] = tool_dict - return tools - - async def get_content_from_response(response) -> Optional[str]: content = None if hasattr(response, "body_iterator"): @@ -443,15 +370,17 @@ async def get_content_from_response(response) -> Optional[str]: async def chat_completion_tools_handler( body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: + # If tool_ids field is present, call the functions + metadata = body.get("metadata", {}) + tool_ids = metadata.get("tool_ids", None) + if not tool_ids: + return body, {} + skip_files = False contexts = [] citations = [] task_model_id = get_task_model_id(body["model"]) - # If tool_ids field is present, call the functions - tool_ids = body.pop("tool_ids", None) - if not tool_ids: - return body, {} log.debug(f"{tool_ids=}") @@ -459,9 +388,9 @@ async def chat_completion_tools_handler( **extra_params, "__model__": app.state.MODELS[task_model_id], "__messages__": body["messages"], - "__files__": body.get("files", []), + "__files__": metadata.get("files", []), } - tools = get_tools(tool_ids, user, custom_params) + tools = get_tools(webui_app, tool_ids, user, custom_params) log.info(f"{tools=}") specs = [tool["spec"] for tool in tools.values()] @@ -486,7 +415,7 @@ async def chat_completion_tools_handler( content = await get_content_from_response(response) log.debug(f"{content=}") - if content is None: + if not content: return body, {} result = json.loads(content) @@ -521,13 +450,13 @@ async def chat_completion_tools_handler( contexts.append(tool_output) except Exception as e: - print(f"Error: {e}") + log.exception(f"Error: {e}") content = None log.debug(f"tool_contexts: {contexts}") - if skip_files and "files" in body: - del body["files"] + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] return body, {"contexts": contexts, "citations": citations} @@ -536,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: contexts = [] citations = [] - if files := body.pop("files", None): + if files := body.get("metadata", {}).get("files", None): contexts, citations = get_rag_context( files=files, messages=body["messages"], @@ -597,6 +526,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): "message_id": body.pop("id", None), "session_id": body.pop("session_id", None), "valves": body.pop("valves", None), + "tool_ids": body.pop("tool_ids", None), + "files": body.pop("files", None), } __user__ = { @@ -680,37 +611,30 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ] response = await call_next(request) - if isinstance(response, StreamingResponse): - # If it's a streaming response, inject it as SSE event or NDJSON line - content_type = response.headers["Content-Type"] - if "text/event-stream" in content_type: - return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, data_items), - ) - if "application/x-ndjson" in content_type: - return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, data_items), - ) + if not isinstance(response, StreamingResponse): + return response - return response + content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + if not is_openai and not is_ollama: + return response + + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" + + async def stream_wrapper(original_generator, data_items): + for item in data_items: + yield wrap_item(json.dumps(item)) + + async for data in original_generator: + yield data + + return StreamingResponse(stream_wrapper(response.body_iterator, data_items)) async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} - async def openai_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"data: {json.dumps(item)}\n\n" - - async for data in original_generator: - yield data - - async def ollama_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"{json.dumps(item)}\n" - - async for data in original_generator: - yield data - app.add_middleware(ChatCompletionMiddleware) @@ -1065,7 +989,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u detail="Model not found", ) model = app.state.MODELS[model_id] - if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": diff --git a/backend/utils/schemas.py b/backend/utils/schemas.py new file mode 100644 index 000000000..09b24897b --- /dev/null +++ b/backend/utils/schemas.py @@ -0,0 +1,104 @@ +from pydantic import BaseModel, Field, create_model +from typing import Any, Optional, Type + + +def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]: + """ + Converts a JSON schema to a Pydantic BaseModel class. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic BaseModel class. + """ + + # Extract the model name from the schema title. + model_name = tool_dict["name"] + schema = tool_dict["parameters"] + + # Extract the field definitions from the schema properties. + field_definitions = { + name: json_schema_to_pydantic_field(name, prop, schema.get("required", [])) + for name, prop in schema.get("properties", {}).items() + } + + # Create the BaseModel class using create_model(). + return create_model(model_name, **field_definitions) + + +def json_schema_to_pydantic_field( + name: str, json_schema: dict[str, Any], required: list[str] +) -> Any: + """ + Converts a JSON schema property to a Pydantic field definition. + + Args: + name: The field name. + json_schema: The JSON schema property. + + Returns: + A Pydantic field definition. + """ + + # Get the field type. + type_ = json_schema_to_pydantic_type(json_schema) + + # Get the field description. + description = json_schema.get("description") + + # Get the field examples. + examples = json_schema.get("examples") + + # Create a Field object with the type, description, and examples. + # The 'required' flag will be set later when creating the model. + return ( + type_, + Field( + description=description, + examples=examples, + default=... if name in required else None, + ), + ) + + +def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any: + """ + Converts a JSON schema type to a Pydantic type. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic type. + """ + + type_ = json_schema.get("type") + + if type_ == "string" or type_ == "str": + return str + elif type_ == "integer" or type_ == "int": + return int + elif type_ == "number" or type_ == "float": + return float + elif type_ == "boolean" or type_ == "bool": + return bool + elif type_ == "array": + items_schema = json_schema.get("items") + if items_schema: + item_type = json_schema_to_pydantic_type(items_schema) + return list[item_type] + else: + return list + elif type_ == "object": + # Handle nested models. + properties = json_schema.get("properties") + if properties: + nested_model = json_schema_to_model(json_schema) + return nested_model + else: + return dict + elif type_ == "null": + return Optional[Any] # Use Optional[Any] for nullable fields + else: + raise ValueError(f"Unsupported JSON schema type: {type_}") diff --git a/backend/utils/tools.py b/backend/utils/tools.py index eac36b5d9..1a2fea32b 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -1,5 +1,90 @@ import inspect -from typing import get_type_hints +import logging +from typing import Awaitable, Callable, get_type_hints + +from apps.webui.models.tools import Tools +from apps.webui.models.users import UserModel +from apps.webui.utils import load_toolkit_module_by_id + +from utils.schemas import json_schema_to_model + +log = logging.getLogger(__name__) + + +def apply_extra_params_to_tool_function( + function: Callable, extra_params: dict +) -> Callable[..., Awaitable]: + sig = inspect.signature(function) + extra_params = { + key: value for key, value in extra_params.items() if key in sig.parameters + } + is_coroutine = inspect.iscoroutinefunction(function) + + async def new_function(**kwargs): + extra_kwargs = kwargs | extra_params + if is_coroutine: + return await function(**extra_kwargs) + return function(**extra_kwargs) + + return new_function + + +# Mutation on extra_params +def get_tools( + webui_app, tool_ids: list[str], user: UserModel, extra_params: dict +) -> dict[str, dict]: + tools = {} + for tool_id in tool_ids: + toolkit = Tools.get_tool_by_id(tool_id) + if toolkit is None: + continue + + module = webui_app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = module + + extra_params["__id__"] = tool_id + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + + if hasattr(module, "UserValves"): + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in toolkit.specs: + # TODO: Fix hack for OpenAI API + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" + function_name = spec["name"] + + # convert to function that takes only model params and inserts custom params + original_func = getattr(module, function_name) + callable = apply_extra_params_to_tool_function(original_func, extra_params) + if hasattr(original_func, "__doc__"): + callable.__doc__ = original_func.__doc__ + + # TODO: This needs to be a pydantic model + tool_dict = { + "toolkit_id": tool_id, + "callable": callable, + "spec": spec, + "pydantic_model": json_schema_to_model(spec), + "file_handler": hasattr(module, "file_handler") and module.file_handler, + "citation": hasattr(module, "citation") and module.citation, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools: + log.warning(f"Tool {function_name} already exists in another toolkit!") + log.warning(f"Collision between {toolkit} and {tool_id}.") + log.warning(f"Discarding {toolkit}.{function_name}") + else: + tools[function_name] = tool_dict + return tools def doc_to_dict(docstring):