From 53f03f65561f74927382b115d5eee5ed826015e7 Mon Sep 17 00:00:00 2001
From: thiswillbeyourgithub
<26625900+thiswillbeyourgithub@users.noreply.github.com>
Date: Thu, 12 Sep 2024 15:34:52 +0200
Subject: [PATCH 1/8] fix: log exception when issues of collection querying
---
backend/open_webui/apps/rag/utils.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py
index 2bf8a02e4..c179f47c0 100644
--- a/backend/open_webui/apps/rag/utils.py
+++ b/backend/open_webui/apps/rag/utils.py
@@ -150,8 +150,8 @@ def query_collection(
embedding_function=embedding_function,
)
results.append(result)
- except Exception:
- pass
+ except Exception as e:
+ log.exception(f"Error when querying the collection: {e}")
else:
pass
@@ -178,8 +178,11 @@ def query_collection_with_hybrid_search(
r=r,
)
results.append(result)
- except Exception:
- pass
+ except Exception as e:
+ log.exception(
+ "Error when querying the collection with "
+ f"hybrid_search: {e}"
+ )
return merge_and_sort_query_results(results, k=k, reverse=True)
From ed2a1e7db960fa6aa4c1c298272164c6f48f5bc0 Mon Sep 17 00:00:00 2001
From: thiswillbeyourgithub
<26625900+thiswillbeyourgithub@users.noreply.github.com>
Date: Thu, 12 Sep 2024 15:58:26 +0200
Subject: [PATCH 2/8] enh: use non hybrid search as fallback if hybrid search
failed
---
backend/open_webui/apps/rag/utils.py | 30 ++++++++++++++++++----------
1 file changed, 20 insertions(+), 10 deletions(-)
diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py
index c179f47c0..dd4cdd007 100644
--- a/backend/open_webui/apps/rag/utils.py
+++ b/backend/open_webui/apps/rag/utils.py
@@ -167,6 +167,7 @@ def query_collection_with_hybrid_search(
r: float,
):
results = []
+ failed = 0
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
@@ -183,6 +184,10 @@ def query_collection_with_hybrid_search(
"Error when querying the collection with "
f"hybrid_search: {e}"
)
+ failed += 1
+ if failed == len(collection_names):
+ raise Exception("Hybrid search failed for all collections. Using "
+ "Non hybrid search as fallback.")
return merge_and_sort_query_results(results, k=k, reverse=True)
@@ -265,19 +270,25 @@ def get_rag_context(
continue
try:
+ context = None
if file["type"] == "text":
context = file["content"]
else:
if hybrid_search:
- context = query_collection_with_hybrid_search(
- collection_names=collection_names,
- query=query,
- embedding_function=embedding_function,
- k=k,
- reranking_function=reranking_function,
- r=r,
- )
- else:
+ try:
+ context = query_collection_with_hybrid_search(
+ collection_names=collection_names,
+ query=query,
+ embedding_function=embedding_function,
+ k=k,
+ reranking_function=reranking_function,
+ r=r,
+ )
+ except Exception as e:
+ log.debug("Error when using hybrid search, using"
+ " non hybrid search as fallback.")
+
+ if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
@@ -286,7 +297,6 @@ def get_rag_context(
)
except Exception as e:
log.exception(e)
- context = None
if context:
relevant_contexts.append({**context, "source": file})
From 209e246e6f5fce408421eb30001c1886885aee5a Mon Sep 17 00:00:00 2001
From: thiswillbeyourgithub
<26625900+thiswillbeyourgithub@users.noreply.github.com>
Date: Thu, 12 Sep 2024 15:31:49 +0200
Subject: [PATCH 3/8] fix: much improved RAG template
---
backend/open_webui/config.py | 24 +++++++++++++++---------
1 file changed, 15 insertions(+), 9 deletions(-)
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 5ccb40d47..c87256a08 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -1085,19 +1085,25 @@ CHUNK_OVERLAP = PersistentConfig(
int(os.environ.get("CHUNK_OVERLAP", "100")),
)
-DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags.
+DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
+
- [context]
+[context]
-When answer to user:
-- If you don't know, just say that you don't know.
-- If you don't know when you are not sure, ask for clarification.
-Avoid mentioning that you obtained the information from the context.
-And answer according to the language of the user's question.
+
+- If you don't know, just say so.
+- If you are not sure, ask for clarification.
+- Answer in the same language as the user query.
+- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
+- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
+- Answer directly and without using xml tags.
+
-Given the context information, answer the query.
-Query: [query]"""
+
+[query]
+
+"""
RAG_TEMPLATE = PersistentConfig(
"RAG_TEMPLATE",
From adf26789b8c1c9c256f22aef98f64449d8f684a5 Mon Sep 17 00:00:00 2001
From: thiswillbeyourgithub
<26625900+thiswillbeyourgithub@users.noreply.github.com>
Date: Thu, 12 Sep 2024 15:19:24 +0200
Subject: [PATCH 4/8] logs: crash if rag_template would be wrong
---
backend/open_webui/apps/rag/utils.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py
index dd4cdd007..5bddf0a96 100644
--- a/backend/open_webui/apps/rag/utils.py
+++ b/backend/open_webui/apps/rag/utils.py
@@ -192,6 +192,11 @@ def query_collection_with_hybrid_search(
def rag_template(template: str, context: str, query: str):
+ count = template.count("[context]")
+ assert count == 1, (
+ f"RAG template contains an unexpected number of '[context]' : {count}"
+ )
+ assert "[context]" in template, "RAG template does not contain '[context]'"
template = template.replace("[context]", context)
template = template.replace("[query]", query)
return template
From 9661fee55428b81c3f65bc2863efbcd70c7e403a Mon Sep 17 00:00:00 2001
From: thiswillbeyourgithub
<26625900+thiswillbeyourgithub@users.noreply.github.com>
Date: Thu, 12 Sep 2024 15:19:40 +0200
Subject: [PATCH 5/8] fix: handle case where [query] happens in the RAG context
---
backend/open_webui/apps/rag/utils.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py
index 5bddf0a96..41fc4e2d4 100644
--- a/backend/open_webui/apps/rag/utils.py
+++ b/backend/open_webui/apps/rag/utils.py
@@ -1,5 +1,6 @@
import logging
import os
+import uuid
from typing import Optional, Union
import requests
@@ -197,8 +198,15 @@ def rag_template(template: str, context: str, query: str):
f"RAG template contains an unexpected number of '[context]' : {count}"
)
assert "[context]" in template, "RAG template does not contain '[context]'"
- template = template.replace("[context]", context)
- template = template.replace("[query]", query)
+
+ if "[query]" in context:
+ query_placeholder = str(uuid.uuid4())
+ template = template.replace("[QUERY]", query_placeholder)
+ template = template.replace("[context]", context)
+ template = template.replace(query_placeholder, query)
+ else:
+ template = template.replace("[context]", context)
+ template = template.replace("[query]", query)
return template
From b4ad64586aceeae45dd5cc4178f43b4b6c00fb77 Mon Sep 17 00:00:00 2001
From: thiswillbeyourgithub
<26625900+thiswillbeyourgithub@users.noreply.github.com>
Date: Thu, 12 Sep 2024 16:36:37 +0200
Subject: [PATCH 6/8] fix: add check that the context for RAG is not empty if
the threshold is 0
---
backend/open_webui/main.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 8914cb491..bf586e56d 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -588,6 +588,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
prompt = get_last_user_message(body["messages"])
if prompt is None:
raise Exception("No user message found")
+ if rag_app.state.config.RELEVANCE_THRESHOLD == 0:
+ assert context_string.strip(), (
+ "With a 0 relevancy threshold for RAG, the context cannot "
+ "be empty"
+ )
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
From e872f5dc78e5585b205e378bab87b26de5ff59e2 Mon Sep 17 00:00:00 2001
From: thiswillbeyourgithub
<26625900+thiswillbeyourgithub@users.noreply.github.com>
Date: Thu, 12 Sep 2024 16:04:41 +0200
Subject: [PATCH 7/8] log: added a debug log if detecting a potential prompt
injection attack
---
backend/open_webui/apps/rag/utils.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py
index 41fc4e2d4..29b12d0b0 100644
--- a/backend/open_webui/apps/rag/utils.py
+++ b/backend/open_webui/apps/rag/utils.py
@@ -198,6 +198,12 @@ def rag_template(template: str, context: str, query: str):
f"RAG template contains an unexpected number of '[context]' : {count}"
)
assert "[context]" in template, "RAG template does not contain '[context]'"
+ if "" in context and "" in context:
+ log.debug(
+ "WARNING: Potential prompt injection attack: the RAG "
+ "context contains '' and ''. This might be "
+ "nothing, or the user might be trying to hack something."
+ )
if "[query]" in context:
query_placeholder = str(uuid.uuid4())
From 65d5545cf0cd85c1b4a5e55266289b42879c4673 Mon Sep 17 00:00:00 2001
From: thiswillbeyourgithub
<26625900+thiswillbeyourgithub@users.noreply.github.com>
Date: Thu, 12 Sep 2024 15:50:18 +0200
Subject: [PATCH 8/8] added a few type hints
---
backend/open_webui/apps/rag/utils.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py
index 29b12d0b0..9da08e28a 100644
--- a/backend/open_webui/apps/rag/utils.py
+++ b/backend/open_webui/apps/rag/utils.py
@@ -48,7 +48,7 @@ def query_doc_with_hybrid_search(
k: int,
reranking_function,
r: float,
-):
+) -> dict:
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
documents = collection.get() # get all documents
@@ -93,7 +93,7 @@ def query_doc_with_hybrid_search(
raise e
-def merge_and_sort_query_results(query_results, k, reverse=False):
+def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]:
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
@@ -139,7 +139,7 @@ def query_collection(
query: str,
embedding_function,
k: int,
-):
+) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
@@ -166,7 +166,7 @@ def query_collection_with_hybrid_search(
k: int,
reranking_function,
r: float,
-):
+) -> dict:
results = []
failed = 0
for collection_name in collection_names: