From 81715f6553be7968e454fb8125c27b9e7bf4c9aa Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 18 Feb 2025 21:14:58 -0800 Subject: [PATCH] enh: RAG full context mode --- backend/open_webui/config.py | 10 +- backend/open_webui/main.py | 3 + backend/open_webui/retrieval/utils.py | 112 ++++++++++++++---- backend/open_webui/routers/retrieval.py | 11 +- backend/open_webui/utils/middleware.py | 4 +- .../admin/Settings/Documents.svelte | 19 ++- 6 files changed, 127 insertions(+), 32 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 714c4486c..6e5fb8de6 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1578,6 +1578,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", ) +RAG_FULL_CONTEXT = PersistentConfig( + "RAG_FULL_CONTEXT", + "rag.full_context", + os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true", +) + RAG_FILE_MAX_COUNT = PersistentConfig( "RAG_FILE_MAX_COUNT", "rag.file.max_count", @@ -1929,7 +1935,7 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( RAG_WEB_LOADER_ENGINE = PersistentConfig( "RAG_WEB_LOADER_ENGINE", "rag.web.loader.engine", - os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web") + os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"), ) RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig( @@ -1941,7 +1947,7 @@ RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig( PLAYWRIGHT_WS_URI = PersistentConfig( "PLAYWRIGHT_WS_URI", "rag.web.loader.engine.playwright.ws.uri", - os.environ.get("PLAYWRIGHT_WS_URI", None) + os.environ.get("PLAYWRIGHT_WS_URI", None), ) #################################### diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 4eafbb533..5cad2ac27 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -156,6 +156,7 @@ from open_webui.config import ( # Retrieval RAG_TEMPLATE, DEFAULT_RAG_TEMPLATE, + RAG_FULL_CONTEXT, RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, @@ -519,6 +520,8 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT + +app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 59490f37f..887d6e02a 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -84,6 +84,19 @@ def query_doc( raise e +def get_doc(collection_name: str, user: UserModel = None): + try: + result = VECTOR_DB_CLIENT.get(collection_name=collection_name) + + if result: + log.info(f"query_doc:result {result.ids} {result.metadatas}") + + return result + except Exception as e: + print(e) + raise e + + def query_doc_with_hybrid_search( collection_name: str, query: str, @@ -137,6 +150,24 @@ def query_doc_with_hybrid_search( raise e +def merge_get_results(get_results: list[dict]) -> dict: + # Initialize lists to store combined data + combined_documents = [] + combined_metadatas = [] + + for data in get_results: + combined_documents.extend(data["documents"][0]) + combined_metadatas.extend(data["metadatas"][0]) + + # Create the output dictionary + result = { + "documents": [combined_documents], + "metadatas": [combined_metadatas], + } + + return result + + def merge_and_sort_query_results( query_results: list[dict], k: int, reverse: bool = False ) -> list[dict]: @@ -194,6 +225,23 @@ def merge_and_sort_query_results( return result +def get_all_items_from_collections(collection_names: list[str]) -> dict: + results = [] + + for collection_name in collection_names: + if collection_name: + try: + result = get_doc(collection_name=collection_name) + if result is not None: + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + else: + pass + + return merge_get_results(results) + + def query_collection( collection_names: list[str], queries: list[str], @@ -311,8 +359,11 @@ def get_sources_from_files( reranking_function, r, hybrid_search, + full_context=False, ): - log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}") + log.debug( + f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}" + ) extracted_collections = [] relevant_contexts = [] @@ -350,36 +401,45 @@ def get_sources_from_files( log.debug(f"skipping {file} as it has already been extracted") continue - try: - context = None - if file.get("type") == "text": - context = file["content"] - else: - if hybrid_search: - try: - context = query_collection_with_hybrid_search( + if full_context: + try: + context = get_all_items_from_collections(collection_names) + + print("context", context) + except Exception as e: + log.exception(e) + + else: + try: + context = None + if file.get("type") == "text": + context = file["content"] + else: + if hybrid_search: + try: + context = query_collection_with_hybrid_search( + collection_names=collection_names, + queries=queries, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + except Exception as e: + log.debug( + "Error when using hybrid search, using" + " non hybrid search as fallback." + ) + + if (not hybrid_search) or (context is None): + context = query_collection( collection_names=collection_names, queries=queries, embedding_function=embedding_function, k=k, - reranking_function=reranking_function, - r=r, ) - except Exception as e: - log.debug( - "Error when using hybrid search, using" - " non hybrid search as fallback." - ) - - if (not hybrid_search) or (context is None): - context = query_collection( - collection_names=collection_names, - queries=queries, - embedding_function=embedding_function, - k=k, - ) - except Exception as e: - log.exception(e) + except Exception as e: + log.exception(e) extracted_collections.extend(collection_names) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 4f7d20fa9..e69d2ce96 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -351,6 +351,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): return { "status": True, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, + "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, "enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, "content_extraction": { "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, @@ -463,6 +464,7 @@ class WebConfig(BaseModel): class ConfigUpdateForm(BaseModel): + RAG_FULL_CONTEXT: Optional[bool] = None pdf_extract_images: Optional[bool] = None enable_google_drive_integration: Optional[bool] = None file: Optional[FileConfig] = None @@ -482,6 +484,12 @@ async def update_rag_config( else request.app.state.config.PDF_EXTRACT_IMAGES ) + request.app.state.config.RAG_FULL_CONTEXT = ( + form_data.RAG_FULL_CONTEXT + if form_data.RAG_FULL_CONTEXT is not None + else request.app.state.config.RAG_FULL_CONTEXT + ) + request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ( form_data.enable_google_drive_integration if form_data.enable_google_drive_integration is not None @@ -588,6 +596,7 @@ async def update_rag_config( return { "status": True, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, + "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, "file": { "max_size": request.app.state.config.FILE_MAX_SIZE, "max_count": request.app.state.config.FILE_MAX_COUNT, @@ -1379,7 +1388,7 @@ async def process_web_search( docs, collection_name, overwrite=True, - user=user + user=user, ) return { diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 073a019ed..b624f2a34 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -344,7 +344,7 @@ async def chat_web_search_handler( "query": searchQuery, } ), - user=user + user=user, ) if results: @@ -560,9 +560,9 @@ async def chat_completion_files_handler( reranking_function=request.app.state.rf, r=request.app.state.config.RELEVANCE_THRESHOLD, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + full_context=request.app.state.config.RAG_FULL_CONTEXT, ), ) - except Exception as e: log.exception(e) diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index 917e924ae..c7c1f0e8f 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -27,7 +27,6 @@ import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; import Switch from '$lib/components/common/Switch.svelte'; - import { text } from '@sveltejs/kit'; import Textarea from '$lib/components/common/Textarea.svelte'; const i18n = getContext('i18n'); @@ -56,6 +55,8 @@ let chunkOverlap = 0; let pdfExtractImages = true; + let RAG_FULL_CONTEXT = false; + let enableGoogleDriveIntegration = false; let OpenAIUrl = ''; @@ -182,6 +183,7 @@ max_size: fileMaxSize === '' ? null : fileMaxSize, max_count: fileMaxCount === '' ? null : fileMaxCount }, + RAG_FULL_CONTEXT: RAG_FULL_CONTEXT, chunk: { text_splitter: textSplitter, chunk_overlap: chunkOverlap, @@ -242,6 +244,8 @@ chunkSize = res.chunk.chunk_size; chunkOverlap = res.chunk.chunk_overlap; + RAG_FULL_CONTEXT = res.RAG_FULL_CONTEXT; + contentExtractionEngine = res.content_extraction.engine; tikaServerUrl = res.content_extraction.tika_server_url; showTikaServerUrl = contentExtractionEngine === 'tika'; @@ -388,6 +392,19 @@ {/if} + +
+
{$i18n.t('Full Context Mode')}
+
+ + + +
+