diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py
index 77d97814c..6d87c98e3 100644
--- a/backend/open_webui/apps/retrieval/utils.py
+++ b/backend/open_webui/apps/retrieval/utils.py
@@ -177,35 +177,34 @@ def merge_and_sort_query_results(
def query_collection(
collection_names: list[str],
- query: str,
+ queries: list[str],
embedding_function,
k: int,
) -> dict:
-
results = []
- query_embedding = embedding_function(query)
-
- for collection_name in collection_names:
- if collection_name:
- try:
- result = query_doc(
- collection_name=collection_name,
- k=k,
- query_embedding=query_embedding,
- )
- if result is not None:
- results.append(result.model_dump())
- except Exception as e:
- log.exception(f"Error when querying the collection: {e}")
- else:
- pass
+ for query in queries:
+ query_embedding = embedding_function(query)
+ for collection_name in collection_names:
+ if collection_name:
+ try:
+ result = query_doc(
+ collection_name=collection_name,
+ k=k,
+ query_embedding=query_embedding,
+ )
+ if result is not None:
+ results.append(result.model_dump())
+ except Exception as e:
+ log.exception(f"Error when querying the collection: {e}")
+ else:
+ pass
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
collection_names: list[str],
- query: str,
+ queries: list[str],
embedding_function,
k: int,
reranking_function,
@@ -215,15 +214,16 @@ def query_collection_with_hybrid_search(
error = False
for collection_name in collection_names:
try:
- result = query_doc_with_hybrid_search(
- collection_name=collection_name,
- query=query,
- embedding_function=embedding_function,
- k=k,
- reranking_function=reranking_function,
- r=r,
- )
- results.append(result)
+ for query in queries:
+ result = query_doc_with_hybrid_search(
+ collection_name=collection_name,
+ query=query,
+ embedding_function=embedding_function,
+ k=k,
+ reranking_function=reranking_function,
+ r=r,
+ )
+ results.append(result)
except Exception as e:
log.exception(
"Error when querying the collection with " f"hybrid_search: {e}"
@@ -309,15 +309,14 @@ def get_embedding_function(
def get_rag_context(
files,
- messages,
+ queries,
embedding_function,
k,
reranking_function,
r,
hybrid_search,
):
- log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
- query = get_last_user_message(messages)
+ log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
extracted_collections = []
relevant_contexts = []
@@ -359,7 +358,7 @@ def get_rag_context(
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
- query=query,
+ queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
@@ -374,7 +373,7 @@ def get_rag_context(
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
- query=query,
+ queries=queries,
embedding_function=embedding_function,
k=k,
)
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 621ebf35b..c33895396 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -941,19 +941,49 @@ ENABLE_TAGS_GENERATION = PersistentConfig(
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
)
-ENABLE_SEARCH_QUERY = PersistentConfig(
- "ENABLE_SEARCH_QUERY",
- "task.search.enable",
- os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true",
+
+ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
+ "ENABLE_SEARCH_QUERY_GENERATION",
+ "task.query.search.enable",
+ os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true",
+)
+
+ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig(
+ "ENABLE_RETRIEVAL_QUERY_GENERATION",
+ "task.query.retrieval.enable",
+ os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true",
)
-SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
- "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
- "task.search.prompt_template",
- os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""),
+QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+ "QUERY_GENERATION_PROMPT_TEMPLATE",
+ "task.query.prompt_template",
+ os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""),
)
+DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
+Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list.
+
+### Guidelines:
+- Respond exclusively with a JSON object.
+- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise.
+- If no search query is necessary, output should be: { "queries": [] }
+- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required.
+- Be concise, focusing strictly on composing search queries with no additional commentary or text.
+- When in doubt, prefer to suggest a search for comprehensiveness.
+- Today's date is: {{CURRENT_DATE}}
+
+### Output:
+JSON format: {
+ "queries": ["query1", "query2"]
+}
+
+### Chat History:
+
+{{MESSAGES:END:6}}
+
+"""
+
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
@@ -1127,27 +1157,6 @@ RAG_TEXT_SPLITTER = PersistentConfig(
)
-ENABLE_RAG_QUERY_GENERATION = PersistentConfig(
- "ENABLE_RAG_QUERY_GENERATION",
- "rag.query_generation.enable",
- os.environ.get("ENABLE_RAG_QUERY_GENERATION", "False").lower() == "true",
-)
-
-DEFAULT_RAG_QUERY_GENERATION_TEMPLATE = """Given the user's message and interaction history, decide if a file search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt.
-User Message:
-{{prompt:end:4000}}
-Interaction History:
-{{MESSAGES:END:6}}
-Search Query:"""
-
-
-RAG_QUERY_GENERATION_TEMPLATE = PersistentConfig(
- "RAG_QUERY_GENERATION_TEMPLATE",
- "rag.query_generation.template",
- os.environ.get("RAG_QUERY_GENERATION_TEMPLATE", ""),
-)
-
-
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
TIKTOKEN_ENCODING_NAME = PersistentConfig(
"TIKTOKEN_ENCODING_NAME",
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 6f2c7ac42..04c86395a 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -78,11 +78,13 @@ from open_webui.config import (
ENV,
FRONTEND_BUILD_DIR,
OAUTH_PROVIDERS,
- ENABLE_SEARCH_QUERY,
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
+ ENABLE_SEARCH_QUERY_GENERATION,
+ ENABLE_RETRIEVAL_QUERY_GENERATION,
+ QUERY_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@@ -122,7 +124,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import (
moa_response_generation_template,
tags_generation_template,
- search_query_generation_template,
+ query_generation_template,
emoji_generation_template,
title_generation_template,
tools_function_calling_generation_template,
@@ -206,10 +208,9 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
-app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
-app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
-)
+app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
+app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
+app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
@@ -492,14 +493,41 @@ async def chat_completion_tools_handler(
return body, {"contexts": contexts, "citations": citations}
-async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
+async def chat_completion_files_handler(
+ body: dict, user: UserModel
+) -> tuple[dict, dict[str, list]]:
contexts = []
citations = []
+ try:
+ queries_response = await generate_queries(
+ {
+ "model": body["model"],
+ "messages": body["messages"],
+ "type": "retrieval",
+ },
+ user,
+ )
+ queries_response = queries_response["choices"][0]["message"]["content"]
+
+ try:
+ queries_response = json.loads(queries_response)
+ except Exception as e:
+ queries_response = {"queries": []}
+
+ queries = queries_response.get("queries", [])
+ except Exception as e:
+ queries = []
+
+ if len(queries) == 0:
+ queries = [get_last_user_message(body["messages"])]
+
+ print(f"{queries=}")
+
if files := body.get("metadata", {}).get("files", None):
contexts, citations = get_rag_context(
files=files,
- messages=body["messages"],
+ queries=queries,
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
k=retrieval_app.state.config.TOP_K,
reranking_function=retrieval_app.state.sentence_transformer_rf,
@@ -643,7 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
log.exception(e)
try:
- body, flags = await chat_completion_files_handler(body)
+ body, flags = await chat_completion_files_handler(body, user)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
@@ -1579,8 +1607,9 @@ async def get_task_config(user=Depends(get_verified_user)):
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
- "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
- "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
+ "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
+ "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
+ "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@@ -1591,8 +1620,9 @@ class TaskConfigForm(BaseModel):
TITLE_GENERATION_PROMPT_TEMPLATE: str
TAGS_GENERATION_PROMPT_TEMPLATE: str
ENABLE_TAGS_GENERATION: bool
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
- ENABLE_SEARCH_QUERY: bool
+ ENABLE_SEARCH_QUERY_GENERATION: bool
+ ENABLE_RETRIEVAL_QUERY_GENERATION: bool
+ QUERY_GENERATION_PROMPT_TEMPLATE: str
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@@ -1607,11 +1637,16 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
)
app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
-
- app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
- form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
+ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
+ form_data.ENABLE_SEARCH_QUERY_GENERATION
+ )
+ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
+ form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
+ )
+
+ app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
+ form_data.QUERY_GENERATION_PROMPT_TEMPLATE
)
- app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
@@ -1622,8 +1657,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
- "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
- "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
+ "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
+ "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
+ "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@@ -1799,14 +1835,22 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
return await generate_chat_completions(form_data=payload, user=user)
-@app.post("/api/task/query/completions")
-async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
- print("generate_search_query")
- if not app.state.config.ENABLE_SEARCH_QUERY:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Search query generation is disabled",
- )
+@app.post("/api/task/queries/completions")
+async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
+ print("generate_queries")
+ type = form_data.get("type")
+ if type == "web_search":
+ if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Search query generation is disabled",
+ )
+ elif type == "retrieval":
+ if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Query generation is disabled",
+ )
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
@@ -1830,20 +1874,12 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
model = models[task_model_id]
- if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
- template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
+ if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "":
+ template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
else:
- template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}.
+ template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
-User Message:
-{{prompt:end:4000}}
-
-Interaction History:
-{{MESSAGES:END:6}}
-
-Search Query:"""
-
- content = search_query_generation_template(
+ content = query_generation_template(
template, form_data["messages"], {"name": user.name}
)
@@ -1851,13 +1887,6 @@ Search Query:"""
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
- **(
- {"max_tokens": 30}
- if models[task_model_id]["owned_by"] == "ollama"
- else {
- "max_completion_tokens": 30,
- }
- ),
"metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data},
}
log.debug(payload)
diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py
index 799cca11a..28b07da37 100644
--- a/backend/open_webui/utils/task.py
+++ b/backend/open_webui/utils/task.py
@@ -163,7 +163,7 @@ def emoji_generation_template(
return template
-def search_query_generation_template(
+def query_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts
index d22923670..56d4d9bc6 100644
--- a/src/lib/apis/index.ts
+++ b/src/lib/apis/index.ts
@@ -348,15 +348,16 @@ export const generateEmoji = async (
return null;
};
-export const generateSearchQuery = async (
+export const generateQueries = async (
token: string = '',
model: string,
messages: object[],
- prompt: string
+ prompt: string,
+ type?: string = 'web_search'
) => {
let error = null;
- const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, {
+ const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, {
method: 'POST',
headers: {
Accept: 'application/json',
@@ -366,7 +367,8 @@ export const generateSearchQuery = async (
body: JSON.stringify({
model: model,
messages: messages,
- prompt: prompt
+ prompt: prompt,
+ type: type
})
})
.then(async (res) => {
@@ -385,7 +387,40 @@ export const generateSearchQuery = async (
throw error;
}
- return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
+
+ try {
+ // Step 1: Safely extract the response string
+ const response = res?.choices[0]?.message?.content ?? '';
+
+ // Step 2: Attempt to fix common JSON format issues like single quotes
+ const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON
+
+ // Step 3: Find the relevant JSON block within the response
+ const jsonStartIndex = sanitizedResponse.indexOf('{');
+ const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
+
+ // Step 4: Check if we found a valid JSON block (with both `{` and `}`)
+ if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
+ const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
+
+ // Step 5: Parse the JSON block
+ const parsed = JSON.parse(jsonResponse);
+
+ // Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array
+ if (parsed && parsed.queries) {
+ return Array.isArray(parsed.queries) ? parsed.queries : [];
+ } else {
+ return [];
+ }
+ }
+
+ // If no valid JSON block found, return an empty array
+ return [];
+ } catch (e) {
+ // Catch and safely return empty array on any parsing errors
+ console.error('Failed to parse response: ', e);
+ return [];
+ }
};
export const generateMoACompletion = async (
diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte
index b79ce0747..3043ffae4 100644
--- a/src/lib/components/admin/Settings/Interface.svelte
+++ b/src/lib/components/admin/Settings/Interface.svelte
@@ -26,8 +26,9 @@
TITLE_GENERATION_PROMPT_TEMPLATE: '',
TAGS_GENERATION_PROMPT_TEMPLATE: '',
ENABLE_TAGS_GENERATION: true,
- ENABLE_SEARCH_QUERY: true,
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: ''
+ ENABLE_SEARCH_QUERY_GENERATION: true,
+ ENABLE_RETRIEVAL_QUERY_GENERATION: true,
+ QUERY_GENERATION_PROMPT_TEMPLATE: ''
};
let promptSuggestions = [];
@@ -164,31 +165,35 @@
+
+
+ {$i18n.t('Enable Retrieval Query Generation')}
+
+
+
+
+
{$i18n.t('Enable Web Search Query Generation')}
-
+
- {#if taskConfig.ENABLE_SEARCH_QUERY}
-
-
{$i18n.t('Search Query Generation Prompt')}
+
+
{$i18n.t('Query Generation Prompt')}
-
-
-
-
- {/if}
+
+
+
+
diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte
index 081d99b28..0cc03b497 100644
--- a/src/lib/components/chat/Chat.svelte
+++ b/src/lib/components/chat/Chat.svelte
@@ -66,7 +66,7 @@
import {
chatCompleted,
generateTitle,
- generateSearchQuery,
+ generateQueries,
chatAction,
generateMoACompletion,
generateTags
@@ -2047,17 +2047,17 @@
history.messages[responseMessageId] = responseMessage;
const prompt = userMessage.content;
- let searchQuery = await generateSearchQuery(
+ let queries = await generateQueries(
localStorage.token,
model,
messages.filter((message) => message?.content?.trim()),
prompt
).catch((error) => {
console.log(error);
- return prompt;
+ return [];
});
- if (!searchQuery || searchQuery == '') {
+ if (queries.length === 0) {
responseMessage.statusHistory.push({
done: true,
error: true,
@@ -2068,6 +2068,8 @@
return;
}
+ const searchQuery = queries[0];
+
responseMessage.statusHistory.push({
done: false,
action: 'web_search',