From 69822e4c25f038e7aace0a1f029c40009836c267 Mon Sep 17 00:00:00 2001 From: Steven Kreitzer Date: Thu, 25 Apr 2024 20:00:47 -0500 Subject: [PATCH] fix: sort ranking hybrid --- backend/apps/rag/utils.py | 29 ++++++++++++----------------- backend/main.py | 1 + 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 0e6e3dd68..62c29b2be 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -18,8 +18,6 @@ from langchain.retrievers import ( EnsembleRetriever, ) -from sentence_transformers import CrossEncoder - from typing import Optional from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -34,14 +32,13 @@ def query_embeddings_doc( embeddings_function, reranking_function, k: int, - r: Optional[float] = None, - hybrid: Optional[bool] = False, + r: int, + hybrid: bool, ): try: - if hybrid: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection(name=collection_name) + collection = CHROMA_CLIENT.get_collection(name=collection_name) + if hybrid: documents = collection.get() # get all documents bm25_retriever = BM25Retriever.from_texts( texts=documents.get("documents"), @@ -77,24 +74,19 @@ def query_embeddings_doc( "metadatas": [[d.metadata for d in result]], } else: - # if you use docker use the model from the environment variable query_embeddings = embeddings_function(query) - - log.info(f"query_embeddings_doc {query_embeddings}") - collection = CHROMA_CLIENT.get_collection(name=collection_name) - result = collection.query( query_embeddings=[query_embeddings], n_results=k, ) - log.info(f"query_embeddings_doc:result {result}") + log.info(f"query_embeddings_doc:result {result}") return result except Exception as e: raise e -def merge_and_sort_query_results(query_results, k): +def merge_and_sort_query_results(query_results, k, reverse=False): # Initialize lists to store combined data combined_distances = [] combined_documents = [] @@ -109,7 +101,7 @@ def merge_and_sort_query_results(query_results, k): combined = list(zip(combined_distances, combined_documents, combined_metadatas)) # Sort the list based on distances - combined.sort(key=lambda x: x[0]) + combined.sort(key=lambda x: x[0], reverse=reverse) # We don't have anything :-( if not combined: @@ -162,7 +154,8 @@ def query_embeddings_collection( except: pass - return merge_and_sort_query_results(results, k) + reverse = hybrid and reranking_function is not None + return merge_and_sort_query_results(results, k=k, reverse=reverse) def rag_template(template: str, context: str, query: str): @@ -484,7 +477,9 @@ class RerankCompressor(BaseDocumentCompressor): (d, s) for d, s in docs_with_scores if s >= self.r_score ] - result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) + reverse = self.reranking_function is not None + result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse) + final_results = [] for doc, doc_score in result[: self.top_n]: metadata = doc.metadata diff --git a/backend/main.py b/backend/main.py index 1b92ae733..284d83719 100644 --- a/backend/main.py +++ b/backend/main.py @@ -121,6 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware): rag_app.state.RAG_TEMPLATE, rag_app.state.TOP_K, rag_app.state.RELEVANCE_THRESHOLD, + rag_app.state.HYBRID, rag_app.state.RAG_EMBEDDING_ENGINE, rag_app.state.RAG_EMBEDDING_MODEL, rag_app.state.sentence_transformer_ef,