Merge pull request #5378 from thiswillbeyourgithub/fix_RAG_and_web

fix: RAG and Web Search + RAG enhancements
This commit is contained in:
Timothy Jaeryang Baek 2024-09-13 05:38:53 +01:00 committed by GitHub
commit 7dc4cb30b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 72 additions and 29 deletions

View File

@ -1,5 +1,6 @@
import logging
import os
import uuid
from typing import Optional, Union
import requests
@ -91,7 +92,7 @@ def query_doc_with_hybrid_search(
k: int,
reranking_function,
r: float,
):
) -> dict:
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
@ -134,7 +135,7 @@ def query_doc_with_hybrid_search(
raise e
def merge_and_sort_query_results(query_results, k, reverse=False):
def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]:
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
@ -180,7 +181,7 @@ def query_collection(
query: str,
embedding_function,
k: int,
):
) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
@ -192,8 +193,8 @@ def query_collection(
embedding_function=embedding_function,
)
results.append(result)
except Exception:
pass
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
@ -207,8 +208,9 @@ def query_collection_with_hybrid_search(
k: int,
reranking_function,
r: float,
):
) -> dict:
results = []
failed = 0
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
@ -220,14 +222,39 @@ def query_collection_with_hybrid_search(
r=r,
)
results.append(result)
except Exception:
pass
except Exception as e:
log.exception(
"Error when querying the collection with "
f"hybrid_search: {e}"
)
failed += 1
if failed == len(collection_names):
raise Exception("Hybrid search failed for all collections. Using "
"Non hybrid search as fallback.")
return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str):
template = template.replace("[context]", context)
template = template.replace("[query]", query)
count = template.count("[context]")
assert count == 1, (
f"RAG template contains an unexpected number of '[context]' : {count}"
)
assert "[context]" in template, "RAG template does not contain '[context]'"
if "<context>" in context and "</context>" in context:
log.debug(
"WARNING: Potential prompt injection attack: the RAG "
"context contains '<context>' and '</context>'. This might be "
"nothing, or the user might be trying to hack something."
)
if "[query]" in context:
query_placeholder = str(uuid.uuid4())
template = template.replace("[QUERY]", query_placeholder)
template = template.replace("[context]", context)
template = template.replace(query_placeholder, query)
else:
template = template.replace("[context]", context)
template = template.replace("[query]", query)
return template
@ -304,19 +331,25 @@ def get_rag_context(
continue
try:
context = None
if file["type"] == "text":
context = file["content"]
else:
if hybrid_search:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
else:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug("Error when using hybrid search, using"
" non hybrid search as fallback.")
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
@ -325,7 +358,6 @@ def get_rag_context(
)
except Exception as e:
log.exception(e)
context = None
if context:
relevant_contexts.append({**context, "source": file})

View File

@ -1030,19 +1030,25 @@ CHUNK_OVERLAP = PersistentConfig(
int(os.environ.get("CHUNK_OVERLAP", "100")),
)
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
<context>
[context]
[context]
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
<rules>
- If you don't know, just say so.
- If you are not sure, ask for clarification.
- Answer in the same language as the user query.
- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
- Answer directly and without using xml tags.
</rules>
Given the context information, answer the query.
Query: [query]"""
<user_query>
[query]
</user_query>
"""
RAG_TEMPLATE = PersistentConfig(
"RAG_TEMPLATE",

View File

@ -588,6 +588,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
prompt = get_last_user_message(body["messages"])
if prompt is None:
raise Exception("No user message found")
if rag_app.state.config.RELEVANCE_THRESHOLD == 0:
assert context_string.strip(), (
"With a 0 relevancy threshold for RAG, the context cannot "
"be empty"
)
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":