Feat: Adjusted to handle individual rag config - adjusted user settings to handle individual rag config; adjusted to update/delete used embedders/rerankers; adjusted process file to handle indivudal rag config without changing logic

This commit is contained in:
weberm1 2025-05-23 10:46:56 +02:00
parent bbd312325c
commit 4189459ae2

View File

@ -210,6 +210,10 @@ class SearchForm(BaseModel):
queries: List[str]
class CollectionForm(BaseModel):
knowledge_id: Optional[str] = None
@router.get("/")
async def get_status(request: Request):
return {
@ -224,21 +228,32 @@ async def get_status(request: Request):
}
@router.get("/embedding")
async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
@router.post("/embedding")
async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
"""
Retrieve the embedding configuration.
If DEFAULT_RAG_SETTINGS is True, return the default embedding settings.
Otherwise, return the embedding configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
rag_config = {}
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the embedding configuration from the database
rag_config = knowledge_base.rag_config
return {
"status": True,
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE),
"embedding_model": rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL),
"embedding_batch_size": rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE),
"openai_config": rag_config.get("openai_config", {
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
"key": request.app.state.config.RAG_OPENAI_API_KEY,
},
"ollama_config": {
}),
"ollama_config": rag_config.get("ollama_config", {
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
}),
}
@ -258,18 +273,137 @@ class EmbeddingModelUpdateForm(BaseModel):
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
knowledge_id: Optional[str] = None
@router.post("/embedding/update")
async def update_embedding_config(
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_verified_user)
):
"""
Update the embedding model configuration.
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
try:
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
log.info(
f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
)
# Check if model is in use elsewhere, otherwise free up memory
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('embedding_model'), model_type="embedding_model", id=form_data.knowledge_id)
if not in_use and not request.app.state.ef.get(request.app.state.config.RAG_EMBEDDING_MODEL) == rag_config.get("embedding_model") and rag_config.get("embedding_model"):
del request.app.state.ef[rag_config["embedding_model"]]
engine = rag_config["embedding_engine"]
target_model = rag_config["embedding_model"]
models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine]
# Find and remove the dictionary that contains the target model
for model in models_list[:]: # Create a copy of the list for safe iteration
if model == target_model:
models_list.remove(model)
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
# Update embedding-related fields
rag_config["embedding_engine"] = form_data.embedding_engine
rag_config["embedding_model"] = form_data.embedding_model
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
rag_config["openai_config"] = {
"url": form_data.openai_config.url,
"key": form_data.openai_config.key,
}
rag_config["ollama_config"] = {
"url": form_data.ollama_config.url,
"key": form_data.ollama_config.key,
}
# Update the embedding function
if not rag_config["embedding_model"] in request.app.state.ef:
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
rag_config["embedding_engine"],
rag_config["embedding_model"],
)
request.app.state.EMBEDDING_FUNCTION["embedding_model"] = get_embedding_function(
rag_config["embedding_engine"],
rag_config["embedding_model"],
request.app.state.ef[rag_config["embedding_model"]],
(
rag_config["openai_config"]["url"]
if rag_config["embedding_engine"] == "openai"
else rag_config["ollama_config"]["url"]
),
(
rag_config["openai_config"]["key"]
if rag_config["embedding_engine"] == "openai"
else rag_config["ollama_config"]["key"]
),
rag_config["embedding_batch_size"]
)
# add model to state for reloading on startup
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
# add model to state for selectable reranking models
if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]:
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
rag_config["DOWNLOADED_EMBEDDING_MODELS"] = request.app.state.config.DOWNLOADED_EMBEDDING_MODELS
rag_config["LOADED_EMBEDDING_MODELS"] = request.app.state.config.LOADED_EMBEDDING_MODELS
# Save the updated configuration to the database
Knowledges.update_rag_config_by_id(
id=form_data.knowledge_id, rag_config=rag_config
)
return {
"status": True,
"embedding_engine": rag_config["embedding_engine"],
"embedding_model": rag_config["embedding_model"],
"embedding_batch_size": rag_config["embedding_batch_size"],
"openai_config": rag_config.get("openai_config", {}),
"ollama_config": rag_config.get("ollama_config", {}),
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
"message": "Embedding configuration updated in the database.",
}
else:
# Update the global configuration
log.info(
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
try:
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
# Check if model is in use elsewhere, otherwise free up memory
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_EMBEDDING_MODEL, model_type="embedding_model")
if not in_use:
del request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL]
engine = request.app.state.config.RAG_EMBEDDING_ENGINE
target_model = request.app.state.config.RAG_EMBEDDING_MODEL
models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine]
# Find and remove the dictionary that contains the target model
for model in models_list[:]: # Create a copy of the list for safe iteration
if model == target_model:
models_list.remove(model)
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config is not None:
@ -292,15 +426,17 @@ async def update_embedding_config(
form_data.embedding_batch_size
)
request.app.state.ef = get_ef(
# Update the embedding function
if not form_data.embedding_model in request.app.state.ef:
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
)
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.ef,
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL],
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
@ -313,6 +449,13 @@ async def update_embedding_config(
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
# add model to state for reloading on startup
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
# add model to state for selectable embedding models
if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]:
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
return {
"status": True,
@ -327,6 +470,9 @@ async def update_embedding_config(
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
"message": "Embedding configuration updated globally.",
}
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
@ -336,98 +482,116 @@ async def update_embedding_config(
)
@router.get("/config")
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
@router.post("/config")
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_verified_user)):
"""
Retrieve the full RAG configuration.
If DEFAULT_RAG_SETTINGS is True, return the default settings.
Otherwise, return the RAG configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
rag_config = {}
web_config = {}
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the RAG configuration from the database
rag_config = knowledge_base.rag_config
web_config = rag_config.get("web", {})
return {
"status": True,
# RAG settings
"RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE,
"TOP_K": request.app.state.config.TOP_K,
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
"RAG_TEMPLATE": rag_config.get("TEMPLATE", request.app.state.config.RAG_TEMPLATE),
"TOP_K": rag_config.get("TOP_K", request.app.state.config.TOP_K),
"BYPASS_EMBEDDING_AND_RETRIEVAL": rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL),
"RAG_FULL_CONTEXT": rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT),
# Hybrid search settings
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
"ENABLE_RAG_HYBRID_SEARCH": rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH),
"TOP_K_RERANKER": rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER),
"RELEVANCE_THRESHOLD": rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD),
# Content extraction settings
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
"CONTENT_EXTRACTION_ENGINE": rag_config.get("CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE),
"PDF_EXTRACT_IMAGES": rag_config.get("PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES),
"EXTERNAL_DOCUMENT_LOADER_URL": rag_config.get("EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL),
"EXTERNAL_DOCUMENT_LOADER_API_KEY": rag_config.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY),
"TIKA_SERVER_URL": rag_config.get("TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL),
"DOCLING_SERVER_URL": rag_config.get("DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL),
"DOCLING_OCR_ENGINE": rag_config.get("DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE),
"DOCLING_OCR_LANG": rag_config.get("DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG),
"DOCLING_DO_PICTURE_DESCRIPTION": rag_config.get("DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION),
"DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT),
"DOCUMENT_INTELLIGENCE_KEY": rag_config.get("DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY),
"MISTRAL_OCR_API_KEY": rag_config.get("MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY),
# Reranking settings
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
"RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
"RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
"RAG_RERANKING_MODEL": rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL),
"RAG_RERANKING_ENGINE": rag_config.get("RAG_RERANKING_ENGINE", request.app.state.config.RAG_RERANKING_ENGINE),
"RAG_EXTERNAL_RERANKER_URL": rag_config.get("RAG_EXTERNAL_RERANKER_URL", request.app.state.config.RAG_EXTERNAL_RERANKER_URL),
"RAG_EXTERNAL_RERANKER_API_KEY": rag_config.get("RAG_EXTERNAL_RERANKER_API_KEY", request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY),
# Chunking settings
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
"CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP,
"TEXT_SPLITTER": rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER),
"CHUNK_SIZE": rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE),
"CHUNK_OVERLAP": rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP),
# File upload settings
"FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE,
"FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT,
"ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS,
"FILE_MAX_SIZE": rag_config.get("FILE_MAX_SIZE", request.app.state.config.FILE_MAX_SIZE),
"FILE_MAX_COUNT": rag_config.get("FILE_MAX_COUNT", request.app.state.config.FILE_MAX_COUNT),
"ALLOWED_FILE_EXTENSIONS": rag_config.get("ALLOWED_FILE_EXTENSIONS", request.app.state.config.ALLOWED_FILE_EXTENSIONS),
# Integration settings
"ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
"ENABLE_GOOGLE_DRIVE_INTEGRATION": rag_config.get("ENABLE_GOOGLE_DRIVE_INTEGRATION", request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION),
"ENABLE_ONEDRIVE_INTEGRATION": rag_config.get("enable_onedrive_integration", request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION),
# Web search settings
"web": {
"ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH,
"WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE,
"WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV,
"WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT,
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
"YACY_PASSWORD": request.app.state.config.YACY_PASSWORD,
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
"GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
"BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY,
"KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY,
"MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY,
"BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY,
"SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY,
"SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS,
"SERPER_API_KEY": request.app.state.config.SERPER_API_KEY,
"SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY,
"TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY,
"SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY,
"SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE,
"SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY,
"SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE,
"JINA_API_KEY": request.app.state.config.JINA_API_KEY,
"BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
"ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
"PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL,
"PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT,
"FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY,
"FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL,
"TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH,
"EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
"EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
"EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL,
"EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY,
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
"ENABLE_WEB_SEARCH": web_config.get("ENABLE_WEB_SEARCH", request.app.state.config.ENABLE_WEB_SEARCH),
"WEB_SEARCH_ENGINE": web_config.get("WEB_SEARCH_ENGINE", request.app.state.config.WEB_SEARCH_ENGINE),
"WEB_SEARCH_TRUST_ENV": web_config.get("WEB_SEARCH_TRUST_ENV", request.app.state.config.WEB_SEARCH_TRUST_ENV),
"WEB_SEARCH_RESULT_COUNT": web_config.get("WEB_SEARCH_RESULT_COUNT", request.app.state.config.WEB_SEARCH_RESULT_COUNT),
"WEB_SEARCH_CONCURRENT_REQUESTS": web_config.get("WEB_SEARCH_CONCURRENT_REQUESTS", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS),
"WEB_SEARCH_DOMAIN_FILTER_LIST": web_config.get("WEB_SEARCH_DOMAIN_FILTER_LIST", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST),
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": web_config.get("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL),
"SEARXNG_QUERY_URL": web_config.get("SEARXNG_QUERY_URL", request.app.state.config.SEARXNG_QUERY_URL),
"YACY_QUERY_URL": web_config.get("YACY_QUERY_URL", request.app.state.config.YACY_QUERY_URL),
"YACY_USERNAME": web_config.get("YACY_QUERY_USERNAME",request.app.state.config.YACY_USERNAME),
"YACY_PASSWORD": web_config.get("YACY_QUERY_PASSWORD",request.app.state.config.YACY_PASSWORD),
"GOOGLE_PSE_API_KEY": web_config.get("GOOGLE_PSE_API_KEY", request.app.state.config.GOOGLE_PSE_API_KEY),
"GOOGLE_PSE_ENGINE_ID": web_config.get("GOOGLE_PSE_ENGINE_ID", request.app.state.config.GOOGLE_PSE_ENGINE_ID),
"BRAVE_SEARCH_API_KEY": web_config.get("BRAVE_SEARCH_API_KEY", request.app.state.config.BRAVE_SEARCH_API_KEY),
"KAGI_SEARCH_API_KEY": web_config.get("KAGI_SEARCH_API_KEY", request.app.state.config.KAGI_SEARCH_API_KEY),
"MOJEEK_SEARCH_API_KEY": web_config.get("MOJEEK_SEARCH_API_KEY", request.app.state.config.MOJEEK_SEARCH_API_KEY),
"BOCHA_SEARCH_API_KEY": web_config.get("BOCHA_SEARCH_API_KEY", request.app.state.config.BOCHA_SEARCH_API_KEY),
"SERPSTACK_API_KEY": web_config.get("SERPSTACK_API_KEY", request.app.state.config.SERPSTACK_API_KEY),
"SERPSTACK_HTTPS": web_config.get("SERPSTACK_HTTPS", request.app.state.config.SERPSTACK_HTTPS),
"SERPER_API_KEY": web_config.get("SERPER_API_KEY", request.app.state.config.SERPER_API_KEY),
"SERPLY_API_KEY": web_config.get("SERPLY_API_KEY", request.app.state.config.SERPLY_API_KEY),
"TAVILY_API_KEY": web_config.get("TAVILY_API_KEY", request.app.state.config.TAVILY_API_KEY),
"SEARCHAPI_API_KEY": web_config.get("SEARCHAPI_API_KEY", request.app.state.config.SEARCHAPI_API_KEY),
"SEARCHAPI_ENGINE": web_config.get("SEARCHAPI_ENGINE", request.app.state.config.SEARCHAPI_ENGINE),
"SERPAPI_API_KEY": web_config.get("SERPAPI_API_KEY", request.app.state.config.SERPAPI_API_KEY),
"SERPAPI_ENGINE": web_config.get("SERPAPI_ENGINE", request.app.state.config.SERPAPI_ENGINE),
"JINA_API_KEY": web_config.get("JINA_API_KEY", request.app.state.config.JINA_API_KEY),
"BING_SEARCH_V7_ENDPOINT": web_config.get("BING_SEARCH_V7_ENDPOINT", request.app.state.config.BING_SEARCH_V7_ENDPOINT),
"BING_SEARCH_V7_SUBSCRIPTION_KEY": web_config.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY),
"EXA_API_KEY": web_config.get("EXA_API_KEY", request.app.state.config.EXA_API_KEY),
"PERPLEXITY_API_KEY": web_config.get("PERPLEXITY_API_KEY", request.app.state.config.PERPLEXITY_API_KEY),
"SOUGOU_API_SID": web_config.get("SOUGOU_API_SID", request.app.state.config.SOUGOU_API_SID),
"SOUGOU_API_SK": web_config.get("SOUGOU_API_SK", request.app.state.config.SOUGOU_API_SK),
"WEB_LOADER_ENGINE": web_config.get("WEB_LOADER_ENGINE", request.app.state.config.WEB_LOADER_ENGINE),
"ENABLE_WEB_LOADER_SSL_VERIFICATION": web_config.get("ENABLE_WEB_LOADER_SSL_VERIFICATION", request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION),
"PLAYWRIGHT_WS_URL": web_config.get("PLAYWRIGHT_WS_URL", request.app.state.config.PLAYWRIGHT_WS_URL),
"PLAYWRIGHT_TIMEOUT": web_config.get("PLAYWRIGHT_TIMEOUT", request.app.state.config.PLAYWRIGHT_TIMEOUT),
"FIRECRAWL_API_KEY": web_config.get("FIRECRAWL_API_KEY", request.app.state.config.FIRECRAWL_API_KEY),
"FIRECRAWL_API_BASE_URL": web_config.get("FIRECRAWL_API_BASE_URL", request.app.state.config.FIRECRAWL_API_BASE_URL),
"TAVILY_EXTRACT_DEPTH": web_config.get("TAVILY_EXTRACT_DEPTH", request.app.state.config.TAVILY_EXTRACT_DEPTH),
"EXTERNAL_WEB_SEARCH_URL": web_config.get("WEB_SEARCH_URL", request.app.state.config.EXTERNAL_WEB_SEARCH_URL),
"EXTERNAL_WEB_SEARCH_API_KEY": web_config.get("WEB_SEARCH_KEY", request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY),
"EXTERNAL_WEB_LOADER_URL": web_config.get("WEB_LOADER_URL", request.app.state.config.EXTERNAL_WEB_LOADER_URL),
"EXTERNAL_WEB_LOADER_API_KEY": web_config.get("WEB_LOADER_KEY", request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY),
"YOUTUBE_LOADER_LANGUAGE": web_config.get("YOUTUBE_LOADER_LANGUAGE", request.app.state.config.YOUTUBE_LOADER_LANGUAGE),
"YOUTUBE_LOADER_PROXY_URL": web_config.get("YOUTUBE_LOADER_PROXY_URL", request.app.state.config.YOUTUBE_LOADER_PROXY_URL),
"YOUTUBE_LOADER_TRANSLATION": web_config.get("YOUTUBE_LOADER_TRANSLATION", request.app.state.config.YOUTUBE_LOADER_TRANSLATION),
},
"DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS),
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
"DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS,
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
"LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS,
}
@ -531,11 +695,102 @@ class ConfigForm(BaseModel):
# Web search settings
web: Optional[WebConfig] = None
# knowledge base ID
knowledge_id: Optional[str] = None
@router.post("/config/update")
async def update_rag_config(
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
request: Request, form_data: ConfigForm, user=Depends(get_verified_user)
):
"""
Update the RAG configuration.
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
# Free up memory if hybrid search is disabled and model is not in use elswhere
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get("RAG_RERANKING_MODEL"), model_type="RAG_RERANKING_MODEL", id=form_data.knowledge_id)
if not form_data.ENABLE_RAG_HYBRID_SEARCH and \
not in_use and \
request.app.state.rf.get(rag_config["RAG_RERANKING_MODEL"]):
if rag_config.get("RAG_RERANKING_MODEL"):
del request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]]
engine = request.app.state.config.RAG_RERANKING_ENGINE
target_model = rag_config["RAG_RERANKING_MODEL"]
models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine]
# Find and remove the dictionary that contains the target model
for model_config in models_list[:]: # Create a copy of the list for safe iteration
if model_config["RAG_RERANKING_MODEL"] == target_model:
models_list.remove(model_config)
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
# Update only the provided fields in the rag_config
for field, value in form_data.model_dump(exclude_unset=True).items():
if field == "web" and value is not None:
rag_config["web"] = {**rag_config.get("web", {}), **value}
else:
rag_config[field] = value
log.info(
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
)
try:
try:
if not rag_config["RAG_RERANKING_MODEL"] in request.app.state.rf and not rag_config["RAG_RERANKING_MODEL"] == "":
request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]] = get_rf(
rag_config["RAG_RERANKING_ENGINE"],
rag_config["RAG_RERANKING_MODEL"],
rag_config["RAG_EXTERNAL_RERANKER_URL"],
rag_config["RAG_EXTERNAL_RERANKER_API_KEY"],
True,
)
# add model to state for reloading on startup
request.app.state.config.LOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append({
"RAG_RERANKING_MODEL": rag_config["RAG_RERANKING_MODEL"],
"RAG_EXTERNAL_RERANKER_URL": rag_config["RAG_EXTERNAL_RERANKER_URL"],
"RAG_EXTERNAL_RERANKER_API_KEY": rag_config["RAG_EXTERNAL_RERANKER_API_KEY"]})
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
# add model to state for selectable reranking models
if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]]:
request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append(rag_config["RAG_RERANKING_MODEL"])
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS
rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
)
Knowledges.update_rag_config_by_id(
id=knowledge_base.id, rag_config=rag_config
)
return rag_config
else:
# Update the global configuration
# RAG settings
request.app.state.config.RAG_TEMPLATE = (
form_data.RAG_TEMPLATE
@ -564,9 +819,29 @@ async def update_rag_config(
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
)
# Free up memory if hybrid search is disabled
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
request.app.state.rf = None
# Free up memory if hybrid search is disabled and model is not in use elswhere
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_RERANKING_MODEL, model_type="RAG_RERANKING_MODEL")
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and \
not in_use and \
request.app.state.rf.get(request.app.state.config.RAG_RERANKING_MODEL):
del request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL]
engine = request.app.state.config.RAG_RERANKING_ENGINE
target_model = request.app.state.config.RAG_RERANKING_MODEL
models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine]
# Find and remove the dictionary that contains the target model
for model_config in models_list[:]: # Create a copy of the list for safe iteration
if model_config["RAG_RERANKING_MODEL"] == target_model:
models_list.remove(model_config)
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
request.app.state.config.TOP_K_RERANKER = (
form_data.TOP_K_RERANKER
@ -662,6 +937,7 @@ async def update_rag_config(
else request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY
)
log.info(
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
)
@ -669,13 +945,32 @@ async def update_rag_config(
request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
try:
request.app.state.rf = get_rf(
if not request.app.state.config.RAG_RERANKING_MODEL in request.app.state.rf and not request.app.state.config.RAG_RERANKING_MODEL == "":
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf(
request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
True,
)
# add model to state for reloading on startup
request.app.state.config.LOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append({
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
"RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
"RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY
})
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
# add model to state for selectable reranking models
if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE]:
request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append(request.app.state.config.RAG_RERANKING_MODEL)
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS
rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
@ -916,6 +1211,7 @@ async def update_rag_config(
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
},
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS
}
@ -935,6 +1231,7 @@ def save_docs_to_vector_db(
split: bool = True,
add: bool = False,
user=None,
knowledge_id: Optional[str] = None
) -> bool:
def _get_docs_info(docs: list[Document]) -> str:
docs_info = set()
@ -956,6 +1253,26 @@ def save_docs_to_vector_db(
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
)
rag_config = {}
# Retrieve the knowledge base using the collection_name
if knowledge_id:
knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id)
# Retrieve the RAG configuration
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
# Use knowledge-base-specific or default configurations
text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER)
chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE)
chunk_overlap = rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP)
embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE)
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE)
openai_api_base_url = rag_config.get("openai_api_base_url", request.app.state.config.RAG_OPENAI_API_BASE_URL)
openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY)
ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
# Check if entries with the same hash (metadata.hash) already exist
if metadata and "hash" in metadata:
result = VECTOR_DB_CLIENT.query(
@ -970,13 +1287,13 @@ def save_docs_to_vector_db(
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
if split:
if request.app.state.config.TEXT_SPLITTER in ["", "character"]:
if text_splitter_type in ["", "character"]:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=request.app.state.config.CHUNK_SIZE,
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=True,
)
elif request.app.state.config.TEXT_SPLITTER == "token":
elif text_splitter_type == "token":
log.info(
f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
)
@ -984,8 +1301,8 @@ def save_docs_to_vector_db(
tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME))
text_splitter = TokenTextSplitter(
encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME),
chunk_size=request.app.state.config.CHUNK_SIZE,
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=True,
)
else:
@ -1003,8 +1320,8 @@ def save_docs_to_vector_db(
**(metadata if metadata else {}),
"embedding_config": json.dumps(
{
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
"model": request.app.state.config.RAG_EMBEDDING_MODEL,
"engine": embedding_engine,
"model": embedding_model,
}
),
}
@ -1037,20 +1354,20 @@ def save_docs_to_vector_db(
log.info(f"adding to collection {collection_name}")
embedding_function = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.ef,
embedding_engine,
embedding_model,
request.app.state.ef.get(embedding_model, request.app.state.config.RAG_EMBEDDING_MODEL),
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_BASE_URL
openai_api_base_url
if embedding_engine == "openai"
else ollama_base_url
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
openai_api_key
if embedding_engine == "openai"
else ollama_api_key
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
embedding_batch_size,
)
embeddings = embedding_function(
@ -1084,6 +1401,7 @@ class ProcessFileForm(BaseModel):
file_id: str
content: Optional[str] = None
collection_name: Optional[str] = None
knowledge_id: Optional[str] = None
@router.post("/process/file")
@ -1100,6 +1418,61 @@ def process_file(
if collection_name is None:
collection_name = f"file-{file.id}"
rag_config = {}
# Retrieve the knowledge base using the collection id - knowledge_id == collection_name (minimal working solution without logic changes)
if form_data.collection_name:
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
# Retrieve the RAG configuration
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db
elif form_data.knowledge_id:
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
# Retrieve the RAG configuration
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
# Use knowledge-base-specific or default configurations
content_extraction_engine = rag_config.get(
"CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE
)
external_document_loader_url = rag_config.get(
"EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL
)
external_document_loader_api_key = rag_config.get(
"EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY
)
tika_server_url = rag_config.get(
"TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL
)
docling_server_url = rag_config.get(
"DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL
)
docling_ocr_engine=rag_config.get(
"DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE
)
docling_ocr_lang=rag_config.get(
"DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG
)
docling_do_picture_description=rag_config.get(
"DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
)
pdf_extract_images = rag_config.get(
"PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES
)
document_intelligence_endpoint = rag_config.get(
"DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT
)
document_intelligence_key = rag_config.get(
"DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY
)
mistral_ocr_api_key = rag_config.get(
"MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY
)
if form_data.content:
# Update the content in the file
# Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
@ -1163,18 +1536,18 @@ def process_file(
if file_path:
file_path = Storage.get_file(file_path)
loader = Loader(
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE,
DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG,
DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
engine=content_extraction_engine,
EXTERNAL_DOCUMENT_LOADER_URL=external_document_loader_url,
EXTERNAL_DOCUMENT_LOADER_API_KEY=external_document_loader_api_key,
TIKA_SERVER_URL=tika_server_url,
DOCLING_SERVER_URL=docling_server_url,
DOCLING_OCR_ENGINE=docling_ocr_engine,
DOCLING_OCR_LANG=docling_ocr_lang,
DOCLING_DO_PICTURE_DESCRIPTION=docling_do_picture_description,
PDF_EXTRACT_IMAGES=pdf_extract_images,
DOCUMENT_INTELLIGENCE_ENDPOINT=document_intelligence_endpoint,
DOCUMENT_INTELLIGENCE_KEY=document_intelligence_key,
MISTRAL_OCR_API_KEY=mistral_ocr_api_key,
)
docs = loader.load(
file.filename, file.meta.get("content_type"), file_path
@ -1217,7 +1590,7 @@ def process_file(
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
if not rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL):
try:
result = save_docs_to_vector_db(
request,
@ -1230,6 +1603,7 @@ def process_file(
},
add=(True if form_data.collection_name else False),
user=user,
knowledge_id=form_data.knowledge_id
)
if result:
@ -1280,7 +1654,7 @@ class ProcessTextForm(BaseModel):
def process_text(
request: Request,
form_data: ProcessTextForm,
user=Depends(get_verified_user),
user=Depends(get_verified_user)
):
collection_name = form_data.collection_name
if collection_name is None:
@ -1762,11 +2136,11 @@ def query_doc_handler(
collection_name=form_data.collection_name,
collection_result=collection_results[form_data.collection_name],
query=form_data.query,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(
@ -1779,7 +2153,7 @@ def query_doc_handler(
else:
return query_doc(
collection_name=form_data.collection_name,
query_embedding=request.app.state.EMBEDDING_FUNCTION(
query_embedding=request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
@ -1813,11 +2187,11 @@ def query_collection_handler(
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(
@ -1830,7 +2204,7 @@ def query_collection_handler(
return query_collection(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,