From 833cf5130ce346585563c73a8f60e4c5f9116218 Mon Sep 17 00:00:00 2001 From: zzzevaka Date: Sat, 31 May 2025 22:09:12 +0200 Subject: [PATCH] Feat: models are aware of attached files when they choose a tool to call --- backend/open_webui/config.py | 7 ++- backend/open_webui/utils/middleware.py | 67 ++++++++++++++------------ backend/open_webui/utils/task.py | 3 +- 3 files changed, 44 insertions(+), 33 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 0f4948361..0ce652b3d 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1558,7 +1558,12 @@ The format for the JSON response is strictly: {"name": "toolName1", "parameters": {"key1": "value1"}}, {"name": "toolName2", "parameters": {"key2": "value2"}} ] -}""" +} + + +{{CONTEXT}} + +""" DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 7b5659d51..cd71dbd90 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -99,7 +99,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) async def chat_completion_tools_handler( - request: Request, body: dict, extra_params: dict, user: UserModel, models, tools + request: Request, body: dict, extra_params: dict, user: UserModel, models: dict, tools: dict, context: str, ) -> tuple[dict, dict]: async def get_content_from_response(response) -> Optional[str]: content = None @@ -156,7 +156,7 @@ async def chat_completion_tools_handler( template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE tools_function_calling_prompt = tools_function_calling_generation_template( - template, tools_specs + template, tools_specs, context, ) payload = get_tools_function_calling_payload( body["messages"], task_model_id, tools_function_calling_prompt @@ -716,6 +716,31 @@ def apply_params_to_form_data(form_data, model): return form_data +def create_context_string_from_sources(sources: list[dict[str, Any]]) -> str: + context_string = "" + citation_idx = {} + for source in sources: + if "document" in source: + for doc_context, doc_meta in zip( + source["document"], source["metadata"] + ): + source_name = source.get("source", {}).get("name", None) + citation_id = ( + doc_meta.get("source", None) + or source.get("source", {}).get("id", None) + or "N/A" + ) + if citation_id not in citation_idx: + citation_idx[citation_id] = len(citation_idx) + 1 + context_string += ( + f'{doc_context}\n" + ) + + return context_string.strip() + + async def process_chat_payload(request, form_data, user, metadata, model): form_data = apply_params_to_form_data(form_data, model) log.debug(f"form_data: {form_data}") @@ -901,6 +926,12 @@ async def process_chat_payload(request, form_data, user, metadata, model): "server": tool_server, } + try: + form_data, flags = await chat_completion_files_handler(request, form_data, user) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) + if tools_dict: if metadata.get("function_calling") == "native": # If the function calling is native, then call the tools function calling handler @@ -912,44 +943,18 @@ async def process_chat_payload(request, form_data, user, metadata, model): else: # If the function calling is not native, then call the tools function calling handler try: + context_string = create_context_string_from_sources(sources) form_data, flags = await chat_completion_tools_handler( - request, form_data, extra_params, user, models, tools_dict + request, form_data, extra_params, user, models, tools_dict, context_string, ) sources.extend(flags.get("sources", [])) except Exception as e: log.exception(e) - try: - form_data, flags = await chat_completion_files_handler(request, form_data, user) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) - # If context is not empty, insert it into the messages if len(sources) > 0: - context_string = "" - citation_idx = {} - for source in sources: - if "document" in source: - for doc_context, doc_meta in zip( - source["document"], source["metadata"] - ): - source_name = source.get("source", {}).get("name", None) - citation_id = ( - doc_meta.get("source", None) - or source.get("source", {}).get("id", None) - or "N/A" - ) - if citation_id not in citation_idx: - citation_idx[citation_id] = len(citation_idx) + 1 - context_string += ( - f'{doc_context}\n" - ) - - context_string = context_string.strip() + context_string = create_context_string_from_sources(sources) prompt = get_last_user_message(form_data["messages"]) if prompt is None: diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 95018eef1..6f3a900b2 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -336,6 +336,7 @@ def moa_response_generation_template( return template -def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: +def tools_function_calling_generation_template(template: str, tools_specs: str, context: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) + template = template.replace("{{CONTEXT}}", context) return template