mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	feat: toggle filter middleware
This commit is contained in:
		
							parent
							
								
									e37433f2b1
								
							
						
					
					
						commit
						1f38350128
					
				@ -1186,6 +1186,7 @@ async def chat_completion(
 | 
			
		||||
            "chat_id": form_data.pop("chat_id", None),
 | 
			
		||||
            "message_id": form_data.pop("id", None),
 | 
			
		||||
            "session_id": form_data.pop("session_id", None),
 | 
			
		||||
            "filter_ids": form_data.pop("filter_ids", None),
 | 
			
		||||
            "tool_ids": form_data.get("tool_ids", None),
 | 
			
		||||
            "tool_servers": form_data.pop("tool_servers", None),
 | 
			
		||||
            "files": form_data.get("files", None),
 | 
			
		||||
 | 
			
		||||
@ -20,10 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
 | 
			
		||||
from open_webui.constants import TASKS
 | 
			
		||||
 | 
			
		||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
 | 
			
		||||
from open_webui.utils.filter import (
 | 
			
		||||
    get_sorted_filter_ids,
 | 
			
		||||
    process_filter_functions,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from open_webui.utils.task import get_task_model_id
 | 
			
		||||
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
 | 
			
		||||
@ -330,7 +330,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
 | 
			
		||||
    try:
 | 
			
		||||
        filter_functions = [
 | 
			
		||||
            Functions.get_function_by_id(filter_id)
 | 
			
		||||
            for filter_id in get_sorted_filter_ids(model)
 | 
			
		||||
            for filter_id in get_sorted_filter_ids(request, model)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        result, _ = await process_filter_functions(
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,20 @@ log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_sorted_filter_ids(model: dict):
 | 
			
		||||
def get_function_module(request, function_id):
 | 
			
		||||
    """
 | 
			
		||||
    Get the function module by its ID.
 | 
			
		||||
    """
 | 
			
		||||
    if function_id in request.app.state.FUNCTIONS:
 | 
			
		||||
        function_module = request.app.state.FUNCTIONS[function_id]
 | 
			
		||||
    else:
 | 
			
		||||
        function_module, _, _ = load_function_module_by_id(function_id)
 | 
			
		||||
        request.app.state.FUNCTIONS[function_id] = function_module
 | 
			
		||||
 | 
			
		||||
    return function_module
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
 | 
			
		||||
    def get_priority(function_id):
 | 
			
		||||
        function = Functions.get_function_by_id(function_id)
 | 
			
		||||
        if function is not None:
 | 
			
		||||
@ -21,14 +34,23 @@ def get_sorted_filter_ids(model: dict):
 | 
			
		||||
    if "info" in model and "meta" in model["info"]:
 | 
			
		||||
        filter_ids.extend(model["info"]["meta"].get("filterIds", []))
 | 
			
		||||
        filter_ids = list(set(filter_ids))
 | 
			
		||||
 | 
			
		||||
    enabled_filter_ids = [
 | 
			
		||||
    active_filter_ids = [
 | 
			
		||||
        function.id
 | 
			
		||||
        for function in Functions.get_functions_by_type("filter", active_only=True)
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
 | 
			
		||||
    for filter_id in active_filter_ids:
 | 
			
		||||
        function_module = get_function_module(request, filter_id)
 | 
			
		||||
 | 
			
		||||
        if getattr(function_module, "toggle", None) and (
 | 
			
		||||
            filter_id not in enabled_filter_ids
 | 
			
		||||
        ):
 | 
			
		||||
            active_filter_ids.remove(filter_id)
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
    filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
 | 
			
		||||
    filter_ids.sort(key=get_priority)
 | 
			
		||||
 | 
			
		||||
    return filter_ids
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -43,12 +65,7 @@ async def process_filter_functions(
 | 
			
		||||
        if not filter:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if filter_id in request.app.state.FUNCTIONS:
 | 
			
		||||
            function_module = request.app.state.FUNCTIONS[filter_id]
 | 
			
		||||
        else:
 | 
			
		||||
            function_module, _, _ = load_function_module_by_id(filter_id)
 | 
			
		||||
            request.app.state.FUNCTIONS[filter_id] = function_module
 | 
			
		||||
 | 
			
		||||
        function_module = get_function_module(request, filter_id)
 | 
			
		||||
        # Prepare handler function
 | 
			
		||||
        handler = getattr(function_module, filter_type, None)
 | 
			
		||||
        if not handler:
 | 
			
		||||
 | 
			
		||||
@ -754,9 +754,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
 | 
			
		||||
        raise e
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
 | 
			
		||||
        filter_functions = [
 | 
			
		||||
            Functions.get_function_by_id(filter_id)
 | 
			
		||||
            for filter_id in get_sorted_filter_ids(model)
 | 
			
		||||
            for filter_id in get_sorted_filter_ids(
 | 
			
		||||
                request, model, metadata.get("filter_ids", [])
 | 
			
		||||
            )
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        form_data, flags = await process_filter_functions(
 | 
			
		||||
@ -1188,7 +1191,9 @@ async def process_chat_response(
 | 
			
		||||
    }
 | 
			
		||||
    filter_functions = [
 | 
			
		||||
        Functions.get_function_by_id(filter_id)
 | 
			
		||||
        for filter_id in get_sorted_filter_ids(model)
 | 
			
		||||
        for filter_id in get_sorted_filter_ids(
 | 
			
		||||
            request, model, metadata.get("filter_ids", [])
 | 
			
		||||
        )
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    # Streaming response
 | 
			
		||||
 | 
			
		||||
@ -1635,6 +1635,7 @@
 | 
			
		||||
				},
 | 
			
		||||
 | 
			
		||||
				files: (files?.length ?? 0) > 0 ? files : undefined,
 | 
			
		||||
 | 
			
		||||
				filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
 | 
			
		||||
				tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 | 
			
		||||
				tool_servers: $toolServers,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user