diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py
index 43fbd1596..e64b71297 100644
--- a/backend/open_webui/apps/rag/utils.py
+++ b/backend/open_webui/apps/rag/utils.py
@@ -1,5 +1,6 @@
import logging
import os
+import uuid
from typing import Optional, Union
import requests
@@ -91,7 +92,7 @@ def query_doc_with_hybrid_search(
k: int,
reranking_function,
r: float,
-):
+) -> dict:
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
@@ -134,7 +135,7 @@ def query_doc_with_hybrid_search(
raise e
-def merge_and_sort_query_results(query_results, k, reverse=False):
+def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]:
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
@@ -180,7 +181,7 @@ def query_collection(
query: str,
embedding_function,
k: int,
-):
+) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
@@ -192,8 +193,8 @@ def query_collection(
embedding_function=embedding_function,
)
results.append(result)
- except Exception:
- pass
+ except Exception as e:
+ log.exception(f"Error when querying the collection: {e}")
else:
pass
@@ -207,8 +208,9 @@ def query_collection_with_hybrid_search(
k: int,
reranking_function,
r: float,
-):
+) -> dict:
results = []
+ failed = 0
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
@@ -220,14 +222,39 @@ def query_collection_with_hybrid_search(
r=r,
)
results.append(result)
- except Exception:
- pass
+ except Exception as e:
+ log.exception(
+ "Error when querying the collection with "
+ f"hybrid_search: {e}"
+ )
+ failed += 1
+ if failed == len(collection_names):
+ raise Exception("Hybrid search failed for all collections. Using "
+ "Non hybrid search as fallback.")
return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str):
- template = template.replace("[context]", context)
- template = template.replace("[query]", query)
+ count = template.count("[context]")
+ assert count == 1, (
+ f"RAG template contains an unexpected number of '[context]' : {count}"
+ )
+ assert "[context]" in template, "RAG template does not contain '[context]'"
+ if "" in context and "" in context:
+ log.debug(
+ "WARNING: Potential prompt injection attack: the RAG "
+ "context contains '' and ''. This might be "
+ "nothing, or the user might be trying to hack something."
+ )
+
+ if "[query]" in context:
+ query_placeholder = str(uuid.uuid4())
+ template = template.replace("[QUERY]", query_placeholder)
+ template = template.replace("[context]", context)
+ template = template.replace(query_placeholder, query)
+ else:
+ template = template.replace("[context]", context)
+ template = template.replace("[query]", query)
return template
@@ -304,19 +331,25 @@ def get_rag_context(
continue
try:
+ context = None
if file["type"] == "text":
context = file["content"]
else:
if hybrid_search:
- context = query_collection_with_hybrid_search(
- collection_names=collection_names,
- query=query,
- embedding_function=embedding_function,
- k=k,
- reranking_function=reranking_function,
- r=r,
- )
- else:
+ try:
+ context = query_collection_with_hybrid_search(
+ collection_names=collection_names,
+ query=query,
+ embedding_function=embedding_function,
+ k=k,
+ reranking_function=reranking_function,
+ r=r,
+ )
+ except Exception as e:
+ log.debug("Error when using hybrid search, using"
+ " non hybrid search as fallback.")
+
+ if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
@@ -325,7 +358,6 @@ def get_rag_context(
)
except Exception as e:
log.exception(e)
- context = None
if context:
relevant_contexts.append({**context, "source": file})
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 019cc8847..aff6ba547 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -1030,19 +1030,25 @@ CHUNK_OVERLAP = PersistentConfig(
int(os.environ.get("CHUNK_OVERLAP", "100")),
)
-DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags.
+DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
+
- [context]
+[context]
-When answer to user:
-- If you don't know, just say that you don't know.
-- If you don't know when you are not sure, ask for clarification.
-Avoid mentioning that you obtained the information from the context.
-And answer according to the language of the user's question.
+
+- If you don't know, just say so.
+- If you are not sure, ask for clarification.
+- Answer in the same language as the user query.
+- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
+- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
+- Answer directly and without using xml tags.
+
-Given the context information, answer the query.
-Query: [query]"""
+
+[query]
+
+"""
RAG_TEMPLATE = PersistentConfig(
"RAG_TEMPLATE",
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index a47115977..49a559d58 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -588,6 +588,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
prompt = get_last_user_message(body["messages"])
if prompt is None:
raise Exception("No user message found")
+ if rag_app.state.config.RELEVANCE_THRESHOLD == 0:
+ assert context_string.strip(), (
+ "With a 0 relevancy threshold for RAG, the context cannot "
+ "be empty"
+ )
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":