From 984dbf13abb15a933fc044e6d11a31cf46860823 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 25 Apr 2024 17:03:00 -0400 Subject: [PATCH] revert: original rag pipeline --- backend/apps/rag/utils.py | 85 +++++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index b5351217b..da71495bb 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -18,6 +18,9 @@ from langchain.retrievers import ( EnsembleRetriever, ) +from sentence_transformers import CrossEncoder + +from typing import Optional from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -28,50 +31,64 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) def query_embeddings_doc( collection_name: str, query: str, - k: int, - r: float, embeddings_function, - reranking_function, + k: int, + reranking_function: Optional[CrossEncoder] = None, + r: Optional[float] = None, ): try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection(name=collection_name) - documents = collection.get() # get all documents - bm25_retriever = BM25Retriever.from_texts( - texts=documents.get("documents"), - metadatas=documents.get("metadatas"), - ) - bm25_retriever.k = k + if reranking_function: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection(name=collection_name) - chroma_retriever = ChromaRetriever( - collection=collection, - embeddings_function=embeddings_function, - top_n=k, - ) + documents = collection.get() # get all documents + bm25_retriever = BM25Retriever.from_texts( + texts=documents.get("documents"), + metadatas=documents.get("metadatas"), + ) + bm25_retriever.k = k - ensemble_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] - ) + chroma_retriever = ChromaRetriever( + collection=collection, + embeddings_function=embeddings_function, + top_n=k, + ) - compressor = RerankCompressor( - embeddings_function=embeddings_function, - reranking_function=reranking_function, - r_score=r, - top_n=k, - ) + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] + ) - compression_retriever = ContextualCompressionRetriever( - base_compressor=compressor, base_retriever=ensemble_retriever - ) + compressor = RerankCompressor( + embeddings_function=embeddings_function, + reranking_function=reranking_function, + r_score=r, + top_n=k, + ) - result = compression_retriever.invoke(query) - result = { - "distances": [[d.metadata.get("score") for d in result]], - "documents": [[d.page_content for d in result]], - "metadatas": [[d.metadata for d in result]], - } + compression_retriever = ContextualCompressionRetriever( + base_compressor=compressor, base_retriever=ensemble_retriever + ) + result = compression_retriever.invoke(query) + result = { + "distances": [[d.metadata.get("score") for d in result]], + "documents": [[d.page_content for d in result]], + "metadatas": [[d.metadata for d in result]], + } + else: + # if you use docker use the model from the environment variable + query_embeddings = embeddings_function(query) + + log.info(f"query_embeddings_doc {query_embeddings}") + collection = CHROMA_CLIENT.get_collection(name=collection_name) + + result = collection.query( + query_embeddings=[query_embeddings], + n_results=k, + ) + + log.info(f"query_embeddings_doc:result {result}") return result except Exception as e: raise e