From ce9a5d12e0eb6881865617a04b4eee8246e7882f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 27 Apr 2024 15:38:50 -0400 Subject: [PATCH] refac: rag pipeline --- backend/apps/rag/main.py | 100 ++++++++++-------- backend/apps/rag/utils.py | 213 +++++++++++++++++++++----------------- backend/main.py | 20 ++-- 3 files changed, 179 insertions(+), 154 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 405546dd2..715d70b1b 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -47,9 +47,11 @@ from apps.web.models.documents import ( from apps.rag.utils import ( get_model_path, - query_embeddings_doc, - get_embeddings_function, - query_embeddings_collection, + get_embedding_function, + query_doc, + query_doc_with_hybrid_search, + query_collection, + query_collection_with_hybrid_search, ) from utils.misc import ( @@ -147,6 +149,15 @@ update_reranking_model( RAG_RERANKING_MODEL_AUTO_UPDATE, ) + +app.state.EMBEDDING_FUNCTION = get_embedding_function( + app.state.RAG_EMBEDDING_ENGINE, + app.state.RAG_EMBEDDING_MODEL, + app.state.sentence_transformer_ef, + app.state.OPENAI_API_KEY, + app.state.OPENAI_API_BASE_URL, +) + origins = ["*"] @@ -227,6 +238,14 @@ async def update_embedding_config( update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) + app.state.EMBEDDING_FUNCTION = get_embedding_function( + app.state.RAG_EMBEDDING_ENGINE, + app.state.RAG_EMBEDDING_MODEL, + app.state.sentence_transformer_ef, + app.state.OPENAI_API_KEY, + app.state.OPENAI_API_BASE_URL, + ) + return { "status": True, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, @@ -367,27 +386,22 @@ def query_doc_handler( user=Depends(get_current_user), ): try: - embeddings_function = get_embeddings_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, - ) - - return query_embeddings_doc( - collection_name=form_data.collection_name, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, - embeddings_function=embeddings_function, - reranking_function=app.state.sentence_transformer_rf, - hybrid_search=( - form_data.hybrid - if form_data.hybrid - else app.state.ENABLE_RAG_HYBRID_SEARCH - ), - ) + if app.state.ENABLE_RAG_HYBRID_SEARCH: + return query_doc_with_hybrid_search( + collection_name=form_data.collection_name, + query=form_data.query, + embeddings_function=app.state.EMBEDDING_FUNCTION, + reranking_function=app.state.sentence_transformer_rf, + k=form_data.k if form_data.k else app.state.TOP_K, + r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, + ) + else: + return query_doc( + collection_name=form_data.collection_name, + query=form_data.query, + embeddings_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.TOP_K, + ) except Exception as e: log.exception(e) raise HTTPException( @@ -410,27 +424,23 @@ def query_collection_handler( user=Depends(get_current_user), ): try: - embeddings_function = get_embeddings_function( - app.state.RAG_EMBEDDING_ENGINE, - app.state.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - app.state.OPENAI_API_KEY, - app.state.OPENAI_API_BASE_URL, - ) + if app.state.ENABLE_RAG_HYBRID_SEARCH: + return query_collection_with_hybrid_search( + collection_names=form_data.collection_names, + query=form_data.query, + embeddings_function=app.state.EMBEDDING_FUNCTION, + reranking_function=app.state.sentence_transformer_rf, + k=form_data.k if form_data.k else app.state.TOP_K, + r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, + ) + else: + return query_collection( + collection_names=form_data.collection_names, + query=form_data.query, + embeddings_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.TOP_K, + ) - return query_embeddings_collection( - collection_names=form_data.collection_names, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, - embeddings_function=embeddings_function, - reranking_function=app.state.sentence_transformer_rf, - hybrid_search=( - form_data.hybrid - if form_data.hybrid - else app.state.ENABLE_RAG_HYBRID_SEARCH - ), - ) except Exception as e: log.exception(e) raise HTTPException( @@ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b collection = CHROMA_CLIENT.create_collection(name=collection_name) - embedding_func = get_embeddings_function( + embedding_func = get_embedding_function( app.state.RAG_EMBEDDING_ENGINE, app.state.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 10ab9bbda..eb9d5c84b 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -26,61 +26,72 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def query_embeddings_doc( +def query_doc( collection_name: str, query: str, - embeddings_function, - reranking_function, + embedding_function, k: int, - r: int, - hybrid_search: bool, ): try: collection = CHROMA_CLIENT.get_collection(name=collection_name) + query_embeddings = embedding_function(query) + result = collection.query( + query_embeddings=[query_embeddings], + n_results=k, + ) - if hybrid_search: - documents = collection.get() # get all documents - bm25_retriever = BM25Retriever.from_texts( - texts=documents.get("documents"), - metadatas=documents.get("metadatas"), - ) - bm25_retriever.k = k + log.info(f"query_doc:result {result}") + return result + except Exception as e: + raise e - chroma_retriever = ChromaRetriever( - collection=collection, - embeddings_function=embeddings_function, - top_n=k, - ) - ensemble_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] - ) +def query_doc_with_hybrid_search( + collection_name: str, + query: str, + embedding_function, + k: int, + reranking_function, + r: int, +): + try: + collection = CHROMA_CLIENT.get_collection(name=collection_name) + documents = collection.get() # get all documents - compressor = RerankCompressor( - embeddings_function=embeddings_function, - reranking_function=reranking_function, - r_score=r, - top_n=k, - ) + bm25_retriever = BM25Retriever.from_texts( + texts=documents.get("documents"), + metadatas=documents.get("metadatas"), + ) + bm25_retriever.k = k - compression_retriever = ContextualCompressionRetriever( - base_compressor=compressor, base_retriever=ensemble_retriever - ) + chroma_retriever = ChromaRetriever( + collection=collection, + embedding_function=embedding_function, + top_n=k, + ) - result = compression_retriever.invoke(query) - result = { - "distances": [[d.metadata.get("score") for d in result]], - "documents": [[d.page_content for d in result]], - "metadatas": [[d.metadata for d in result]], - } - else: - query_embeddings = embeddings_function(query) - result = collection.query( - query_embeddings=[query_embeddings], - n_results=k, - ) + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] + ) - log.info(f"query_embeddings_doc:result {result}") + compressor = RerankCompressor( + embedding_function=embedding_function, + reranking_function=reranking_function, + r_score=r, + top_n=k, + ) + + compression_retriever = ContextualCompressionRetriever( + base_compressor=compressor, base_retriever=ensemble_retriever + ) + + result = compression_retriever.invoke(query) + result = { + "distances": [[d.metadata.get("score") for d in result]], + "documents": [[d.page_content for d in result]], + "metadatas": [[d.metadata for d in result]], + } + log.info(f"query_doc_with_hybrid_search:result {result}") return result except Exception as e: raise e @@ -127,35 +138,52 @@ def merge_and_sort_query_results(query_results, k, reverse=False): return result -def query_embeddings_collection( +def query_collection( collection_names: List[str], query: str, + embedding_function, k: int, - r: float, - embeddings_function, - reranking_function, - hybrid_search: bool, ): - results = [] - for collection_name in collection_names: try: - result = query_embeddings_doc( + result = query_doc( collection_name=collection_name, query=query, k=k, - r=r, - embeddings_function=embeddings_function, + embedding_function=embedding_function, + ) + results.append(result) + except: + pass + return merge_and_sort_query_results(results, k=k) + + +def query_collection_with_hybrid_search( + collection_names: List[str], + query: str, + embedding_function, + k: int, + reranking_function, + r: float, +): + + results = [] + for collection_name in collection_names: + try: + result = query_doc_with_hybrid_search( + collection_name=collection_name, + query=query, + embedding_function=embedding_function, + k=k, reranking_function=reranking_function, - hybrid_search=hybrid_search, + r=r, ) results.append(result) except: pass - reverse = hybrid_search and reranking_function is not None - return merge_and_sort_query_results(results, k=k, reverse=reverse) + return merge_and_sort_query_results(results, k=k, reverse=True) def rag_template(template: str, context: str, query: str): @@ -164,7 +192,7 @@ def rag_template(template: str, context: str, query: str): return template -def get_embeddings_function( +def get_embedding_function( embedding_engine, embedding_model, embedding_function, @@ -204,19 +232,13 @@ def rag_messages( docs, messages, template, + embedding_function, k, + reranking_function, r, hybrid_search, - embedding_engine, - embedding_model, - embedding_function, - reranking_function, - openai_key, - openai_url, ): - log.debug( - f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}" - ) + log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}") last_user_message_idx = None for i in range(len(messages) - 1, -1, -1): @@ -243,14 +265,6 @@ def rag_messages( content_type = None query = "" - embeddings_function = get_embeddings_function( - embedding_engine, - embedding_model, - embedding_function, - openai_key, - openai_url, - ) - extracted_collections = [] relevant_contexts = [] @@ -271,26 +285,31 @@ def rag_messages( try: if doc["type"] == "text": context = doc["content"] - elif doc["type"] == "collection": - context = query_embeddings_collection( - collection_names=doc["collection_names"], - query=query, - k=k, - r=r, - embeddings_function=embeddings_function, - reranking_function=reranking_function, - hybrid_search=hybrid_search, - ) else: - context = query_embeddings_doc( - collection_name=doc["collection_name"], - query=query, - k=k, - r=r, - embeddings_function=embeddings_function, - reranking_function=reranking_function, - hybrid_search=hybrid_search, - ) + if hybrid_search: + context = query_collection_with_hybrid_search( + collection_names=( + doc["collection_names"] + if doc["type"] == "collection" + else [doc["collection_name"]] + ), + query=query, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + else: + context = query_collection( + collection_names=( + doc["collection_names"] + if doc["type"] == "collection" + else [doc["collection_name"]] + ), + query=query, + embedding_function=embedding_function, + k=k, + ) except Exception as e: log.exception(e) context = None @@ -404,7 +423,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun class ChromaRetriever(BaseRetriever): collection: Any - embeddings_function: Any + embedding_function: Any top_n: int def _get_relevant_documents( @@ -413,7 +432,7 @@ class ChromaRetriever(BaseRetriever): *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: - query_embeddings = self.embeddings_function(query) + query_embeddings = self.embedding_function(query) results = self.collection.query( query_embeddings=[query_embeddings], @@ -445,7 +464,7 @@ from sentence_transformers import util class RerankCompressor(BaseDocumentCompressor): - embeddings_function: Any + embedding_function: Any reranking_function: Any r_score: float top_n: int @@ -465,8 +484,8 @@ class RerankCompressor(BaseDocumentCompressor): [(query, doc.page_content) for doc in documents] ) else: - query_embedding = self.embeddings_function(query) - document_embedding = self.embeddings_function( + query_embedding = self.embedding_function(query) + document_embedding = self.embedding_function( [doc.page_content for doc in documents] ) scores = util.cos_sim(query_embedding, document_embedding)[0] diff --git a/backend/main.py b/backend/main.py index cbcbc72f3..1b2772627 100644 --- a/backend/main.py +++ b/backend/main.py @@ -117,18 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware): if "docs" in data: data = {**data} data["messages"] = rag_messages( - data["docs"], - data["messages"], - rag_app.state.RAG_TEMPLATE, - rag_app.state.TOP_K, - rag_app.state.RELEVANCE_THRESHOLD, - rag_app.state.ENABLE_RAG_HYBRID_SEARCH, - rag_app.state.RAG_EMBEDDING_ENGINE, - rag_app.state.RAG_EMBEDDING_MODEL, - rag_app.state.sentence_transformer_ef, - rag_app.state.sentence_transformer_rf, - rag_app.state.OPENAI_API_KEY, - rag_app.state.OPENAI_API_BASE_URL, + docs=data["docs"], + messages=data["messages"], + template=rag_app.state.RAG_TEMPLATE, + embedding_function=rag_app.state.EMBEDDING_FUNCTION, + k=rag_app.state.TOP_K, + reranking_function=rag_app.state.sentence_transformer_rf, + r=rag_app.state.RELEVANCE_THRESHOLD, + hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH, ) del data["docs"]