Merge pull request #5378 from thiswillbeyourgithub/fix_RAG_and_web

fix: RAG and Web Search + RAG enhancements
This commit is contained in:
Timothy Jaeryang Baek
2024-09-13 05:38:53 +01:00
committed by GitHub
3 changed files with 72 additions and 29 deletions

View File

@@ -1,5 +1,6 @@
import logging
import os
import uuid
from typing import Optional, Union
import requests
@@ -91,7 +92,7 @@ def query_doc_with_hybrid_search(
k: int,
reranking_function,
r: float,
):
) -> dict:
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
@@ -134,7 +135,7 @@ def query_doc_with_hybrid_search(
raise e
def merge_and_sort_query_results(query_results, k, reverse=False):
def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]:
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
@@ -180,7 +181,7 @@ def query_collection(
query: str,
embedding_function,
k: int,
):
) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
@@ -192,8 +193,8 @@ def query_collection(
embedding_function=embedding_function,
)
results.append(result)
except Exception:
pass
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
@@ -207,8 +208,9 @@ def query_collection_with_hybrid_search(
k: int,
reranking_function,
r: float,
):
) -> dict:
results = []
failed = 0
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
@@ -220,14 +222,39 @@ def query_collection_with_hybrid_search(
r=r,
)
results.append(result)
except Exception:
pass
except Exception as e:
log.exception(
"Error when querying the collection with "
f"hybrid_search: {e}"
)
failed += 1
if failed == len(collection_names):
raise Exception("Hybrid search failed for all collections. Using "
"Non hybrid search as fallback.")
return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str):
template = template.replace("[context]", context)
template = template.replace("[query]", query)
count = template.count("[context]")
assert count == 1, (
f"RAG template contains an unexpected number of '[context]' : {count}"
)
assert "[context]" in template, "RAG template does not contain '[context]'"
if "<context>" in context and "</context>" in context:
log.debug(
"WARNING: Potential prompt injection attack: the RAG "
"context contains '<context>' and '</context>'. This might be "
"nothing, or the user might be trying to hack something."
)
if "[query]" in context:
query_placeholder = str(uuid.uuid4())
template = template.replace("[QUERY]", query_placeholder)
template = template.replace("[context]", context)
template = template.replace(query_placeholder, query)
else:
template = template.replace("[context]", context)
template = template.replace("[query]", query)
return template
@@ -304,19 +331,25 @@ def get_rag_context(
continue
try:
context = None
if file["type"] == "text":
context = file["content"]
else:
if hybrid_search:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
else:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug("Error when using hybrid search, using"
" non hybrid search as fallback.")
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
@@ -325,7 +358,6 @@ def get_rag_context(
)
except Exception as e:
log.exception(e)
context = None
if context:
relevant_contexts.append({**context, "source": file})