From b1937a9dfa114eaeb4e630aabd7c480657c34d37 Mon Sep 17 00:00:00 2001 From: weberm1 Date: Fri, 9 May 2025 22:44:53 +0200 Subject: [PATCH] Fix: refactoring of variables and fixing issue of process_file logic for individual rag config --- backend/open_webui/routers/retrieval.py | 244 ++++++++++++------------ 1 file changed, 126 insertions(+), 118 deletions(-) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index adb4f24ec..ad45d7285 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -213,7 +213,7 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF Otherwise, return the embedding configuration stored in the database. """ knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True): + if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): # Return the embedding configuration from the database rag_config = knowledge_base.data.get("rag_config", {}) return { @@ -256,7 +256,7 @@ async def get_reranking_config(request: Request, collectionForm: CollectionNameF Otherwise, return the reranking configuration stored in the database. """ knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True): + if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): # Return the reranking configuration from the database rag_config = knowledge_base.data.get("rag_config", {}) return { @@ -293,6 +293,7 @@ class EmbeddingModelUpdateForm(BaseModel): async def update_embedding_config( request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): + # TODO Update for individual rag config log.info( f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) @@ -367,11 +368,11 @@ async def update_embedding_config( class RerankingModelUpdateForm(BaseModel): reranking_model: str - + collection_name: Optional[str] @router.post("/reranking/update") async def update_reranking_config( - request: Request, form_data: RerankingModelUpdateForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user) + request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) ): """ Update the reranking model configuration. @@ -379,9 +380,9 @@ async def update_reranking_config( Otherwise, update the RAG configuration in the database for the user's knowledge base. """ try: - knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - - if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True): + knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name) + # TODO UPdate reranking accoridngly + if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): # Update the RAG configuration in the database rag_config = knowledge_base.data.get("rag_config", {}) rag_config["reranking_model"] = form_data.reranking_model @@ -430,88 +431,90 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u Otherwise, return the RAG configuration stored in the database. """ knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True): + if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): # Return the RAG configuration from the database rag_config = knowledge_base.data.get("rag_config", {}) + web_config = rag_config.get("web", {}) return { "status": True, # RAG settings - "RAG_TEMPLATE": rag_config.get("template", request.app.state.config.RAG_TEMPLATE), - "TOP_K": rag_config.get("top_k", request.app.state.config.TOP_K), - "BYPASS_EMBEDDING_AND_RETRIEVAL": rag_config.get("bypass_embedding_and_retrieval", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL), - "RAG_FULL_CONTEXT": rag_config.get("rag_full_context", request.app.state.config.RAG_FULL_CONTEXT), + "RAG_TEMPLATE": rag_config.get("TEMPLATE", request.app.state.config.RAG_TEMPLATE), + "TOP_K": rag_config.get("TOP_K", request.app.state.config.TOP_K), + "BYPASS_EMBEDDING_AND_RETRIEVAL": rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL), + "RAG_FULL_CONTEXT": rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT), # Hybrid search settings - "ENABLE_RAG_HYBRID_SEARCH": rag_config.get("enable_rag_hybrid_search", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH), - "TOP_K_RERANKER": rag_config.get("top_k_reranker", request.app.state.config.TOP_K_RERANKER), - "RELEVANCE_THRESHOLD": rag_config.get("relevance_threshold", request.app.state.config.RELEVANCE_THRESHOLD), + "ENABLE_RAG_HYBRID_SEARCH": rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH), + "TOP_K_RERANKER": rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER), + "RELEVANCE_THRESHOLD": rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD), # Content extraction settings - "CONTENT_EXTRACTION_ENGINE": rag_config.get("content_extraction_engine", request.app.state.config.CONTENT_EXTRACTION_ENGINE), - "PDF_EXTRACT_IMAGES": rag_config.get("pdf_extract_images", request.app.state.config.PDF_EXTRACT_IMAGES), - "TIKA_SERVER_URL": rag_config.get("tika_server_url", request.app.state.config.TIKA_SERVER_URL), - "DOCLING_SERVER_URL": rag_config.get("docling_server_url", request.app.state.config.DOCLING_SERVER_URL), - "DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("document_intelligence_endpoint", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT), - "DOCUMENT_INTELLIGENCE_KEY": rag_config.get("document_intelligence_key", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY), - "MISTRAL_OCR_API_KEY": rag_config.get("mistral_ocr_api_key", request.app.state.config.MISTRAL_OCR_API_KEY), + "CONTENT_EXTRACTION_ENGINE": rag_config.get("CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE), + "PDF_EXTRACT_IMAGES": rag_config.get("PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES), + "TIKA_SERVER_URL": rag_config.get("TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL), + "DOCLING_SERVER_URL": rag_config.get("DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL), + "DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT), + "DOCUMENT_INTELLIGENCE_KEY": rag_config.get("DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY), + "MISTRAL_OCR_API_KEY": rag_config.get("MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY), # Chunking settings - "TEXT_SPLITTER": rag_config.get("text_splitter", request.app.state.config.TEXT_SPLITTER), - "CHUNK_SIZE": rag_config.get("chunk_size", request.app.state.config.CHUNK_SIZE), - "CHUNK_OVERLAP": rag_config.get("chunk_overlap", request.app.state.config.CHUNK_OVERLAP), + "TEXT_SPLITTER": rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER), + "CHUNK_SIZE": rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE), + "CHUNK_OVERLAP": rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP), # File upload settings - "FILE_MAX_SIZE": rag_config.get("file_max_size", request.app.state.config.FILE_MAX_SIZE), - "FILE_MAX_COUNT": rag_config.get("file_max_count", request.app.state.config.FILE_MAX_COUNT), + "FILE_MAX_SIZE": rag_config.get("FILE_MAX_SIZE", request.app.state.config.FILE_MAX_SIZE), + "FILE_MAX_COUNT": rag_config.get("FILE_MAX_COUNT", request.app.state.config.FILE_MAX_COUNT), # Integration settings - "ENABLE_GOOGLE_DRIVE_INTEGRATION": rag_config.get("enable_google_drive_integration", request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION), + "ENABLE_GOOGLE_DRIVE_INTEGRATION": rag_config.get("ENABLE_GOOGLE_DRIVE_INTEGRATION", request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION), "ENABLE_ONEDRIVE_INTEGRATION": rag_config.get("enable_onedrive_integration", request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION), # Web search settings "web": { - "ENABLE_WEB_SEARCH": rag_config.get("enable_web_search", request.app.state.config.ENABLE_WEB_SEARCH), - "WEB_SEARCH_ENGINE": rag_config.get("web_search_engine", request.app.state.config.WEB_SEARCH_ENGINE), - "WEB_SEARCH_TRUST_ENV": rag_config.get("web_search_trust_env", request.app.state.config.WEB_SEARCH_TRUST_ENV), - "WEB_SEARCH_RESULT_COUNT": rag_config.get("web_search_result_count", request.app.state.config.WEB_SEARCH_RESULT_COUNT), - "WEB_SEARCH_CONCURRENT_REQUESTS": rag_config.get("web_search_concurrent_requests", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS), - "WEB_SEARCH_DOMAIN_FILTER_LIST": rag_config.get("web_search_domain_filter_list", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST), - "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": rag_config.get("bypass_web_search_embedding_and_retrieval", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL), - "SEARXNG_QUERY_URL": rag_config.get("searxng_query_url", request.app.state.config.SEARXNG_QUERY_URL), - "YACY_QUERY_URL": rag_config.get("yacy_query_url", request.app.state.config.YACY_QUERY_URL), - "YACY_USERNAME": rag_config.get("yacy_query_username",request.app.state.config.YACY_USERNAME), - "YACY_PASSWORD": rag_config.get("yacy_query_password",request.app.state.config.YACY_PASSWORD), - "GOOGLE_PSE_API_KEY": rag_config.get("google_pse_api_key", request.app.state.config.GOOGLE_PSE_API_KEY), - "GOOGLE_PSE_ENGINE_ID": rag_config.get("google_pse_engine_id", request.app.state.config.GOOGLE_PSE_ENGINE_ID), - "BRAVE_SEARCH_API_KEY": rag_config.get("brave_search_api_key", request.app.state.config.BRAVE_SEARCH_API_KEY), - "KAGI_SEARCH_API_KEY": rag_config.get("kagi_search_api_key", request.app.state.config.KAGI_SEARCH_API_KEY), - "MOJEEK_SEARCH_API_KEY": rag_config.get("mojeek_search_api_key", request.app.state.config.MOJEEK_SEARCH_API_KEY), - "BOCHA_SEARCH_API_KEY": rag_config.get("bocha_search_api_key", request.app.state.config.BOCHA_SEARCH_API_KEY), - "SERPSTACK_API_KEY": rag_config.get("serpstack_api_key", request.app.state.config.SERPSTACK_API_KEY), - "SERPSTACK_HTTPS": rag_config.get("serpstack_https", request.app.state.config.SERPSTACK_HTTPS), - "SERPER_API_KEY": rag_config.get("serper_api_key", request.app.state.config.SERPER_API_KEY), - "SERPLY_API_KEY": rag_config.get("serply_api_key", request.app.state.config.SERPLY_API_KEY), - "TAVILY_API_KEY": rag_config.get("tavily_api_key", request.app.state.config.TAVILY_API_KEY), - "SEARCHAPI_API_KEY": rag_config.get("searchapi_api_key", request.app.state.config.SEARCHAPI_API_KEY), - "SEARCHAPI_ENGINE": rag_config.get("searchapi_engine", request.app.state.config.SEARCHAPI_ENGINE), - "SERPAPI_API_KEY": rag_config.get("serpapi_api_key", request.app.state.config.SERPAPI_API_KEY), - "SERPAPI_ENGINE": rag_config.get("serpapi_engine", request.app.state.config.SERPAPI_ENGINE), - "JINA_API_KEY": rag_config.get("jina_api_key", request.app.state.config.JINA_API_KEY), - "BING_SEARCH_V7_ENDPOINT": rag_config.get("bing_search_v7_endpoint", request.app.state.config.BING_SEARCH_V7_ENDPOINT), - "BING_SEARCH_V7_SUBSCRIPTION_KEY": rag_config.get("bing_search_v7_subscription_key", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY), - "EXA_API_KEY": rag_config.get("exa_api_key", request.app.state.config.EXA_API_KEY), - "PERPLEXITY_API_KEY": rag_config.get("perplexity_api_key", request.app.state.config.PERPLEXITY_API_KEY), - "SOUGOU_API_SID": rag_config.get("sougou_api_sid", request.app.state.config.SOUGOU_API_SID), - "SOUGOU_API_SK": rag_config.get("sougou_api_sk", request.app.state.config.SOUGOU_API_SK), - "WEB_LOADER_ENGINE": rag_config.get("web_loader_engine", request.app.state.config.WEB_LOADER_ENGINE), - "ENABLE_WEB_LOADER_SSL_VERIFICATION": rag_config.get("enable_web_loader_ssl_verification", request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION), - "PLAYWRIGHT_WS_URL": rag_config.get("playwright_ws_url", request.app.state.config.PLAYWRIGHT_WS_URL), - "PLAYWRIGHT_TIMEOUT": rag_config.get("playwright_timeout", request.app.state.config.PLAYWRIGHT_TIMEOUT), - "FIRECRAWL_API_KEY": rag_config.get("firecrawl_api_key", request.app.state.config.FIRECRAWL_API_KEY), - "FIRECRAWL_API_BASE_URL": rag_config.get("firecrawl_api_base_url", request.app.state.config.FIRECRAWL_API_BASE_URL), - "TAVILY_EXTRACT_DEPTH": rag_config.get("tavily_extract_depth", request.app.state.config.TAVILY_EXTRACT_DEPTH), - "EXTERNAL_WEB_SEARCH_URL": rag_config.get("web_search_url", request.app.state.config.EXTERNAL_WEB_SEARCH_URL), - "EXTERNAL_WEB_SEARCH_API_KEY": rag_config.get("web_search_key", request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY), - "EXTERNAL_WEB_LOADER_URL": rag_config.get("web_loader_url", request.app.state.config.EXTERNAL_WEB_LOADER_URL), - "EXTERNAL_WEB_LOADER_API_KEY": rag_config.get("web_loader_key", request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY), - "YOUTUBE_LOADER_LANGUAGE": rag_config.get("youtube_loader_language", request.app.state.config.YOUTUBE_LOADER_LANGUAGE), - "YOUTUBE_LOADER_PROXY_URL": rag_config.get("youtube_loader_proxy_url", request.app.state.config.YOUTUBE_LOADER_PROXY_URL), - "YOUTUBE_LOADER_TRANSLATION": rag_config.get("youtube_loader_translation", request.app.state.config.YOUTUBE_LOADER_TRANSLATION), + "ENABLE_WEB_SEARCH": web_config.get("ENABLE_WEB_SEARCH", request.app.state.config.ENABLE_WEB_SEARCH), + "WEB_SEARCH_ENGINE": web_config.get("WEB_SEARCH_ENGINE", request.app.state.config.WEB_SEARCH_ENGINE), + "WEB_SEARCH_TRUST_ENV": web_config.get("WEB_SEARCH_TRUST_ENV", request.app.state.config.WEB_SEARCH_TRUST_ENV), + "WEB_SEARCH_RESULT_COUNT": web_config.get("WEB_SEARCH_RESULT_COUNT", request.app.state.config.WEB_SEARCH_RESULT_COUNT), + "WEB_SEARCH_CONCURRENT_REQUESTS": web_config.get("WEB_SEARCH_CONCURRENT_REQUESTS", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS), + "WEB_SEARCH_DOMAIN_FILTER_LIST": web_config.get("WEB_SEARCH_DOMAIN_FILTER_LIST", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST), + "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": web_config.get("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL), + "SEARXNG_QUERY_URL": web_config.get("SEARXNG_QUERY_URL", request.app.state.config.SEARXNG_QUERY_URL), + "YACY_QUERY_URL": web_config.get("YACY_QUERY_URL", request.app.state.config.YACY_QUERY_URL), + "YACY_USERNAME": web_config.get("YACY_QUERY_USERNAME",request.app.state.config.YACY_USERNAME), + "YACY_PASSWORD": web_config.get("YACY_QUERY_PASSWORD",request.app.state.config.YACY_PASSWORD), + "GOOGLE_PSE_API_KEY": web_config.get("GOOGLE_PSE_API_KEY", request.app.state.config.GOOGLE_PSE_API_KEY), + "GOOGLE_PSE_ENGINE_ID": web_config.get("GOOGLE_PSE_ENGINE_ID", request.app.state.config.GOOGLE_PSE_ENGINE_ID), + "BRAVE_SEARCH_API_KEY": web_config.get("BRAVE_SEARCH_API_KEY", request.app.state.config.BRAVE_SEARCH_API_KEY), + "KAGI_SEARCH_API_KEY": web_config.get("KAGI_SEARCH_API_KEY", request.app.state.config.KAGI_SEARCH_API_KEY), + "MOJEEK_SEARCH_API_KEY": web_config.get("MOJEEK_SEARCH_API_KEY", request.app.state.config.MOJEEK_SEARCH_API_KEY), + "BOCHA_SEARCH_API_KEY": web_config.get("BOCHA_SEARCH_API_KEY", request.app.state.config.BOCHA_SEARCH_API_KEY), + "SERPSTACK_API_KEY": web_config.get("SERPSTACK_API_KEY", request.app.state.config.SERPSTACK_API_KEY), + "SERPSTACK_HTTPS": web_config.get("SERPSTACK_HTTPS", request.app.state.config.SERPSTACK_HTTPS), + "SERPER_API_KEY": web_config.get("SERPER_API_KEY", request.app.state.config.SERPER_API_KEY), + "SERPLY_API_KEY": web_config.get("SERPLY_API_KEY", request.app.state.config.SERPLY_API_KEY), + "TAVILY_API_KEY": web_config.get("TAVILY_API_KEY", request.app.state.config.TAVILY_API_KEY), + "SEARCHAPI_API_KEY": web_config.get("SEARCHAPI_API_KEY", request.app.state.config.SEARCHAPI_API_KEY), + "SEARCHAPI_ENGINE": web_config.get("SEARCHAPI_ENGINE", request.app.state.config.SEARCHAPI_ENGINE), + "SERPAPI_API_KEY": web_config.get("SERPAPI_API_KEY", request.app.state.config.SERPAPI_API_KEY), + "SERPAPI_ENGINE": web_config.get("SERPAPI_ENGINE", request.app.state.config.SERPAPI_ENGINE), + "JINA_API_KEY": web_config.get("JINA_API_KEY", request.app.state.config.JINA_API_KEY), + "BING_SEARCH_V7_ENDPOINT": web_config.get("BING_SEARCH_V7_ENDPOINT", request.app.state.config.BING_SEARCH_V7_ENDPOINT), + "BING_SEARCH_V7_SUBSCRIPTION_KEY": web_config.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY), + "EXA_API_KEY": web_config.get("EXA_API_KEY", request.app.state.config.EXA_API_KEY), + "PERPLEXITY_API_KEY": web_config.get("PERPLEXITY_API_KEY", request.app.state.config.PERPLEXITY_API_KEY), + "SOUGOU_API_SID": web_config.get("SOUGOU_API_SID", request.app.state.config.SOUGOU_API_SID), + "SOUGOU_API_SK": web_config.get("SOUGOU_API_SK", request.app.state.config.SOUGOU_API_SK), + "WEB_LOADER_ENGINE": web_config.get("WEB_LOADER_ENGINE", request.app.state.config.WEB_LOADER_ENGINE), + "ENABLE_WEB_LOADER_SSL_VERIFICATION": web_config.get("ENABLE_WEB_LOADER_SSL_VERIFICATION", request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION), + "PLAYWRIGHT_WS_URL": web_config.get("PLAYWRIGHT_WS_URL", request.app.state.config.PLAYWRIGHT_WS_URL), + "PLAYWRIGHT_TIMEOUT": web_config.get("PLAYWRIGHT_TIMEOUT", request.app.state.config.PLAYWRIGHT_TIMEOUT), + "FIRECRAWL_API_KEY": web_config.get("FIRECRAWL_API_KEY", request.app.state.config.FIRECRAWL_API_KEY), + "FIRECRAWL_API_BASE_URL": web_config.get("FIRECRAWL_API_BASE_URL", request.app.state.config.FIRECRAWL_API_BASE_URL), + "TAVILY_EXTRACT_DEPTH": web_config.get("TAVILY_EXTRACT_DEPTH", request.app.state.config.TAVILY_EXTRACT_DEPTH), + "EXTERNAL_WEB_SEARCH_URL": web_config.get("WEB_SEARCH_URL", request.app.state.config.EXTERNAL_WEB_SEARCH_URL), + "EXTERNAL_WEB_SEARCH_API_KEY": web_config.get("WEB_SEARCH_KEY", request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY), + "EXTERNAL_WEB_LOADER_URL": web_config.get("WEB_LOADER_URL", request.app.state.config.EXTERNAL_WEB_LOADER_URL), + "EXTERNAL_WEB_LOADER_API_KEY": web_config.get("WEB_LOADER_KEY", request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY), + "YOUTUBE_LOADER_LANGUAGE": web_config.get("YOUTUBE_LOADER_LANGUAGE", request.app.state.config.YOUTUBE_LOADER_LANGUAGE), + "YOUTUBE_LOADER_PROXY_URL": web_config.get("YOUTUBE_LOADER_PROXY_URL", request.app.state.config.YOUTUBE_LOADER_PROXY_URL), + "YOUTUBE_LOADER_TRANSLATION": web_config.get("YOUTUBE_LOADER_TRANSLATION", request.app.state.config.YOUTUBE_LOADER_TRANSLATION), }, + "DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS) } else: # Return default RAG settings @@ -594,6 +597,7 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, }, + "DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS } @@ -696,7 +700,7 @@ async def update_rag_config( """ knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True): + if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): # Update the RAG configuration in the database rag_config = knowledge_base.data.get("rag_config", {}) @@ -969,9 +973,9 @@ async def update_rag_config( "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, - "YACY_USERNAME": request.app.state.config.YACY_USERNAME, - "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, - "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, + "YACY_USERNAME": request.app.state.config.YACY_USERNAME, + "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, + "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, "GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID, "BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY, "KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY, @@ -1001,10 +1005,10 @@ async def update_rag_config( "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, - "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, - "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, - "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, - "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, + "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, + "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, + "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, }, @@ -1027,7 +1031,7 @@ def save_docs_to_vector_db( split: bool = True, add: bool = False, user=None, - knowledge_id: Optional[str] = None, + knowledge_id: Optional[str] = None ) -> bool: def _get_docs_info(docs: list[Document]) -> str: docs_info = set() @@ -1049,20 +1053,18 @@ def save_docs_to_vector_db( f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" ) - # Retrieve the knowledge base using the collection_name - knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id) - if not knowledge_base: - raise ValueError(f"Knowledge base not found for collection: {knowledge_base}") - - # Retrieve the RAG configuration rag_config = {} - if not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True): - rag_config = knowledge_base.data.get("rag_config", {}) - + # Retrieve the knowledge base using the collection_name + if knowledge_id: + knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id) + # Retrieve the RAG configuration + if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.data.get("rag_config", {}) + print("RAG CONFIG: ", rag_config) # Use knowledge-base-specific or default configurations - text_splitter_type = rag_config.get("text_splitter", request.app.state.config.TEXT_SPLITTER) - chunk_size = rag_config.get("chunk_size", request.app.state.config.CHUNK_SIZE) - chunk_overlap = rag_config.get("chunk_overlap", request.app.state.config.CHUNK_OVERLAP) + text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER) + chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE) + chunk_overlap = rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP) embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE) embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL) embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE) @@ -1199,13 +1201,13 @@ class ProcessFileForm(BaseModel): file_id: str content: Optional[str] = None collection_name: Optional[str] = None + knowledge_id: Optional[str] = None @router.post("/process/file") def process_file( request: Request, form_data: ProcessFileForm, - knowledge_id: Optional[str] = None, user=Depends(get_verified_user), ): try: @@ -1215,38 +1217,45 @@ def process_file( if collection_name is None: collection_name = f"file-{file.id}" - - # Retrieve the knowledge base using the collection name - knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id) - if not knowledge_base: - raise ValueError(f"Knowledge base not found for collection: {knowledge_base}") - - # Retrieve the RAG configuration + rag_config = {} - if not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True): - rag_config = knowledge_base.data.get("rag_config", {}) + # Retrieve the knowledge base using the collection id + if form_data.collection_name: + knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name) + + # Retrieve the RAG configuration + if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.data.get("rag_config", {}) + form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db + + elif form_data.knowledge_id: + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) + + # Retrieve the RAG configuration + if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.data.get("rag_config", {}) # Use knowledge-base-specific or default configurations content_extraction_engine = rag_config.get( - "content_extraction_engine", request.app.state.config.CONTENT_EXTRACTION_ENGINE + "CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE ) tika_server_url = rag_config.get( - "tika_server_url", request.app.state.config.TIKA_SERVER_URL + "TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL ) docling_server_url = rag_config.get( - "docling_server_url", request.app.state.config.DOCLING_SERVER_URL + "DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL ) pdf_extract_images = rag_config.get( - "pdf_extract_images", request.app.state.config.PDF_EXTRACT_IMAGES + "PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES ) document_intelligence_endpoint = rag_config.get( - "document_intelligence_endpoint", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT + "DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT ) document_intelligence_key = rag_config.get( - "document_intelligence_key", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY + "DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY ) mistral_ocr_api_key = rag_config.get( - "mistral_ocr_api_key", request.app.state.config.MISTRAL_OCR_API_KEY + "MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY ) if form_data.content: @@ -1361,7 +1370,7 @@ def process_file( hash = calculate_sha256_string(text_content) Files.update_file_hash_by_id(file.id, hash) - if not rag_config.get("bypass_embedding_and_retrieval", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL): + if not rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL): try: result = save_docs_to_vector_db( request, @@ -1374,7 +1383,7 @@ def process_file( }, add=(True if form_data.collection_name else False), user=user, - knowledge_id=knowledge_id, + knowledge_id=form_data.knowledge_id ) if result: @@ -1425,8 +1434,7 @@ class ProcessTextForm(BaseModel): def process_text( request: Request, form_data: ProcessTextForm, - user=Depends(get_verified_user), - knowledge_id: Optional[str] = None, + user=Depends(get_verified_user) ): collection_name = form_data.collection_name if collection_name is None: @@ -1441,7 +1449,7 @@ def process_text( text_content = form_data.content log.debug(f"text_content: {text_content}") - result = save_docs_to_vector_db(request, docs, collection_name, user=user, knowledge_id=knowledge_id) + result = save_docs_to_vector_db(request, docs, collection_name, user=user) if result: return { "status": True,