diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 898ac1b59..4914fe6dd 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -1603,7 +1603,12 @@ The format for the JSON response is strictly:
{"name": "toolName1", "parameters": {"key1": "value1"}},
{"name": "toolName2", "parameters": {"key2": "value2"}}
]
-}"""
+}
+
+
+{{CONTEXT}}
+
+"""
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index b1e69db26..813e72c27 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -105,7 +105,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def chat_completion_tools_handler(
- request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
+ request: Request, body: dict, extra_params: dict, user: UserModel, models: dict, tools: dict, context: str,
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
@@ -162,7 +162,7 @@ async def chat_completion_tools_handler(
template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
tools_function_calling_prompt = tools_function_calling_generation_template(
- template, tools_specs
+ template, tools_specs, context,
)
payload = get_tools_function_calling_payload(
body["messages"], task_model_id, tools_function_calling_prompt
@@ -717,6 +717,31 @@ def apply_params_to_form_data(form_data, model):
return form_data
+def create_context_string_from_sources(sources: list[dict[str, Any]]) -> str:
+ context_string = ""
+ citation_idx = {}
+ for source in sources:
+ if "document" in source:
+ for doc_context, doc_meta in zip(
+ source["document"], source["metadata"]
+ ):
+ source_name = source.get("source", {}).get("name", None)
+ citation_id = (
+ doc_meta.get("source", None)
+ or source.get("source", {}).get("id", None)
+ or "N/A"
+ )
+ if citation_id not in citation_idx:
+ citation_idx[citation_id] = len(citation_idx) + 1
+ context_string += (
+ f'{doc_context}\n"
+ )
+
+ return context_string.strip()
+
+
async def process_chat_payload(request, form_data, user, metadata, model):
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
@@ -897,6 +922,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"server": tool_server,
}
+ try:
+ form_data, flags = await chat_completion_files_handler(request, form_data, user)
+ sources.extend(flags.get("sources", []))
+ except Exception as e:
+ log.exception(e)
+
if tools_dict:
if metadata.get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
@@ -908,44 +939,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
else:
# If the function calling is not native, then call the tools function calling handler
try:
+ context_string = create_context_string_from_sources(sources)
form_data, flags = await chat_completion_tools_handler(
- request, form_data, extra_params, user, models, tools_dict
+ request, form_data, extra_params, user, models, tools_dict, context_string,
)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
- try:
- form_data, flags = await chat_completion_files_handler(request, form_data, user)
- sources.extend(flags.get("sources", []))
- except Exception as e:
- log.exception(e)
-
# If context is not empty, insert it into the messages
if len(sources) > 0:
- context_string = ""
- citation_idx = {}
- for source in sources:
- if "document" in source:
- for doc_context, doc_meta in zip(
- source["document"], source["metadata"]
- ):
- source_name = source.get("source", {}).get("name", None)
- citation_id = (
- doc_meta.get("source", None)
- or source.get("source", {}).get("id", None)
- or "N/A"
- )
- if citation_id not in citation_idx:
- citation_idx[citation_id] = len(citation_idx) + 1
- context_string += (
- f'{doc_context}\n"
- )
-
- context_string = context_string.strip()
+ context_string = create_context_string_from_sources(sources)
prompt = get_last_user_message(form_data["messages"])
if prompt is None:
diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py
index 42b44d516..3d84becb8 100644
--- a/backend/open_webui/utils/task.py
+++ b/backend/open_webui/utils/task.py
@@ -354,6 +354,7 @@ def moa_response_generation_template(
return template
-def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
+def tools_function_calling_generation_template(template: str, tools_specs: str, context: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
+ template = template.replace("{{CONTEXT}}", context)
return template