Feature: Adjusted all necessary functions to handle individual rag configuration - fallback to default configuration if individual config is not used

This commit is contained in:
Maytown 2025-05-02 09:55:53 +02:00
parent 220ad3723e
commit e3a93b24a0

View File

@ -193,8 +193,33 @@ async def get_status(request: Request):
}
@router.get("/embedding")
async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
@router.post("/embedding")
async def get_embedding_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
"""
Retrieve the embedding configuration.
If DEFAULT_RAG_SETTINGS is True, return the default embedding settings.
Otherwise, return the embedding configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
# Return the embedding configuration from the database
rag_config = knowledge_base.data.get("rag_config", {})
return {
"status": True,
"embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE),
"embedding_model": rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL),
"embedding_batch_size": rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE),
"openai_config": rag_config.get("openai_config", {
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
"key": request.app.state.config.RAG_OPENAI_API_KEY,
}),
"ollama_config": rag_config.get("ollama_config", {
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
}),
}
else:
# Return default embedding settings
return {
"status": True,
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
@ -211,8 +236,23 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
}
@router.get("/reranking")
async def get_reraanking_config(request: Request, user=Depends(get_admin_user)):
@router.post("/reranking")
async def get_reranking_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
"""
Retrieve the reranking configuration.
If DEFAULT_RAG_SETTINGS is True, return the default reranking settings.
Otherwise, return the reranking configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
# Return the reranking configuration from the database
rag_config = knowledge_base.data.get("rag_config", {})
return {
"status": True,
"reranking_model": rag_config.get("reranking_model", request.app.state.config.RAG_RERANKING_MODEL),
}
else:
# Return default reranking settings
return {
"status": True,
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
@ -321,10 +361,31 @@ class RerankingModelUpdateForm(BaseModel):
async def update_reranking_config(
request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
"""
Update the reranking model configuration.
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
try:
knowledge_base = Knowledges.get_knowledge_base_by_user_id(user.id)
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.data.get("rag_config", {})
rag_config["reranking_model"] = form_data.reranking_model
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config}
)
return {
"status": True,
"reranking_model": rag_config["reranking_model"],
"message": "Reranking model updated in the database.",
}
else:
# Update the global configuration
log.info(
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
try:
@ -339,6 +400,7 @@ async def update_reranking_config(
return {
"status": True,
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
"message": "Reranking model updated globally.",
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
@ -348,8 +410,92 @@ async def update_reranking_config(
)
@router.get("/config")
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
@router.post("/config")
async def get_rag_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_admin_user)):
"""
Retrieve the full RAG configuration.
If DEFAULT_RAG_SETTINGS is True, return the default settings.
Otherwise, return the RAG configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
# Return the RAG configuration from the database
rag_config = knowledge_base.data.get("rag_config", {})
return {
"status": True,
# RAG settings
"RAG_TEMPLATE": rag_config.get("template", request.app.state.config.RAG_TEMPLATE),
"TOP_K": rag_config.get("top_k", request.app.state.config.TOP_K),
"BYPASS_EMBEDDING_AND_RETRIEVAL": rag_config.get("bypass_embedding_and_retrieval", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL),
"RAG_FULL_CONTEXT": rag_config.get("rag_full_context", request.app.state.config.RAG_FULL_CONTEXT),
# Hybrid search settings
"ENABLE_RAG_HYBRID_SEARCH": rag_config.get("enable_rag_hybrid_search", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH),
"TOP_K_RERANKER": rag_config.get("top_k_reranker", request.app.state.config.TOP_K_RERANKER),
"RELEVANCE_THRESHOLD": rag_config.get("relevance_threshold", request.app.state.config.RELEVANCE_THRESHOLD),
# Content extraction settings
"CONTENT_EXTRACTION_ENGINE": rag_config.get("content_extraction_engine", request.app.state.config.CONTENT_EXTRACTION_ENGINE),
"PDF_EXTRACT_IMAGES": rag_config.get("pdf_extract_images", request.app.state.config.PDF_EXTRACT_IMAGES),
"TIKA_SERVER_URL": rag_config.get("tika_server_url", request.app.state.config.TIKA_SERVER_URL),
"DOCLING_SERVER_URL": rag_config.get("docling_server_url", request.app.state.config.DOCLING_SERVER_URL),
"DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("document_intelligence_endpoint", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT),
"DOCUMENT_INTELLIGENCE_KEY": rag_config.get("document_intelligence_key", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY),
"MISTRAL_OCR_API_KEY": rag_config.get("mistral_ocr_api_key", request.app.state.config.MISTRAL_OCR_API_KEY),
# Chunking settings
"TEXT_SPLITTER": rag_config.get("text_splitter", request.app.state.config.TEXT_SPLITTER),
"CHUNK_SIZE": rag_config.get("chunk_size", request.app.state.config.CHUNK_SIZE),
"CHUNK_OVERLAP": rag_config.get("chunk_overlap", request.app.state.config.CHUNK_OVERLAP),
# File upload settings
"FILE_MAX_SIZE": rag_config.get("file_max_size", request.app.state.config.FILE_MAX_SIZE),
"FILE_MAX_COUNT": rag_config.get("file_max_count", request.app.state.config.FILE_MAX_COUNT),
# Integration settings
"ENABLE_GOOGLE_DRIVE_INTEGRATION": rag_config.get("enable_google_drive_integration", request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION),
"ENABLE_ONEDRIVE_INTEGRATION": rag_config.get("enable_onedrive_integration", request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION),
# Web search settings
"web": {
"ENABLE_WEB_SEARCH": rag_config.get("enable_web_search", request.app.state.config.ENABLE_WEB_SEARCH),
"WEB_SEARCH_ENGINE": rag_config.get("web_search_engine", request.app.state.config.WEB_SEARCH_ENGINE),
"WEB_SEARCH_TRUST_ENV": rag_config.get("web_search_trust_env", request.app.state.config.WEB_SEARCH_TRUST_ENV),
"WEB_SEARCH_RESULT_COUNT": rag_config.get("web_search_result_count", request.app.state.config.WEB_SEARCH_RESULT_COUNT),
"WEB_SEARCH_CONCURRENT_REQUESTS": rag_config.get("web_search_concurrent_requests", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS),
"WEB_SEARCH_DOMAIN_FILTER_LIST": rag_config.get("web_search_domain_filter_list", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST),
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": rag_config.get("bypass_web_search_embedding_and_retrieval", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL),
"SEARXNG_QUERY_URL": rag_config.get("searxng_query_url", request.app.state.config.SEARXNG_QUERY_URL),
"GOOGLE_PSE_API_KEY": rag_config.get("google_pse_api_key", request.app.state.config.GOOGLE_PSE_API_KEY),
"GOOGLE_PSE_ENGINE_ID": rag_config.get("google_pse_engine_id", request.app.state.config.GOOGLE_PSE_ENGINE_ID),
"BRAVE_SEARCH_API_KEY": rag_config.get("brave_search_api_key", request.app.state.config.BRAVE_SEARCH_API_KEY),
"KAGI_SEARCH_API_KEY": rag_config.get("kagi_search_api_key", request.app.state.config.KAGI_SEARCH_API_KEY),
"MOJEEK_SEARCH_API_KEY": rag_config.get("mojeek_search_api_key", request.app.state.config.MOJEEK_SEARCH_API_KEY),
"BOCHA_SEARCH_API_KEY": rag_config.get("bocha_search_api_key", request.app.state.config.BOCHA_SEARCH_API_KEY),
"SERPSTACK_API_KEY": rag_config.get("serpstack_api_key", request.app.state.config.SERPSTACK_API_KEY),
"SERPSTACK_HTTPS": rag_config.get("serpstack_https", request.app.state.config.SERPSTACK_HTTPS),
"SERPER_API_KEY": rag_config.get("serper_api_key", request.app.state.config.SERPER_API_KEY),
"SERPLY_API_KEY": rag_config.get("serply_api_key", request.app.state.config.SERPLY_API_KEY),
"TAVILY_API_KEY": rag_config.get("tavily_api_key", request.app.state.config.TAVILY_API_KEY),
"SEARCHAPI_API_KEY": rag_config.get("searchapi_api_key", request.app.state.config.SEARCHAPI_API_KEY),
"SEARCHAPI_ENGINE": rag_config.get("searchapi_engine", request.app.state.config.SEARCHAPI_ENGINE),
"SERPAPI_API_KEY": rag_config.get("serpapi_api_key", request.app.state.config.SERPAPI_API_KEY),
"SERPAPI_ENGINE": rag_config.get("serpapi_engine", request.app.state.config.SERPAPI_ENGINE),
"JINA_API_KEY": rag_config.get("jina_api_key", request.app.state.config.JINA_API_KEY),
"BING_SEARCH_V7_ENDPOINT": rag_config.get("bing_search_v7_endpoint", request.app.state.config.BING_SEARCH_V7_ENDPOINT),
"BING_SEARCH_V7_SUBSCRIPTION_KEY": rag_config.get("bing_search_v7_subscription_key", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY),
"EXA_API_KEY": rag_config.get("exa_api_key", request.app.state.config.EXA_API_KEY),
"PERPLEXITY_API_KEY": rag_config.get("perplexity_api_key", request.app.state.config.PERPLEXITY_API_KEY),
"SOUGOU_API_SID": rag_config.get("sougou_api_sid", request.app.state.config.SOUGOU_API_SID),
"SOUGOU_API_SK": rag_config.get("sougou_api_sk", request.app.state.config.SOUGOU_API_SK),
"WEB_LOADER_ENGINE": rag_config.get("web_loader_engine", request.app.state.config.WEB_LOADER_ENGINE),
"ENABLE_WEB_LOADER_SSL_VERIFICATION": rag_config.get("enable_web_loader_ssl_verification", request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION),
"PLAYWRIGHT_WS_URL": rag_config.get("playwright_ws_url", request.app.state.config.PLAYWRIGHT_WS_URL),
"PLAYWRIGHT_TIMEOUT": rag_config.get("playwright_timeout", request.app.state.config.PLAYWRIGHT_TIMEOUT),
"FIRECRAWL_API_KEY": rag_config.get("firecrawl_api_key", request.app.state.config.FIRECRAWL_API_KEY),
"FIRECRAWL_API_BASE_URL": rag_config.get("firecrawl_api_base_url", request.app.state.config.FIRECRAWL_API_BASE_URL),
"TAVILY_EXTRACT_DEPTH": rag_config.get("tavily_extract_depth", request.app.state.config.TAVILY_EXTRACT_DEPTH),
"YOUTUBE_LOADER_LANGUAGE": rag_config.get("youtube_loader_language", request.app.state.config.YOUTUBE_LOADER_LANGUAGE),
"YOUTUBE_LOADER_PROXY_URL": rag_config.get("youtube_loader_proxy_url", request.app.state.config.YOUTUBE_LOADER_PROXY_URL),
"YOUTUBE_LOADER_TRANSLATION": rag_config.get("youtube_loader_translation", request.app.state.config.YOUTUBE_LOADER_TRANSLATION),
},
}
else:
# Return default RAG settings
return {
"status": True,
# RAG settings
@ -508,8 +654,33 @@ class ConfigForm(BaseModel):
@router.post("/config/update")
async def update_rag_config(
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
request: Request, form_data: ConfigForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user)
):
"""
Update the RAG configuration.
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.data.get("rag_config", {})
# Update only the provided fields in the rag_config
for field, value in form_data.dict(exclude_unset=True).items():
if field == "web" and value is not None:
rag_config["web"] = {**rag_config.get("web", {}), **value.dict(exclude_unset=True)}
else:
rag_config[field] = value
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config}
)
return rag_config
else:
# Update the global configuration
# RAG settings
request.app.state.config.RAG_TEMPLATE = (
form_data.RAG_TEMPLATE
@ -538,7 +709,6 @@ async def update_rag_config(
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
)
# Free up memory if hybrid search is disabled
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
request.app.state.rf = None
@ -822,6 +992,28 @@ def save_docs_to_vector_db(
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
)
# Retrieve the knowledge base using the collection_name
knowledge_base = Knowledges.get_knowledge_base_by_collection_name(collection_name)
if not knowledge_base:
raise ValueError(f"Knowledge base not found for collection: {collection_name}")
# Retrieve the RAG configuration
rag_config = {}
if not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.data.get("rag_config", {})
# Use knowledge-base-specific or default configurations
text_splitter_type = rag_config.get("text_splitter", request.app.state.config.TEXT_SPLITTER)
chunk_size = rag_config.get("chunk_size", request.app.state.config.CHUNK_SIZE)
chunk_overlap = rag_config.get("chunk_overlap", request.app.state.config.CHUNK_OVERLAP)
embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE)
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE)
openai_api_base_url = rag_config.get("openai_api_base_url", request.app.state.config.RAG_OPENAI_API_BASE_URL)
openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY)
ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
# Check if entries with the same hash (metadata.hash) already exist
if metadata and "hash" in metadata:
result = VECTOR_DB_CLIENT.query(
@ -836,13 +1028,13 @@ def save_docs_to_vector_db(
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
if split:
if request.app.state.config.TEXT_SPLITTER in ["", "character"]:
if text_splitter_type in ["", "character"]:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=request.app.state.config.CHUNK_SIZE,
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=True,
)
elif request.app.state.config.TEXT_SPLITTER == "token":
elif text_splitter_type == "token":
log.info(
f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
)
@ -850,8 +1042,8 @@ def save_docs_to_vector_db(
tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME))
text_splitter = TokenTextSplitter(
encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME),
chunk_size=request.app.state.config.CHUNK_SIZE,
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=True,
)
else:
@ -869,8 +1061,8 @@ def save_docs_to_vector_db(
**(metadata if metadata else {}),
"embedding_config": json.dumps(
{
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
"model": request.app.state.config.RAG_EMBEDDING_MODEL,
"engine": embedding_engine,
"model": embedding_model,
}
),
}
@ -903,20 +1095,20 @@ def save_docs_to_vector_db(
log.info(f"adding to collection {collection_name}")
embedding_function = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
embedding_engine,
embedding_model,
request.app.state.ef,
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_BASE_URL
openai_api_base_url
if embedding_engine == "openai"
else ollama_base_url
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
openai_api_key
if embedding_engine == "openai"
else ollama_api_key
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
embedding_batch_size,
)
embeddings = embedding_function(
@ -966,6 +1158,39 @@ def process_file(
if collection_name is None:
collection_name = f"file-{file.id}"
# Retrieve the knowledge base using the collection name
knowledge_base = Knowledges.get_knowledge_base_by_collection_name(collection_name)
if not knowledge_base:
raise ValueError(f"Knowledge base not found for collection: {collection_name}")
# Retrieve the RAG configuration
rag_config = {}
if not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.data.get("rag_config", {})
# Use knowledge-base-specific or default configurations
content_extraction_engine = rag_config.get(
"content_extraction_engine", request.app.state.config.CONTENT_EXTRACTION_ENGINE
)
tika_server_url = rag_config.get(
"tika_server_url", request.app.state.config.TIKA_SERVER_URL
)
docling_server_url = rag_config.get(
"docling_server_url", request.app.state.config.DOCLING_SERVER_URL
)
pdf_extract_images = rag_config.get(
"pdf_extract_images", request.app.state.config.PDF_EXTRACT_IMAGES
)
document_intelligence_endpoint = rag_config.get(
"document_intelligence_endpoint", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT
)
document_intelligence_key = rag_config.get(
"document_intelligence_key", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY
)
mistral_ocr_api_key = rag_config.get(
"mistral_ocr_api_key", request.app.state.config.MISTRAL_OCR_API_KEY
)
if form_data.content:
# Update the content in the file
# Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
@ -1029,13 +1254,13 @@ def process_file(
if file_path:
file_path = Storage.get_file(file_path)
loader = Loader(
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
engine=content_extraction_engine,
TIKA_SERVER_URL=tika_server_url,
DOCLING_SERVER_URL=docling_server_url,
PDF_EXTRACT_IMAGES=pdf_extract_images,
DOCUMENT_INTELLIGENCE_ENDPOINT=document_intelligence_endpoint,
DOCUMENT_INTELLIGENCE_KEY=document_intelligence_key,
MISTRAL_OCR_API_KEY=mistral_ocr_api_key,
)
docs = loader.load(
file.filename, file.meta.get("content_type"), file_path
@ -1078,7 +1303,7 @@ def process_file(
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
if not rag_config.get("bypass_embedding_and_retrieval", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL):
try:
result = save_docs_to_vector_db(
request,