mirror of
https://github.com/open-webui/open-webui
synced 2025-05-19 20:57:54 +00:00
refac
This commit is contained in:
parent
89669a21fc
commit
3dde2f67cf
@ -203,10 +203,10 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result, _ = await process_filter_functions(
|
result, _ = await process_filter_functions(
|
||||||
handler_type="outlet",
|
|
||||||
filter_ids=get_sorted_filter_ids(model),
|
|
||||||
request=request,
|
request=request,
|
||||||
data=data,
|
filter_ids=get_sorted_filter_ids(model),
|
||||||
|
filter_type="outlet",
|
||||||
|
form_data=data,
|
||||||
extra_params=extra_params,
|
extra_params=extra_params,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
@ -2,6 +2,7 @@ import inspect
|
|||||||
from open_webui.utils.plugin import load_function_module_by_id
|
from open_webui.utils.plugin import load_function_module_by_id
|
||||||
from open_webui.models.functions import Functions
|
from open_webui.models.functions import Functions
|
||||||
|
|
||||||
|
|
||||||
def get_sorted_filter_ids(model):
|
def get_sorted_filter_ids(model):
|
||||||
def get_priority(function_id):
|
def get_priority(function_id):
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = Functions.get_function_by_id(function_id)
|
||||||
@ -24,12 +25,9 @@ def get_sorted_filter_ids(model):
|
|||||||
filter_ids.sort(key=get_priority)
|
filter_ids.sort(key=get_priority)
|
||||||
return filter_ids
|
return filter_ids
|
||||||
|
|
||||||
|
|
||||||
async def process_filter_functions(
|
async def process_filter_functions(
|
||||||
handler_type,
|
request, filter_ids, filter_type, form_data, extra_params
|
||||||
filter_ids,
|
|
||||||
request,
|
|
||||||
data,
|
|
||||||
extra_params
|
|
||||||
):
|
):
|
||||||
skip_files = None
|
skip_files = None
|
||||||
|
|
||||||
@ -45,7 +43,7 @@ async def process_filter_functions(
|
|||||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||||
|
|
||||||
# Check if the function has a file_handler variable
|
# Check if the function has a file_handler variable
|
||||||
if handler_type == "inlet" and hasattr(function_module, "file_handler"):
|
if filter_type == "inlet" and hasattr(function_module, "file_handler"):
|
||||||
skip_files = function_module.file_handler
|
skip_files = function_module.file_handler
|
||||||
|
|
||||||
# Apply valves to the function
|
# Apply valves to the function
|
||||||
@ -56,14 +54,14 @@ async def process_filter_functions(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prepare handler function
|
# Prepare handler function
|
||||||
handler = getattr(function_module, handler_type, None)
|
handler = getattr(function_module, filter_type, None)
|
||||||
if not handler:
|
if not handler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Prepare parameters
|
# Prepare parameters
|
||||||
sig = inspect.signature(handler)
|
sig = inspect.signature(handler)
|
||||||
params = {"body": data}
|
params = {"body": form_data}
|
||||||
|
|
||||||
# Add extra parameters that exist in the handler's signature
|
# Add extra parameters that exist in the handler's signature
|
||||||
for key in list(extra_params.keys()):
|
for key in list(extra_params.keys()):
|
||||||
@ -82,19 +80,18 @@ async def process_filter_functions(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
# Execute handler
|
# Execute handler
|
||||||
if inspect.iscoroutinefunction(handler):
|
if inspect.iscoroutinefunction(handler):
|
||||||
data = await handler(**params)
|
form_data = await handler(**params)
|
||||||
else:
|
else:
|
||||||
data = handler(**params)
|
form_data = handler(**params)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in {handler_type} handler {filter_id}: {e}")
|
print(f"Error in {filter_type} handler {filter_id}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Handle file cleanup for inlet
|
# Handle file cleanup for inlet
|
||||||
if skip_files and "files" in data.get("metadata", {}):
|
if skip_files and "files" in form_data.get("metadata", {}):
|
||||||
del data["metadata"]["files"]
|
del form_data["metadata"]["files"]
|
||||||
|
|
||||||
return data, {}
|
return form_data, {}
|
||||||
|
@ -694,10 +694,10 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
form_data, flags = await process_filter_functions(
|
form_data, flags = await process_filter_functions(
|
||||||
handler_type="inlet",
|
|
||||||
filter_ids=get_sorted_filter_ids(model),
|
|
||||||
request=request,
|
request=request,
|
||||||
data=form_data,
|
filter_ids=get_sorted_filter_ids(model),
|
||||||
|
filter_type="inlet",
|
||||||
|
form_data=form_data,
|
||||||
extra_params=extra_params,
|
extra_params=extra_params,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1039,11 +1039,15 @@ async def process_chat_response(
|
|||||||
|
|
||||||
def split_content_and_whitespace(content):
|
def split_content_and_whitespace(content):
|
||||||
content_stripped = content.rstrip()
|
content_stripped = content.rstrip()
|
||||||
original_whitespace = content[len(content_stripped):] if len(content) > len(content_stripped) else ''
|
original_whitespace = (
|
||||||
|
content[len(content_stripped) :]
|
||||||
|
if len(content) > len(content_stripped)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
return content_stripped, original_whitespace
|
return content_stripped, original_whitespace
|
||||||
|
|
||||||
def is_opening_code_block(content):
|
def is_opening_code_block(content):
|
||||||
backtick_segments = content.split('```')
|
backtick_segments = content.split("```")
|
||||||
# Even number of segments means the last backticks are opening a new block
|
# Even number of segments means the last backticks are opening a new block
|
||||||
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
||||||
|
|
||||||
@ -1113,10 +1117,15 @@ async def process_chat_response(
|
|||||||
output = block.get("output", None)
|
output = block.get("output", None)
|
||||||
lang = attributes.get("lang", "")
|
lang = attributes.get("lang", "")
|
||||||
|
|
||||||
content_stripped, original_whitespace = split_content_and_whitespace(content)
|
content_stripped, original_whitespace = (
|
||||||
|
split_content_and_whitespace(content)
|
||||||
|
)
|
||||||
if is_opening_code_block(content_stripped):
|
if is_opening_code_block(content_stripped):
|
||||||
# Remove trailing backticks that would open a new block
|
# Remove trailing backticks that would open a new block
|
||||||
content = content_stripped.rstrip('`').rstrip() + original_whitespace
|
content = (
|
||||||
|
content_stripped.rstrip("`").rstrip()
|
||||||
|
+ original_whitespace
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Keep content as is - either closing backticks or no backticks
|
# Keep content as is - either closing backticks or no backticks
|
||||||
content = content_stripped + original_whitespace
|
content = content_stripped + original_whitespace
|
||||||
|
Loading…
Reference in New Issue
Block a user