diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 2db2cf1ff..d72fde741 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -70,6 +70,7 @@ from config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_HYBRID, RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, @@ -91,6 +92,8 @@ app = FastAPI() app.state.TOP_K = RAG_TOP_K app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.HYBRID = RAG_HYBRID + app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP @@ -321,6 +324,7 @@ async def get_query_settings(user=Depends(get_admin_user)): "template": app.state.RAG_TEMPLATE, "k": app.state.TOP_K, "r": app.state.RELEVANCE_THRESHOLD, + "hybrid": app.state.HYBRID, } @@ -328,6 +332,7 @@ class QuerySettingsForm(BaseModel): k: Optional[int] = None r: Optional[float] = None template: Optional[str] = None + hybrid: Optional[bool] = None @app.post("/query/settings/update") @@ -337,7 +342,14 @@ async def update_query_settings( app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.TOP_K = form_data.k if form_data.k else 4 app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - return {"status": True, "template": app.state.RAG_TEMPLATE} + app.state.HYBRID = form_data.hybrid if form_data.hybrid else False + return { + "status": True, + "template": app.state.RAG_TEMPLATE, + "k": app.state.TOP_K, + "r": app.state.RELEVANCE_THRESHOLD, + "hybrid": app.state.HYBRID, + } class QueryDocForm(BaseModel): @@ -345,6 +357,7 @@ class QueryDocForm(BaseModel): query: str k: Optional[int] = None r: Optional[float] = None + hybrid: Optional[bool] = None @app.post("/query/doc") @@ -368,6 +381,7 @@ def query_doc_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, + hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, ) except Exception as e: log.exception(e) @@ -382,6 +396,7 @@ class QueryCollectionsForm(BaseModel): query: str k: Optional[int] = None r: Optional[float] = None + hybrid: Optional[bool] = None @app.post("/query/collection") @@ -405,6 +420,7 @@ def query_collection_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, + hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, ) except Exception as e: log.exception(e) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index da71495bb..0e6e3dd68 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -32,13 +32,13 @@ def query_embeddings_doc( collection_name: str, query: str, embeddings_function, + reranking_function, k: int, - reranking_function: Optional[CrossEncoder] = None, r: Optional[float] = None, + hybrid: Optional[bool] = False, ): try: - - if reranking_function: + if hybrid: # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection(name=collection_name) @@ -142,6 +142,7 @@ def query_embeddings_collection( r: float, embeddings_function, reranking_function, + hybrid: bool, ): results = [] @@ -155,6 +156,7 @@ def query_embeddings_collection( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid=hybrid, ) results.append(result) except: @@ -211,6 +213,7 @@ def rag_messages( template, k, r, + hybrid, embedding_engine, embedding_model, embedding_function, @@ -283,6 +286,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid=hybrid, ) else: context = query_embeddings_doc( @@ -292,6 +296,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid=hybrid, ) except Exception as e: log.exception(e) diff --git a/backend/config.py b/backend/config.py index 622b95059..e60a789b7 100644 --- a/backend/config.py +++ b/backend/config.py @@ -422,6 +422,7 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) +RAG_HYBRID = os.environ.get("RAG_HYBRID", "").lower() == "true" RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index c6695bb6b..9fb7c6775 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -43,7 +43,8 @@ let querySettings = { template: '', r: 0.0, - k: 4 + k: 4, + hybrid: false }; const scanHandler = async () => { @@ -174,6 +175,12 @@ } }; + const toggleHybridSearch = async () => { + querySettings.hybrid = !querySettings.hybrid; + + querySettings = await updateQuerySettings(localStorage.token, querySettings); + }; + onMount(async () => { const res = await getRAGConfig(localStorage.token); @@ -202,6 +209,24 @@