From 3164354c0b86517a2e168ac50f44c4abe1d319d8 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 15 Aug 2024 17:03:42 +0100 Subject: [PATCH 01/17] refactor into single wrapper --- backend/main.py | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/backend/main.py b/backend/main.py index d539834ed..411c33e1c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -681,36 +681,29 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): response = await call_next(request) if isinstance(response, StreamingResponse): - # If it's a streaming response, inject it as SSE event or NDJSON line content_type = response.headers["Content-Type"] - if "text/event-stream" in content_type: - return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, data_items), - ) - if "application/x-ndjson" in content_type: - return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, data_items), - ) + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + if not is_openai and not is_ollama: + return response + + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" + + async def stream_wrapper(original_generator, data_items): + for item in data_items: + yield wrap_item(json.dumps(item)) + + async for data in original_generator: + yield data + + return StreamingResponse(stream_wrapper(response.body_iterator, data_items)) return response async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} - async def openai_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"data: {json.dumps(item)}\n\n" - - async for data in original_generator: - yield data - - async def ollama_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"{json.dumps(item)}\n" - - async for data in original_generator: - yield data - app.add_middleware(ChatCompletionMiddleware) From 32874a816d894dc8d012eee24ce96887d6851d5d Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 10:04:08 +0100 Subject: [PATCH 02/17] add filter toggle envvars --- backend/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backend/config.py b/backend/config.py index 3453ee24b..1850fcefc 100644 --- a/backend/config.py +++ b/backend/config.py @@ -174,6 +174,12 @@ for version in soup.find_all("h2"): CHANGELOG = changelog_json +#################################### +# FILTERS +#################################### + +ENABLE_TOOLS_FILTER = os.environ.get("ENABLE_TOOLS_FILTER", "True").lower() == "true" +ENABLE_FILES_FILTER = os.environ.get("ENABLE_FILES_FILTER", "True").lower() == "true" #################################### # SAFE_MODE From ce7a1a73ac4a5554b251e2970145879326d445f9 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 10:34:44 +0100 Subject: [PATCH 03/17] remove more nesting --- backend/main.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/backend/main.py b/backend/main.py index 411c33e1c..948c9081e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -680,26 +680,26 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ] response = await call_next(request) - if isinstance(response, StreamingResponse): - content_type = response.headers["Content-Type"] - is_openai = "text/event-stream" in content_type - is_ollama = "application/x-ndjson" in content_type - if not is_openai and not is_ollama: - return response + if not isinstance(response, StreamingResponse): + return response - def wrap_item(item): - return f"data: {item}\n\n" if is_openai else f"{item}\n" + content_type = response.headers["Content-Type"] + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + if not is_openai and not is_ollama: + return response - async def stream_wrapper(original_generator, data_items): - for item in data_items: - yield wrap_item(json.dumps(item)) + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" - async for data in original_generator: - yield data + async def stream_wrapper(original_generator, data_items): + for item in data_items: + yield wrap_item(json.dumps(item)) - return StreamingResponse(stream_wrapper(response.body_iterator, data_items)) + async for data in original_generator: + yield data - return response + return StreamingResponse(stream_wrapper(response.body_iterator, data_items)) async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} From fd422d2e3c0340cd0dd02da46e3071e4e96e6bde Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 10:46:52 +0100 Subject: [PATCH 04/17] use filters envvars --- backend/main.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/backend/main.py b/backend/main.py index 948c9081e..f19e444c3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -118,6 +118,8 @@ from config import ( WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, ENABLE_ADMIN_CHAT_ACCESS, + ENABLE_TOOLS_FILTER, + ENABLE_FILES_FILTER, AppConfig, ) @@ -443,6 +445,10 @@ async def get_content_from_response(response) -> Optional[str]: async def chat_completion_tools_handler( body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: + log.debug(f"{ENABLE_TOOLS_FILTER=}") + if not ENABLE_TOOLS_FILTER: + return body, {} + skip_files = False contexts = [] citations = [] @@ -533,6 +539,10 @@ async def chat_completion_tools_handler( async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: + log.debug(f"{ENABLE_FILES_FILTER=}") + if not ENABLE_FILES_FILTER: + return body, {} + contexts = [] citations = [] From a4a7d678f9908e0466e3395ff437824e7fe88ee3 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 10:53:12 +0100 Subject: [PATCH 05/17] move tools utils to utils.tools --- backend/main.py | 82 ++--------------------------------------- backend/utils/tools.py | 83 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 79 deletions(-) diff --git a/backend/main.py b/backend/main.py index f19e444c3..dbe9d30bf 100644 --- a/backend/main.py +++ b/backend/main.py @@ -51,15 +51,13 @@ from apps.webui.internal.db import Session from pydantic import BaseModel -from typing import Optional, Callable, Awaitable from apps.webui.models.auths import Auths from apps.webui.models.models import Models -from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions from apps.webui.models.users import Users, UserModel -from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id +from apps.webui.utils import load_function_module_by_id from utils.utils import ( get_admin_user, @@ -75,6 +73,8 @@ from utils.task import ( tools_function_calling_generation_template, moa_response_generation_template, ) + +from utils.tools import get_tools from utils.misc import ( get_last_user_message, add_or_update_system_message, @@ -353,80 +353,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content): } -def apply_extra_params_to_tool_function( - function: Callable, extra_params: dict -) -> Callable[..., Awaitable]: - sig = inspect.signature(function) - extra_params = { - key: value for key, value in extra_params.items() if key in sig.parameters - } - is_coroutine = inspect.iscoroutinefunction(function) - - async def new_function(**kwargs): - extra_kwargs = kwargs | extra_params - if is_coroutine: - return await function(**extra_kwargs) - return function(**extra_kwargs) - - return new_function - - -# Mutation on extra_params -def get_tools( - tool_ids: list[str], user: UserModel, extra_params: dict -) -> dict[str, dict]: - tools = {} - for tool_id in tool_ids: - toolkit = Tools.get_tool_by_id(tool_id) - if toolkit is None: - continue - - module = webui_app.state.TOOLS.get(tool_id, None) - if module is None: - module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = module - - extra_params["__id__"] = tool_id - if hasattr(module, "valves") and hasattr(module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) or {} - module.valves = module.Valves(**valves) - - if hasattr(module, "UserValves"): - extra_params["__user__"]["valves"] = module.UserValves( # type: ignore - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - - for spec in toolkit.specs: - # TODO: Fix hack for OpenAI API - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val["type"] == "str": - val["type"] = "string" - function_name = spec["name"] - - # convert to function that takes only model params and inserts custom params - callable = apply_extra_params_to_tool_function( - getattr(module, function_name), extra_params - ) - - # TODO: This needs to be a pydantic model - tool_dict = { - "toolkit_id": tool_id, - "callable": callable, - "spec": spec, - "file_handler": hasattr(module, "file_handler") and module.file_handler, - "citation": hasattr(module, "citation") and module.citation, - } - - # TODO: if collision, prepend toolkit name - if function_name in tools: - log.warning(f"Tool {function_name} already exists in another toolkit!") - log.warning(f"Collision between {toolkit} and {tool_id}.") - log.warning(f"Discarding {toolkit}.{function_name}") - else: - tools[function_name] = tool_dict - return tools - - async def get_content_from_response(response) -> Optional[str]: content = None if hasattr(response, "body_iterator"): @@ -467,7 +393,7 @@ async def chat_completion_tools_handler( "__messages__": body["messages"], "__files__": body.get("files", []), } - tools = get_tools(tool_ids, user, custom_params) + tools = get_tools(webui_app, tool_ids, user, custom_params) log.info(f"{tools=}") specs = [tool["spec"] for tool in tools.values()] diff --git a/backend/utils/tools.py b/backend/utils/tools.py index eac36b5d9..12642ccfd 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -1,5 +1,86 @@ import inspect -from typing import get_type_hints +import logging +from typing import Awaitable, Callable, get_type_hints + +from apps.webui.models.tools import Tools +from apps.webui.models.users import UserModel +from apps.webui.utils import load_toolkit_module_by_id + +log = logging.getLogger(__name__) + + +def apply_extra_params_to_tool_function( + function: Callable, extra_params: dict +) -> Callable[..., Awaitable]: + sig = inspect.signature(function) + extra_params = { + key: value for key, value in extra_params.items() if key in sig.parameters + } + is_coroutine = inspect.iscoroutinefunction(function) + + async def new_function(**kwargs): + extra_kwargs = kwargs | extra_params + if is_coroutine: + return await function(**extra_kwargs) + return function(**extra_kwargs) + + return new_function + + +# Mutation on extra_params +def get_tools( + webui_app, tool_ids: list[str], user: UserModel, extra_params: dict +) -> dict[str, dict]: + tools = {} + for tool_id in tool_ids: + toolkit = Tools.get_tool_by_id(tool_id) + if toolkit is None: + continue + + module = webui_app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = module + + extra_params["__id__"] = tool_id + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + + if hasattr(module, "UserValves"): + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in toolkit.specs: + # TODO: Fix hack for OpenAI API + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" + function_name = spec["name"] + + # convert to function that takes only model params and inserts custom params + callable = apply_extra_params_to_tool_function( + getattr(module, function_name), extra_params + ) + + # TODO: This needs to be a pydantic model + tool_dict = { + "toolkit_id": tool_id, + "callable": callable, + "spec": spec, + "file_handler": hasattr(module, "file_handler") and module.file_handler, + "citation": hasattr(module, "citation") and module.citation, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools: + log.warning(f"Tool {function_name} already exists in another toolkit!") + log.warning(f"Collision between {toolkit} and {tool_id}.") + log.warning(f"Discarding {toolkit}.{function_name}") + else: + tools[function_name] = tool_dict + return tools def doc_to_dict(docstring): From 18965dcdacfc793fc6eeaa72b56b035f38785ba8 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 11:03:55 +0100 Subject: [PATCH 06/17] delete keys if envvars are set --- backend/apps/ollama/main.py | 9 ++++----- backend/main.py | 6 +++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 810a05999..37b72a105 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -732,11 +732,10 @@ async def generate_chat_completion( ): log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") - payload = { - **form_data.model_dump(exclude_none=True, exclude=["metadata"]), - } - if "metadata" in payload: - del payload["metadata"] + payload = {**form_data.model_dump(exclude_none=True)} + for key in ["metadata", "files", "tool_ids"]: + if key in payload: + del payload[key] model_id = form_data.model model_info = Models.get_model_by_id(model_id) diff --git a/backend/main.py b/backend/main.py index dbe9d30bf..49984e9bc 100644 --- a/backend/main.py +++ b/backend/main.py @@ -453,7 +453,7 @@ async def chat_completion_tools_handler( contexts.append(tool_output) except Exception as e: - print(f"Error: {e}") + log.exception(f"Error: {e}") content = None log.debug(f"tool_contexts: {contexts}") @@ -997,6 +997,10 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) + + for key in ["tool_ids", "files"]: + if key in form_data: + del form_data[key] if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user) else: From 13c03bfd7d5a69ae8f00ce370d47158b55cd13e4 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 11:08:27 +0100 Subject: [PATCH 07/17] add __tools__ custom param --- backend/apps/webui/main.py | 30 +++++++++++++++++++----------- backend/main.py | 8 +++----- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index dddf3fbb2..06c1a0921 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -26,6 +26,7 @@ from utils.misc import ( apply_model_system_prompt_to_body, ) +from utils.tools import get_tools from config import ( SHOW_ADMIN_DETAILS, @@ -47,6 +48,7 @@ from config import ( OAUTH_USERNAME_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_EMAIL_CLAIM, + ENABLE_TOOLS_FILTER, ) from apps.socket.main import get_event_call, get_event_emitter @@ -271,7 +273,7 @@ def get_function_params(function_module, form_data, user, extra_params={}): return params -async def generate_function_chat_completion(form_data, user): +async def generate_function_chat_completion(form_data, user, files, tool_ids): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) metadata = form_data.pop("metadata", None) @@ -286,6 +288,21 @@ async def generate_function_chat_completion(form_data, user): __event_call__ = get_event_call(metadata) __task__ = metadata.get("task", None) + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + } + if not ENABLE_TOOLS_FILTER: + tools_params = { + **extra_params, + "__model__": app.state.MODELS[form_data["model"]], + "__messages__": form_data["messages"], + "__files__": files, + } + configured_tools = get_tools(app, tool_ids, user, tools_params) + + extra_params["__tools__"] = configured_tools if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id @@ -298,16 +315,7 @@ async def generate_function_chat_completion(form_data, user): function_module = get_function_module(pipe_id) pipe = function_module.pipe - params = get_function_params( - function_module, - form_data, - user, - { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - }, - ) + params = get_function_params(function_module, form_data, user, extra_params) if form_data["stream"]: diff --git a/backend/main.py b/backend/main.py index 49984e9bc..0bd49ef3a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -994,13 +994,11 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u detail="Model not found", ) model = app.state.MODELS[model_id] + files = form_data.pop("files", None) + tool_ids = form_data.pop("tool_ids", None) if model.get("pipe"): - return await generate_function_chat_completion(form_data, user=user) - - for key in ["tool_ids", "files"]: - if key in form_data: - del form_data[key] + return await generate_function_chat_completion(form_data, user, files, tool_ids) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user) else: From a933319adb28b556cdd085ca98eaecce74343130 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 11:11:00 +0100 Subject: [PATCH 08/17] import error? --- backend/main.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/main.py b/backend/main.py index 0bd49ef3a..ad1d7388c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,6 +14,7 @@ import requests import mimetypes import shutil import inspect +from typing import Optional from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles @@ -375,15 +376,16 @@ async def chat_completion_tools_handler( if not ENABLE_TOOLS_FILTER: return body, {} + # If tool_ids field is present, call the functions + tool_ids = body.pop("tool_ids", None) + if not tool_ids: + return body, {} + skip_files = False contexts = [] citations = [] task_model_id = get_task_model_id(body["model"]) - # If tool_ids field is present, call the functions - tool_ids = body.pop("tool_ids", None) - if not tool_ids: - return body, {} log.debug(f"{tool_ids=}") From 528df12bf1c227cbd69e8416468ffcb047152d95 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 11:15:22 +0100 Subject: [PATCH 09/17] fix: nonetype error --- backend/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/main.py b/backend/main.py index ad1d7388c..53c92020f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -996,8 +996,8 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u detail="Model not found", ) model = app.state.MODELS[model_id] - files = form_data.pop("files", None) - tool_ids = form_data.pop("tool_ids", None) + files = form_data.pop("files", []) + tool_ids = form_data.pop("tool_ids", []) if model.get("pipe"): return await generate_function_chat_completion(form_data, user, files, tool_ids) From 5edc211392ebff06ac3687021f5ea6c04d75e065 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 16:27:21 +0100 Subject: [PATCH 10/17] pass docstring to function --- backend/utils/tools.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/backend/utils/tools.py b/backend/utils/tools.py index 12642ccfd..14519f1be 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -60,9 +60,10 @@ def get_tools( function_name = spec["name"] # convert to function that takes only model params and inserts custom params - callable = apply_extra_params_to_tool_function( - getattr(module, function_name), extra_params - ) + original_func = getattr(module, function_name) + callable = apply_extra_params_to_tool_function(original_func, extra_params) + if hasattr(original_func, "__doc__"): + callable.__doc__ = original_func.__doc__ # TODO: This needs to be a pydantic model tool_dict = { From 9d7037b730dc3fff09cecdc3d17d8a444cb0cebd Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 16:27:38 +0100 Subject: [PATCH 11/17] add pydantic model from json --- backend/utils/schemas.py | 104 +++++++++++++++++++++++++++++++++++++++ backend/utils/tools.py | 3 ++ 2 files changed, 107 insertions(+) create mode 100644 backend/utils/schemas.py diff --git a/backend/utils/schemas.py b/backend/utils/schemas.py new file mode 100644 index 000000000..09b24897b --- /dev/null +++ b/backend/utils/schemas.py @@ -0,0 +1,104 @@ +from pydantic import BaseModel, Field, create_model +from typing import Any, Optional, Type + + +def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]: + """ + Converts a JSON schema to a Pydantic BaseModel class. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic BaseModel class. + """ + + # Extract the model name from the schema title. + model_name = tool_dict["name"] + schema = tool_dict["parameters"] + + # Extract the field definitions from the schema properties. + field_definitions = { + name: json_schema_to_pydantic_field(name, prop, schema.get("required", [])) + for name, prop in schema.get("properties", {}).items() + } + + # Create the BaseModel class using create_model(). + return create_model(model_name, **field_definitions) + + +def json_schema_to_pydantic_field( + name: str, json_schema: dict[str, Any], required: list[str] +) -> Any: + """ + Converts a JSON schema property to a Pydantic field definition. + + Args: + name: The field name. + json_schema: The JSON schema property. + + Returns: + A Pydantic field definition. + """ + + # Get the field type. + type_ = json_schema_to_pydantic_type(json_schema) + + # Get the field description. + description = json_schema.get("description") + + # Get the field examples. + examples = json_schema.get("examples") + + # Create a Field object with the type, description, and examples. + # The 'required' flag will be set later when creating the model. + return ( + type_, + Field( + description=description, + examples=examples, + default=... if name in required else None, + ), + ) + + +def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any: + """ + Converts a JSON schema type to a Pydantic type. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic type. + """ + + type_ = json_schema.get("type") + + if type_ == "string" or type_ == "str": + return str + elif type_ == "integer" or type_ == "int": + return int + elif type_ == "number" or type_ == "float": + return float + elif type_ == "boolean" or type_ == "bool": + return bool + elif type_ == "array": + items_schema = json_schema.get("items") + if items_schema: + item_type = json_schema_to_pydantic_type(items_schema) + return list[item_type] + else: + return list + elif type_ == "object": + # Handle nested models. + properties = json_schema.get("properties") + if properties: + nested_model = json_schema_to_model(json_schema) + return nested_model + else: + return dict + elif type_ == "null": + return Optional[Any] # Use Optional[Any] for nullable fields + else: + raise ValueError(f"Unsupported JSON schema type: {type_}") diff --git a/backend/utils/tools.py b/backend/utils/tools.py index 14519f1be..1a2fea32b 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -6,6 +6,8 @@ from apps.webui.models.tools import Tools from apps.webui.models.users import UserModel from apps.webui.utils import load_toolkit_module_by_id +from utils.schemas import json_schema_to_model + log = logging.getLogger(__name__) @@ -70,6 +72,7 @@ def get_tools( "toolkit_id": tool_id, "callable": callable, "spec": spec, + "pydantic_model": json_schema_to_model(spec), "file_handler": hasattr(module, "file_handler") and module.file_handler, "citation": hasattr(module, "citation") and module.citation, } From 556bc8669a744471ad69771da01e87c00ee0bb04 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 16:50:52 +0100 Subject: [PATCH 12/17] remove config options for now --- backend/apps/webui/main.py | 17 ++++++++--------- backend/config.py | 7 ------- backend/main.py | 10 ---------- 3 files changed, 8 insertions(+), 26 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 06c1a0921..e8b12f683 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -293,16 +293,15 @@ async def generate_function_chat_completion(form_data, user, files, tool_ids): "__event_call__": __event_call__, "__task__": __task__, } - if not ENABLE_TOOLS_FILTER: - tools_params = { - **extra_params, - "__model__": app.state.MODELS[form_data["model"]], - "__messages__": form_data["messages"], - "__files__": files, - } - configured_tools = get_tools(app, tool_ids, user, tools_params) + tools_params = { + **extra_params, + "__model__": app.state.MODELS[form_data["model"]], + "__messages__": form_data["messages"], + "__files__": files, + } + configured_tools = get_tools(app, tool_ids, user, tools_params) - extra_params["__tools__"] = configured_tools + extra_params["__tools__"] = configured_tools if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id diff --git a/backend/config.py b/backend/config.py index 1850fcefc..6d73eec0d 100644 --- a/backend/config.py +++ b/backend/config.py @@ -174,13 +174,6 @@ for version in soup.find_all("h2"): CHANGELOG = changelog_json -#################################### -# FILTERS -#################################### - -ENABLE_TOOLS_FILTER = os.environ.get("ENABLE_TOOLS_FILTER", "True").lower() == "true" -ENABLE_FILES_FILTER = os.environ.get("ENABLE_FILES_FILTER", "True").lower() == "true" - #################################### # SAFE_MODE #################################### diff --git a/backend/main.py b/backend/main.py index 53c92020f..823d1a9f4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -119,8 +119,6 @@ from config import ( WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, ENABLE_ADMIN_CHAT_ACCESS, - ENABLE_TOOLS_FILTER, - ENABLE_FILES_FILTER, AppConfig, ) @@ -372,10 +370,6 @@ async def get_content_from_response(response) -> Optional[str]: async def chat_completion_tools_handler( body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: - log.debug(f"{ENABLE_TOOLS_FILTER=}") - if not ENABLE_TOOLS_FILTER: - return body, {} - # If tool_ids field is present, call the functions tool_ids = body.pop("tool_ids", None) if not tool_ids: @@ -467,10 +461,6 @@ async def chat_completion_tools_handler( async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: - log.debug(f"{ENABLE_FILES_FILTER=}") - if not ENABLE_FILES_FILTER: - return body, {} - contexts = [] citations = [] From c89df923c51261e63fe11fcaa0575416c6978ba9 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 16:52:42 +0100 Subject: [PATCH 13/17] fix import error --- backend/apps/webui/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index e8b12f683..2ed35bf17 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -48,7 +48,6 @@ from config import ( OAUTH_USERNAME_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_EMAIL_CLAIM, - ENABLE_TOOLS_FILTER, ) from apps.socket.main import get_event_call, get_event_emitter From 9652c8f8af4909b7c44447aad3fe74d59c539e11 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 16:57:29 +0100 Subject: [PATCH 14/17] dont delete files and tool_ids --- backend/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/main.py b/backend/main.py index 823d1a9f4..461ae937b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -371,7 +371,7 @@ async def chat_completion_tools_handler( body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions - tool_ids = body.pop("tool_ids", None) + tool_ids = body.get("tool_ids", None) if not tool_ids: return body, {} @@ -464,7 +464,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: contexts = [] citations = [] - if files := body.pop("files", None): + if files := body.get("files", None): contexts, citations = get_rag_context( files=files, messages=body["messages"], From 44966db50571567af2897c48fd859cec958d178d Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 19 Aug 2024 17:04:57 +0100 Subject: [PATCH 15/17] avoid ugly exception --- backend/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/main.py b/backend/main.py index 461ae937b..3e3d265a2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -414,7 +414,7 @@ async def chat_completion_tools_handler( content = await get_content_from_response(response) log.debug(f"{content=}") - if content is None: + if not content: return body, {} result = json.loads(content) From 2e3146263c8b0bbd30fa58eb11fa1d436329f0ad Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 20 Aug 2024 15:41:49 +0100 Subject: [PATCH 16/17] put tool_ids and files in metadata --- backend/apps/ollama/main.py | 8 +++----- backend/apps/webui/main.py | 6 ++++-- backend/main.py | 20 +++++++++----------- src/lib/components/chat/Chat.svelte | 12 ++++++++---- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 0fa3abb6d..d3931b1ab 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -731,12 +731,10 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") - payload = {**form_data.model_dump(exclude_none=True)} - for key in ["metadata", "files", "tool_ids"]: - if key in payload: - del payload[key] + log.debug(f"{payload = }") + if "metadata" in payload: + del payload["metadata"] model_id = form_data.model model_info = Models.get_model_by_id(model_id) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 2dbe7f787..375615180 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -273,10 +273,12 @@ def get_function_params(function_module, form_data, user, extra_params={}): return params -async def generate_function_chat_completion(form_data, user, files, tool_ids): +async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) - metadata = form_data.pop("metadata", None) + metadata = form_data.pop("metadata", {}) + files = metadata.get("files", []) + tool_ids = metadata.get("tool_ids", []) __event_emitter__ = None __event_call__ = None diff --git a/backend/main.py b/backend/main.py index 1557de2b9..1fc292e5f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -326,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params): print(f"Error: {e}") raise e - if skip_files and "files" in body: - del body["files"] + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] return body, {} @@ -371,7 +371,8 @@ async def chat_completion_tools_handler( body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions - tool_ids = body.get("tool_ids", None) + metadata = body.get("metadata", {}) + tool_ids = metadata.get("tool_ids", None) if not tool_ids: return body, {} @@ -387,7 +388,7 @@ async def chat_completion_tools_handler( **extra_params, "__model__": app.state.MODELS[task_model_id], "__messages__": body["messages"], - "__files__": body.get("files", []), + "__files__": metadata.get("files", []), } tools = get_tools(webui_app, tool_ids, user, custom_params) log.info(f"{tools=}") @@ -454,8 +455,8 @@ async def chat_completion_tools_handler( log.debug(f"tool_contexts: {contexts}") - if skip_files and "files" in body: - del body["files"] + if skip_files and "files" in body.get("metadata", {}): + del body["metadata"]["files"] return body, {"contexts": contexts, "citations": citations} @@ -464,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: contexts = [] citations = [] - if files := body.get("files", None): + if files := body.get("metadata", {}).get("files", None): contexts, citations = get_rag_context( files=files, messages=body["messages"], @@ -986,11 +987,8 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u detail="Model not found", ) model = app.state.MODELS[model_id] - files = form_data.pop("files", []) - tool_ids = form_data.pop("tool_ids", []) - if model.get("pipe"): - return await generate_function_chat_completion(form_data, user, files, tool_ids) + return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user) else: diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index c805d82fa..95161bf56 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -844,8 +844,10 @@ }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, - tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, - files: files.length > 0 ? files : undefined, + metadata: { + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, + files: files.length > 0 ? files : undefined + }, session_id: $socket?.id, chat_id: $chatId, id: responseMessageId @@ -1136,8 +1138,10 @@ frequency_penalty: params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined, max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, - tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, - files: files.length > 0 ? files : undefined, + metadata: { + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, + files: files.length > 0 ? files : undefined + }, session_id: $socket?.id, chat_id: $chatId, id: responseMessageId From 454f59d59aa3089b83131e0bd90b7fa316b3e841 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 20 Aug 2024 15:48:14 +0100 Subject: [PATCH 17/17] undo frontend change --- backend/main.py | 2 ++ src/lib/components/chat/Chat.svelte | 12 ++++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/backend/main.py b/backend/main.py index 1fc292e5f..f00d6c382 100644 --- a/backend/main.py +++ b/backend/main.py @@ -526,6 +526,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): "message_id": body.pop("id", None), "session_id": body.pop("session_id", None), "valves": body.pop("valves", None), + "tool_ids": body.pop("tool_ids", None), + "files": body.pop("files", None), } __user__ = { diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 95161bf56..c805d82fa 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -844,10 +844,8 @@ }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, - metadata: { - tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, - files: files.length > 0 ? files : undefined - }, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, + files: files.length > 0 ? files : undefined, session_id: $socket?.id, chat_id: $chatId, id: responseMessageId @@ -1138,10 +1136,8 @@ frequency_penalty: params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined, max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, - metadata: { - tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, - files: files.length > 0 ? files : undefined - }, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, + files: files.length > 0 ? files : undefined, session_id: $socket?.id, chat_id: $chatId, id: responseMessageId