From 9755cd5baa367620f6b1f08ef0565498c505e10b Mon Sep 17 00:00:00 2001 From: Steven Kreitzer Date: Thu, 25 Apr 2024 17:31:21 -0500 Subject: [PATCH 1/3] feat: toggle hybrid search --- backend/apps/rag/main.py | 18 +- backend/apps/rag/utils.py | 11 +- backend/config.py | 1 + .../documents/Settings/General.svelte | 191 ++++++++++-------- 4 files changed, 133 insertions(+), 88 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 2db2cf1ff..d72fde741 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -70,6 +70,7 @@ from config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_HYBRID, RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, @@ -91,6 +92,8 @@ app = FastAPI() app.state.TOP_K = RAG_TOP_K app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.HYBRID = RAG_HYBRID + app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP @@ -321,6 +324,7 @@ async def get_query_settings(user=Depends(get_admin_user)): "template": app.state.RAG_TEMPLATE, "k": app.state.TOP_K, "r": app.state.RELEVANCE_THRESHOLD, + "hybrid": app.state.HYBRID, } @@ -328,6 +332,7 @@ class QuerySettingsForm(BaseModel): k: Optional[int] = None r: Optional[float] = None template: Optional[str] = None + hybrid: Optional[bool] = None @app.post("/query/settings/update") @@ -337,7 +342,14 @@ async def update_query_settings( app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.TOP_K = form_data.k if form_data.k else 4 app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - return {"status": True, "template": app.state.RAG_TEMPLATE} + app.state.HYBRID = form_data.hybrid if form_data.hybrid else False + return { + "status": True, + "template": app.state.RAG_TEMPLATE, + "k": app.state.TOP_K, + "r": app.state.RELEVANCE_THRESHOLD, + "hybrid": app.state.HYBRID, + } class QueryDocForm(BaseModel): @@ -345,6 +357,7 @@ class QueryDocForm(BaseModel): query: str k: Optional[int] = None r: Optional[float] = None + hybrid: Optional[bool] = None @app.post("/query/doc") @@ -368,6 +381,7 @@ def query_doc_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, + hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, ) except Exception as e: log.exception(e) @@ -382,6 +396,7 @@ class QueryCollectionsForm(BaseModel): query: str k: Optional[int] = None r: Optional[float] = None + hybrid: Optional[bool] = None @app.post("/query/collection") @@ -405,6 +420,7 @@ def query_collection_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, + hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, ) except Exception as e: log.exception(e) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index da71495bb..0e6e3dd68 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -32,13 +32,13 @@ def query_embeddings_doc( collection_name: str, query: str, embeddings_function, + reranking_function, k: int, - reranking_function: Optional[CrossEncoder] = None, r: Optional[float] = None, + hybrid: Optional[bool] = False, ): try: - - if reranking_function: + if hybrid: # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection(name=collection_name) @@ -142,6 +142,7 @@ def query_embeddings_collection( r: float, embeddings_function, reranking_function, + hybrid: bool, ): results = [] @@ -155,6 +156,7 @@ def query_embeddings_collection( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid=hybrid, ) results.append(result) except: @@ -211,6 +213,7 @@ def rag_messages( template, k, r, + hybrid, embedding_engine, embedding_model, embedding_function, @@ -283,6 +286,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid=hybrid, ) else: context = query_embeddings_doc( @@ -292,6 +296,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid=hybrid, ) except Exception as e: log.exception(e) diff --git a/backend/config.py b/backend/config.py index 622b95059..e60a789b7 100644 --- a/backend/config.py +++ b/backend/config.py @@ -422,6 +422,7 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) +RAG_HYBRID = os.environ.get("RAG_HYBRID", "").lower() == "true" RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index c6695bb6b..9fb7c6775 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -43,7 +43,8 @@ let querySettings = { template: '', r: 0.0, - k: 4 + k: 4, + hybrid: false }; const scanHandler = async () => { @@ -174,6 +175,12 @@ } }; + const toggleHybridSearch = async () => { + querySettings.hybrid = !querySettings.hybrid; + + querySettings = await updateQuerySettings(localStorage.token, querySettings); + }; + onMount(async () => { const res = await getRAGConfig(localStorage.token); @@ -202,6 +209,24 @@
{$i18n.t('General Settings')}
+
+
{$i18n.t('Hybrid Search')}
+ + +
+
{$i18n.t('Embedding Model Engine')}
@@ -386,78 +411,74 @@
-
-
{$i18n.t('Update Reranking Model')}
+ {#if querySettings.hybrid === true} +
+
{$i18n.t('Update Reranking Model')}
-
-
- -
-
- {:else} - - - - - {/if} - + + + + {/if} + +
-
-
- {$i18n.t( - 'Note: If you choose a reranking model, it will use that to score and rerank instead of the embedding model.' - )} -
- -
+
+ {/if}
@@ -583,25 +604,27 @@
-
-
-
- {$i18n.t('Relevance Threshold')} -
+ {#if querySettings.hybrid === true} +
+
+
+ {$i18n.t('Relevance Threshold')} +
-
- +
+ +
-
+ {/if}
{$i18n.t('RAG Template')}
From 69822e4c25f038e7aace0a1f029c40009836c267 Mon Sep 17 00:00:00 2001 From: Steven Kreitzer Date: Thu, 25 Apr 2024 20:00:47 -0500 Subject: [PATCH 2/3] fix: sort ranking hybrid --- backend/apps/rag/utils.py | 29 ++++++++++++----------------- backend/main.py | 1 + 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 0e6e3dd68..62c29b2be 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -18,8 +18,6 @@ from langchain.retrievers import ( EnsembleRetriever, ) -from sentence_transformers import CrossEncoder - from typing import Optional from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -34,14 +32,13 @@ def query_embeddings_doc( embeddings_function, reranking_function, k: int, - r: Optional[float] = None, - hybrid: Optional[bool] = False, + r: int, + hybrid: bool, ): try: - if hybrid: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection(name=collection_name) + collection = CHROMA_CLIENT.get_collection(name=collection_name) + if hybrid: documents = collection.get() # get all documents bm25_retriever = BM25Retriever.from_texts( texts=documents.get("documents"), @@ -77,24 +74,19 @@ def query_embeddings_doc( "metadatas": [[d.metadata for d in result]], } else: - # if you use docker use the model from the environment variable query_embeddings = embeddings_function(query) - - log.info(f"query_embeddings_doc {query_embeddings}") - collection = CHROMA_CLIENT.get_collection(name=collection_name) - result = collection.query( query_embeddings=[query_embeddings], n_results=k, ) - log.info(f"query_embeddings_doc:result {result}") + log.info(f"query_embeddings_doc:result {result}") return result except Exception as e: raise e -def merge_and_sort_query_results(query_results, k): +def merge_and_sort_query_results(query_results, k, reverse=False): # Initialize lists to store combined data combined_distances = [] combined_documents = [] @@ -109,7 +101,7 @@ def merge_and_sort_query_results(query_results, k): combined = list(zip(combined_distances, combined_documents, combined_metadatas)) # Sort the list based on distances - combined.sort(key=lambda x: x[0]) + combined.sort(key=lambda x: x[0], reverse=reverse) # We don't have anything :-( if not combined: @@ -162,7 +154,8 @@ def query_embeddings_collection( except: pass - return merge_and_sort_query_results(results, k) + reverse = hybrid and reranking_function is not None + return merge_and_sort_query_results(results, k=k, reverse=reverse) def rag_template(template: str, context: str, query: str): @@ -484,7 +477,9 @@ class RerankCompressor(BaseDocumentCompressor): (d, s) for d, s in docs_with_scores if s >= self.r_score ] - result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) + reverse = self.reranking_function is not None + result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse) + final_results = [] for doc, doc_score in result[: self.top_n]: metadata = doc.metadata diff --git a/backend/main.py b/backend/main.py index 1b92ae733..284d83719 100644 --- a/backend/main.py +++ b/backend/main.py @@ -121,6 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware): rag_app.state.RAG_TEMPLATE, rag_app.state.TOP_K, rag_app.state.RELEVANCE_THRESHOLD, + rag_app.state.HYBRID, rag_app.state.RAG_EMBEDDING_ENGINE, rag_app.state.RAG_EMBEDDING_MODEL, rag_app.state.sentence_transformer_ef, From cebf733b9d0e188e1fd903707a2342678008f4ff Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 26 Apr 2024 14:41:39 -0400 Subject: [PATCH 3/3] refac: naming convention --- backend/apps/rag/main.py | 23 ++++++++++++++++------- backend/apps/rag/utils.py | 14 +++++++------- backend/config.py | 5 ++++- backend/main.py | 2 +- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index d72fde741..654b2481b 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -70,7 +70,7 @@ from config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_HYBRID, + ENABLE_RAG_HYBRID_SEARCH, RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, @@ -92,7 +92,8 @@ app = FastAPI() app.state.TOP_K = RAG_TOP_K app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD -app.state.HYBRID = RAG_HYBRID + +app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP @@ -324,7 +325,7 @@ async def get_query_settings(user=Depends(get_admin_user)): "template": app.state.RAG_TEMPLATE, "k": app.state.TOP_K, "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.HYBRID, + "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, } @@ -342,13 +343,13 @@ async def update_query_settings( app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.TOP_K = form_data.k if form_data.k else 4 app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - app.state.HYBRID = form_data.hybrid if form_data.hybrid else False + app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False return { "status": True, "template": app.state.RAG_TEMPLATE, "k": app.state.TOP_K, "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.HYBRID, + "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, } @@ -381,7 +382,11 @@ def query_doc_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, - hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, + hybrid_search=( + form_data.hybrid + if form_data.hybrid + else app.state.ENABLE_RAG_HYBRID_SEARCH + ), ) except Exception as e: log.exception(e) @@ -420,7 +425,11 @@ def query_collection_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, - hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, + hybrid_search=( + form_data.hybrid + if form_data.hybrid + else app.state.ENABLE_RAG_HYBRID_SEARCH + ), ) except Exception as e: log.exception(e) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 62c29b2be..e9fe8319f 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -33,12 +33,12 @@ def query_embeddings_doc( reranking_function, k: int, r: int, - hybrid: bool, + hybrid_search: bool, ): try: collection = CHROMA_CLIENT.get_collection(name=collection_name) - if hybrid: + if hybrid_search: documents = collection.get() # get all documents bm25_retriever = BM25Retriever.from_texts( texts=documents.get("documents"), @@ -134,7 +134,7 @@ def query_embeddings_collection( r: float, embeddings_function, reranking_function, - hybrid: bool, + hybrid_search: bool, ): results = [] @@ -148,7 +148,7 @@ def query_embeddings_collection( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, - hybrid=hybrid, + hybrid_search=hybrid_search, ) results.append(result) except: @@ -206,7 +206,7 @@ def rag_messages( template, k, r, - hybrid, + hybrid_search, embedding_engine, embedding_model, embedding_function, @@ -279,7 +279,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, - hybrid=hybrid, + hybrid_search=hybrid_search, ) else: context = query_embeddings_doc( @@ -289,7 +289,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, - hybrid=hybrid, + hybrid_search=hybrid_search, ) except Exception as e: log.exception(e) diff --git a/backend/config.py b/backend/config.py index e60a789b7..f67fd0172 100644 --- a/backend/config.py +++ b/backend/config.py @@ -422,7 +422,10 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) -RAG_HYBRID = os.environ.get("RAG_HYBRID", "").lower() == "true" + +ENABLE_RAG_HYBRID_SEARCH = ( + os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" +) RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") diff --git a/backend/main.py b/backend/main.py index 284d83719..b0dc3a7fd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -121,7 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware): rag_app.state.RAG_TEMPLATE, rag_app.state.TOP_K, rag_app.state.RELEVANCE_THRESHOLD, - rag_app.state.HYBRID, + rag_app.state.ENABLE_RAG_HYBRID_SEARCH, rag_app.state.RAG_EMBEDDING_ENGINE, rag_app.state.RAG_EMBEDDING_MODEL, rag_app.state.sentence_transformer_ef,