This commit is contained in:
Timothy Jaeryang Baek 2025-02-07 22:57:39 -08:00
parent 89669a21fc
commit 3dde2f67cf
3 changed files with 32 additions and 26 deletions

View File

@ -203,10 +203,10 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
try: try:
result, _ = await process_filter_functions( result, _ = await process_filter_functions(
handler_type="outlet",
filter_ids=get_sorted_filter_ids(model),
request=request, request=request,
data=data, filter_ids=get_sorted_filter_ids(model),
filter_type="outlet",
form_data=data,
extra_params=extra_params, extra_params=extra_params,
) )
return result return result

View File

@ -2,6 +2,7 @@ import inspect
from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.plugin import load_function_module_by_id
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
def get_sorted_filter_ids(model): def get_sorted_filter_ids(model):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
@ -24,12 +25,9 @@ def get_sorted_filter_ids(model):
filter_ids.sort(key=get_priority) filter_ids.sort(key=get_priority)
return filter_ids return filter_ids
async def process_filter_functions( async def process_filter_functions(
handler_type, request, filter_ids, filter_type, form_data, extra_params
filter_ids,
request,
data,
extra_params
): ):
skip_files = None skip_files = None
@ -45,7 +43,7 @@ async def process_filter_functions(
request.app.state.FUNCTIONS[filter_id] = function_module request.app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable # Check if the function has a file_handler variable
if handler_type == "inlet" and hasattr(function_module, "file_handler"): if filter_type == "inlet" and hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler skip_files = function_module.file_handler
# Apply valves to the function # Apply valves to the function
@ -56,14 +54,14 @@ async def process_filter_functions(
) )
# Prepare handler function # Prepare handler function
handler = getattr(function_module, handler_type, None) handler = getattr(function_module, filter_type, None)
if not handler: if not handler:
continue continue
try: try:
# Prepare parameters # Prepare parameters
sig = inspect.signature(handler) sig = inspect.signature(handler)
params = {"body": data} params = {"body": form_data}
# Add extra parameters that exist in the handler's signature # Add extra parameters that exist in the handler's signature
for key in list(extra_params.keys()): for key in list(extra_params.keys()):
@ -82,19 +80,18 @@ async def process_filter_functions(
except Exception as e: except Exception as e:
print(e) print(e)
# Execute handler # Execute handler
if inspect.iscoroutinefunction(handler): if inspect.iscoroutinefunction(handler):
data = await handler(**params) form_data = await handler(**params)
else: else:
data = handler(**params) form_data = handler(**params)
except Exception as e: except Exception as e:
print(f"Error in {handler_type} handler {filter_id}: {e}") print(f"Error in {filter_type} handler {filter_id}: {e}")
raise e raise e
# Handle file cleanup for inlet # Handle file cleanup for inlet
if skip_files and "files" in data.get("metadata", {}): if skip_files and "files" in form_data.get("metadata", {}):
del data["metadata"]["files"] del form_data["metadata"]["files"]
return data, {} return form_data, {}

View File

@ -694,10 +694,10 @@ async def process_chat_payload(request, form_data, metadata, user, model):
try: try:
form_data, flags = await process_filter_functions( form_data, flags = await process_filter_functions(
handler_type="inlet",
filter_ids=get_sorted_filter_ids(model),
request=request, request=request,
data=form_data, filter_ids=get_sorted_filter_ids(model),
filter_type="inlet",
form_data=form_data,
extra_params=extra_params, extra_params=extra_params,
) )
except Exception as e: except Exception as e:
@ -1039,11 +1039,15 @@ async def process_chat_response(
def split_content_and_whitespace(content): def split_content_and_whitespace(content):
content_stripped = content.rstrip() content_stripped = content.rstrip()
original_whitespace = content[len(content_stripped):] if len(content) > len(content_stripped) else '' original_whitespace = (
content[len(content_stripped) :]
if len(content) > len(content_stripped)
else ""
)
return content_stripped, original_whitespace return content_stripped, original_whitespace
def is_opening_code_block(content): def is_opening_code_block(content):
backtick_segments = content.split('```') backtick_segments = content.split("```")
# Even number of segments means the last backticks are opening a new block # Even number of segments means the last backticks are opening a new block
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
@ -1113,10 +1117,15 @@ async def process_chat_response(
output = block.get("output", None) output = block.get("output", None)
lang = attributes.get("lang", "") lang = attributes.get("lang", "")
content_stripped, original_whitespace = split_content_and_whitespace(content) content_stripped, original_whitespace = (
split_content_and_whitespace(content)
)
if is_opening_code_block(content_stripped): if is_opening_code_block(content_stripped):
# Remove trailing backticks that would open a new block # Remove trailing backticks that would open a new block
content = content_stripped.rstrip('`').rstrip() + original_whitespace content = (
content_stripped.rstrip("`").rstrip()
+ original_whitespace
)
else: else:
# Keep content as is - either closing backticks or no backticks # Keep content as is - either closing backticks or no backticks
content = content_stripped + original_whitespace content = content_stripped + original_whitespace