diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 898ac1b59..4914fe6dd 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1603,7 +1603,12 @@ The format for the JSON response is strictly: {"name": "toolName1", "parameters": {"key1": "value1"}}, {"name": "toolName2", "parameters": {"key2": "value2"}} ] -}""" +} + + +{{CONTEXT}} + +""" DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index b1e69db26..813e72c27 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -105,7 +105,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) async def chat_completion_tools_handler( - request: Request, body: dict, extra_params: dict, user: UserModel, models, tools + request: Request, body: dict, extra_params: dict, user: UserModel, models: dict, tools: dict, context: str, ) -> tuple[dict, dict]: async def get_content_from_response(response) -> Optional[str]: content = None @@ -162,7 +162,7 @@ async def chat_completion_tools_handler( template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE tools_function_calling_prompt = tools_function_calling_generation_template( - template, tools_specs + template, tools_specs, context, ) payload = get_tools_function_calling_payload( body["messages"], task_model_id, tools_function_calling_prompt @@ -717,6 +717,31 @@ def apply_params_to_form_data(form_data, model): return form_data +def create_context_string_from_sources(sources: list[dict[str, Any]]) -> str: + context_string = "" + citation_idx = {} + for source in sources: + if "document" in source: + for doc_context, doc_meta in zip( + source["document"], source["metadata"] + ): + source_name = source.get("source", {}).get("name", None) + citation_id = ( + doc_meta.get("source", None) + or source.get("source", {}).get("id", None) + or "N/A" + ) + if citation_id not in citation_idx: + citation_idx[citation_id] = len(citation_idx) + 1 + context_string += ( + f'{doc_context}\n" + ) + + return context_string.strip() + + async def process_chat_payload(request, form_data, user, metadata, model): form_data = apply_params_to_form_data(form_data, model) log.debug(f"form_data: {form_data}") @@ -897,6 +922,12 @@ async def process_chat_payload(request, form_data, user, metadata, model): "server": tool_server, } + try: + form_data, flags = await chat_completion_files_handler(request, form_data, user) + sources.extend(flags.get("sources", [])) + except Exception as e: + log.exception(e) + if tools_dict: if metadata.get("function_calling") == "native": # If the function calling is native, then call the tools function calling handler @@ -908,44 +939,18 @@ async def process_chat_payload(request, form_data, user, metadata, model): else: # If the function calling is not native, then call the tools function calling handler try: + context_string = create_context_string_from_sources(sources) form_data, flags = await chat_completion_tools_handler( - request, form_data, extra_params, user, models, tools_dict + request, form_data, extra_params, user, models, tools_dict, context_string, ) sources.extend(flags.get("sources", [])) except Exception as e: log.exception(e) - try: - form_data, flags = await chat_completion_files_handler(request, form_data, user) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) - # If context is not empty, insert it into the messages if len(sources) > 0: - context_string = "" - citation_idx = {} - for source in sources: - if "document" in source: - for doc_context, doc_meta in zip( - source["document"], source["metadata"] - ): - source_name = source.get("source", {}).get("name", None) - citation_id = ( - doc_meta.get("source", None) - or source.get("source", {}).get("id", None) - or "N/A" - ) - if citation_id not in citation_idx: - citation_idx[citation_id] = len(citation_idx) + 1 - context_string += ( - f'{doc_context}\n" - ) - - context_string = context_string.strip() + context_string = create_context_string_from_sources(sources) prompt = get_last_user_message(form_data["messages"]) if prompt is None: diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 42b44d516..3d84becb8 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -354,6 +354,7 @@ def moa_response_generation_template( return template -def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: +def tools_function_calling_generation_template(template: str, tools_specs: str, context: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) + template = template.replace("{{CONTEXT}}", context) return template