diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py index 43fbd1596..e64b71297 100644 --- a/backend/open_webui/apps/rag/utils.py +++ b/backend/open_webui/apps/rag/utils.py @@ -1,5 +1,6 @@ import logging import os +import uuid from typing import Optional, Union import requests @@ -91,7 +92,7 @@ def query_doc_with_hybrid_search( k: int, reranking_function, r: float, -): +) -> dict: try: result = VECTOR_DB_CLIENT.get(collection_name=collection_name) @@ -134,7 +135,7 @@ def query_doc_with_hybrid_search( raise e -def merge_and_sort_query_results(query_results, k, reverse=False): +def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]: # Initialize lists to store combined data combined_distances = [] combined_documents = [] @@ -180,7 +181,7 @@ def query_collection( query: str, embedding_function, k: int, -): +) -> dict: results = [] for collection_name in collection_names: if collection_name: @@ -192,8 +193,8 @@ def query_collection( embedding_function=embedding_function, ) results.append(result) - except Exception: - pass + except Exception as e: + log.exception(f"Error when querying the collection: {e}") else: pass @@ -207,8 +208,9 @@ def query_collection_with_hybrid_search( k: int, reranking_function, r: float, -): +) -> dict: results = [] + failed = 0 for collection_name in collection_names: try: result = query_doc_with_hybrid_search( @@ -220,14 +222,39 @@ def query_collection_with_hybrid_search( r=r, ) results.append(result) - except Exception: - pass + except Exception as e: + log.exception( + "Error when querying the collection with " + f"hybrid_search: {e}" + ) + failed += 1 + if failed == len(collection_names): + raise Exception("Hybrid search failed for all collections. Using " + "Non hybrid search as fallback.") return merge_and_sort_query_results(results, k=k, reverse=True) def rag_template(template: str, context: str, query: str): - template = template.replace("[context]", context) - template = template.replace("[query]", query) + count = template.count("[context]") + assert count == 1, ( + f"RAG template contains an unexpected number of '[context]' : {count}" + ) + assert "[context]" in template, "RAG template does not contain '[context]'" + if "" in context and "" in context: + log.debug( + "WARNING: Potential prompt injection attack: the RAG " + "context contains '' and ''. This might be " + "nothing, or the user might be trying to hack something." + ) + + if "[query]" in context: + query_placeholder = str(uuid.uuid4()) + template = template.replace("[QUERY]", query_placeholder) + template = template.replace("[context]", context) + template = template.replace(query_placeholder, query) + else: + template = template.replace("[context]", context) + template = template.replace("[query]", query) return template @@ -304,19 +331,25 @@ def get_rag_context( continue try: + context = None if file["type"] == "text": context = file["content"] else: if hybrid_search: - context = query_collection_with_hybrid_search( - collection_names=collection_names, - query=query, - embedding_function=embedding_function, - k=k, - reranking_function=reranking_function, - r=r, - ) - else: + try: + context = query_collection_with_hybrid_search( + collection_names=collection_names, + query=query, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + except Exception as e: + log.debug("Error when using hybrid search, using" + " non hybrid search as fallback.") + + if (not hybrid_search) or (context is None): context = query_collection( collection_names=collection_names, query=query, @@ -325,7 +358,6 @@ def get_rag_context( ) except Exception as e: log.exception(e) - context = None if context: relevant_contexts.append({**context, "source": file}) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 019cc8847..aff6ba547 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1030,19 +1030,25 @@ CHUNK_OVERLAP = PersistentConfig( int(os.environ.get("CHUNK_OVERLAP", "100")), ) -DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags. +DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules. + - [context] +[context] -When answer to user: -- If you don't know, just say that you don't know. -- If you don't know when you are not sure, ask for clarification. -Avoid mentioning that you obtained the information from the context. -And answer according to the language of the user's question. + +- If you don't know, just say so. +- If you are not sure, ask for clarification. +- Answer in the same language as the user query. +- If the context appears unreadable or of poor quality, tell the user then answer as best as you can. +- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge. +- Answer directly and without using xml tags. + -Given the context information, answer the query. -Query: [query]""" + +[query] + +""" RAG_TEMPLATE = PersistentConfig( "RAG_TEMPLATE", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a47115977..49a559d58 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -588,6 +588,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): prompt = get_last_user_message(body["messages"]) if prompt is None: raise Exception("No user message found") + if rag_app.state.config.RELEVANCE_THRESHOLD == 0: + assert context_string.strip(), ( + "With a 0 relevancy threshold for RAG, the context cannot " + "be empty" + ) # Workaround for Ollama 2.0+ system prompt issue # TODO: replace with add_or_update_system_message if model["owned_by"] == "ollama":