mirror of
https://github.com/open-webui/open-webui
synced 2025-06-09 16:07:54 +00:00
Merge pull request #9631 from XingjianXie/remove_inlet_outlet_duplication
Refactor common code between inlet and outlet
This commit is contained in:
commit
79c0b45543
@ -44,6 +44,10 @@ from open_webui.utils.response import (
|
|||||||
convert_response_ollama_to_openai,
|
convert_response_ollama_to_openai,
|
||||||
convert_streaming_response_ollama_to_openai,
|
convert_streaming_response_ollama_to_openai,
|
||||||
)
|
)
|
||||||
|
from open_webui.utils.filter import (
|
||||||
|
get_sorted_filter_ids,
|
||||||
|
process_filter_functions,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||||
|
|
||||||
@ -177,116 +181,37 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Exception(f"Error: {e}")
|
return Exception(f"Error: {e}")
|
||||||
|
|
||||||
__event_emitter__ = get_event_emitter(
|
metadata = {
|
||||||
{
|
"chat_id": data["chat_id"],
|
||||||
"chat_id": data["chat_id"],
|
"message_id": data["id"],
|
||||||
"message_id": data["id"],
|
"session_id": data["session_id"],
|
||||||
"session_id": data["session_id"],
|
"user_id": user.id,
|
||||||
"user_id": user.id,
|
}
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
__event_call__ = get_event_call(
|
extra_params = {
|
||||||
{
|
"__event_emitter__": get_event_emitter(metadata),
|
||||||
"chat_id": data["chat_id"],
|
"__event_call__": get_event_call(metadata),
|
||||||
"message_id": data["id"],
|
"__user__": {
|
||||||
"session_id": data["session_id"],
|
"id": user.id,
|
||||||
"user_id": user.id,
|
"email": user.email,
|
||||||
}
|
"name": user.name,
|
||||||
)
|
"role": user.role,
|
||||||
|
},
|
||||||
|
"__metadata__": metadata,
|
||||||
|
"__request__": request,
|
||||||
|
}
|
||||||
|
|
||||||
def get_priority(function_id):
|
try:
|
||||||
function = Functions.get_function_by_id(function_id)
|
result, _ = await process_filter_functions(
|
||||||
if function is not None and hasattr(function, "valves"):
|
request=request,
|
||||||
# TODO: Fix FunctionModel to include vavles
|
filter_ids=get_sorted_filter_ids(model),
|
||||||
return (function.valves if function.valves else {}).get("priority", 0)
|
filter_type="outlet",
|
||||||
return 0
|
form_data=data,
|
||||||
|
extra_params=extra_params,
|
||||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
)
|
||||||
if "info" in model and "meta" in model["info"]:
|
return result
|
||||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
except Exception as e:
|
||||||
filter_ids = list(set(filter_ids))
|
return Exception(f"Error: {e}")
|
||||||
|
|
||||||
enabled_filter_ids = [
|
|
||||||
function.id
|
|
||||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
||||||
]
|
|
||||||
filter_ids = [
|
|
||||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
# Sort filter_ids by priority, using the get_priority function
|
|
||||||
filter_ids.sort(key=get_priority)
|
|
||||||
|
|
||||||
for filter_id in filter_ids:
|
|
||||||
filter = Functions.get_function_by_id(filter_id)
|
|
||||||
if not filter:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if filter_id in request.app.state.FUNCTIONS:
|
|
||||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
|
||||||
else:
|
|
||||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
|
||||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
|
||||||
|
|
||||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
|
||||||
valves = Functions.get_function_valves_by_id(filter_id)
|
|
||||||
function_module.valves = function_module.Valves(
|
|
||||||
**(valves if valves else {})
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hasattr(function_module, "outlet"):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
outlet = function_module.outlet
|
|
||||||
|
|
||||||
# Get the signature of the function
|
|
||||||
sig = inspect.signature(outlet)
|
|
||||||
params = {"body": data}
|
|
||||||
|
|
||||||
# Extra parameters to be passed to the function
|
|
||||||
extra_params = {
|
|
||||||
"__model__": model,
|
|
||||||
"__id__": filter_id,
|
|
||||||
"__event_emitter__": __event_emitter__,
|
|
||||||
"__event_call__": __event_call__,
|
|
||||||
"__request__": request,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add extra params in contained in function signature
|
|
||||||
for key, value in extra_params.items():
|
|
||||||
if key in sig.parameters:
|
|
||||||
params[key] = value
|
|
||||||
|
|
||||||
if "__user__" in sig.parameters:
|
|
||||||
__user__ = {
|
|
||||||
"id": user.id,
|
|
||||||
"email": user.email,
|
|
||||||
"name": user.name,
|
|
||||||
"role": user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if hasattr(function_module, "UserValves"):
|
|
||||||
__user__["valves"] = function_module.UserValves(
|
|
||||||
**Functions.get_user_valves_by_id_and_user_id(
|
|
||||||
filter_id, user.id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
params = {**params, "__user__": __user__}
|
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(outlet):
|
|
||||||
data = await outlet(**params)
|
|
||||||
else:
|
|
||||||
data = outlet(**params)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return Exception(f"Error: {e}")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
||||||
|
97
backend/open_webui/utils/filter.py
Normal file
97
backend/open_webui/utils/filter.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import inspect
|
||||||
|
from open_webui.utils.plugin import load_function_module_by_id
|
||||||
|
from open_webui.models.functions import Functions
|
||||||
|
|
||||||
|
|
||||||
|
def get_sorted_filter_ids(model):
|
||||||
|
def get_priority(function_id):
|
||||||
|
function = Functions.get_function_by_id(function_id)
|
||||||
|
if function is not None and hasattr(function, "valves"):
|
||||||
|
# TODO: Fix FunctionModel to include vavles
|
||||||
|
return (function.valves if function.valves else {}).get("priority", 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||||
|
if "info" in model and "meta" in model["info"]:
|
||||||
|
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||||
|
filter_ids = list(set(filter_ids))
|
||||||
|
|
||||||
|
enabled_filter_ids = [
|
||||||
|
function.id
|
||||||
|
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
|
||||||
|
filter_ids.sort(key=get_priority)
|
||||||
|
return filter_ids
|
||||||
|
|
||||||
|
|
||||||
|
async def process_filter_functions(
|
||||||
|
request, filter_ids, filter_type, form_data, extra_params
|
||||||
|
):
|
||||||
|
skip_files = None
|
||||||
|
|
||||||
|
for filter_id in filter_ids:
|
||||||
|
filter = Functions.get_function_by_id(filter_id)
|
||||||
|
if not filter:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if filter_id in request.app.state.FUNCTIONS:
|
||||||
|
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||||
|
else:
|
||||||
|
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||||
|
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||||
|
|
||||||
|
# Check if the function has a file_handler variable
|
||||||
|
if filter_type == "inlet" and hasattr(function_module, "file_handler"):
|
||||||
|
skip_files = function_module.file_handler
|
||||||
|
|
||||||
|
# Apply valves to the function
|
||||||
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
|
valves = Functions.get_function_valves_by_id(filter_id)
|
||||||
|
function_module.valves = function_module.Valves(
|
||||||
|
**(valves if valves else {})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare handler function
|
||||||
|
handler = getattr(function_module, filter_type, None)
|
||||||
|
if not handler:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare parameters
|
||||||
|
sig = inspect.signature(handler)
|
||||||
|
params = {"body": form_data}
|
||||||
|
|
||||||
|
# Add extra parameters that exist in the handler's signature
|
||||||
|
for key in list(extra_params.keys()):
|
||||||
|
if key in sig.parameters:
|
||||||
|
params[key] = extra_params[key]
|
||||||
|
|
||||||
|
# Handle user parameters
|
||||||
|
if "__user__" in sig.parameters:
|
||||||
|
if hasattr(function_module, "UserValves"):
|
||||||
|
try:
|
||||||
|
params["__user__"]["valves"] = function_module.UserValves(
|
||||||
|
**Functions.get_user_valves_by_id_and_user_id(
|
||||||
|
filter_id, params["__user__"]["id"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
# Execute handler
|
||||||
|
if inspect.iscoroutinefunction(handler):
|
||||||
|
form_data = await handler(**params)
|
||||||
|
else:
|
||||||
|
form_data = handler(**params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in {filter_type} handler {filter_id}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Handle file cleanup for inlet
|
||||||
|
if skip_files and "files" in form_data.get("metadata", {}):
|
||||||
|
del form_data["metadata"]["files"]
|
||||||
|
|
||||||
|
return form_data, {}
|
@ -68,6 +68,10 @@ from open_webui.utils.misc import (
|
|||||||
)
|
)
|
||||||
from open_webui.utils.tools import get_tools
|
from open_webui.utils.tools import get_tools
|
||||||
from open_webui.utils.plugin import load_function_module_by_id
|
from open_webui.utils.plugin import load_function_module_by_id
|
||||||
|
from open_webui.utils.filter import (
|
||||||
|
get_sorted_filter_ids,
|
||||||
|
process_filter_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from open_webui.tasks import create_task
|
from open_webui.tasks import create_task
|
||||||
@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
|
|||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_filter_functions_handler(request, body, model, extra_params):
|
|
||||||
skip_files = None
|
|
||||||
|
|
||||||
def get_filter_function_ids(model):
|
|
||||||
def get_priority(function_id):
|
|
||||||
function = Functions.get_function_by_id(function_id)
|
|
||||||
if function is not None and hasattr(function, "valves"):
|
|
||||||
# TODO: Fix FunctionModel
|
|
||||||
return (function.valves if function.valves else {}).get("priority", 0)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
filter_ids = [
|
|
||||||
function.id for function in Functions.get_global_filter_functions()
|
|
||||||
]
|
|
||||||
if "info" in model and "meta" in model["info"]:
|
|
||||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
||||||
filter_ids = list(set(filter_ids))
|
|
||||||
|
|
||||||
enabled_filter_ids = [
|
|
||||||
function.id
|
|
||||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
||||||
]
|
|
||||||
|
|
||||||
filter_ids = [
|
|
||||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
filter_ids.sort(key=get_priority)
|
|
||||||
return filter_ids
|
|
||||||
|
|
||||||
filter_ids = get_filter_function_ids(model)
|
|
||||||
for filter_id in filter_ids:
|
|
||||||
filter = Functions.get_function_by_id(filter_id)
|
|
||||||
if not filter:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if filter_id in request.app.state.FUNCTIONS:
|
|
||||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
|
||||||
else:
|
|
||||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
|
||||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
|
||||||
|
|
||||||
# Check if the function has a file_handler variable
|
|
||||||
if hasattr(function_module, "file_handler"):
|
|
||||||
skip_files = function_module.file_handler
|
|
||||||
|
|
||||||
# Apply valves to the function
|
|
||||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
|
||||||
valves = Functions.get_function_valves_by_id(filter_id)
|
|
||||||
function_module.valves = function_module.Valves(
|
|
||||||
**(valves if valves else {})
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(function_module, "inlet"):
|
|
||||||
try:
|
|
||||||
inlet = function_module.inlet
|
|
||||||
|
|
||||||
# Create a dictionary of parameters to be passed to the function
|
|
||||||
params = {"body": body} | {
|
|
||||||
k: v
|
|
||||||
for k, v in {
|
|
||||||
**extra_params,
|
|
||||||
"__model__": model,
|
|
||||||
"__id__": filter_id,
|
|
||||||
}.items()
|
|
||||||
if k in inspect.signature(inlet).parameters
|
|
||||||
}
|
|
||||||
|
|
||||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
|
||||||
try:
|
|
||||||
params["__user__"]["valves"] = function_module.UserValves(
|
|
||||||
**Functions.get_user_valves_by_id_and_user_id(
|
|
||||||
filter_id, params["__user__"]["id"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(inlet):
|
|
||||||
body = await inlet(**params)
|
|
||||||
else:
|
|
||||||
body = inlet(**params)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
if skip_files and "files" in body.get("metadata", {}):
|
|
||||||
del body["metadata"]["files"]
|
|
||||||
|
|
||||||
return body, {}
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(
|
async def chat_completion_tools_handler(
|
||||||
request: Request, body: dict, user: UserModel, models, tools
|
request: Request, body: dict, user: UserModel, models, tools
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
@ -782,8 +693,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
form_data, flags = await chat_completion_filter_functions_handler(
|
form_data, flags = await process_filter_functions(
|
||||||
request, form_data, model, extra_params
|
request=request,
|
||||||
|
filter_ids=get_sorted_filter_ids(model),
|
||||||
|
filter_type="inlet",
|
||||||
|
form_data=form_data,
|
||||||
|
extra_params=extra_params,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Error: {e}")
|
raise Exception(f"Error: {e}")
|
||||||
@ -1124,11 +1039,15 @@ async def process_chat_response(
|
|||||||
|
|
||||||
def split_content_and_whitespace(content):
|
def split_content_and_whitespace(content):
|
||||||
content_stripped = content.rstrip()
|
content_stripped = content.rstrip()
|
||||||
original_whitespace = content[len(content_stripped):] if len(content) > len(content_stripped) else ''
|
original_whitespace = (
|
||||||
|
content[len(content_stripped) :]
|
||||||
|
if len(content) > len(content_stripped)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
return content_stripped, original_whitespace
|
return content_stripped, original_whitespace
|
||||||
|
|
||||||
def is_opening_code_block(content):
|
def is_opening_code_block(content):
|
||||||
backtick_segments = content.split('```')
|
backtick_segments = content.split("```")
|
||||||
# Even number of segments means the last backticks are opening a new block
|
# Even number of segments means the last backticks are opening a new block
|
||||||
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
||||||
|
|
||||||
@ -1198,10 +1117,15 @@ async def process_chat_response(
|
|||||||
output = block.get("output", None)
|
output = block.get("output", None)
|
||||||
lang = attributes.get("lang", "")
|
lang = attributes.get("lang", "")
|
||||||
|
|
||||||
content_stripped, original_whitespace = split_content_and_whitespace(content)
|
content_stripped, original_whitespace = (
|
||||||
|
split_content_and_whitespace(content)
|
||||||
|
)
|
||||||
if is_opening_code_block(content_stripped):
|
if is_opening_code_block(content_stripped):
|
||||||
# Remove trailing backticks that would open a new block
|
# Remove trailing backticks that would open a new block
|
||||||
content = content_stripped.rstrip('`').rstrip() + original_whitespace
|
content = (
|
||||||
|
content_stripped.rstrip("`").rstrip()
|
||||||
|
+ original_whitespace
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Keep content as is - either closing backticks or no backticks
|
# Keep content as is - either closing backticks or no backticks
|
||||||
content = content_stripped + original_whitespace
|
content = content_stripped + original_whitespace
|
||||||
|
Loading…
Reference in New Issue
Block a user