feat: toggle filter middleware

This commit is contained in:
Timothy Jaeryang Baek 2025-05-16 23:33:02 +04:00
parent e37433f2b1
commit 1f38350128
6 changed files with 38 additions and 17 deletions

View File

@ -1186,6 +1186,7 @@ async def chat_completion(
"chat_id": form_data.pop("chat_id", None), "chat_id": form_data.pop("chat_id", None),
"message_id": form_data.pop("id", None), "message_id": form_data.pop("id", None),
"session_id": form_data.pop("session_id", None), "session_id": form_data.pop("session_id", None),
"filter_ids": form_data.pop("filter_ids", None),
"tool_ids": form_data.get("tool_ids", None), "tool_ids": form_data.get("tool_ids", None),
"tool_servers": form_data.pop("tool_servers", None), "tool_servers": form_data.pop("tool_servers", None),
"files": form_data.get("files", None), "files": form_data.get("files", None),

View File

@ -20,10 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS from open_webui.constants import TASKS
from open_webui.routers.pipelines import process_pipeline_inlet_filter from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.utils.task import get_task_model_id from open_webui.utils.task import get_task_model_id
from open_webui.config import ( from open_webui.config import (

View File

@ -330,7 +330,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
try: try:
filter_functions = [ filter_functions = [
Functions.get_function_by_id(filter_id) Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(model) for filter_id in get_sorted_filter_ids(request, model)
] ]
result, _ = await process_filter_functions( result, _ = await process_filter_functions(

View File

@ -9,7 +9,20 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_sorted_filter_ids(model: dict): def get_function_module(request, function_id):
"""
Get the function module by its ID.
"""
if function_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[function_id]
else:
function_module, _, _ = load_function_module_by_id(function_id)
request.app.state.FUNCTIONS[function_id] = function_module
return function_module
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None: if function is not None:
@ -21,14 +34,23 @@ def get_sorted_filter_ids(model: dict):
if "info" in model and "meta" in model["info"]: if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", [])) filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids)) filter_ids = list(set(filter_ids))
active_filter_ids = [
enabled_filter_ids = [
function.id function.id
for function in Functions.get_functions_by_type("filter", active_only=True) for function in Functions.get_functions_by_type("filter", active_only=True)
] ]
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids] for filter_id in active_filter_ids:
function_module = get_function_module(request, filter_id)
if getattr(function_module, "toggle", None) and (
filter_id not in enabled_filter_ids
):
active_filter_ids.remove(filter_id)
continue
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
filter_ids.sort(key=get_priority) filter_ids.sort(key=get_priority)
return filter_ids return filter_ids
@ -43,12 +65,7 @@ async def process_filter_functions(
if not filter: if not filter:
continue continue
if filter_id in request.app.state.FUNCTIONS: function_module = get_function_module(request, filter_id)
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
# Prepare handler function # Prepare handler function
handler = getattr(function_module, filter_type, None) handler = getattr(function_module, filter_type, None)
if not handler: if not handler:

View File

@ -754,9 +754,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
raise e raise e
try: try:
filter_functions = [ filter_functions = [
Functions.get_function_by_id(filter_id) Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(model) for filter_id in get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
] ]
form_data, flags = await process_filter_functions( form_data, flags = await process_filter_functions(
@ -1188,7 +1191,9 @@ async def process_chat_response(
} }
filter_functions = [ filter_functions = [
Functions.get_function_by_id(filter_id) Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(model) for filter_id in get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
] ]
# Streaming response # Streaming response

View File

@ -1635,6 +1635,7 @@
}, },
files: (files?.length ?? 0) > 0 ? files : undefined, files: (files?.length ?? 0) > 0 ? files : undefined,
filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined, filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
tool_servers: $toolServers, tool_servers: $toolServers,