mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Feature: Adjusted all necessary functions to handle individual rag configuration - fallback to default configuration if individual config is not used
This commit is contained in:
parent
220ad3723e
commit
e3a93b24a0
@ -193,8 +193,33 @@ async def get_status(request: Request):
|
||||
}
|
||||
|
||||
|
||||
@router.get("/embedding")
|
||||
async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
||||
@router.post("/embedding")
|
||||
async def get_embedding_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
|
||||
"""
|
||||
Retrieve the embedding configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, return the default embedding settings.
|
||||
Otherwise, return the embedding configuration stored in the database.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the embedding configuration from the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE),
|
||||
"embedding_model": rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL),
|
||||
"embedding_batch_size": rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE),
|
||||
"openai_config": rag_config.get("openai_config", {
|
||||
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||||
}),
|
||||
"ollama_config": rag_config.get("ollama_config", {
|
||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||
}),
|
||||
}
|
||||
else:
|
||||
# Return default embedding settings
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
@ -211,8 +236,23 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
||||
}
|
||||
|
||||
|
||||
@router.get("/reranking")
|
||||
async def get_reraanking_config(request: Request, user=Depends(get_admin_user)):
|
||||
@router.post("/reranking")
|
||||
async def get_reranking_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
|
||||
"""
|
||||
Retrieve the reranking configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, return the default reranking settings.
|
||||
Otherwise, return the reranking configuration stored in the database.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the reranking configuration from the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
return {
|
||||
"status": True,
|
||||
"reranking_model": rag_config.get("reranking_model", request.app.state.config.RAG_RERANKING_MODEL),
|
||||
}
|
||||
else:
|
||||
# Return default reranking settings
|
||||
return {
|
||||
"status": True,
|
||||
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
||||
@ -321,10 +361,31 @@ class RerankingModelUpdateForm(BaseModel):
|
||||
async def update_reranking_config(
|
||||
request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
"""
|
||||
Update the reranking model configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
|
||||
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
||||
"""
|
||||
try:
|
||||
knowledge_base = Knowledges.get_knowledge_base_by_user_id(user.id)
|
||||
|
||||
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
rag_config["reranking_model"] = form_data.reranking_model
|
||||
Knowledges.update_knowledge_data_by_id(
|
||||
id=knowledge_base.id, data={"rag_config": rag_config}
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"reranking_model": rag_config["reranking_model"],
|
||||
"message": "Reranking model updated in the database.",
|
||||
}
|
||||
else:
|
||||
# Update the global configuration
|
||||
log.info(
|
||||
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
|
||||
)
|
||||
try:
|
||||
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
|
||||
|
||||
try:
|
||||
@ -339,6 +400,7 @@ async def update_reranking_config(
|
||||
return {
|
||||
"status": True,
|
||||
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
||||
"message": "Reranking model updated globally.",
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating reranking model: {e}")
|
||||
@ -348,8 +410,92 @@ async def update_reranking_config(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
@router.post("/config")
|
||||
async def get_rag_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_admin_user)):
|
||||
"""
|
||||
Retrieve the full RAG configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, return the default settings.
|
||||
Otherwise, return the RAG configuration stored in the database.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the RAG configuration from the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
return {
|
||||
"status": True,
|
||||
# RAG settings
|
||||
"RAG_TEMPLATE": rag_config.get("template", request.app.state.config.RAG_TEMPLATE),
|
||||
"TOP_K": rag_config.get("top_k", request.app.state.config.TOP_K),
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": rag_config.get("bypass_embedding_and_retrieval", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL),
|
||||
"RAG_FULL_CONTEXT": rag_config.get("rag_full_context", request.app.state.config.RAG_FULL_CONTEXT),
|
||||
# Hybrid search settings
|
||||
"ENABLE_RAG_HYBRID_SEARCH": rag_config.get("enable_rag_hybrid_search", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH),
|
||||
"TOP_K_RERANKER": rag_config.get("top_k_reranker", request.app.state.config.TOP_K_RERANKER),
|
||||
"RELEVANCE_THRESHOLD": rag_config.get("relevance_threshold", request.app.state.config.RELEVANCE_THRESHOLD),
|
||||
# Content extraction settings
|
||||
"CONTENT_EXTRACTION_ENGINE": rag_config.get("content_extraction_engine", request.app.state.config.CONTENT_EXTRACTION_ENGINE),
|
||||
"PDF_EXTRACT_IMAGES": rag_config.get("pdf_extract_images", request.app.state.config.PDF_EXTRACT_IMAGES),
|
||||
"TIKA_SERVER_URL": rag_config.get("tika_server_url", request.app.state.config.TIKA_SERVER_URL),
|
||||
"DOCLING_SERVER_URL": rag_config.get("docling_server_url", request.app.state.config.DOCLING_SERVER_URL),
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("document_intelligence_endpoint", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT),
|
||||
"DOCUMENT_INTELLIGENCE_KEY": rag_config.get("document_intelligence_key", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY),
|
||||
"MISTRAL_OCR_API_KEY": rag_config.get("mistral_ocr_api_key", request.app.state.config.MISTRAL_OCR_API_KEY),
|
||||
# Chunking settings
|
||||
"TEXT_SPLITTER": rag_config.get("text_splitter", request.app.state.config.TEXT_SPLITTER),
|
||||
"CHUNK_SIZE": rag_config.get("chunk_size", request.app.state.config.CHUNK_SIZE),
|
||||
"CHUNK_OVERLAP": rag_config.get("chunk_overlap", request.app.state.config.CHUNK_OVERLAP),
|
||||
# File upload settings
|
||||
"FILE_MAX_SIZE": rag_config.get("file_max_size", request.app.state.config.FILE_MAX_SIZE),
|
||||
"FILE_MAX_COUNT": rag_config.get("file_max_count", request.app.state.config.FILE_MAX_COUNT),
|
||||
# Integration settings
|
||||
"ENABLE_GOOGLE_DRIVE_INTEGRATION": rag_config.get("enable_google_drive_integration", request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION),
|
||||
"ENABLE_ONEDRIVE_INTEGRATION": rag_config.get("enable_onedrive_integration", request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION),
|
||||
# Web search settings
|
||||
"web": {
|
||||
"ENABLE_WEB_SEARCH": rag_config.get("enable_web_search", request.app.state.config.ENABLE_WEB_SEARCH),
|
||||
"WEB_SEARCH_ENGINE": rag_config.get("web_search_engine", request.app.state.config.WEB_SEARCH_ENGINE),
|
||||
"WEB_SEARCH_TRUST_ENV": rag_config.get("web_search_trust_env", request.app.state.config.WEB_SEARCH_TRUST_ENV),
|
||||
"WEB_SEARCH_RESULT_COUNT": rag_config.get("web_search_result_count", request.app.state.config.WEB_SEARCH_RESULT_COUNT),
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS": rag_config.get("web_search_concurrent_requests", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS),
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": rag_config.get("web_search_domain_filter_list", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST),
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": rag_config.get("bypass_web_search_embedding_and_retrieval", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL),
|
||||
"SEARXNG_QUERY_URL": rag_config.get("searxng_query_url", request.app.state.config.SEARXNG_QUERY_URL),
|
||||
"GOOGLE_PSE_API_KEY": rag_config.get("google_pse_api_key", request.app.state.config.GOOGLE_PSE_API_KEY),
|
||||
"GOOGLE_PSE_ENGINE_ID": rag_config.get("google_pse_engine_id", request.app.state.config.GOOGLE_PSE_ENGINE_ID),
|
||||
"BRAVE_SEARCH_API_KEY": rag_config.get("brave_search_api_key", request.app.state.config.BRAVE_SEARCH_API_KEY),
|
||||
"KAGI_SEARCH_API_KEY": rag_config.get("kagi_search_api_key", request.app.state.config.KAGI_SEARCH_API_KEY),
|
||||
"MOJEEK_SEARCH_API_KEY": rag_config.get("mojeek_search_api_key", request.app.state.config.MOJEEK_SEARCH_API_KEY),
|
||||
"BOCHA_SEARCH_API_KEY": rag_config.get("bocha_search_api_key", request.app.state.config.BOCHA_SEARCH_API_KEY),
|
||||
"SERPSTACK_API_KEY": rag_config.get("serpstack_api_key", request.app.state.config.SERPSTACK_API_KEY),
|
||||
"SERPSTACK_HTTPS": rag_config.get("serpstack_https", request.app.state.config.SERPSTACK_HTTPS),
|
||||
"SERPER_API_KEY": rag_config.get("serper_api_key", request.app.state.config.SERPER_API_KEY),
|
||||
"SERPLY_API_KEY": rag_config.get("serply_api_key", request.app.state.config.SERPLY_API_KEY),
|
||||
"TAVILY_API_KEY": rag_config.get("tavily_api_key", request.app.state.config.TAVILY_API_KEY),
|
||||
"SEARCHAPI_API_KEY": rag_config.get("searchapi_api_key", request.app.state.config.SEARCHAPI_API_KEY),
|
||||
"SEARCHAPI_ENGINE": rag_config.get("searchapi_engine", request.app.state.config.SEARCHAPI_ENGINE),
|
||||
"SERPAPI_API_KEY": rag_config.get("serpapi_api_key", request.app.state.config.SERPAPI_API_KEY),
|
||||
"SERPAPI_ENGINE": rag_config.get("serpapi_engine", request.app.state.config.SERPAPI_ENGINE),
|
||||
"JINA_API_KEY": rag_config.get("jina_api_key", request.app.state.config.JINA_API_KEY),
|
||||
"BING_SEARCH_V7_ENDPOINT": rag_config.get("bing_search_v7_endpoint", request.app.state.config.BING_SEARCH_V7_ENDPOINT),
|
||||
"BING_SEARCH_V7_SUBSCRIPTION_KEY": rag_config.get("bing_search_v7_subscription_key", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY),
|
||||
"EXA_API_KEY": rag_config.get("exa_api_key", request.app.state.config.EXA_API_KEY),
|
||||
"PERPLEXITY_API_KEY": rag_config.get("perplexity_api_key", request.app.state.config.PERPLEXITY_API_KEY),
|
||||
"SOUGOU_API_SID": rag_config.get("sougou_api_sid", request.app.state.config.SOUGOU_API_SID),
|
||||
"SOUGOU_API_SK": rag_config.get("sougou_api_sk", request.app.state.config.SOUGOU_API_SK),
|
||||
"WEB_LOADER_ENGINE": rag_config.get("web_loader_engine", request.app.state.config.WEB_LOADER_ENGINE),
|
||||
"ENABLE_WEB_LOADER_SSL_VERIFICATION": rag_config.get("enable_web_loader_ssl_verification", request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION),
|
||||
"PLAYWRIGHT_WS_URL": rag_config.get("playwright_ws_url", request.app.state.config.PLAYWRIGHT_WS_URL),
|
||||
"PLAYWRIGHT_TIMEOUT": rag_config.get("playwright_timeout", request.app.state.config.PLAYWRIGHT_TIMEOUT),
|
||||
"FIRECRAWL_API_KEY": rag_config.get("firecrawl_api_key", request.app.state.config.FIRECRAWL_API_KEY),
|
||||
"FIRECRAWL_API_BASE_URL": rag_config.get("firecrawl_api_base_url", request.app.state.config.FIRECRAWL_API_BASE_URL),
|
||||
"TAVILY_EXTRACT_DEPTH": rag_config.get("tavily_extract_depth", request.app.state.config.TAVILY_EXTRACT_DEPTH),
|
||||
"YOUTUBE_LOADER_LANGUAGE": rag_config.get("youtube_loader_language", request.app.state.config.YOUTUBE_LOADER_LANGUAGE),
|
||||
"YOUTUBE_LOADER_PROXY_URL": rag_config.get("youtube_loader_proxy_url", request.app.state.config.YOUTUBE_LOADER_PROXY_URL),
|
||||
"YOUTUBE_LOADER_TRANSLATION": rag_config.get("youtube_loader_translation", request.app.state.config.YOUTUBE_LOADER_TRANSLATION),
|
||||
},
|
||||
}
|
||||
else:
|
||||
# Return default RAG settings
|
||||
return {
|
||||
"status": True,
|
||||
# RAG settings
|
||||
@ -508,8 +654,33 @@ class ConfigForm(BaseModel):
|
||||
|
||||
@router.post("/config/update")
|
||||
async def update_rag_config(
|
||||
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
|
||||
request: Request, form_data: ConfigForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user)
|
||||
):
|
||||
"""
|
||||
Update the RAG configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
|
||||
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
|
||||
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
|
||||
# Update only the provided fields in the rag_config
|
||||
for field, value in form_data.dict(exclude_unset=True).items():
|
||||
if field == "web" and value is not None:
|
||||
rag_config["web"] = {**rag_config.get("web", {}), **value.dict(exclude_unset=True)}
|
||||
else:
|
||||
rag_config[field] = value
|
||||
|
||||
Knowledges.update_knowledge_data_by_id(
|
||||
id=knowledge_base.id, data={"rag_config": rag_config}
|
||||
)
|
||||
|
||||
return rag_config
|
||||
else:
|
||||
# Update the global configuration
|
||||
# RAG settings
|
||||
request.app.state.config.RAG_TEMPLATE = (
|
||||
form_data.RAG_TEMPLATE
|
||||
@ -538,7 +709,6 @@ async def update_rag_config(
|
||||
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
||||
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
)
|
||||
# Free up memory if hybrid search is disabled
|
||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
request.app.state.rf = None
|
||||
|
||||
@ -822,6 +992,28 @@ def save_docs_to_vector_db(
|
||||
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
|
||||
)
|
||||
|
||||
# Retrieve the knowledge base using the collection_name
|
||||
knowledge_base = Knowledges.get_knowledge_base_by_collection_name(collection_name)
|
||||
if not knowledge_base:
|
||||
raise ValueError(f"Knowledge base not found for collection: {collection_name}")
|
||||
|
||||
# Retrieve the RAG configuration
|
||||
rag_config = {}
|
||||
if not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
|
||||
# Use knowledge-base-specific or default configurations
|
||||
text_splitter_type = rag_config.get("text_splitter", request.app.state.config.TEXT_SPLITTER)
|
||||
chunk_size = rag_config.get("chunk_size", request.app.state.config.CHUNK_SIZE)
|
||||
chunk_overlap = rag_config.get("chunk_overlap", request.app.state.config.CHUNK_OVERLAP)
|
||||
embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE)
|
||||
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE)
|
||||
openai_api_base_url = rag_config.get("openai_api_base_url", request.app.state.config.RAG_OPENAI_API_BASE_URL)
|
||||
openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY)
|
||||
ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
|
||||
ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
|
||||
|
||||
# Check if entries with the same hash (metadata.hash) already exist
|
||||
if metadata and "hash" in metadata:
|
||||
result = VECTOR_DB_CLIENT.query(
|
||||
@ -836,13 +1028,13 @@ def save_docs_to_vector_db(
|
||||
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
|
||||
|
||||
if split:
|
||||
if request.app.state.config.TEXT_SPLITTER in ["", "character"]:
|
||||
if text_splitter_type in ["", "character"]:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=request.app.state.config.CHUNK_SIZE,
|
||||
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
add_start_index=True,
|
||||
)
|
||||
elif request.app.state.config.TEXT_SPLITTER == "token":
|
||||
elif text_splitter_type == "token":
|
||||
log.info(
|
||||
f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
|
||||
)
|
||||
@ -850,8 +1042,8 @@ def save_docs_to_vector_db(
|
||||
tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME))
|
||||
text_splitter = TokenTextSplitter(
|
||||
encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME),
|
||||
chunk_size=request.app.state.config.CHUNK_SIZE,
|
||||
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
add_start_index=True,
|
||||
)
|
||||
else:
|
||||
@ -869,8 +1061,8 @@ def save_docs_to_vector_db(
|
||||
**(metadata if metadata else {}),
|
||||
"embedding_config": json.dumps(
|
||||
{
|
||||
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"engine": embedding_engine,
|
||||
"model": embedding_model,
|
||||
}
|
||||
),
|
||||
}
|
||||
@ -903,20 +1095,20 @@ def save_docs_to_vector_db(
|
||||
|
||||
log.info(f"adding to collection {collection_name}")
|
||||
embedding_function = get_embedding_function(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
request.app.state.ef,
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||
openai_api_base_url
|
||||
if embedding_engine == "openai"
|
||||
else ollama_base_url
|
||||
),
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_KEY
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
||||
openai_api_key
|
||||
if embedding_engine == "openai"
|
||||
else ollama_api_key
|
||||
),
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
embedding_batch_size,
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
@ -966,6 +1158,39 @@ def process_file(
|
||||
if collection_name is None:
|
||||
collection_name = f"file-{file.id}"
|
||||
|
||||
# Retrieve the knowledge base using the collection name
|
||||
knowledge_base = Knowledges.get_knowledge_base_by_collection_name(collection_name)
|
||||
if not knowledge_base:
|
||||
raise ValueError(f"Knowledge base not found for collection: {collection_name}")
|
||||
|
||||
# Retrieve the RAG configuration
|
||||
rag_config = {}
|
||||
if not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
|
||||
# Use knowledge-base-specific or default configurations
|
||||
content_extraction_engine = rag_config.get(
|
||||
"content_extraction_engine", request.app.state.config.CONTENT_EXTRACTION_ENGINE
|
||||
)
|
||||
tika_server_url = rag_config.get(
|
||||
"tika_server_url", request.app.state.config.TIKA_SERVER_URL
|
||||
)
|
||||
docling_server_url = rag_config.get(
|
||||
"docling_server_url", request.app.state.config.DOCLING_SERVER_URL
|
||||
)
|
||||
pdf_extract_images = rag_config.get(
|
||||
"pdf_extract_images", request.app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
document_intelligence_endpoint = rag_config.get(
|
||||
"document_intelligence_endpoint", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
)
|
||||
document_intelligence_key = rag_config.get(
|
||||
"document_intelligence_key", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY
|
||||
)
|
||||
mistral_ocr_api_key = rag_config.get(
|
||||
"mistral_ocr_api_key", request.app.state.config.MISTRAL_OCR_API_KEY
|
||||
)
|
||||
|
||||
if form_data.content:
|
||||
# Update the content in the file
|
||||
# Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
|
||||
@ -1029,13 +1254,13 @@ def process_file(
|
||||
if file_path:
|
||||
file_path = Storage.get_file(file_path)
|
||||
loader = Loader(
|
||||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||
engine=content_extraction_engine,
|
||||
TIKA_SERVER_URL=tika_server_url,
|
||||
DOCLING_SERVER_URL=docling_server_url,
|
||||
PDF_EXTRACT_IMAGES=pdf_extract_images,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=document_intelligence_endpoint,
|
||||
DOCUMENT_INTELLIGENCE_KEY=document_intelligence_key,
|
||||
MISTRAL_OCR_API_KEY=mistral_ocr_api_key,
|
||||
)
|
||||
docs = loader.load(
|
||||
file.filename, file.meta.get("content_type"), file_path
|
||||
@ -1078,7 +1303,7 @@ def process_file(
|
||||
hash = calculate_sha256_string(text_content)
|
||||
Files.update_file_hash_by_id(file.id, hash)
|
||||
|
||||
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||
if not rag_config.get("bypass_embedding_and_retrieval", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL):
|
||||
try:
|
||||
result = save_docs_to_vector_db(
|
||||
request,
|
||||
|
Loading…
Reference in New Issue
Block a user