diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 5cb47373f..521649dbe 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -210,6 +210,10 @@ class SearchForm(BaseModel): queries: List[str] +class CollectionForm(BaseModel): + knowledge_id: Optional[str] = None + + @router.get("/") async def get_status(request: Request): return { @@ -224,21 +228,32 @@ async def get_status(request: Request): } -@router.get("/embedding") -async def get_embedding_config(request: Request, user=Depends(get_admin_user)): +@router.post("/embedding") +async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)): + """ + Retrieve the embedding configuration. + If DEFAULT_RAG_SETTINGS is True, return the default embedding settings. + Otherwise, return the embedding configuration stored in the database. + """ + + knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id) + rag_config = {} + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + # Return the embedding configuration from the database + rag_config = knowledge_base.rag_config return { "status": True, - "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "openai_config": { + "embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE), + "embedding_model": rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL), + "embedding_batch_size": rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE), + "openai_config": rag_config.get("openai_config", { "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, "key": request.app.state.config.RAG_OPENAI_API_KEY, - }, - "ollama_config": { + }), + "ollama_config": rag_config.get("ollama_config", { "url": request.app.state.config.RAG_OLLAMA_BASE_URL, "key": request.app.state.config.RAG_OLLAMA_API_KEY, - }, + }), } @@ -258,76 +273,207 @@ class EmbeddingModelUpdateForm(BaseModel): embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 + knowledge_id: Optional[str] = None @router.post("/embedding/update") async def update_embedding_config( - request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_verified_user) ): - log.info( - f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" - ) + """ + Update the embedding model configuration. + If DEFAULT_RAG_SETTINGS is True, update the global configuration. + Otherwise, update the RAG configuration in the database for the user's knowledge base. + """ try: - request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - - if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: - if form_data.openai_config is not None: - request.app.state.config.RAG_OPENAI_API_BASE_URL = ( - form_data.openai_config.url - ) - request.app.state.config.RAG_OPENAI_API_KEY = ( - form_data.openai_config.key - ) - - if form_data.ollama_config is not None: - request.app.state.config.RAG_OLLAMA_BASE_URL = ( - form_data.ollama_config.url - ) - request.app.state.config.RAG_OLLAMA_API_KEY = ( - form_data.ollama_config.key - ) - - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( - form_data.embedding_batch_size + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + # Update the RAG configuration in the database + rag_config = knowledge_base.rag_config + log.info( + f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}" ) - request.app.state.ef = get_ef( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - ) + # Check if model is in use elsewhere, otherwise free up memory + in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('embedding_model'), model_type="embedding_model", id=form_data.knowledge_id) - request.app.state.EMBEDDING_FUNCTION = get_embedding_function( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.ef, - ( - request.app.state.config.RAG_OPENAI_API_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_BASE_URL - ), - ( - request.app.state.config.RAG_OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_API_KEY - ), - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - ) + if not in_use and not request.app.state.ef.get(request.app.state.config.RAG_EMBEDDING_MODEL) == rag_config.get("embedding_model") and rag_config.get("embedding_model"): + del request.app.state.ef[rag_config["embedding_model"]] + engine = rag_config["embedding_engine"] + target_model = rag_config["embedding_model"] + models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine] - return { - "status": True, - "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "openai_config": { - "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, - "key": request.app.state.config.RAG_OPENAI_API_KEY, - }, - "ollama_config": { - "url": request.app.state.config.RAG_OLLAMA_BASE_URL, - "key": request.app.state.config.RAG_OLLAMA_API_KEY, - }, - } + # Find and remove the dictionary that contains the target model + for model in models_list[:]: # Create a copy of the list for safe iteration + if model == target_model: + models_list.remove(model) + + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + + # Update embedding-related fields + rag_config["embedding_engine"] = form_data.embedding_engine + rag_config["embedding_model"] = form_data.embedding_model + rag_config["embedding_batch_size"] = form_data.embedding_batch_size + + + rag_config["openai_config"] = { + "url": form_data.openai_config.url, + "key": form_data.openai_config.key, + } + + rag_config["ollama_config"] = { + "url": form_data.ollama_config.url, + "key": form_data.ollama_config.key, + } + # Update the embedding function + if not rag_config["embedding_model"] in request.app.state.ef: + request.app.state.ef[rag_config["embedding_model"]] = get_ef( + rag_config["embedding_engine"], + rag_config["embedding_model"], + ) + + request.app.state.EMBEDDING_FUNCTION["embedding_model"] = get_embedding_function( + rag_config["embedding_engine"], + rag_config["embedding_model"], + request.app.state.ef[rag_config["embedding_model"]], + ( + rag_config["openai_config"]["url"] + if rag_config["embedding_engine"] == "openai" + else rag_config["ollama_config"]["url"] + ), + ( + rag_config["openai_config"]["key"] + if rag_config["embedding_engine"] == "openai" + else rag_config["ollama_config"]["key"] + ), + rag_config["embedding_batch_size"] + ) + # add model to state for reloading on startup + request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"]) + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + # add model to state for selectable reranking models + if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]: + request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"]) + request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save() + rag_config["DOWNLOADED_EMBEDDING_MODELS"] = request.app.state.config.DOWNLOADED_EMBEDDING_MODELS + rag_config["LOADED_EMBEDDING_MODELS"] = request.app.state.config.LOADED_EMBEDDING_MODELS + + # Save the updated configuration to the database + Knowledges.update_rag_config_by_id( + id=form_data.knowledge_id, rag_config=rag_config + ) + + return { + "status": True, + "embedding_engine": rag_config["embedding_engine"], + "embedding_model": rag_config["embedding_model"], + "embedding_batch_size": rag_config["embedding_batch_size"], + "openai_config": rag_config.get("openai_config", {}), + "ollama_config": rag_config.get("ollama_config", {}), + "DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"], + "LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"], + "message": "Embedding configuration updated in the database.", + } + else: + # Update the global configuration + log.info( + f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + ) + + # Check if model is in use elsewhere, otherwise free up memory + in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_EMBEDDING_MODEL, model_type="embedding_model") + if not in_use: + del request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] + engine = request.app.state.config.RAG_EMBEDDING_ENGINE + target_model = request.app.state.config.RAG_EMBEDDING_MODEL + models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine] + + # Find and remove the dictionary that contains the target model + for model in models_list[:]: # Create a copy of the list for safe iteration + if model == target_model: + models_list.remove(model) + + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + + if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if form_data.openai_config is not None: + request.app.state.config.RAG_OPENAI_API_BASE_URL = ( + form_data.openai_config.url + ) + request.app.state.config.RAG_OPENAI_API_KEY = ( + form_data.openai_config.key + ) + + if form_data.ollama_config is not None: + request.app.state.config.RAG_OLLAMA_BASE_URL = ( + form_data.ollama_config.url + ) + request.app.state.config.RAG_OLLAMA_API_KEY = ( + form_data.ollama_config.key + ) + + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( + form_data.embedding_batch_size + ) + + # Update the embedding function + if not form_data.embedding_model in request.app.state.ef: + request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + ) + + request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL], + ( + request.app.state.config.RAG_OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_BASE_URL + ), + ( + request.app.state.config.RAG_OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.RAG_OLLAMA_API_KEY + ), + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + ) + # add model to state for reloading on startup + request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL) + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + # add model to state for selectable embedding models + if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]: + request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL) + request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save() + + return { + "status": True, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "openai_config": { + "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, + "key": request.app.state.config.RAG_OPENAI_API_KEY, + }, + "ollama_config": { + "url": request.app.state.config.RAG_OLLAMA_BASE_URL, + "key": request.app.state.config.RAG_OLLAMA_API_KEY, + }, + "LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS, + "DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS, + "message": "Embedding configuration updated globally.", + } except Exception as e: log.exception(f"Problem updating embedding model: {e}") raise HTTPException( @@ -336,98 +482,116 @@ async def update_embedding_config( ) -@router.get("/config") -async def get_rag_config(request: Request, user=Depends(get_admin_user)): +@router.post("/config") +async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_verified_user)): + """ + Retrieve the full RAG configuration. + If DEFAULT_RAG_SETTINGS is True, return the default settings. + Otherwise, return the RAG configuration stored in the database. + """ + + knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id) + rag_config = {} + web_config = {} + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + # Return the RAG configuration from the database + rag_config = knowledge_base.rag_config + web_config = rag_config.get("web", {}) return { "status": True, # RAG settings - "RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE, - "TOP_K": request.app.state.config.TOP_K, - "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, - "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, + "RAG_TEMPLATE": rag_config.get("TEMPLATE", request.app.state.config.RAG_TEMPLATE), + "TOP_K": rag_config.get("TOP_K", request.app.state.config.TOP_K), + "BYPASS_EMBEDDING_AND_RETRIEVAL": rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL), + "RAG_FULL_CONTEXT": rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT), # Hybrid search settings - "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, - "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, + "ENABLE_RAG_HYBRID_SEARCH": rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH), + "TOP_K_RERANKER": rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER), + "RELEVANCE_THRESHOLD": rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD), # Content extraction settings - "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, - "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, - "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, - "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, - "DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE, - "DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG, - "DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, - "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + "CONTENT_EXTRACTION_ENGINE": rag_config.get("CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE), + "PDF_EXTRACT_IMAGES": rag_config.get("PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES), + "EXTERNAL_DOCUMENT_LOADER_URL": rag_config.get("EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL), + "EXTERNAL_DOCUMENT_LOADER_API_KEY": rag_config.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY), + "TIKA_SERVER_URL": rag_config.get("TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL), + "DOCLING_SERVER_URL": rag_config.get("DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL), + "DOCLING_OCR_ENGINE": rag_config.get("DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE), + "DOCLING_OCR_LANG": rag_config.get("DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG), + "DOCLING_DO_PICTURE_DESCRIPTION": rag_config.get("DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION), + "DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT), + "DOCUMENT_INTELLIGENCE_KEY": rag_config.get("DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY), + "MISTRAL_OCR_API_KEY": rag_config.get("MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY), # Reranking settings - "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, - "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, - "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + "RAG_RERANKING_MODEL": rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL), + "RAG_RERANKING_ENGINE": rag_config.get("RAG_RERANKING_ENGINE", request.app.state.config.RAG_RERANKING_ENGINE), + "RAG_EXTERNAL_RERANKER_URL": rag_config.get("RAG_EXTERNAL_RERANKER_URL", request.app.state.config.RAG_EXTERNAL_RERANKER_URL), + "RAG_EXTERNAL_RERANKER_API_KEY": rag_config.get("RAG_EXTERNAL_RERANKER_API_KEY", request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY), # Chunking settings - "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, - "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, - "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, + "TEXT_SPLITTER": rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER), + "CHUNK_SIZE": rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE), + "CHUNK_OVERLAP": rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP), # File upload settings - "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, - "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, - "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, + "FILE_MAX_SIZE": rag_config.get("FILE_MAX_SIZE", request.app.state.config.FILE_MAX_SIZE), + "FILE_MAX_COUNT": rag_config.get("FILE_MAX_COUNT", request.app.state.config.FILE_MAX_COUNT), + "ALLOWED_FILE_EXTENSIONS": rag_config.get("ALLOWED_FILE_EXTENSIONS", request.app.state.config.ALLOWED_FILE_EXTENSIONS), # Integration settings - "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, - "ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + "ENABLE_GOOGLE_DRIVE_INTEGRATION": rag_config.get("ENABLE_GOOGLE_DRIVE_INTEGRATION", request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION), + "ENABLE_ONEDRIVE_INTEGRATION": rag_config.get("enable_onedrive_integration", request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION), # Web search settings "web": { - "ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH, - "WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE, - "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, - "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, - "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, - "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, - "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, - "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, - "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, - "YACY_USERNAME": request.app.state.config.YACY_USERNAME, - "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, - "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, - "GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID, - "BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY, - "KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY, - "MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY, - "BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY, - "SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY, - "SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS, - "SERPER_API_KEY": request.app.state.config.SERPER_API_KEY, - "SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY, - "TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY, - "SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY, - "SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE, - "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, - "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, - "JINA_API_KEY": request.app.state.config.JINA_API_KEY, - "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, - "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "EXA_API_KEY": request.app.state.config.EXA_API_KEY, - "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, - "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, - "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, - "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, - "ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - "PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL, - "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, - "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, - "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, - "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, - "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, - "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, - "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, - "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, - "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, - "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, - "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, + "ENABLE_WEB_SEARCH": web_config.get("ENABLE_WEB_SEARCH", request.app.state.config.ENABLE_WEB_SEARCH), + "WEB_SEARCH_ENGINE": web_config.get("WEB_SEARCH_ENGINE", request.app.state.config.WEB_SEARCH_ENGINE), + "WEB_SEARCH_TRUST_ENV": web_config.get("WEB_SEARCH_TRUST_ENV", request.app.state.config.WEB_SEARCH_TRUST_ENV), + "WEB_SEARCH_RESULT_COUNT": web_config.get("WEB_SEARCH_RESULT_COUNT", request.app.state.config.WEB_SEARCH_RESULT_COUNT), + "WEB_SEARCH_CONCURRENT_REQUESTS": web_config.get("WEB_SEARCH_CONCURRENT_REQUESTS", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS), + "WEB_SEARCH_DOMAIN_FILTER_LIST": web_config.get("WEB_SEARCH_DOMAIN_FILTER_LIST", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST), + "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": web_config.get("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL), + "SEARXNG_QUERY_URL": web_config.get("SEARXNG_QUERY_URL", request.app.state.config.SEARXNG_QUERY_URL), + "YACY_QUERY_URL": web_config.get("YACY_QUERY_URL", request.app.state.config.YACY_QUERY_URL), + "YACY_USERNAME": web_config.get("YACY_QUERY_USERNAME",request.app.state.config.YACY_USERNAME), + "YACY_PASSWORD": web_config.get("YACY_QUERY_PASSWORD",request.app.state.config.YACY_PASSWORD), + "GOOGLE_PSE_API_KEY": web_config.get("GOOGLE_PSE_API_KEY", request.app.state.config.GOOGLE_PSE_API_KEY), + "GOOGLE_PSE_ENGINE_ID": web_config.get("GOOGLE_PSE_ENGINE_ID", request.app.state.config.GOOGLE_PSE_ENGINE_ID), + "BRAVE_SEARCH_API_KEY": web_config.get("BRAVE_SEARCH_API_KEY", request.app.state.config.BRAVE_SEARCH_API_KEY), + "KAGI_SEARCH_API_KEY": web_config.get("KAGI_SEARCH_API_KEY", request.app.state.config.KAGI_SEARCH_API_KEY), + "MOJEEK_SEARCH_API_KEY": web_config.get("MOJEEK_SEARCH_API_KEY", request.app.state.config.MOJEEK_SEARCH_API_KEY), + "BOCHA_SEARCH_API_KEY": web_config.get("BOCHA_SEARCH_API_KEY", request.app.state.config.BOCHA_SEARCH_API_KEY), + "SERPSTACK_API_KEY": web_config.get("SERPSTACK_API_KEY", request.app.state.config.SERPSTACK_API_KEY), + "SERPSTACK_HTTPS": web_config.get("SERPSTACK_HTTPS", request.app.state.config.SERPSTACK_HTTPS), + "SERPER_API_KEY": web_config.get("SERPER_API_KEY", request.app.state.config.SERPER_API_KEY), + "SERPLY_API_KEY": web_config.get("SERPLY_API_KEY", request.app.state.config.SERPLY_API_KEY), + "TAVILY_API_KEY": web_config.get("TAVILY_API_KEY", request.app.state.config.TAVILY_API_KEY), + "SEARCHAPI_API_KEY": web_config.get("SEARCHAPI_API_KEY", request.app.state.config.SEARCHAPI_API_KEY), + "SEARCHAPI_ENGINE": web_config.get("SEARCHAPI_ENGINE", request.app.state.config.SEARCHAPI_ENGINE), + "SERPAPI_API_KEY": web_config.get("SERPAPI_API_KEY", request.app.state.config.SERPAPI_API_KEY), + "SERPAPI_ENGINE": web_config.get("SERPAPI_ENGINE", request.app.state.config.SERPAPI_ENGINE), + "JINA_API_KEY": web_config.get("JINA_API_KEY", request.app.state.config.JINA_API_KEY), + "BING_SEARCH_V7_ENDPOINT": web_config.get("BING_SEARCH_V7_ENDPOINT", request.app.state.config.BING_SEARCH_V7_ENDPOINT), + "BING_SEARCH_V7_SUBSCRIPTION_KEY": web_config.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY), + "EXA_API_KEY": web_config.get("EXA_API_KEY", request.app.state.config.EXA_API_KEY), + "PERPLEXITY_API_KEY": web_config.get("PERPLEXITY_API_KEY", request.app.state.config.PERPLEXITY_API_KEY), + "SOUGOU_API_SID": web_config.get("SOUGOU_API_SID", request.app.state.config.SOUGOU_API_SID), + "SOUGOU_API_SK": web_config.get("SOUGOU_API_SK", request.app.state.config.SOUGOU_API_SK), + "WEB_LOADER_ENGINE": web_config.get("WEB_LOADER_ENGINE", request.app.state.config.WEB_LOADER_ENGINE), + "ENABLE_WEB_LOADER_SSL_VERIFICATION": web_config.get("ENABLE_WEB_LOADER_SSL_VERIFICATION", request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION), + "PLAYWRIGHT_WS_URL": web_config.get("PLAYWRIGHT_WS_URL", request.app.state.config.PLAYWRIGHT_WS_URL), + "PLAYWRIGHT_TIMEOUT": web_config.get("PLAYWRIGHT_TIMEOUT", request.app.state.config.PLAYWRIGHT_TIMEOUT), + "FIRECRAWL_API_KEY": web_config.get("FIRECRAWL_API_KEY", request.app.state.config.FIRECRAWL_API_KEY), + "FIRECRAWL_API_BASE_URL": web_config.get("FIRECRAWL_API_BASE_URL", request.app.state.config.FIRECRAWL_API_BASE_URL), + "TAVILY_EXTRACT_DEPTH": web_config.get("TAVILY_EXTRACT_DEPTH", request.app.state.config.TAVILY_EXTRACT_DEPTH), + "EXTERNAL_WEB_SEARCH_URL": web_config.get("WEB_SEARCH_URL", request.app.state.config.EXTERNAL_WEB_SEARCH_URL), + "EXTERNAL_WEB_SEARCH_API_KEY": web_config.get("WEB_SEARCH_KEY", request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY), + "EXTERNAL_WEB_LOADER_URL": web_config.get("WEB_LOADER_URL", request.app.state.config.EXTERNAL_WEB_LOADER_URL), + "EXTERNAL_WEB_LOADER_API_KEY": web_config.get("WEB_LOADER_KEY", request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY), + "YOUTUBE_LOADER_LANGUAGE": web_config.get("YOUTUBE_LOADER_LANGUAGE", request.app.state.config.YOUTUBE_LOADER_LANGUAGE), + "YOUTUBE_LOADER_PROXY_URL": web_config.get("YOUTUBE_LOADER_PROXY_URL", request.app.state.config.YOUTUBE_LOADER_PROXY_URL), + "YOUTUBE_LOADER_TRANSLATION": web_config.get("YOUTUBE_LOADER_TRANSLATION", request.app.state.config.YOUTUBE_LOADER_TRANSLATION), }, + "DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS), + "DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS, + "DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS, + "LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS, + "LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS, } @@ -531,392 +695,524 @@ class ConfigForm(BaseModel): # Web search settings web: Optional[WebConfig] = None + # knowledge base ID + knowledge_id: Optional[str] = None @router.post("/config/update") async def update_rag_config( - request: Request, form_data: ConfigForm, user=Depends(get_admin_user) + request: Request, form_data: ConfigForm, user=Depends(get_verified_user) ): - # RAG settings - request.app.state.config.RAG_TEMPLATE = ( - form_data.RAG_TEMPLATE - if form_data.RAG_TEMPLATE is not None - else request.app.state.config.RAG_TEMPLATE - ) - request.app.state.config.TOP_K = ( - form_data.TOP_K - if form_data.TOP_K is not None - else request.app.state.config.TOP_K - ) - request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = ( - form_data.BYPASS_EMBEDDING_AND_RETRIEVAL - if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None - else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL - ) - request.app.state.config.RAG_FULL_CONTEXT = ( - form_data.RAG_FULL_CONTEXT - if form_data.RAG_FULL_CONTEXT is not None - else request.app.state.config.RAG_FULL_CONTEXT - ) + """ + Update the RAG configuration. + If DEFAULT_RAG_SETTINGS is True, update the global configuration. + Otherwise, update the RAG configuration in the database for the user's knowledge base. + """ - # Hybrid search settings - request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( - form_data.ENABLE_RAG_HYBRID_SEARCH - if form_data.ENABLE_RAG_HYBRID_SEARCH is not None - else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH - ) - # Free up memory if hybrid search is disabled - if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: - request.app.state.rf = None + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + # Update the RAG configuration in the database + rag_config = knowledge_base.rag_config - request.app.state.config.TOP_K_RERANKER = ( - form_data.TOP_K_RERANKER - if form_data.TOP_K_RERANKER is not None - else request.app.state.config.TOP_K_RERANKER - ) - request.app.state.config.RELEVANCE_THRESHOLD = ( - form_data.RELEVANCE_THRESHOLD - if form_data.RELEVANCE_THRESHOLD is not None - else request.app.state.config.RELEVANCE_THRESHOLD - ) + # Free up memory if hybrid search is disabled and model is not in use elswhere + in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get("RAG_RERANKING_MODEL"), model_type="RAG_RERANKING_MODEL", id=form_data.knowledge_id) - # Content extraction settings - request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( - form_data.CONTENT_EXTRACTION_ENGINE - if form_data.CONTENT_EXTRACTION_ENGINE is not None - else request.app.state.config.CONTENT_EXTRACTION_ENGINE - ) - request.app.state.config.PDF_EXTRACT_IMAGES = ( - form_data.PDF_EXTRACT_IMAGES - if form_data.PDF_EXTRACT_IMAGES is not None - else request.app.state.config.PDF_EXTRACT_IMAGES - ) - request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = ( - form_data.EXTERNAL_DOCUMENT_LOADER_URL - if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None - else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL - ) - request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = ( - form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY - if form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY is not None - else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY - ) - request.app.state.config.TIKA_SERVER_URL = ( - form_data.TIKA_SERVER_URL - if form_data.TIKA_SERVER_URL is not None - else request.app.state.config.TIKA_SERVER_URL - ) - request.app.state.config.DOCLING_SERVER_URL = ( - form_data.DOCLING_SERVER_URL - if form_data.DOCLING_SERVER_URL is not None - else request.app.state.config.DOCLING_SERVER_URL - ) - request.app.state.config.DOCLING_OCR_ENGINE = ( - form_data.DOCLING_OCR_ENGINE - if form_data.DOCLING_OCR_ENGINE is not None - else request.app.state.config.DOCLING_OCR_ENGINE - ) - request.app.state.config.DOCLING_OCR_LANG = ( - form_data.DOCLING_OCR_LANG - if form_data.DOCLING_OCR_LANG is not None - else request.app.state.config.DOCLING_OCR_LANG - ) + if not form_data.ENABLE_RAG_HYBRID_SEARCH and \ + not in_use and \ + request.app.state.rf.get(rag_config["RAG_RERANKING_MODEL"]): + if rag_config.get("RAG_RERANKING_MODEL"): + del request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]] + engine = request.app.state.config.RAG_RERANKING_ENGINE + target_model = rag_config["RAG_RERANKING_MODEL"] + models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine] - request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = ( - form_data.DOCLING_DO_PICTURE_DESCRIPTION - if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None - else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION - ) + # Find and remove the dictionary that contains the target model + for model_config in models_list[:]: # Create a copy of the list for safe iteration + if model_config["RAG_RERANKING_MODEL"] == target_model: + models_list.remove(model_config) + + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() - request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = ( - form_data.DOCUMENT_INTELLIGENCE_ENDPOINT - if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None - else request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT - ) - request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = ( - form_data.DOCUMENT_INTELLIGENCE_KEY - if form_data.DOCUMENT_INTELLIGENCE_KEY is not None - else request.app.state.config.DOCUMENT_INTELLIGENCE_KEY - ) - request.app.state.config.MISTRAL_OCR_API_KEY = ( - form_data.MISTRAL_OCR_API_KEY - if form_data.MISTRAL_OCR_API_KEY is not None - else request.app.state.config.MISTRAL_OCR_API_KEY - ) + import gc + import torch + gc.collect() + torch.cuda.empty_cache() - # Reranking settings - request.app.state.config.RAG_RERANKING_ENGINE = ( - form_data.RAG_RERANKING_ENGINE - if form_data.RAG_RERANKING_ENGINE is not None - else request.app.state.config.RAG_RERANKING_ENGINE - ) + # Update only the provided fields in the rag_config + for field, value in form_data.model_dump(exclude_unset=True).items(): + if field == "web" and value is not None: + rag_config["web"] = {**rag_config.get("web", {}), **value} + else: + rag_config[field] = value - request.app.state.config.RAG_EXTERNAL_RERANKER_URL = ( - form_data.RAG_EXTERNAL_RERANKER_URL - if form_data.RAG_EXTERNAL_RERANKER_URL is not None - else request.app.state.config.RAG_EXTERNAL_RERANKER_URL - ) - - request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = ( - form_data.RAG_EXTERNAL_RERANKER_API_KEY - if form_data.RAG_EXTERNAL_RERANKER_API_KEY is not None - else request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY - ) - - log.info( - f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" - ) - try: - request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL + log.info( + f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" + ) try: - request.app.state.rf = get_rf( - request.app.state.config.RAG_RERANKING_ENGINE, - request.app.state.config.RAG_RERANKING_MODEL, - request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - True, - ) + try: + if not rag_config["RAG_RERANKING_MODEL"] in request.app.state.rf and not rag_config["RAG_RERANKING_MODEL"] == "": + request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]] = get_rf( + rag_config["RAG_RERANKING_ENGINE"], + rag_config["RAG_RERANKING_MODEL"], + rag_config["RAG_EXTERNAL_RERANKER_URL"], + rag_config["RAG_EXTERNAL_RERANKER_API_KEY"], + True, + ) + + # add model to state for reloading on startup + request.app.state.config.LOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append({ + "RAG_RERANKING_MODEL": rag_config["RAG_RERANKING_MODEL"], + "RAG_EXTERNAL_RERANKER_URL": rag_config["RAG_EXTERNAL_RERANKER_URL"], + "RAG_EXTERNAL_RERANKER_API_KEY": rag_config["RAG_EXTERNAL_RERANKER_API_KEY"]}) + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() + + # add model to state for selectable reranking models + if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]]: + request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append(rag_config["RAG_RERANKING_MODEL"]) + request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save() + + rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS + rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS + + except Exception as e: + log.error(f"Error loading reranking model: {e}") + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False except Exception as e: - log.error(f"Error loading reranking model: {e}") - request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False - except Exception as e: - log.exception(f"Problem updating reranking model: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT(e), + log.exception(f"Problem updating reranking model: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + Knowledges.update_rag_config_by_id( + id=knowledge_base.id, rag_config=rag_config ) - # Chunking settings - request.app.state.config.TEXT_SPLITTER = ( - form_data.TEXT_SPLITTER - if form_data.TEXT_SPLITTER is not None - else request.app.state.config.TEXT_SPLITTER - ) - request.app.state.config.CHUNK_SIZE = ( - form_data.CHUNK_SIZE - if form_data.CHUNK_SIZE is not None - else request.app.state.config.CHUNK_SIZE - ) - request.app.state.config.CHUNK_OVERLAP = ( - form_data.CHUNK_OVERLAP - if form_data.CHUNK_OVERLAP is not None - else request.app.state.config.CHUNK_OVERLAP - ) - - # File upload settings - request.app.state.config.FILE_MAX_SIZE = ( - form_data.FILE_MAX_SIZE - if form_data.FILE_MAX_SIZE is not None - else request.app.state.config.FILE_MAX_SIZE - ) - request.app.state.config.FILE_MAX_COUNT = ( - form_data.FILE_MAX_COUNT - if form_data.FILE_MAX_COUNT is not None - else request.app.state.config.FILE_MAX_COUNT - ) - request.app.state.config.ALLOWED_FILE_EXTENSIONS = ( - form_data.ALLOWED_FILE_EXTENSIONS - if form_data.ALLOWED_FILE_EXTENSIONS is not None - else request.app.state.config.ALLOWED_FILE_EXTENSIONS - ) - - # Integration settings - request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ( - form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION - if form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION is not None - else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION - ) - request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ( - form_data.ENABLE_ONEDRIVE_INTEGRATION - if form_data.ENABLE_ONEDRIVE_INTEGRATION is not None - else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION - ) - - if form_data.web is not None: - # Web search settings - request.app.state.config.ENABLE_WEB_SEARCH = form_data.web.ENABLE_WEB_SEARCH - request.app.state.config.WEB_SEARCH_ENGINE = form_data.web.WEB_SEARCH_ENGINE - request.app.state.config.WEB_SEARCH_TRUST_ENV = ( - form_data.web.WEB_SEARCH_TRUST_ENV - ) - request.app.state.config.WEB_SEARCH_RESULT_COUNT = ( - form_data.web.WEB_SEARCH_RESULT_COUNT - ) - request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = ( - form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS - ) - request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = ( - form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST - ) - request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( - form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL - ) - request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL - request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL - request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME - request.app.state.config.YACY_PASSWORD = form_data.web.YACY_PASSWORD - request.app.state.config.GOOGLE_PSE_API_KEY = form_data.web.GOOGLE_PSE_API_KEY - request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( - form_data.web.GOOGLE_PSE_ENGINE_ID - ) - request.app.state.config.BRAVE_SEARCH_API_KEY = ( - form_data.web.BRAVE_SEARCH_API_KEY - ) - request.app.state.config.KAGI_SEARCH_API_KEY = form_data.web.KAGI_SEARCH_API_KEY - request.app.state.config.MOJEEK_SEARCH_API_KEY = ( - form_data.web.MOJEEK_SEARCH_API_KEY - ) - request.app.state.config.BOCHA_SEARCH_API_KEY = ( - form_data.web.BOCHA_SEARCH_API_KEY - ) - request.app.state.config.SERPSTACK_API_KEY = form_data.web.SERPSTACK_API_KEY - request.app.state.config.SERPSTACK_HTTPS = form_data.web.SERPSTACK_HTTPS - request.app.state.config.SERPER_API_KEY = form_data.web.SERPER_API_KEY - request.app.state.config.SERPLY_API_KEY = form_data.web.SERPLY_API_KEY - request.app.state.config.TAVILY_API_KEY = form_data.web.TAVILY_API_KEY - request.app.state.config.SEARCHAPI_API_KEY = form_data.web.SEARCHAPI_API_KEY - request.app.state.config.SEARCHAPI_ENGINE = form_data.web.SEARCHAPI_ENGINE - request.app.state.config.SERPAPI_API_KEY = form_data.web.SERPAPI_API_KEY - request.app.state.config.SERPAPI_ENGINE = form_data.web.SERPAPI_ENGINE - request.app.state.config.JINA_API_KEY = form_data.web.JINA_API_KEY - request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( - form_data.web.BING_SEARCH_V7_ENDPOINT - ) - request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( - form_data.web.BING_SEARCH_V7_SUBSCRIPTION_KEY - ) - request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY - request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY - request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID - request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK - - # Web loader settings - request.app.state.config.WEB_LOADER_ENGINE = form_data.web.WEB_LOADER_ENGINE - request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ( - form_data.web.ENABLE_WEB_LOADER_SSL_VERIFICATION - ) - request.app.state.config.PLAYWRIGHT_WS_URL = form_data.web.PLAYWRIGHT_WS_URL - request.app.state.config.PLAYWRIGHT_TIMEOUT = form_data.web.PLAYWRIGHT_TIMEOUT - request.app.state.config.FIRECRAWL_API_KEY = form_data.web.FIRECRAWL_API_KEY - request.app.state.config.FIRECRAWL_API_BASE_URL = ( - form_data.web.FIRECRAWL_API_BASE_URL - ) - request.app.state.config.EXTERNAL_WEB_SEARCH_URL = ( - form_data.web.EXTERNAL_WEB_SEARCH_URL - ) - request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = ( - form_data.web.EXTERNAL_WEB_SEARCH_API_KEY - ) - request.app.state.config.EXTERNAL_WEB_LOADER_URL = ( - form_data.web.EXTERNAL_WEB_LOADER_URL - ) - request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = ( - form_data.web.EXTERNAL_WEB_LOADER_API_KEY - ) - request.app.state.config.TAVILY_EXTRACT_DEPTH = ( - form_data.web.TAVILY_EXTRACT_DEPTH - ) - request.app.state.config.YOUTUBE_LOADER_LANGUAGE = ( - form_data.web.YOUTUBE_LOADER_LANGUAGE - ) - request.app.state.config.YOUTUBE_LOADER_PROXY_URL = ( - form_data.web.YOUTUBE_LOADER_PROXY_URL - ) - request.app.state.YOUTUBE_LOADER_TRANSLATION = ( - form_data.web.YOUTUBE_LOADER_TRANSLATION - ) - - return { - "status": True, + return rag_config + else: + # Update the global configuration # RAG settings - "RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE, - "TOP_K": request.app.state.config.TOP_K, - "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, - "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, + request.app.state.config.RAG_TEMPLATE = ( + form_data.RAG_TEMPLATE + if form_data.RAG_TEMPLATE is not None + else request.app.state.config.RAG_TEMPLATE + ) + request.app.state.config.TOP_K = ( + form_data.TOP_K + if form_data.TOP_K is not None + else request.app.state.config.TOP_K + ) + request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = ( + form_data.BYPASS_EMBEDDING_AND_RETRIEVAL + if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None + else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ) + request.app.state.config.RAG_FULL_CONTEXT = ( + form_data.RAG_FULL_CONTEXT + if form_data.RAG_FULL_CONTEXT is not None + else request.app.state.config.RAG_FULL_CONTEXT + ) + # Hybrid search settings - "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, - "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( + form_data.ENABLE_RAG_HYBRID_SEARCH + if form_data.ENABLE_RAG_HYBRID_SEARCH is not None + else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH + ) + + # Free up memory if hybrid search is disabled and model is not in use elswhere + in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_RERANKING_MODEL, model_type="RAG_RERANKING_MODEL") + + if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and \ + not in_use and \ + request.app.state.rf.get(request.app.state.config.RAG_RERANKING_MODEL): + del request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] + engine = request.app.state.config.RAG_RERANKING_ENGINE + target_model = request.app.state.config.RAG_RERANKING_MODEL + models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine] + + # Find and remove the dictionary that contains the target model + for model_config in models_list[:]: # Create a copy of the list for safe iteration + if model_config["RAG_RERANKING_MODEL"] == target_model: + models_list.remove(model_config) + + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() + + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + + request.app.state.config.TOP_K_RERANKER = ( + form_data.TOP_K_RERANKER + if form_data.TOP_K_RERANKER is not None + else request.app.state.config.TOP_K_RERANKER + ) + request.app.state.config.RELEVANCE_THRESHOLD = ( + form_data.RELEVANCE_THRESHOLD + if form_data.RELEVANCE_THRESHOLD is not None + else request.app.state.config.RELEVANCE_THRESHOLD + ) + # Content extraction settings - "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, - "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, - "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, - "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, - "DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE, - "DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG, - "DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, - "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( + form_data.CONTENT_EXTRACTION_ENGINE + if form_data.CONTENT_EXTRACTION_ENGINE is not None + else request.app.state.config.CONTENT_EXTRACTION_ENGINE + ) + request.app.state.config.PDF_EXTRACT_IMAGES = ( + form_data.PDF_EXTRACT_IMAGES + if form_data.PDF_EXTRACT_IMAGES is not None + else request.app.state.config.PDF_EXTRACT_IMAGES + ) + request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = ( + form_data.EXTERNAL_DOCUMENT_LOADER_URL + if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None + else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL + ) + request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = ( + form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY + if form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY is not None + else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY + ) + request.app.state.config.TIKA_SERVER_URL = ( + form_data.TIKA_SERVER_URL + if form_data.TIKA_SERVER_URL is not None + else request.app.state.config.TIKA_SERVER_URL + ) + request.app.state.config.DOCLING_SERVER_URL = ( + form_data.DOCLING_SERVER_URL + if form_data.DOCLING_SERVER_URL is not None + else request.app.state.config.DOCLING_SERVER_URL + ) + request.app.state.config.DOCLING_OCR_ENGINE = ( + form_data.DOCLING_OCR_ENGINE + if form_data.DOCLING_OCR_ENGINE is not None + else request.app.state.config.DOCLING_OCR_ENGINE + ) + request.app.state.config.DOCLING_OCR_LANG = ( + form_data.DOCLING_OCR_LANG + if form_data.DOCLING_OCR_LANG is not None + else request.app.state.config.DOCLING_OCR_LANG + ) + + request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = ( + form_data.DOCLING_DO_PICTURE_DESCRIPTION + if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None + else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION + ) + + request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = ( + form_data.DOCUMENT_INTELLIGENCE_ENDPOINT + if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None + else request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT + ) + request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = ( + form_data.DOCUMENT_INTELLIGENCE_KEY + if form_data.DOCUMENT_INTELLIGENCE_KEY is not None + else request.app.state.config.DOCUMENT_INTELLIGENCE_KEY + ) + request.app.state.config.MISTRAL_OCR_API_KEY = ( + form_data.MISTRAL_OCR_API_KEY + if form_data.MISTRAL_OCR_API_KEY is not None + else request.app.state.config.MISTRAL_OCR_API_KEY + ) + # Reranking settings - "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, - "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, - "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + request.app.state.config.RAG_RERANKING_ENGINE = ( + form_data.RAG_RERANKING_ENGINE + if form_data.RAG_RERANKING_ENGINE is not None + else request.app.state.config.RAG_RERANKING_ENGINE + ) + + request.app.state.config.RAG_EXTERNAL_RERANKER_URL = ( + form_data.RAG_EXTERNAL_RERANKER_URL + if form_data.RAG_EXTERNAL_RERANKER_URL is not None + else request.app.state.config.RAG_EXTERNAL_RERANKER_URL + ) + + request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = ( + form_data.RAG_EXTERNAL_RERANKER_API_KEY + if form_data.RAG_EXTERNAL_RERANKER_API_KEY is not None + else request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY + ) + + + log.info( + f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" + ) + try: + request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL + + try: + if not request.app.state.config.RAG_RERANKING_MODEL in request.app.state.rf and not request.app.state.config.RAG_RERANKING_MODEL == "": + request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + True, + ) + + # add model to state for reloading on startup + request.app.state.config.LOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append({ + "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, + "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY + }) + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() + + # add model to state for selectable reranking models + if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE]: + request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append(request.app.state.config.RAG_RERANKING_MODEL) + request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save() + + rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS + rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS + + + except Exception as e: + log.error(f"Error loading reranking model: {e}") + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + except Exception as e: + log.exception(f"Problem updating reranking model: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + # Chunking settings - "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, - "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, - "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, + request.app.state.config.TEXT_SPLITTER = ( + form_data.TEXT_SPLITTER + if form_data.TEXT_SPLITTER is not None + else request.app.state.config.TEXT_SPLITTER + ) + request.app.state.config.CHUNK_SIZE = ( + form_data.CHUNK_SIZE + if form_data.CHUNK_SIZE is not None + else request.app.state.config.CHUNK_SIZE + ) + request.app.state.config.CHUNK_OVERLAP = ( + form_data.CHUNK_OVERLAP + if form_data.CHUNK_OVERLAP is not None + else request.app.state.config.CHUNK_OVERLAP + ) + # File upload settings - "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, - "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, - "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, + request.app.state.config.FILE_MAX_SIZE = ( + form_data.FILE_MAX_SIZE + if form_data.FILE_MAX_SIZE is not None + else request.app.state.config.FILE_MAX_SIZE + ) + request.app.state.config.FILE_MAX_COUNT = ( + form_data.FILE_MAX_COUNT + if form_data.FILE_MAX_COUNT is not None + else request.app.state.config.FILE_MAX_COUNT + ) + request.app.state.config.ALLOWED_FILE_EXTENSIONS = ( + form_data.ALLOWED_FILE_EXTENSIONS + if form_data.ALLOWED_FILE_EXTENSIONS is not None + else request.app.state.config.ALLOWED_FILE_EXTENSIONS + ) + # Integration settings - "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, - "ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, - # Web search settings - "web": { - "ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH, - "WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE, - "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, - "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, - "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, - "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, - "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, - "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, - "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, - "YACY_USERNAME": request.app.state.config.YACY_USERNAME, - "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, - "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, - "GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID, - "BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY, - "KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY, - "MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY, - "BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY, - "SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY, - "SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS, - "SERPER_API_KEY": request.app.state.config.SERPER_API_KEY, - "SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY, - "TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY, - "SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY, - "SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE, - "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, - "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, - "JINA_API_KEY": request.app.state.config.JINA_API_KEY, - "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, - "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "EXA_API_KEY": request.app.state.config.EXA_API_KEY, - "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, - "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, - "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, - "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, - "ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - "PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL, - "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, - "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, - "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, - "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, - "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, - "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, - "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, - "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, - "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, - "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, - "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, - }, - } + request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ( + form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION + if form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION is not None + else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION + ) + request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ( + form_data.ENABLE_ONEDRIVE_INTEGRATION + if form_data.ENABLE_ONEDRIVE_INTEGRATION is not None + else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION + ) + + if form_data.web is not None: + # Web search settings + request.app.state.config.ENABLE_WEB_SEARCH = form_data.web.ENABLE_WEB_SEARCH + request.app.state.config.WEB_SEARCH_ENGINE = form_data.web.WEB_SEARCH_ENGINE + request.app.state.config.WEB_SEARCH_TRUST_ENV = ( + form_data.web.WEB_SEARCH_TRUST_ENV + ) + request.app.state.config.WEB_SEARCH_RESULT_COUNT = ( + form_data.web.WEB_SEARCH_RESULT_COUNT + ) + request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = ( + form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS + ) + request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = ( + form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST + ) + request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( + form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL + ) + request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL + request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL + request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME + request.app.state.config.YACY_PASSWORD = form_data.web.YACY_PASSWORD + request.app.state.config.GOOGLE_PSE_API_KEY = form_data.web.GOOGLE_PSE_API_KEY + request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( + form_data.web.GOOGLE_PSE_ENGINE_ID + ) + request.app.state.config.BRAVE_SEARCH_API_KEY = ( + form_data.web.BRAVE_SEARCH_API_KEY + ) + request.app.state.config.KAGI_SEARCH_API_KEY = form_data.web.KAGI_SEARCH_API_KEY + request.app.state.config.MOJEEK_SEARCH_API_KEY = ( + form_data.web.MOJEEK_SEARCH_API_KEY + ) + request.app.state.config.BOCHA_SEARCH_API_KEY = ( + form_data.web.BOCHA_SEARCH_API_KEY + ) + request.app.state.config.SERPSTACK_API_KEY = form_data.web.SERPSTACK_API_KEY + request.app.state.config.SERPSTACK_HTTPS = form_data.web.SERPSTACK_HTTPS + request.app.state.config.SERPER_API_KEY = form_data.web.SERPER_API_KEY + request.app.state.config.SERPLY_API_KEY = form_data.web.SERPLY_API_KEY + request.app.state.config.TAVILY_API_KEY = form_data.web.TAVILY_API_KEY + request.app.state.config.SEARCHAPI_API_KEY = form_data.web.SEARCHAPI_API_KEY + request.app.state.config.SEARCHAPI_ENGINE = form_data.web.SEARCHAPI_ENGINE + request.app.state.config.SERPAPI_API_KEY = form_data.web.SERPAPI_API_KEY + request.app.state.config.SERPAPI_ENGINE = form_data.web.SERPAPI_ENGINE + request.app.state.config.JINA_API_KEY = form_data.web.JINA_API_KEY + request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( + form_data.web.BING_SEARCH_V7_ENDPOINT + ) + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( + form_data.web.BING_SEARCH_V7_SUBSCRIPTION_KEY + ) + request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY + request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY + request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID + request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK + + # Web loader settings + request.app.state.config.WEB_LOADER_ENGINE = form_data.web.WEB_LOADER_ENGINE + request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ( + form_data.web.ENABLE_WEB_LOADER_SSL_VERIFICATION + ) + request.app.state.config.PLAYWRIGHT_WS_URL = form_data.web.PLAYWRIGHT_WS_URL + request.app.state.config.PLAYWRIGHT_TIMEOUT = form_data.web.PLAYWRIGHT_TIMEOUT + request.app.state.config.FIRECRAWL_API_KEY = form_data.web.FIRECRAWL_API_KEY + request.app.state.config.FIRECRAWL_API_BASE_URL = ( + form_data.web.FIRECRAWL_API_BASE_URL + ) + request.app.state.config.EXTERNAL_WEB_SEARCH_URL = ( + form_data.web.EXTERNAL_WEB_SEARCH_URL + ) + request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = ( + form_data.web.EXTERNAL_WEB_SEARCH_API_KEY + ) + request.app.state.config.EXTERNAL_WEB_LOADER_URL = ( + form_data.web.EXTERNAL_WEB_LOADER_URL + ) + request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = ( + form_data.web.EXTERNAL_WEB_LOADER_API_KEY + ) + request.app.state.config.TAVILY_EXTRACT_DEPTH = ( + form_data.web.TAVILY_EXTRACT_DEPTH + ) + request.app.state.config.YOUTUBE_LOADER_LANGUAGE = ( + form_data.web.YOUTUBE_LOADER_LANGUAGE + ) + request.app.state.config.YOUTUBE_LOADER_PROXY_URL = ( + form_data.web.YOUTUBE_LOADER_PROXY_URL + ) + request.app.state.YOUTUBE_LOADER_TRANSLATION = ( + form_data.web.YOUTUBE_LOADER_TRANSLATION + ) + + return { + "status": True, + # RAG settings + "RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE, + "TOP_K": request.app.state.config.TOP_K, + "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, + "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, + # Hybrid search settings + "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, + "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, + # Content extraction settings + "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, + "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, + "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, + "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, + "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, + "DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE, + "DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG, + "DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, + "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + # Reranking settings + "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, + "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, + "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + # Chunking settings + "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, + "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, + "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, + # File upload settings + "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, + "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, + "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, + # Integration settings + "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + "ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + # Web search settings + "web": { + "ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH, + "WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE, + "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, + "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, + "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, + "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, + "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, + "YACY_USERNAME": request.app.state.config.YACY_USERNAME, + "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, + "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, + "GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY, + "KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY, + "MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY, + "SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY, + "SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS, + "SERPER_API_KEY": request.app.state.config.SERPER_API_KEY, + "SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY, + "TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY, + "SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY, + "SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE, + "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, + "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, + "JINA_API_KEY": request.app.state.config.JINA_API_KEY, + "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "EXA_API_KEY": request.app.state.config.EXA_API_KEY, + "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, + "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, + "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, + "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, + "ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, + "PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL, + "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, + "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, + "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, + "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, + "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, + "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, + "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, + "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, + "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, + }, + "DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS + } #################################### @@ -935,6 +1231,7 @@ def save_docs_to_vector_db( split: bool = True, add: bool = False, user=None, + knowledge_id: Optional[str] = None ) -> bool: def _get_docs_info(docs: list[Document]) -> str: docs_info = set() @@ -955,7 +1252,27 @@ def save_docs_to_vector_db( log.info( f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" ) + + rag_config = {} + # Retrieve the knowledge base using the collection_name + if knowledge_id: + knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id) + # Retrieve the RAG configuration + if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + # Use knowledge-base-specific or default configurations + text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER) + chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE) + chunk_overlap = rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP) + embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE) + embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL) + embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE) + openai_api_base_url = rag_config.get("openai_api_base_url", request.app.state.config.RAG_OPENAI_API_BASE_URL) + openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY) + ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL) + ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY) + # Check if entries with the same hash (metadata.hash) already exist if metadata and "hash" in metadata: result = VECTOR_DB_CLIENT.query( @@ -970,13 +1287,13 @@ def save_docs_to_vector_db( raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: - if request.app.state.config.TEXT_SPLITTER in ["", "character"]: + if text_splitter_type in ["", "character"]: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=request.app.state.config.CHUNK_SIZE, - chunk_overlap=request.app.state.config.CHUNK_OVERLAP, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, add_start_index=True, ) - elif request.app.state.config.TEXT_SPLITTER == "token": + elif text_splitter_type == "token": log.info( f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" ) @@ -984,8 +1301,8 @@ def save_docs_to_vector_db( tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME)) text_splitter = TokenTextSplitter( encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME), - chunk_size=request.app.state.config.CHUNK_SIZE, - chunk_overlap=request.app.state.config.CHUNK_OVERLAP, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, add_start_index=True, ) else: @@ -1003,8 +1320,8 @@ def save_docs_to_vector_db( **(metadata if metadata else {}), "embedding_config": json.dumps( { - "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, - "model": request.app.state.config.RAG_EMBEDDING_MODEL, + "engine": embedding_engine, + "model": embedding_model, } ), } @@ -1037,20 +1354,20 @@ def save_docs_to_vector_db( log.info(f"adding to collection {collection_name}") embedding_function = get_embedding_function( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.ef, + embedding_engine, + embedding_model, + request.app.state.ef.get(embedding_model, request.app.state.config.RAG_EMBEDDING_MODEL), ( - request.app.state.config.RAG_OPENAI_API_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_BASE_URL + openai_api_base_url + if embedding_engine == "openai" + else ollama_base_url ), ( - request.app.state.config.RAG_OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_API_KEY + openai_api_key + if embedding_engine == "openai" + else ollama_api_key ), - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + embedding_batch_size, ) embeddings = embedding_function( @@ -1084,6 +1401,7 @@ class ProcessFileForm(BaseModel): file_id: str content: Optional[str] = None collection_name: Optional[str] = None + knowledge_id: Optional[str] = None @router.post("/process/file") @@ -1099,6 +1417,61 @@ def process_file( if collection_name is None: collection_name = f"file-{file.id}" + + rag_config = {} + # Retrieve the knowledge base using the collection id - knowledge_id == collection_name (minimal working solution without logic changes) + if form_data.collection_name: + knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name) + + # Retrieve the RAG configuration + if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db + + elif form_data.knowledge_id: + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) + + # Retrieve the RAG configuration + if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + + # Use knowledge-base-specific or default configurations + content_extraction_engine = rag_config.get( + "CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE + ) + external_document_loader_url = rag_config.get( + "EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL + ) + external_document_loader_api_key = rag_config.get( + "EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY + ) + tika_server_url = rag_config.get( + "TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL + ) + docling_server_url = rag_config.get( + "DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL + ) + docling_ocr_engine=rag_config.get( + "DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE + ) + docling_ocr_lang=rag_config.get( + "DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG + ) + docling_do_picture_description=rag_config.get( + "DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION + ) + pdf_extract_images = rag_config.get( + "PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES + ) + document_intelligence_endpoint = rag_config.get( + "DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT + ) + document_intelligence_key = rag_config.get( + "DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY + ) + mistral_ocr_api_key = rag_config.get( + "MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY + ) if form_data.content: # Update the content in the file @@ -1163,18 +1536,18 @@ def process_file( if file_path: file_path = Storage.get_file(file_path) loader = Loader( - engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, - EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, - DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL, - DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE, - DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG, - DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, - PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, - DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY, + engine=content_extraction_engine, + EXTERNAL_DOCUMENT_LOADER_URL=external_document_loader_url, + EXTERNAL_DOCUMENT_LOADER_API_KEY=external_document_loader_api_key, + TIKA_SERVER_URL=tika_server_url, + DOCLING_SERVER_URL=docling_server_url, + DOCLING_OCR_ENGINE=docling_ocr_engine, + DOCLING_OCR_LANG=docling_ocr_lang, + DOCLING_DO_PICTURE_DESCRIPTION=docling_do_picture_description, + PDF_EXTRACT_IMAGES=pdf_extract_images, + DOCUMENT_INTELLIGENCE_ENDPOINT=document_intelligence_endpoint, + DOCUMENT_INTELLIGENCE_KEY=document_intelligence_key, + MISTRAL_OCR_API_KEY=mistral_ocr_api_key, ) docs = loader.load( file.filename, file.meta.get("content_type"), file_path @@ -1217,7 +1590,7 @@ def process_file( hash = calculate_sha256_string(text_content) Files.update_file_hash_by_id(file.id, hash) - if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: + if not rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL): try: result = save_docs_to_vector_db( request, @@ -1230,6 +1603,7 @@ def process_file( }, add=(True if form_data.collection_name else False), user=user, + knowledge_id=form_data.knowledge_id ) if result: @@ -1280,7 +1654,7 @@ class ProcessTextForm(BaseModel): def process_text( request: Request, form_data: ProcessTextForm, - user=Depends(get_verified_user), + user=Depends(get_verified_user) ): collection_name = form_data.collection_name if collection_name is None: @@ -1762,11 +2136,11 @@ def query_doc_handler( collection_name=form_data.collection_name, collection_result=collection_results[form_data.collection_name], query=form_data.query, - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, + reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL], k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, r=( @@ -1779,7 +2153,7 @@ def query_doc_handler( else: return query_doc( collection_name=form_data.collection_name, - query_embedding=request.app.state.EMBEDDING_FUNCTION( + query_embedding=request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, @@ -1813,11 +2187,11 @@ def query_collection_handler( return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, + reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL], k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, r=( @@ -1830,7 +2204,7 @@ def query_collection_handler( return query_collection( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K,