enh: RAG full context mode

This commit is contained in:
Timothy Jaeryang Baek 2025-02-18 21:14:58 -08:00
parent a6a7c548d5
commit 81715f6553
6 changed files with 127 additions and 32 deletions

View File

@ -1578,6 +1578,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
) )
RAG_FULL_CONTEXT = PersistentConfig(
"RAG_FULL_CONTEXT",
"rag.full_context",
os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true",
)
RAG_FILE_MAX_COUNT = PersistentConfig( RAG_FILE_MAX_COUNT = PersistentConfig(
"RAG_FILE_MAX_COUNT", "RAG_FILE_MAX_COUNT",
"rag.file.max_count", "rag.file.max_count",
@ -1929,7 +1935,7 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
RAG_WEB_LOADER_ENGINE = PersistentConfig( RAG_WEB_LOADER_ENGINE = PersistentConfig(
"RAG_WEB_LOADER_ENGINE", "RAG_WEB_LOADER_ENGINE",
"rag.web.loader.engine", "rag.web.loader.engine",
os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web") os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"),
) )
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig( RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
@ -1941,7 +1947,7 @@ RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
PLAYWRIGHT_WS_URI = PersistentConfig( PLAYWRIGHT_WS_URI = PersistentConfig(
"PLAYWRIGHT_WS_URI", "PLAYWRIGHT_WS_URI",
"rag.web.loader.engine.playwright.ws.uri", "rag.web.loader.engine.playwright.ws.uri",
os.environ.get("PLAYWRIGHT_WS_URI", None) os.environ.get("PLAYWRIGHT_WS_URI", None),
) )
#################################### ####################################

View File

@ -156,6 +156,7 @@ from open_webui.config import (
# Retrieval # Retrieval
RAG_TEMPLATE, RAG_TEMPLATE,
DEFAULT_RAG_TEMPLATE, DEFAULT_RAG_TEMPLATE,
RAG_FULL_CONTEXT,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
@ -519,6 +520,8 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION

View File

@ -84,6 +84,19 @@ def query_doc(
raise e raise e
def get_doc(collection_name: str, user: UserModel = None):
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
print(e)
raise e
def query_doc_with_hybrid_search( def query_doc_with_hybrid_search(
collection_name: str, collection_name: str,
query: str, query: str,
@ -137,6 +150,24 @@ def query_doc_with_hybrid_search(
raise e raise e
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_documents = []
combined_metadatas = []
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
# Create the output dictionary
result = {
"documents": [combined_documents],
"metadatas": [combined_metadatas],
}
return result
def merge_and_sort_query_results( def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False query_results: list[dict], k: int, reverse: bool = False
) -> list[dict]: ) -> list[dict]:
@ -194,6 +225,23 @@ def merge_and_sort_query_results(
return result return result
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
def query_collection( def query_collection(
collection_names: list[str], collection_names: list[str],
queries: list[str], queries: list[str],
@ -311,8 +359,11 @@ def get_sources_from_files(
reranking_function, reranking_function,
r, r,
hybrid_search, hybrid_search,
full_context=False,
): ):
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}") log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = [] extracted_collections = []
relevant_contexts = [] relevant_contexts = []
@ -350,36 +401,45 @@ def get_sources_from_files(
log.debug(f"skipping {file} as it has already been extracted") log.debug(f"skipping {file} as it has already been extracted")
continue continue
try: if full_context:
context = None try:
if file.get("type") == "text": context = get_all_items_from_collections(collection_names)
context = file["content"]
else: print("context", context)
if hybrid_search: except Exception as e:
try: log.exception(e)
context = query_collection_with_hybrid_search(
else:
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names, collection_names=collection_names,
queries=queries, queries=queries,
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
reranking_function=reranking_function,
r=r,
) )
except Exception as e: except Exception as e:
log.debug( log.exception(e)
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
extracted_collections.extend(collection_names) extracted_collections.extend(collection_names)

View File

@ -351,6 +351,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, "enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"content_extraction": { "content_extraction": {
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
@ -463,6 +464,7 @@ class WebConfig(BaseModel):
class ConfigUpdateForm(BaseModel): class ConfigUpdateForm(BaseModel):
RAG_FULL_CONTEXT: Optional[bool] = None
pdf_extract_images: Optional[bool] = None pdf_extract_images: Optional[bool] = None
enable_google_drive_integration: Optional[bool] = None enable_google_drive_integration: Optional[bool] = None
file: Optional[FileConfig] = None file: Optional[FileConfig] = None
@ -482,6 +484,12 @@ async def update_rag_config(
else request.app.state.config.PDF_EXTRACT_IMAGES else request.app.state.config.PDF_EXTRACT_IMAGES
) )
request.app.state.config.RAG_FULL_CONTEXT = (
form_data.RAG_FULL_CONTEXT
if form_data.RAG_FULL_CONTEXT is not None
else request.app.state.config.RAG_FULL_CONTEXT
)
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ( request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
form_data.enable_google_drive_integration form_data.enable_google_drive_integration
if form_data.enable_google_drive_integration is not None if form_data.enable_google_drive_integration is not None
@ -588,6 +596,7 @@ async def update_rag_config(
return { return {
"status": True, "status": True,
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
"file": { "file": {
"max_size": request.app.state.config.FILE_MAX_SIZE, "max_size": request.app.state.config.FILE_MAX_SIZE,
"max_count": request.app.state.config.FILE_MAX_COUNT, "max_count": request.app.state.config.FILE_MAX_COUNT,
@ -1379,7 +1388,7 @@ async def process_web_search(
docs, docs,
collection_name, collection_name,
overwrite=True, overwrite=True,
user=user user=user,
) )
return { return {

View File

@ -344,7 +344,7 @@ async def chat_web_search_handler(
"query": searchQuery, "query": searchQuery,
} }
), ),
user=user user=user,
) )
if results: if results:
@ -560,9 +560,9 @@ async def chat_completion_files_handler(
reranking_function=request.app.state.rf, reranking_function=request.app.state.rf,
r=request.app.state.config.RELEVANCE_THRESHOLD, r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT,
), ),
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)

View File

@ -27,7 +27,6 @@
import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import Switch from '$lib/components/common/Switch.svelte'; import Switch from '$lib/components/common/Switch.svelte';
import { text } from '@sveltejs/kit';
import Textarea from '$lib/components/common/Textarea.svelte'; import Textarea from '$lib/components/common/Textarea.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -56,6 +55,8 @@
let chunkOverlap = 0; let chunkOverlap = 0;
let pdfExtractImages = true; let pdfExtractImages = true;
let RAG_FULL_CONTEXT = false;
let enableGoogleDriveIntegration = false; let enableGoogleDriveIntegration = false;
let OpenAIUrl = ''; let OpenAIUrl = '';
@ -182,6 +183,7 @@
max_size: fileMaxSize === '' ? null : fileMaxSize, max_size: fileMaxSize === '' ? null : fileMaxSize,
max_count: fileMaxCount === '' ? null : fileMaxCount max_count: fileMaxCount === '' ? null : fileMaxCount
}, },
RAG_FULL_CONTEXT: RAG_FULL_CONTEXT,
chunk: { chunk: {
text_splitter: textSplitter, text_splitter: textSplitter,
chunk_overlap: chunkOverlap, chunk_overlap: chunkOverlap,
@ -242,6 +244,8 @@
chunkSize = res.chunk.chunk_size; chunkSize = res.chunk.chunk_size;
chunkOverlap = res.chunk.chunk_overlap; chunkOverlap = res.chunk.chunk_overlap;
RAG_FULL_CONTEXT = res.RAG_FULL_CONTEXT;
contentExtractionEngine = res.content_extraction.engine; contentExtractionEngine = res.content_extraction.engine;
tikaServerUrl = res.content_extraction.tika_server_url; tikaServerUrl = res.content_extraction.tika_server_url;
showTikaServerUrl = contentExtractionEngine === 'tika'; showTikaServerUrl = contentExtractionEngine === 'tika';
@ -388,6 +392,19 @@
{/if} {/if}
</button> </button>
</div> </div>
<div class=" py-0.5 flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('Full Context Mode')}</div>
<div class="flex items-center relative">
<Tooltip
content={RAG_FULL_CONTEXT
? 'Inject entire contents as context for comprehensive processing, this is recommended for complex queries.'
: 'Default to segmented retrieval for focused and relevant content extraction, this is recommended for most cases.'}
>
<Switch bind:state={RAG_FULL_CONTEXT} />
</Tooltip>
</div>
</div>
</div> </div>
<hr class="border-gray-100 dark:border-gray-850" /> <hr class="border-gray-100 dark:border-gray-850" />