mirror of
https://github.com/open-webui/open-webui
synced 2025-06-10 00:17:52 +00:00
feat: toggle filter middleware
This commit is contained in:
parent
e37433f2b1
commit
1f38350128
@ -1186,6 +1186,7 @@ async def chat_completion(
|
|||||||
"chat_id": form_data.pop("chat_id", None),
|
"chat_id": form_data.pop("chat_id", None),
|
||||||
"message_id": form_data.pop("id", None),
|
"message_id": form_data.pop("id", None),
|
||||||
"session_id": form_data.pop("session_id", None),
|
"session_id": form_data.pop("session_id", None),
|
||||||
|
"filter_ids": form_data.pop("filter_ids", None),
|
||||||
"tool_ids": form_data.get("tool_ids", None),
|
"tool_ids": form_data.get("tool_ids", None),
|
||||||
"tool_servers": form_data.pop("tool_servers", None),
|
"tool_servers": form_data.pop("tool_servers", None),
|
||||||
"files": form_data.get("files", None),
|
"files": form_data.get("files", None),
|
||||||
|
@ -20,10 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
|||||||
from open_webui.constants import TASKS
|
from open_webui.constants import TASKS
|
||||||
|
|
||||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||||
from open_webui.utils.filter import (
|
|
||||||
get_sorted_filter_ids,
|
|
||||||
process_filter_functions,
|
|
||||||
)
|
|
||||||
from open_webui.utils.task import get_task_model_id
|
from open_webui.utils.task import get_task_model_id
|
||||||
|
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
|
@ -330,7 +330,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|||||||
try:
|
try:
|
||||||
filter_functions = [
|
filter_functions = [
|
||||||
Functions.get_function_by_id(filter_id)
|
Functions.get_function_by_id(filter_id)
|
||||||
for filter_id in get_sorted_filter_ids(model)
|
for filter_id in get_sorted_filter_ids(request, model)
|
||||||
]
|
]
|
||||||
|
|
||||||
result, _ = await process_filter_functions(
|
result, _ = await process_filter_functions(
|
||||||
|
@ -9,7 +9,20 @@ log = logging.getLogger(__name__)
|
|||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
def get_sorted_filter_ids(model: dict):
|
def get_function_module(request, function_id):
|
||||||
|
"""
|
||||||
|
Get the function module by its ID.
|
||||||
|
"""
|
||||||
|
if function_id in request.app.state.FUNCTIONS:
|
||||||
|
function_module = request.app.state.FUNCTIONS[function_id]
|
||||||
|
else:
|
||||||
|
function_module, _, _ = load_function_module_by_id(function_id)
|
||||||
|
request.app.state.FUNCTIONS[function_id] = function_module
|
||||||
|
|
||||||
|
return function_module
|
||||||
|
|
||||||
|
|
||||||
|
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
|
||||||
def get_priority(function_id):
|
def get_priority(function_id):
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = Functions.get_function_by_id(function_id)
|
||||||
if function is not None:
|
if function is not None:
|
||||||
@ -21,14 +34,23 @@ def get_sorted_filter_ids(model: dict):
|
|||||||
if "info" in model and "meta" in model["info"]:
|
if "info" in model and "meta" in model["info"]:
|
||||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||||
filter_ids = list(set(filter_ids))
|
filter_ids = list(set(filter_ids))
|
||||||
|
active_filter_ids = [
|
||||||
enabled_filter_ids = [
|
|
||||||
function.id
|
function.id
|
||||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
|
for filter_id in active_filter_ids:
|
||||||
|
function_module = get_function_module(request, filter_id)
|
||||||
|
|
||||||
|
if getattr(function_module, "toggle", None) and (
|
||||||
|
filter_id not in enabled_filter_ids
|
||||||
|
):
|
||||||
|
active_filter_ids.remove(filter_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
|
||||||
filter_ids.sort(key=get_priority)
|
filter_ids.sort(key=get_priority)
|
||||||
|
|
||||||
return filter_ids
|
return filter_ids
|
||||||
|
|
||||||
|
|
||||||
@ -43,12 +65,7 @@ async def process_filter_functions(
|
|||||||
if not filter:
|
if not filter:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if filter_id in request.app.state.FUNCTIONS:
|
function_module = get_function_module(request, filter_id)
|
||||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
|
||||||
else:
|
|
||||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
|
||||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
|
||||||
|
|
||||||
# Prepare handler function
|
# Prepare handler function
|
||||||
handler = getattr(function_module, filter_type, None)
|
handler = getattr(function_module, filter_type, None)
|
||||||
if not handler:
|
if not handler:
|
||||||
|
@ -754,9 +754,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
filter_functions = [
|
filter_functions = [
|
||||||
Functions.get_function_by_id(filter_id)
|
Functions.get_function_by_id(filter_id)
|
||||||
for filter_id in get_sorted_filter_ids(model)
|
for filter_id in get_sorted_filter_ids(
|
||||||
|
request, model, metadata.get("filter_ids", [])
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
form_data, flags = await process_filter_functions(
|
form_data, flags = await process_filter_functions(
|
||||||
@ -1188,7 +1191,9 @@ async def process_chat_response(
|
|||||||
}
|
}
|
||||||
filter_functions = [
|
filter_functions = [
|
||||||
Functions.get_function_by_id(filter_id)
|
Functions.get_function_by_id(filter_id)
|
||||||
for filter_id in get_sorted_filter_ids(model)
|
for filter_id in get_sorted_filter_ids(
|
||||||
|
request, model, metadata.get("filter_ids", [])
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
|
@ -1635,6 +1635,7 @@
|
|||||||
},
|
},
|
||||||
|
|
||||||
files: (files?.length ?? 0) > 0 ? files : undefined,
|
files: (files?.length ?? 0) > 0 ? files : undefined,
|
||||||
|
|
||||||
filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
|
filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
|
||||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||||
tool_servers: $toolServers,
|
tool_servers: $toolServers,
|
||||||
|
Loading…
Reference in New Issue
Block a user