From 9755cd5baa367620f6b1f08ef0565498c505e10b Mon Sep 17 00:00:00 2001 From: Steven Kreitzer Date: Thu, 25 Apr 2024 17:31:21 -0500 Subject: [PATCH] feat: toggle hybrid search --- backend/apps/rag/main.py | 18 +- backend/apps/rag/utils.py | 11 +- backend/config.py | 1 + .../documents/Settings/General.svelte | 191 ++++++++++-------- 4 files changed, 133 insertions(+), 88 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 2db2cf1ff..d72fde741 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -70,6 +70,7 @@ from config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_HYBRID, RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, @@ -91,6 +92,8 @@ app = FastAPI() app.state.TOP_K = RAG_TOP_K app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.HYBRID = RAG_HYBRID + app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP @@ -321,6 +324,7 @@ async def get_query_settings(user=Depends(get_admin_user)): "template": app.state.RAG_TEMPLATE, "k": app.state.TOP_K, "r": app.state.RELEVANCE_THRESHOLD, + "hybrid": app.state.HYBRID, } @@ -328,6 +332,7 @@ class QuerySettingsForm(BaseModel): k: Optional[int] = None r: Optional[float] = None template: Optional[str] = None + hybrid: Optional[bool] = None @app.post("/query/settings/update") @@ -337,7 +342,14 @@ async def update_query_settings( app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.TOP_K = form_data.k if form_data.k else 4 app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - return {"status": True, "template": app.state.RAG_TEMPLATE} + app.state.HYBRID = form_data.hybrid if form_data.hybrid else False + return { + "status": True, + "template": app.state.RAG_TEMPLATE, + "k": app.state.TOP_K, + "r": app.state.RELEVANCE_THRESHOLD, + "hybrid": app.state.HYBRID, + } class QueryDocForm(BaseModel): @@ -345,6 +357,7 @@ class QueryDocForm(BaseModel): query: str k: Optional[int] = None r: Optional[float] = None + hybrid: Optional[bool] = None @app.post("/query/doc") @@ -368,6 +381,7 @@ def query_doc_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, + hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, ) except Exception as e: log.exception(e) @@ -382,6 +396,7 @@ class QueryCollectionsForm(BaseModel): query: str k: Optional[int] = None r: Optional[float] = None + hybrid: Optional[bool] = None @app.post("/query/collection") @@ -405,6 +420,7 @@ def query_collection_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, + hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, ) except Exception as e: log.exception(e) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index da71495bb..0e6e3dd68 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -32,13 +32,13 @@ def query_embeddings_doc( collection_name: str, query: str, embeddings_function, + reranking_function, k: int, - reranking_function: Optional[CrossEncoder] = None, r: Optional[float] = None, + hybrid: Optional[bool] = False, ): try: - - if reranking_function: + if hybrid: # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection(name=collection_name) @@ -142,6 +142,7 @@ def query_embeddings_collection( r: float, embeddings_function, reranking_function, + hybrid: bool, ): results = [] @@ -155,6 +156,7 @@ def query_embeddings_collection( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid=hybrid, ) results.append(result) except: @@ -211,6 +213,7 @@ def rag_messages( template, k, r, + hybrid, embedding_engine, embedding_model, embedding_function, @@ -283,6 +286,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid=hybrid, ) else: context = query_embeddings_doc( @@ -292,6 +296,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid=hybrid, ) except Exception as e: log.exception(e) diff --git a/backend/config.py b/backend/config.py index 622b95059..e60a789b7 100644 --- a/backend/config.py +++ b/backend/config.py @@ -422,6 +422,7 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) +RAG_HYBRID = os.environ.get("RAG_HYBRID", "").lower() == "true" RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index c6695bb6b..9fb7c6775 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -43,7 +43,8 @@ let querySettings = { template: '', r: 0.0, - k: 4 + k: 4, + hybrid: false }; const scanHandler = async () => { @@ -174,6 +175,12 @@ } }; + const toggleHybridSearch = async () => { + querySettings.hybrid = !querySettings.hybrid; + + querySettings = await updateQuerySettings(localStorage.token, querySettings); + }; + onMount(async () => { const res = await getRAGConfig(localStorage.token); @@ -202,6 +209,24 @@
{$i18n.t('General Settings')}
+
+
{$i18n.t('Hybrid Search')}
+ + +
+
{$i18n.t('Embedding Model Engine')}
@@ -386,78 +411,74 @@
-
-
{$i18n.t('Update Reranking Model')}
+ {#if querySettings.hybrid === true} +
+
{$i18n.t('Update Reranking Model')}
-
-
- -
-
- {:else} - - - - - {/if} - + + + + {/if} + +
-
-
- {$i18n.t( - 'Note: If you choose a reranking model, it will use that to score and rerank instead of the embedding model.' - )} -
- -
+
+ {/if}
@@ -583,25 +604,27 @@
-
-
-
- {$i18n.t('Relevance Threshold')} -
+ {#if querySettings.hybrid === true} +
+
+
+ {$i18n.t('Relevance Threshold')} +
-
- +
+ +
-
+ {/if}
{$i18n.t('RAG Template')}