mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Feat: Adjusted to handle individual rag config - adjusted user settings to handle individual rag config; adjusted to update/delete used embedders/rerankers; adjusted process file to handle indivudal rag config without changing logic
This commit is contained in:
parent
bbd312325c
commit
4189459ae2
@ -210,6 +210,10 @@ class SearchForm(BaseModel):
|
||||
queries: List[str]
|
||||
|
||||
|
||||
class CollectionForm(BaseModel):
|
||||
knowledge_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_status(request: Request):
|
||||
return {
|
||||
@ -224,21 +228,32 @@ async def get_status(request: Request):
|
||||
}
|
||||
|
||||
|
||||
@router.get("/embedding")
|
||||
async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
||||
@router.post("/embedding")
|
||||
async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
|
||||
"""
|
||||
Retrieve the embedding configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, return the default embedding settings.
|
||||
Otherwise, return the embedding configuration stored in the database.
|
||||
"""
|
||||
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
|
||||
rag_config = {}
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the embedding configuration from the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
"openai_config": {
|
||||
"embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE),
|
||||
"embedding_model": rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL),
|
||||
"embedding_batch_size": rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE),
|
||||
"openai_config": rag_config.get("openai_config", {
|
||||
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||||
},
|
||||
"ollama_config": {
|
||||
}),
|
||||
"ollama_config": rag_config.get("ollama_config", {
|
||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
|
||||
@ -258,18 +273,137 @@ class EmbeddingModelUpdateForm(BaseModel):
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
embedding_batch_size: Optional[int] = 1
|
||||
knowledge_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/embedding/update")
|
||||
async def update_embedding_config(
|
||||
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_verified_user)
|
||||
):
|
||||
"""
|
||||
Update the embedding model configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
|
||||
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
||||
"""
|
||||
try:
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
log.info(
|
||||
f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
|
||||
)
|
||||
|
||||
# Check if model is in use elsewhere, otherwise free up memory
|
||||
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('embedding_model'), model_type="embedding_model", id=form_data.knowledge_id)
|
||||
|
||||
if not in_use and not request.app.state.ef.get(request.app.state.config.RAG_EMBEDDING_MODEL) == rag_config.get("embedding_model") and rag_config.get("embedding_model"):
|
||||
del request.app.state.ef[rag_config["embedding_model"]]
|
||||
engine = rag_config["embedding_engine"]
|
||||
target_model = rag_config["embedding_model"]
|
||||
models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine]
|
||||
|
||||
# Find and remove the dictionary that contains the target model
|
||||
for model in models_list[:]: # Create a copy of the list for safe iteration
|
||||
if model == target_model:
|
||||
models_list.remove(model)
|
||||
|
||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||
|
||||
import gc
|
||||
import torch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Update embedding-related fields
|
||||
rag_config["embedding_engine"] = form_data.embedding_engine
|
||||
rag_config["embedding_model"] = form_data.embedding_model
|
||||
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
|
||||
|
||||
|
||||
rag_config["openai_config"] = {
|
||||
"url": form_data.openai_config.url,
|
||||
"key": form_data.openai_config.key,
|
||||
}
|
||||
|
||||
rag_config["ollama_config"] = {
|
||||
"url": form_data.ollama_config.url,
|
||||
"key": form_data.ollama_config.key,
|
||||
}
|
||||
# Update the embedding function
|
||||
if not rag_config["embedding_model"] in request.app.state.ef:
|
||||
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
|
||||
rag_config["embedding_engine"],
|
||||
rag_config["embedding_model"],
|
||||
)
|
||||
|
||||
request.app.state.EMBEDDING_FUNCTION["embedding_model"] = get_embedding_function(
|
||||
rag_config["embedding_engine"],
|
||||
rag_config["embedding_model"],
|
||||
request.app.state.ef[rag_config["embedding_model"]],
|
||||
(
|
||||
rag_config["openai_config"]["url"]
|
||||
if rag_config["embedding_engine"] == "openai"
|
||||
else rag_config["ollama_config"]["url"]
|
||||
),
|
||||
(
|
||||
rag_config["openai_config"]["key"]
|
||||
if rag_config["embedding_engine"] == "openai"
|
||||
else rag_config["ollama_config"]["key"]
|
||||
),
|
||||
rag_config["embedding_batch_size"]
|
||||
)
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||
# add model to state for selectable reranking models
|
||||
if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]:
|
||||
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
||||
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
|
||||
rag_config["DOWNLOADED_EMBEDDING_MODELS"] = request.app.state.config.DOWNLOADED_EMBEDDING_MODELS
|
||||
rag_config["LOADED_EMBEDDING_MODELS"] = request.app.state.config.LOADED_EMBEDDING_MODELS
|
||||
|
||||
# Save the updated configuration to the database
|
||||
Knowledges.update_rag_config_by_id(
|
||||
id=form_data.knowledge_id, rag_config=rag_config
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": rag_config["embedding_engine"],
|
||||
"embedding_model": rag_config["embedding_model"],
|
||||
"embedding_batch_size": rag_config["embedding_batch_size"],
|
||||
"openai_config": rag_config.get("openai_config", {}),
|
||||
"ollama_config": rag_config.get("ollama_config", {}),
|
||||
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
|
||||
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
|
||||
"message": "Embedding configuration updated in the database.",
|
||||
}
|
||||
else:
|
||||
# Update the global configuration
|
||||
log.info(
|
||||
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||
)
|
||||
try:
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
||||
# Check if model is in use elsewhere, otherwise free up memory
|
||||
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_EMBEDDING_MODEL, model_type="embedding_model")
|
||||
if not in_use:
|
||||
del request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL]
|
||||
engine = request.app.state.config.RAG_EMBEDDING_ENGINE
|
||||
target_model = request.app.state.config.RAG_EMBEDDING_MODEL
|
||||
models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine]
|
||||
|
||||
# Find and remove the dictionary that contains the target model
|
||||
for model in models_list[:]: # Create a copy of the list for safe iteration
|
||||
if model == target_model:
|
||||
models_list.remove(model)
|
||||
|
||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||
|
||||
import gc
|
||||
import torch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||
if form_data.openai_config is not None:
|
||||
@ -292,15 +426,17 @@ async def update_embedding_config(
|
||||
form_data.embedding_batch_size
|
||||
)
|
||||
|
||||
request.app.state.ef = get_ef(
|
||||
# Update the embedding function
|
||||
if not form_data.embedding_model in request.app.state.ef:
|
||||
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
request.app.state.ef,
|
||||
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL],
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
@ -313,6 +449,13 @@ async def update_embedding_config(
|
||||
),
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||
# add model to state for selectable embedding models
|
||||
if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]:
|
||||
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@ -327,6 +470,9 @@ async def update_embedding_config(
|
||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||
},
|
||||
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
|
||||
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
|
||||
"message": "Embedding configuration updated globally.",
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating embedding model: {e}")
|
||||
@ -336,98 +482,116 @@ async def update_embedding_config(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
@router.post("/config")
|
||||
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_verified_user)):
|
||||
"""
|
||||
Retrieve the full RAG configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, return the default settings.
|
||||
Otherwise, return the RAG configuration stored in the database.
|
||||
"""
|
||||
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
|
||||
rag_config = {}
|
||||
web_config = {}
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the RAG configuration from the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
web_config = rag_config.get("web", {})
|
||||
return {
|
||||
"status": True,
|
||||
# RAG settings
|
||||
"RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE,
|
||||
"TOP_K": request.app.state.config.TOP_K,
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||
"RAG_TEMPLATE": rag_config.get("TEMPLATE", request.app.state.config.RAG_TEMPLATE),
|
||||
"TOP_K": rag_config.get("TOP_K", request.app.state.config.TOP_K),
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL),
|
||||
"RAG_FULL_CONTEXT": rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT),
|
||||
# Hybrid search settings
|
||||
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
|
||||
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
"ENABLE_RAG_HYBRID_SEARCH": rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH),
|
||||
"TOP_K_RERANKER": rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER),
|
||||
"RELEVANCE_THRESHOLD": rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD),
|
||||
# Content extraction settings
|
||||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
||||
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
||||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||
"CONTENT_EXTRACTION_ENGINE": rag_config.get("CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE),
|
||||
"PDF_EXTRACT_IMAGES": rag_config.get("PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES),
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL": rag_config.get("EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL),
|
||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": rag_config.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY),
|
||||
"TIKA_SERVER_URL": rag_config.get("TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL),
|
||||
"DOCLING_SERVER_URL": rag_config.get("DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL),
|
||||
"DOCLING_OCR_ENGINE": rag_config.get("DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE),
|
||||
"DOCLING_OCR_LANG": rag_config.get("DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG),
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION": rag_config.get("DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION),
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT),
|
||||
"DOCUMENT_INTELLIGENCE_KEY": rag_config.get("DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY),
|
||||
"MISTRAL_OCR_API_KEY": rag_config.get("MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY),
|
||||
# Reranking settings
|
||||
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
||||
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
||||
"RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||
"RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||
"RAG_RERANKING_MODEL": rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL),
|
||||
"RAG_RERANKING_ENGINE": rag_config.get("RAG_RERANKING_ENGINE", request.app.state.config.RAG_RERANKING_ENGINE),
|
||||
"RAG_EXTERNAL_RERANKER_URL": rag_config.get("RAG_EXTERNAL_RERANKER_URL", request.app.state.config.RAG_EXTERNAL_RERANKER_URL),
|
||||
"RAG_EXTERNAL_RERANKER_API_KEY": rag_config.get("RAG_EXTERNAL_RERANKER_API_KEY", request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY),
|
||||
# Chunking settings
|
||||
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
||||
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
||||
"CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP,
|
||||
"TEXT_SPLITTER": rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER),
|
||||
"CHUNK_SIZE": rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE),
|
||||
"CHUNK_OVERLAP": rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP),
|
||||
# File upload settings
|
||||
"FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE,
|
||||
"FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT,
|
||||
"ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS,
|
||||
"FILE_MAX_SIZE": rag_config.get("FILE_MAX_SIZE", request.app.state.config.FILE_MAX_SIZE),
|
||||
"FILE_MAX_COUNT": rag_config.get("FILE_MAX_COUNT", request.app.state.config.FILE_MAX_COUNT),
|
||||
"ALLOWED_FILE_EXTENSIONS": rag_config.get("ALLOWED_FILE_EXTENSIONS", request.app.state.config.ALLOWED_FILE_EXTENSIONS),
|
||||
# Integration settings
|
||||
"ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
"ENABLE_GOOGLE_DRIVE_INTEGRATION": rag_config.get("ENABLE_GOOGLE_DRIVE_INTEGRATION", request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION),
|
||||
"ENABLE_ONEDRIVE_INTEGRATION": rag_config.get("enable_onedrive_integration", request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION),
|
||||
# Web search settings
|
||||
"web": {
|
||||
"ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH,
|
||||
"WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE,
|
||||
"WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||||
"WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
|
||||
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
|
||||
"YACY_PASSWORD": request.app.state.config.YACY_PASSWORD,
|
||||
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||
"GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||||
"BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||
"MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||
"BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
"SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY,
|
||||
"SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS,
|
||||
"SERPER_API_KEY": request.app.state.config.SERPER_API_KEY,
|
||||
"SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY,
|
||||
"TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY,
|
||||
"SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY,
|
||||
"SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE,
|
||||
"SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY,
|
||||
"SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE,
|
||||
"JINA_API_KEY": request.app.state.config.JINA_API_KEY,
|
||||
"BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
|
||||
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
||||
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
||||
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
||||
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
||||
"ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
"PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL,
|
||||
"PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT,
|
||||
"FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY,
|
||||
"FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL,
|
||||
"TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH,
|
||||
"EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
||||
"EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
||||
"EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL,
|
||||
"EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY,
|
||||
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
"ENABLE_WEB_SEARCH": web_config.get("ENABLE_WEB_SEARCH", request.app.state.config.ENABLE_WEB_SEARCH),
|
||||
"WEB_SEARCH_ENGINE": web_config.get("WEB_SEARCH_ENGINE", request.app.state.config.WEB_SEARCH_ENGINE),
|
||||
"WEB_SEARCH_TRUST_ENV": web_config.get("WEB_SEARCH_TRUST_ENV", request.app.state.config.WEB_SEARCH_TRUST_ENV),
|
||||
"WEB_SEARCH_RESULT_COUNT": web_config.get("WEB_SEARCH_RESULT_COUNT", request.app.state.config.WEB_SEARCH_RESULT_COUNT),
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS": web_config.get("WEB_SEARCH_CONCURRENT_REQUESTS", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS),
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": web_config.get("WEB_SEARCH_DOMAIN_FILTER_LIST", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST),
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": web_config.get("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL),
|
||||
"SEARXNG_QUERY_URL": web_config.get("SEARXNG_QUERY_URL", request.app.state.config.SEARXNG_QUERY_URL),
|
||||
"YACY_QUERY_URL": web_config.get("YACY_QUERY_URL", request.app.state.config.YACY_QUERY_URL),
|
||||
"YACY_USERNAME": web_config.get("YACY_QUERY_USERNAME",request.app.state.config.YACY_USERNAME),
|
||||
"YACY_PASSWORD": web_config.get("YACY_QUERY_PASSWORD",request.app.state.config.YACY_PASSWORD),
|
||||
"GOOGLE_PSE_API_KEY": web_config.get("GOOGLE_PSE_API_KEY", request.app.state.config.GOOGLE_PSE_API_KEY),
|
||||
"GOOGLE_PSE_ENGINE_ID": web_config.get("GOOGLE_PSE_ENGINE_ID", request.app.state.config.GOOGLE_PSE_ENGINE_ID),
|
||||
"BRAVE_SEARCH_API_KEY": web_config.get("BRAVE_SEARCH_API_KEY", request.app.state.config.BRAVE_SEARCH_API_KEY),
|
||||
"KAGI_SEARCH_API_KEY": web_config.get("KAGI_SEARCH_API_KEY", request.app.state.config.KAGI_SEARCH_API_KEY),
|
||||
"MOJEEK_SEARCH_API_KEY": web_config.get("MOJEEK_SEARCH_API_KEY", request.app.state.config.MOJEEK_SEARCH_API_KEY),
|
||||
"BOCHA_SEARCH_API_KEY": web_config.get("BOCHA_SEARCH_API_KEY", request.app.state.config.BOCHA_SEARCH_API_KEY),
|
||||
"SERPSTACK_API_KEY": web_config.get("SERPSTACK_API_KEY", request.app.state.config.SERPSTACK_API_KEY),
|
||||
"SERPSTACK_HTTPS": web_config.get("SERPSTACK_HTTPS", request.app.state.config.SERPSTACK_HTTPS),
|
||||
"SERPER_API_KEY": web_config.get("SERPER_API_KEY", request.app.state.config.SERPER_API_KEY),
|
||||
"SERPLY_API_KEY": web_config.get("SERPLY_API_KEY", request.app.state.config.SERPLY_API_KEY),
|
||||
"TAVILY_API_KEY": web_config.get("TAVILY_API_KEY", request.app.state.config.TAVILY_API_KEY),
|
||||
"SEARCHAPI_API_KEY": web_config.get("SEARCHAPI_API_KEY", request.app.state.config.SEARCHAPI_API_KEY),
|
||||
"SEARCHAPI_ENGINE": web_config.get("SEARCHAPI_ENGINE", request.app.state.config.SEARCHAPI_ENGINE),
|
||||
"SERPAPI_API_KEY": web_config.get("SERPAPI_API_KEY", request.app.state.config.SERPAPI_API_KEY),
|
||||
"SERPAPI_ENGINE": web_config.get("SERPAPI_ENGINE", request.app.state.config.SERPAPI_ENGINE),
|
||||
"JINA_API_KEY": web_config.get("JINA_API_KEY", request.app.state.config.JINA_API_KEY),
|
||||
"BING_SEARCH_V7_ENDPOINT": web_config.get("BING_SEARCH_V7_ENDPOINT", request.app.state.config.BING_SEARCH_V7_ENDPOINT),
|
||||
"BING_SEARCH_V7_SUBSCRIPTION_KEY": web_config.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY),
|
||||
"EXA_API_KEY": web_config.get("EXA_API_KEY", request.app.state.config.EXA_API_KEY),
|
||||
"PERPLEXITY_API_KEY": web_config.get("PERPLEXITY_API_KEY", request.app.state.config.PERPLEXITY_API_KEY),
|
||||
"SOUGOU_API_SID": web_config.get("SOUGOU_API_SID", request.app.state.config.SOUGOU_API_SID),
|
||||
"SOUGOU_API_SK": web_config.get("SOUGOU_API_SK", request.app.state.config.SOUGOU_API_SK),
|
||||
"WEB_LOADER_ENGINE": web_config.get("WEB_LOADER_ENGINE", request.app.state.config.WEB_LOADER_ENGINE),
|
||||
"ENABLE_WEB_LOADER_SSL_VERIFICATION": web_config.get("ENABLE_WEB_LOADER_SSL_VERIFICATION", request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION),
|
||||
"PLAYWRIGHT_WS_URL": web_config.get("PLAYWRIGHT_WS_URL", request.app.state.config.PLAYWRIGHT_WS_URL),
|
||||
"PLAYWRIGHT_TIMEOUT": web_config.get("PLAYWRIGHT_TIMEOUT", request.app.state.config.PLAYWRIGHT_TIMEOUT),
|
||||
"FIRECRAWL_API_KEY": web_config.get("FIRECRAWL_API_KEY", request.app.state.config.FIRECRAWL_API_KEY),
|
||||
"FIRECRAWL_API_BASE_URL": web_config.get("FIRECRAWL_API_BASE_URL", request.app.state.config.FIRECRAWL_API_BASE_URL),
|
||||
"TAVILY_EXTRACT_DEPTH": web_config.get("TAVILY_EXTRACT_DEPTH", request.app.state.config.TAVILY_EXTRACT_DEPTH),
|
||||
"EXTERNAL_WEB_SEARCH_URL": web_config.get("WEB_SEARCH_URL", request.app.state.config.EXTERNAL_WEB_SEARCH_URL),
|
||||
"EXTERNAL_WEB_SEARCH_API_KEY": web_config.get("WEB_SEARCH_KEY", request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY),
|
||||
"EXTERNAL_WEB_LOADER_URL": web_config.get("WEB_LOADER_URL", request.app.state.config.EXTERNAL_WEB_LOADER_URL),
|
||||
"EXTERNAL_WEB_LOADER_API_KEY": web_config.get("WEB_LOADER_KEY", request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY),
|
||||
"YOUTUBE_LOADER_LANGUAGE": web_config.get("YOUTUBE_LOADER_LANGUAGE", request.app.state.config.YOUTUBE_LOADER_LANGUAGE),
|
||||
"YOUTUBE_LOADER_PROXY_URL": web_config.get("YOUTUBE_LOADER_PROXY_URL", request.app.state.config.YOUTUBE_LOADER_PROXY_URL),
|
||||
"YOUTUBE_LOADER_TRANSLATION": web_config.get("YOUTUBE_LOADER_TRANSLATION", request.app.state.config.YOUTUBE_LOADER_TRANSLATION),
|
||||
},
|
||||
"DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS),
|
||||
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
|
||||
"DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS,
|
||||
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
|
||||
"LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS,
|
||||
}
|
||||
|
||||
|
||||
@ -531,11 +695,102 @@ class ConfigForm(BaseModel):
|
||||
# Web search settings
|
||||
web: Optional[WebConfig] = None
|
||||
|
||||
# knowledge base ID
|
||||
knowledge_id: Optional[str] = None
|
||||
|
||||
@router.post("/config/update")
|
||||
async def update_rag_config(
|
||||
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
|
||||
request: Request, form_data: ConfigForm, user=Depends(get_verified_user)
|
||||
):
|
||||
"""
|
||||
Update the RAG configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
|
||||
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
||||
"""
|
||||
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
|
||||
# Free up memory if hybrid search is disabled and model is not in use elswhere
|
||||
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get("RAG_RERANKING_MODEL"), model_type="RAG_RERANKING_MODEL", id=form_data.knowledge_id)
|
||||
|
||||
if not form_data.ENABLE_RAG_HYBRID_SEARCH and \
|
||||
not in_use and \
|
||||
request.app.state.rf.get(rag_config["RAG_RERANKING_MODEL"]):
|
||||
if rag_config.get("RAG_RERANKING_MODEL"):
|
||||
del request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]]
|
||||
engine = request.app.state.config.RAG_RERANKING_ENGINE
|
||||
target_model = rag_config["RAG_RERANKING_MODEL"]
|
||||
models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine]
|
||||
|
||||
# Find and remove the dictionary that contains the target model
|
||||
for model_config in models_list[:]: # Create a copy of the list for safe iteration
|
||||
if model_config["RAG_RERANKING_MODEL"] == target_model:
|
||||
models_list.remove(model_config)
|
||||
|
||||
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
||||
|
||||
import gc
|
||||
import torch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Update only the provided fields in the rag_config
|
||||
for field, value in form_data.model_dump(exclude_unset=True).items():
|
||||
if field == "web" and value is not None:
|
||||
rag_config["web"] = {**rag_config.get("web", {}), **value}
|
||||
else:
|
||||
rag_config[field] = value
|
||||
|
||||
|
||||
log.info(
|
||||
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
|
||||
)
|
||||
try:
|
||||
try:
|
||||
if not rag_config["RAG_RERANKING_MODEL"] in request.app.state.rf and not rag_config["RAG_RERANKING_MODEL"] == "":
|
||||
request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]] = get_rf(
|
||||
rag_config["RAG_RERANKING_ENGINE"],
|
||||
rag_config["RAG_RERANKING_MODEL"],
|
||||
rag_config["RAG_EXTERNAL_RERANKER_URL"],
|
||||
rag_config["RAG_EXTERNAL_RERANKER_API_KEY"],
|
||||
True,
|
||||
)
|
||||
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append({
|
||||
"RAG_RERANKING_MODEL": rag_config["RAG_RERANKING_MODEL"],
|
||||
"RAG_EXTERNAL_RERANKER_URL": rag_config["RAG_EXTERNAL_RERANKER_URL"],
|
||||
"RAG_EXTERNAL_RERANKER_API_KEY": rag_config["RAG_EXTERNAL_RERANKER_API_KEY"]})
|
||||
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
||||
|
||||
# add model to state for selectable reranking models
|
||||
if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]]:
|
||||
request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append(rag_config["RAG_RERANKING_MODEL"])
|
||||
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
|
||||
|
||||
rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS
|
||||
rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error loading reranking model: {e}")
|
||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating reranking model: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
Knowledges.update_rag_config_by_id(
|
||||
id=knowledge_base.id, rag_config=rag_config
|
||||
)
|
||||
|
||||
return rag_config
|
||||
else:
|
||||
# Update the global configuration
|
||||
# RAG settings
|
||||
request.app.state.config.RAG_TEMPLATE = (
|
||||
form_data.RAG_TEMPLATE
|
||||
@ -564,9 +819,29 @@ async def update_rag_config(
|
||||
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
||||
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
)
|
||||
# Free up memory if hybrid search is disabled
|
||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
request.app.state.rf = None
|
||||
|
||||
# Free up memory if hybrid search is disabled and model is not in use elswhere
|
||||
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_RERANKING_MODEL, model_type="RAG_RERANKING_MODEL")
|
||||
|
||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and \
|
||||
not in_use and \
|
||||
request.app.state.rf.get(request.app.state.config.RAG_RERANKING_MODEL):
|
||||
del request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL]
|
||||
engine = request.app.state.config.RAG_RERANKING_ENGINE
|
||||
target_model = request.app.state.config.RAG_RERANKING_MODEL
|
||||
models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine]
|
||||
|
||||
# Find and remove the dictionary that contains the target model
|
||||
for model_config in models_list[:]: # Create a copy of the list for safe iteration
|
||||
if model_config["RAG_RERANKING_MODEL"] == target_model:
|
||||
models_list.remove(model_config)
|
||||
|
||||
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
||||
|
||||
import gc
|
||||
import torch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
request.app.state.config.TOP_K_RERANKER = (
|
||||
form_data.TOP_K_RERANKER
|
||||
@ -662,6 +937,7 @@ async def update_rag_config(
|
||||
else request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY
|
||||
)
|
||||
|
||||
|
||||
log.info(
|
||||
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
|
||||
)
|
||||
@ -669,13 +945,32 @@ async def update_rag_config(
|
||||
request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
|
||||
|
||||
try:
|
||||
request.app.state.rf = get_rf(
|
||||
if not request.app.state.config.RAG_RERANKING_MODEL in request.app.state.rf and not request.app.state.config.RAG_RERANKING_MODEL == "":
|
||||
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf(
|
||||
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||
True,
|
||||
)
|
||||
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append({
|
||||
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
||||
"RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||
"RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY
|
||||
})
|
||||
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
||||
|
||||
# add model to state for selectable reranking models
|
||||
if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE]:
|
||||
request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append(request.app.state.config.RAG_RERANKING_MODEL)
|
||||
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
|
||||
|
||||
rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS
|
||||
rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error loading reranking model: {e}")
|
||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||
@ -916,6 +1211,7 @@ async def update_rag_config(
|
||||
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
},
|
||||
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS
|
||||
}
|
||||
|
||||
|
||||
@ -935,6 +1231,7 @@ def save_docs_to_vector_db(
|
||||
split: bool = True,
|
||||
add: bool = False,
|
||||
user=None,
|
||||
knowledge_id: Optional[str] = None
|
||||
) -> bool:
|
||||
def _get_docs_info(docs: list[Document]) -> str:
|
||||
docs_info = set()
|
||||
@ -956,6 +1253,26 @@ def save_docs_to_vector_db(
|
||||
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
|
||||
)
|
||||
|
||||
rag_config = {}
|
||||
# Retrieve the knowledge base using the collection_name
|
||||
if knowledge_id:
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id)
|
||||
# Retrieve the RAG configuration
|
||||
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
|
||||
# Use knowledge-base-specific or default configurations
|
||||
text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER)
|
||||
chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE)
|
||||
chunk_overlap = rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP)
|
||||
embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE)
|
||||
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE)
|
||||
openai_api_base_url = rag_config.get("openai_api_base_url", request.app.state.config.RAG_OPENAI_API_BASE_URL)
|
||||
openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY)
|
||||
ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
|
||||
ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
|
||||
|
||||
# Check if entries with the same hash (metadata.hash) already exist
|
||||
if metadata and "hash" in metadata:
|
||||
result = VECTOR_DB_CLIENT.query(
|
||||
@ -970,13 +1287,13 @@ def save_docs_to_vector_db(
|
||||
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
|
||||
|
||||
if split:
|
||||
if request.app.state.config.TEXT_SPLITTER in ["", "character"]:
|
||||
if text_splitter_type in ["", "character"]:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=request.app.state.config.CHUNK_SIZE,
|
||||
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
add_start_index=True,
|
||||
)
|
||||
elif request.app.state.config.TEXT_SPLITTER == "token":
|
||||
elif text_splitter_type == "token":
|
||||
log.info(
|
||||
f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
|
||||
)
|
||||
@ -984,8 +1301,8 @@ def save_docs_to_vector_db(
|
||||
tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME))
|
||||
text_splitter = TokenTextSplitter(
|
||||
encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME),
|
||||
chunk_size=request.app.state.config.CHUNK_SIZE,
|
||||
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
add_start_index=True,
|
||||
)
|
||||
else:
|
||||
@ -1003,8 +1320,8 @@ def save_docs_to_vector_db(
|
||||
**(metadata if metadata else {}),
|
||||
"embedding_config": json.dumps(
|
||||
{
|
||||
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"engine": embedding_engine,
|
||||
"model": embedding_model,
|
||||
}
|
||||
),
|
||||
}
|
||||
@ -1037,20 +1354,20 @@ def save_docs_to_vector_db(
|
||||
|
||||
log.info(f"adding to collection {collection_name}")
|
||||
embedding_function = get_embedding_function(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
request.app.state.ef,
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
request.app.state.ef.get(embedding_model, request.app.state.config.RAG_EMBEDDING_MODEL),
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||
openai_api_base_url
|
||||
if embedding_engine == "openai"
|
||||
else ollama_base_url
|
||||
),
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_KEY
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
||||
openai_api_key
|
||||
if embedding_engine == "openai"
|
||||
else ollama_api_key
|
||||
),
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
embedding_batch_size,
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
@ -1084,6 +1401,7 @@ class ProcessFileForm(BaseModel):
|
||||
file_id: str
|
||||
content: Optional[str] = None
|
||||
collection_name: Optional[str] = None
|
||||
knowledge_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/process/file")
|
||||
@ -1100,6 +1418,61 @@ def process_file(
|
||||
if collection_name is None:
|
||||
collection_name = f"file-{file.id}"
|
||||
|
||||
rag_config = {}
|
||||
# Retrieve the knowledge base using the collection id - knowledge_id == collection_name (minimal working solution without logic changes)
|
||||
if form_data.collection_name:
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
|
||||
|
||||
# Retrieve the RAG configuration
|
||||
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db
|
||||
|
||||
elif form_data.knowledge_id:
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
||||
|
||||
# Retrieve the RAG configuration
|
||||
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
|
||||
# Use knowledge-base-specific or default configurations
|
||||
content_extraction_engine = rag_config.get(
|
||||
"CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE
|
||||
)
|
||||
external_document_loader_url = rag_config.get(
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL
|
||||
)
|
||||
external_document_loader_api_key = rag_config.get(
|
||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY
|
||||
)
|
||||
tika_server_url = rag_config.get(
|
||||
"TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL
|
||||
)
|
||||
docling_server_url = rag_config.get(
|
||||
"DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL
|
||||
)
|
||||
docling_ocr_engine=rag_config.get(
|
||||
"DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE
|
||||
)
|
||||
docling_ocr_lang=rag_config.get(
|
||||
"DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG
|
||||
)
|
||||
docling_do_picture_description=rag_config.get(
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
|
||||
)
|
||||
pdf_extract_images = rag_config.get(
|
||||
"PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
document_intelligence_endpoint = rag_config.get(
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
)
|
||||
document_intelligence_key = rag_config.get(
|
||||
"DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY
|
||||
)
|
||||
mistral_ocr_api_key = rag_config.get(
|
||||
"MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY
|
||||
)
|
||||
|
||||
if form_data.content:
|
||||
# Update the content in the file
|
||||
# Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
|
||||
@ -1163,18 +1536,18 @@ def process_file(
|
||||
if file_path:
|
||||
file_path = Storage.get_file(file_path)
|
||||
loader = Loader(
|
||||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||||
DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG,
|
||||
DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||
engine=content_extraction_engine,
|
||||
EXTERNAL_DOCUMENT_LOADER_URL=external_document_loader_url,
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY=external_document_loader_api_key,
|
||||
TIKA_SERVER_URL=tika_server_url,
|
||||
DOCLING_SERVER_URL=docling_server_url,
|
||||
DOCLING_OCR_ENGINE=docling_ocr_engine,
|
||||
DOCLING_OCR_LANG=docling_ocr_lang,
|
||||
DOCLING_DO_PICTURE_DESCRIPTION=docling_do_picture_description,
|
||||
PDF_EXTRACT_IMAGES=pdf_extract_images,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=document_intelligence_endpoint,
|
||||
DOCUMENT_INTELLIGENCE_KEY=document_intelligence_key,
|
||||
MISTRAL_OCR_API_KEY=mistral_ocr_api_key,
|
||||
)
|
||||
docs = loader.load(
|
||||
file.filename, file.meta.get("content_type"), file_path
|
||||
@ -1217,7 +1590,7 @@ def process_file(
|
||||
hash = calculate_sha256_string(text_content)
|
||||
Files.update_file_hash_by_id(file.id, hash)
|
||||
|
||||
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||
if not rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL):
|
||||
try:
|
||||
result = save_docs_to_vector_db(
|
||||
request,
|
||||
@ -1230,6 +1603,7 @@ def process_file(
|
||||
},
|
||||
add=(True if form_data.collection_name else False),
|
||||
user=user,
|
||||
knowledge_id=form_data.knowledge_id
|
||||
)
|
||||
|
||||
if result:
|
||||
@ -1280,7 +1654,7 @@ class ProcessTextForm(BaseModel):
|
||||
def process_text(
|
||||
request: Request,
|
||||
form_data: ProcessTextForm,
|
||||
user=Depends(get_verified_user),
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name is None:
|
||||
@ -1762,11 +2136,11 @@ def query_doc_handler(
|
||||
collection_name=form_data.collection_name,
|
||||
collection_result=collection_results[form_data.collection_name],
|
||||
query=form_data.query,
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
r=(
|
||||
@ -1779,7 +2153,7 @@ def query_doc_handler(
|
||||
else:
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
@ -1813,11 +2187,11 @@ def query_collection_handler(
|
||||
return query_collection_with_hybrid_search(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
r=(
|
||||
@ -1830,7 +2204,7 @@ def query_collection_handler(
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
|
Loading…
Reference in New Issue
Block a user