diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 898ac1b59..3a346b211 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2336,6 +2336,44 @@ YOUTUBE_LOADER_PROXY_URL = PersistentConfig( os.getenv("YOUTUBE_LOADER_PROXY_URL", ""), ) +DEFAULT_RAG_SETTINGS = PersistentConfig( + "DEFAULT_RAG_SETTINGS", + "rag.default_settings", + os.getenv("DEFAULT_RAG_SETTINGS", "True").lower() == "true", +) + +DOWNLOADED_EMBEDDING_MODELS = PersistentConfig( + "DOWNLOADED_EMBEDDING_MODELS", + "rag.downloaded_embedding_models", + os.getenv("DOWNLOADED_EMBEDDING_MODELS", {"":["sentence-transformers/all-MiniLM-L6-v2"], + "openai":["text-embedding-3-small"], + "ollama":[], + "azure_openai": []}) +) + +DOWNLOADED_RERANKING_MODELS = PersistentConfig( + "DOWNLOADED_RERANKING_MODELS", + "rag.downloaded_reranking_models", + os.getenv("DOWNLOADED_RERANKING_MODELS", {"":[], + "external":[]}) +) + +LOADED_EMBEDDING_MODELS = PersistentConfig( + "LOADED_EMBEDDING_MODELS", + "rag.loaded_embedding_models", + os.getenv("LOADED_EMBEDDING_MODELS", {"":["sentence-transformers/all-MiniLM-L6-v2"], + "openai":[], + "ollama":[], + "azure_openai": []}) +) + +LOADED_RERANKING_MODELS = PersistentConfig( + "LOADED_RERANKING_MODELS", + "rag.loaded_reranking_models", + os.getenv("LOADED_RERANKING_MODELS", {"":[], + "external":[]}) +) + #################################### # Web Search (RAG) diff --git a/backend/open_webui/internal/migrations/019_add_rag_config_to_knowledge.py b/backend/open_webui/internal/migrations/019_add_rag_config_to_knowledge.py new file mode 100644 index 000000000..e449b57c4 --- /dev/null +++ b/backend/open_webui/internal/migrations/019_add_rag_config_to_knowledge.py @@ -0,0 +1,84 @@ +"""Peewee migrations -- 019_add_rag_config_to_knowledge.py. + Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" +"""Add rag_config field to knowledge table if not present.""" +from contextlib import suppress +from peewee_migrate import Migrator +import peewee as pw +import json + +# Try importing JSONField from playhouse.postgres_ext +with suppress(ImportError): + from playhouse.postgres_ext import JSONField as PostgresJSONField + + +# Fallback JSONField for SQLite (stores JSON as text) +class SQLiteJSONField(pw.TextField): + def db_value(self, value): + return json.dumps(value) + + def python_value(self, value): + if value is not None: + return json.loads(value) + return None + + +def get_compatible_json_field(database: pw.Database): + """Return a JSON-compatible field for the current database.""" + if isinstance(database, pw.SqliteDatabase): + return SQLiteJSONField(null=False, default={"DEFAULT_RAG_SETTINGS": True}) + else: + return PostgresJSONField(null=False, default={"DEFAULT_RAG_SETTINGS": True}) + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Add rag_config JSON field to knowledge table""" + if 'knowledge' not in database.get_tables(): + print("Knowledge table hasn't been created yet, skipping migration.") + return + + class Knowledge(pw.Model): + class Meta: + table_name = 'knowledge' + + Knowledge._meta.database = database # bind DB + + migrator.add_fields( + Knowledge, + rag_config=get_compatible_json_field(database) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Remove rag_config field from knowledge table.""" + if 'knowledge' not in database.get_tables(): + print("Knowledge table hasn't been created yet, skipping migration.") + return + + class Knowledge(pw.Model): + class Meta: + table_name = 'knowledge' + + Knowledge._meta.database = database + migrator.remove_fields(Knowledge, 'rag_config') diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8e37d9e53..375c72312 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -250,6 +250,11 @@ from open_webui.config import ( PDF_EXTRACT_IMAGES, YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_PROXY_URL, + DEFAULT_RAG_SETTINGS, + DOWNLOADED_EMBEDDING_MODELS, + DOWNLOADED_RERANKING_MODELS, + LOADED_EMBEDDING_MODELS, + LOADED_RERANKING_MODELS, # Retrieval (Web Search) ENABLE_WEB_SEARCH, WEB_SEARCH_ENGINE, @@ -836,6 +841,11 @@ app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = EXTERNAL_WEB_SEARCH_API_KEY app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY +app.state.config.DEFAULT_RAG_SETTINGS = DEFAULT_RAG_SETTINGS +app.state.config.DOWNLOADED_EMBEDDING_MODELS = DOWNLOADED_EMBEDDING_MODELS +app.state.config.DOWNLOADED_RERANKING_MODELS = DOWNLOADED_RERANKING_MODELS +app.state.config.LOADED_EMBEDDING_MODELS = LOADED_EMBEDDING_MODELS +app.state.config.LOADED_RERANKING_MODELS = LOADED_RERANKING_MODELS app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT @@ -843,62 +853,61 @@ app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH -app.state.EMBEDDING_FUNCTION = None -app.state.ef = None -app.state.rf = None +app.state.EMBEDDING_FUNCTION = {} +app.state.ef = {} +app.state.rf = {} app.state.YOUTUBE_LOADER_TRANSLATION = None try: - app.state.ef = get_ef( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, - ) + # Load all embedding models that are currently in use + for engine, model_list in app.state.config.LOADED_EMBEDDING_MODELS.items(): + for model in model_list: + if engine == "azure_openai": + # For Azure OpenAI, model is a dict: {model_name: version} + model_name, azure_openai_api_version = next(iter(model.items())) + model = model_name + + app.state.ef[model] = get_ef( + engine, + model, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + ) + app.state.EMBEDDING_FUNCTION[model] = get_embedding_function( + engine, + model, + app.state.ef[model], + ( + app.state.config.RAG_OPENAI_API_BASE_URL + if engine == "openai" + else app.state.config.RAG_OLLAMA_BASE_URL + ), + ( + app.state.config.RAG_OPENAI_API_KEY + if engine == "openai" + else app.state.config.RAG_OLLAMA_API_KEY + ), + app.state.config.RAG_EMBEDDING_BATCH_SIZE, + azure_api_version=( + app.state.config.RAG_AZURE_OPENAI_API_VERSION + if engine == "azure_openai" + else None + ), + ) + # Load all reranking models that are currently in use + for engine, model_list in app.state.config.LOADED_RERANKING_MODELS.items(): + for model in model_list: + app.state.rf[model["RAG_RERANKING_MODEL"]] = get_rf( + engine, + model["RAG_RERANKING_MODEL"], + model["RAG_EXTERNAL_RERANKER_URL"], + model["RAG_EXTERNAL_RERANKER_API_KEY"], + ) - app.state.rf = get_rf( - app.state.config.RAG_RERANKING_ENGINE, - app.state.config.RAG_RERANKING_MODEL, - app.state.config.RAG_EXTERNAL_RERANKER_URL, - app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - RAG_RERANKING_MODEL_AUTO_UPDATE, - ) except Exception as e: log.error(f"Error updating models: {e}") pass - - -app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.ef, - ( - app.state.config.RAG_OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else ( - app.state.config.RAG_OLLAMA_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" - else app.state.config.RAG_AZURE_OPENAI_BASE_URL - ) - ), - ( - app.state.config.RAG_OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else ( - app.state.config.RAG_OLLAMA_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" - else app.state.config.RAG_AZURE_OPENAI_API_KEY - ) - ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, - azure_api_version=( - app.state.config.RAG_AZURE_OPENAI_API_VERSION - if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" - else None - ), -) - ######################################## # # CODE EXECUTION diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index bed3d5542..3a72e3713 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -35,6 +35,7 @@ class Knowledge(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) + rag_config = Column(JSON, nullable=True) # Configuration for RAG (Retrieval-Augmented Generation) model. access_control = Column(JSON, nullable=True) # Controls data access levels. # Defines access control rules for this entry. @@ -68,6 +69,7 @@ class KnowledgeModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None + rag_config: Optional[dict] = None # Configuration for RAG (Retrieval-Augmented Generation) model. access_control: Optional[dict] = None @@ -97,6 +99,7 @@ class KnowledgeForm(BaseModel): description: str data: Optional[dict] = None access_control: Optional[dict] = None + rag_config: Optional[dict] = {} class KnowledgeTable: @@ -217,5 +220,49 @@ class KnowledgeTable: except Exception: return False + def update_rag_config_by_id( + self, id: str, rag_config: dict + ) -> Optional[KnowledgeModel]: + try: + with get_db() as db: + knowledge = self.get_knowledge_by_id(id=id) + db.query(Knowledge).filter_by(id=id).update( + { + "rag_config": rag_config, + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_knowledge_by_id(id=id) + except Exception as e: + log.exception(e) + return None + + def is_model_in_use_elsewhere( + self, model: str, model_type: str, id: Optional[str] = None + ) -> bool: + try: + from sqlalchemy import func + with get_db() as db: + if db.bind.dialect.name == "sqlite": + query = db.query(Knowledge).filter( + func.json_extract(Knowledge.rag_config, f'$.{model_type}') == model + ) + elif db.bind.dialect.name == "postgresql": + query = db.query(Knowledge).filter( + Knowledge.rag_config.op("->>")(model_type) == model, + ) + else: + raise NotImplementedError( + f"Unsupported dialect: {db.bind.dialect.name}" + ) + if id: + query = query.filter(Knowledge.id != id) + + return query.first() is not None + + except Exception as e: + log.exception(f"Error checking model usage elsewhere: {e}") + return False Knowledges = KnowledgeTable() diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 00dd68306..8861dee5a 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -270,13 +270,15 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict: def query_collection( collection_names: list[str], queries: list[str], - embedding_function, + user, + ef, + embedding_model, k: int, ) -> dict: results = [] error = False - def process_query_collection(collection_name, query_embedding): + def process_query_collection(collection_name, query_embedding, k): try: if collection_name: result = query_doc( @@ -291,18 +293,30 @@ def query_collection( log.exception(f"Error when querying the collection: {e}") return None, e - # Generate all query embeddings (in one call) - query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) log.debug( f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" ) + from open_webui.models.knowledge import Knowledges with ThreadPoolExecutor() as executor: future_results = [] - for query_embedding in query_embeddings: - for collection_name in collection_names: + for collection_name in collection_names: + rag_config = {} + knowledge_base = Knowledges.get_knowledge_by_id(collection_name) + + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + embedding_model = rag_config.get("embedding_model", embedding_model) + k = rag_config.get("TOP_K", k) + + embedding_function=lambda query, prefix: ef[embedding_model]( + query, prefix=prefix, user=user + ) + # Generate embeddings for each query using the collection's embedding function + query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) + for query_embedding in query_embeddings: result = executor.submit( - process_query_collection, collection_name, query_embedding + process_query_collection, collection_name, query_embedding, k ) future_results.append(result) task_results = [future.result() for future in future_results] @@ -322,12 +336,14 @@ def query_collection( def query_collection_with_hybrid_search( collection_names: list[str], queries: list[str], - embedding_function, + user, + ef, k: int, reranking_function, k_reranker: int, r: float, hybrid_bm25_weight: float, + embedding_model: str, ) -> dict: results = [] error = False @@ -352,13 +368,32 @@ def query_collection_with_hybrid_search( def process_query(collection_name, query): try: + from open_webui.models.knowledge import Knowledges + + # Use Knowledges to get per-collection RAG config + knowledge_base = Knowledges.get_knowledge_by_id(collection_name) + + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + # Use config from rag_config if present, else fallback to global config + embedding_model = rag_config.get("embedding_model", embedding_model) + reranking_model = rag_config.get("reranking_function", reranking_model) + k = rag_config.get("TOP_K", k) + k_reranker = rag_config.get("TOP_K_RERANKER", k_reranker) + r = rag_config.get("RELEVANCE_THRESHOLD", r) + hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", hybrid_bm25_weight) + + embedding_function=lambda query, prefix: ef[embedding_model]( + query, prefix=prefix, user=user + ), + result = query_doc_with_hybrid_search( collection_name=collection_name, collection_result=collection_results[collection_name], query=query, embedding_function=embedding_function, k=k, - reranking_function=reranking_function, + reranking_function=reranking_function[reranking_model], k_reranker=k_reranker, r=r, hybrid_bm25_weight=hybrid_bm25_weight, @@ -446,7 +481,8 @@ def get_sources_from_files( request, files, queries, - embedding_function, + user, + ef, k, reranking_function, k_reranker, @@ -454,9 +490,10 @@ def get_sources_from_files( hybrid_bm25_weight, hybrid_search, full_context=False, + embedding_model=None ): log.debug( - f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}" + f"files: {files} {queries} {ef[embedding_model]} {reranking_function} {full_context}" ) extracted_collections = [] @@ -564,12 +601,14 @@ def get_sources_from_files( context = query_collection_with_hybrid_search( collection_names=collection_names, queries=queries, - embedding_function=embedding_function, + user=user, + ef=ef, k=k, reranking_function=reranking_function, k_reranker=k_reranker, r=r, hybrid_bm25_weight=hybrid_bm25_weight, + embedding_model=embedding_model, ) except Exception as e: log.debug( @@ -581,8 +620,10 @@ def get_sources_from_files( context = query_collection( collection_names=collection_names, queries=queries, - embedding_function=embedding_function, + user=user, + ef=ef, k=k, + embedding_model=embedding_model ) except Exception as e: log.exception(e) diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index bdf5780fc..8e9ddf89d 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -17,6 +17,7 @@ from fastapi import ( UploadFile, status, Query, + Form ) from fastapi.responses import FileResponse, StreamingResponse from open_webui.constants import ERROR_MESSAGES @@ -90,6 +91,7 @@ def upload_file( process: bool = Query(True), internal: bool = False, user=Depends(get_verified_user), + knowledge_id: Optional[str] = Form(None) ): log.info(f"file.content_type: {file.content_type}") @@ -173,18 +175,18 @@ def upload_file( process_file( request, - ProcessFileForm(file_id=id, content=result.get("text", "")), + ProcessFileForm(file_id=id, content=result.get("text", ""), knowledge_id=knowledge_id), user=user, ) elif (not file.content_type.startswith(("image/", "video/"))) or ( request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external" ): - process_file(request, ProcessFileForm(file_id=id), user=user) + process_file(request, ProcessFileForm(file_id=id, knowledge_id=knowledge_id), user=user) else: log.info( f"File type {file.content_type} is not provided, but trying to process anyway" ) - process_file(request, ProcessFileForm(file_id=id), user=user) + process_file(request, ProcessFileForm(file_id=id, knowledge_id=knowledge_id), user=user) file_item = Files.get_file_by_id(id=id) except Exception as e: diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index e6e55f4d3..ff2b90ef1 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -241,6 +241,69 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us ) return True +@router.post("/reindex/{id}", response_model=bool) +async def reindex_specific_knowledge_files(request: Request, id: str, user=Depends(get_verified_user)): + log.info(f"reindex_specific_knowledge_files called with id={id}") + knowledge_base = Knowledges.get_knowledge_by_id(id=id) + + deleted_knowledge_bases = [] + + # -- Robust error handling for missing or invalid data + if not knowledge_base.data or not isinstance(knowledge_base.data, dict): + log.warning( + f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting." + ) + try: + Knowledges.delete_knowledge_by_id(id=knowledge_base.id) + deleted_knowledge_bases.append(knowledge_base.id) + except Exception as e: + log.error( + f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}" + ) + + try: + file_ids = knowledge_base.data.get("file_ids", []) + files = Files.get_files_by_ids(file_ids) + try: + if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id): + VECTOR_DB_CLIENT.delete_collection( + collection_name=knowledge_base.id + ) + except Exception as e: + log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}") + + failed_files = [] + for file in files: + try: + process_file( + request, + ProcessFileForm( + file_id=file.id, collection_name=knowledge_base.id + ), + user=user, + ) + except Exception as e: + log.error( + f"Error processing file {file.filename} (ID: {file.id}): {str(e)}" + ) + failed_files.append({"file_id": file.id, "error": str(e)}) + continue + + if failed_files: + log.warning( + f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}" + ) + for failed in failed_files: + log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}") + + except Exception as e: + log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}") + + log.info( + f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}" + ) + return True + ############################ # GetKnowledgeById diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 333e9ecc6..efa13a63b 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -17,7 +17,7 @@ router = APIRouter() @router.get("/ef") async def get_embeddings(request: Request): - return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")} + return {"result": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]("hello world")} ############################ @@ -57,7 +57,7 @@ async def add_memory( { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION( + "vector": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( memory.content, user=user ), "metadata": {"created_at": memory.created_at}, @@ -84,7 +84,7 @@ async def query_memory( ): results = VECTOR_DB_CLIENT.search( collection_name=f"user-memory-{user.id}", - vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)], + vectors=[request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](form_data.content, user=user)], limit=form_data.k, ) @@ -107,7 +107,7 @@ async def reset_memory_from_vector_db( { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION( + "vector": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( memory.content, user=user ), "metadata": { @@ -166,7 +166,7 @@ async def update_memory_by_id( { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION( + "vector": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( memory.content, user=user ), "metadata": { diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 6d888ca99..b3867c51c 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -210,6 +210,10 @@ class SearchForm(BaseModel): queries: List[str] +class CollectionForm(BaseModel): + knowledge_id: Optional[str] = None + + @router.get("/") async def get_status(request: Request): return { @@ -224,26 +228,37 @@ async def get_status(request: Request): } -@router.get("/embedding") -async def get_embedding_config(request: Request, user=Depends(get_admin_user)): +@router.post("/embedding") +async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)): + """ + Retrieve the embedding configuration. + If DEFAULT_RAG_SETTINGS is True, return the default embedding settings. + Otherwise, return the embedding configuration stored in the database. + """ + + knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id) + rag_config = {} + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + # Return the embedding configuration from the database + rag_config = knowledge_base.rag_config return { "status": True, - "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "openai_config": { + "embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE), + "embedding_model": rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL), + "embedding_batch_size": rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE), + "openai_config": rag_config.get("openai_config", { "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, "key": request.app.state.config.RAG_OPENAI_API_KEY, - }, - "ollama_config": { + }), + "ollama_config": rag_config.get("ollama_config", { "url": request.app.state.config.RAG_OLLAMA_BASE_URL, "key": request.app.state.config.RAG_OLLAMA_API_KEY, - }, - "azure_openai_config": { + }), + "azure_openai_config": rag_config.get("azure_openai_config", { "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, "version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION, - }, + }), } @@ -270,109 +285,267 @@ class EmbeddingModelUpdateForm(BaseModel): embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 + knowledge_id: Optional[str] = None @router.post("/embedding/update") async def update_embedding_config( - request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_verified_user) ): - log.info( - f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" - ) + """ + Update the embedding model configuration. + If DEFAULT_RAG_SETTINGS is True, update the global configuration. + Otherwise, update the RAG configuration in the database for the user's knowledge base. + """ try: - request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - - if request.app.state.config.RAG_EMBEDDING_ENGINE in [ - "ollama", - "openai", - "azure_openai", - ]: - if form_data.openai_config is not None: - request.app.state.config.RAG_OPENAI_API_BASE_URL = ( - form_data.openai_config.url - ) - request.app.state.config.RAG_OPENAI_API_KEY = ( - form_data.openai_config.key - ) - - if form_data.ollama_config is not None: - request.app.state.config.RAG_OLLAMA_BASE_URL = ( - form_data.ollama_config.url - ) - request.app.state.config.RAG_OLLAMA_API_KEY = ( - form_data.ollama_config.key - ) - - if form_data.azure_openai_config is not None: - request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( - form_data.azure_openai_config.url - ) - request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( - form_data.azure_openai_config.key - ) - request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( - form_data.azure_openai_config.version - ) - - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( - form_data.embedding_batch_size + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + # Update the RAG configuration in the database + rag_config = knowledge_base.rag_config + log.info( + f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}" ) - request.app.state.ef = get_ef( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - ) + # Check if model is in use elsewhere, otherwise free up memory + in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('embedding_model'), model_type="embedding_model", id=form_data.knowledge_id) - request.app.state.EMBEDDING_FUNCTION = get_embedding_function( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.ef, - ( - request.app.state.config.RAG_OPENAI_API_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else ( - request.app.state.config.RAG_OLLAMA_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" - else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL - ) - ), - ( - request.app.state.config.RAG_OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else ( - request.app.state.config.RAG_OLLAMA_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" - else request.app.state.config.RAG_AZURE_OPENAI_API_KEY - ) - ), - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - azure_api_version=( - request.app.state.config.RAG_AZURE_OPENAI_API_VERSION - if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" - else None - ), - ) + if not in_use and not request.app.state.ef.get(request.app.state.config.RAG_EMBEDDING_MODEL) == rag_config.get("embedding_model") and rag_config.get("embedding_model"): + del request.app.state.ef[rag_config["embedding_model"]] + engine = rag_config["embedding_engine"] + target_model = rag_config["embedding_model"] + models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine] - return { - "status": True, - "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, - "openai_config": { - "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, - "key": request.app.state.config.RAG_OPENAI_API_KEY, - }, - "ollama_config": { - "url": request.app.state.config.RAG_OLLAMA_BASE_URL, - "key": request.app.state.config.RAG_OLLAMA_API_KEY, - }, + # Find and remove the dictionary that contains the target model + for model in models_list[:]: # Create a copy of the list for safe iteration + if model == target_model: + models_list.remove(model) + + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + + # Update embedding-related fields + rag_config["embedding_engine"] = form_data.embedding_engine + rag_config["embedding_model"] = form_data.embedding_model + rag_config["embedding_batch_size"] = form_data.embedding_batch_size + + # Update OpenAI, Ollama, and Azure OpenAI configurations if provided + if form_data.openai_config is not None: + rag_config["openai_config"] = { + "url": form_data.openai_config.url, + "key": form_data.openai_config.key, + } + + if form_data.ollama_config is not None: + rag_config["ollama_config"] = { + "url": form_data.ollama_config.url, + "key": form_data.ollama_config.key, + } + + if form_data.azure_openai_config is not None: + rag_config["azure_openai_config"] = { + "url": form_data.azure_openai_config.url, + "key": form_data.azure_openai_config.key, + "version": form_data.azure_openai_config.version, + } + + # Update the embedding function + if not rag_config["embedding_model"] in request.app.state.ef: + request.app.state.ef[rag_config["embedding_model"]] = get_ef( + rag_config["embedding_engine"], + rag_config["embedding_model"], + ) + + request.app.state.EMBEDDING_FUNCTION[rag_config["embedding_model"]] = get_embedding_function( + rag_config["embedding_engine"], + rag_config["embedding_model"], + request.app.state.ef[rag_config["embedding_model"]], + ( + rag_config["openai_config"]["url"] + if rag_config["embedding_engine"] == "openai" + else rag_config["ollama_config"]["url"] + ), + ( + rag_config["openai_config"]["key"] + if rag_config["embedding_engine"] == "openai" + else rag_config["ollama_config"]["key"] + ), + rag_config["embedding_batch_size"], + azure_api_version=( + rag_config["azure_openai_config"]["version"] + if rag_config["embedding_engine"] == "azure_openai" + else None + ) + ) + # add model to state for reloading on startup + if rag_config["embedding_engine"] == "azure_openai": + request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append( + {rag_config["embedding_model"]: rag_config.get("azure_openai_config", {}).get("version")} + ) + else: + request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"]) + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + # add model to state for selectable reranking models + if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]: + request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"]) + request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save() + rag_config["DOWNLOADED_EMBEDDING_MODELS"] = request.app.state.config.DOWNLOADED_EMBEDDING_MODELS + rag_config["LOADED_EMBEDDING_MODELS"] = request.app.state.config.LOADED_EMBEDDING_MODELS + + # Save the updated configuration to the database + Knowledges.update_rag_config_by_id( + id=form_data.knowledge_id, rag_config=rag_config + ) + + return { + "status": True, + "embedding_engine": rag_config["embedding_engine"], + "embedding_model": rag_config["embedding_model"], + "embedding_batch_size": rag_config["embedding_batch_size"], + "openai_config": rag_config.get("openai_config", {}), + "ollama_config": rag_config.get("ollama_config", {}), + "azure_openai_config": rag_config.get("azure_openai_config", {}), + "DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"], + "LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"], + } + else: + # Update the global configuration + log.info( + f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + ) + + # Check if model is in use elsewhere, otherwise free up memory + in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_EMBEDDING_MODEL, model_type="embedding_model") + if not in_use: + del request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] + engine = request.app.state.config.RAG_EMBEDDING_ENGINE + target_model = request.app.state.config.RAG_EMBEDDING_MODEL + models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine] + + # Find and remove the dictionary that contains the target model + for model in models_list[:]: # Create a copy of the list for safe iteration + if model == target_model: + models_list.remove(model) + + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + + request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model + + if request.app.state.config.RAG_EMBEDDING_ENGINE in [ + "ollama", + "openai", + "azure_openai", + ]: + if form_data.openai_config is not None: + request.app.state.config.RAG_OPENAI_API_BASE_URL = ( + form_data.openai_config.url + ) + request.app.state.config.RAG_OPENAI_API_KEY = ( + form_data.openai_config.key + ) + + if form_data.ollama_config is not None: + request.app.state.config.RAG_OLLAMA_BASE_URL = ( + form_data.ollama_config.url + ) + request.app.state.config.RAG_OLLAMA_API_KEY = ( + form_data.ollama_config.key + ) + + if form_data.azure_openai_config is not None: + request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( + form_data.azure_openai_config.url + ) + request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( + form_data.azure_openai_config.key + ) + request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( + form_data.azure_openai_config.version + ) + + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( + form_data.embedding_batch_size + ) + + # Update the embedding function + if not form_data.embedding_model in request.app.state.ef: + request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + ) + + request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL], + ( + request.app.state.config.RAG_OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else ( + request.app.state.config.RAG_OLLAMA_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL + ) + ), + ( + request.app.state.config.RAG_OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else ( + request.app.state.config.RAG_OLLAMA_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_API_KEY + ) + ), + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + azure_api_version=( + request.app.state.config.RAG_AZURE_OPENAI_API_VERSION + if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None + ), + ) + # add model to state for reloading on startup + if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai": + request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append( + {request.app.state.config.RAG_EMBEDDING_MODEL: request.app.state.config.RAG_AZURE_OPENAI_API_VERSION} + ) + else: + request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL) + request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() + # add model to state for selectable embedding models + if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]: + request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL) + request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save() + + return { + "status": True, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "openai_config": { + "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, + "key": request.app.state.config.RAG_OPENAI_API_KEY, + }, + "ollama_config": { + "url": request.app.state.config.RAG_OLLAMA_BASE_URL, + "key": request.app.state.config.RAG_OLLAMA_API_KEY, + }, "azure_openai_config": { "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, "version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION, }, - } + "LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS, + "DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS, + } except Exception as e: log.exception(f"Problem updating embedding model: {e}") raise HTTPException( @@ -381,116 +554,134 @@ async def update_embedding_config( ) -@router.get("/config") -async def get_rag_config(request: Request, user=Depends(get_admin_user)): +@router.post("/config") +async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_verified_user)): + """ + Retrieve the full RAG configuration. + If DEFAULT_RAG_SETTINGS is True, return the default settings. + Otherwise, return the RAG configuration stored in the database. + """ + + knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id) + rag_config = {} + web_config = {} + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + # Return the RAG configuration from the database + rag_config = knowledge_base.rag_config + web_config = rag_config.get("web", {}) return { "status": True, # RAG settings - "RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE, - "TOP_K": request.app.state.config.TOP_K, - "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, - "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, + "RAG_TEMPLATE": rag_config.get("TEMPLATE", request.app.state.config.RAG_TEMPLATE), + "TOP_K": rag_config.get("TOP_K", request.app.state.config.TOP_K), + "BYPASS_EMBEDDING_AND_RETRIEVAL": rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL), + "RAG_FULL_CONTEXT": rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT), # Hybrid search settings - "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, - "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, - "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, + "ENABLE_RAG_HYBRID_SEARCH": rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH), + "TOP_K_RERANKER": rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER), + "RELEVANCE_THRESHOLD": rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD), + "HYBRID_BM25_WEIGHT": rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT), # Content extraction settings - "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, - "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, - "DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY, - "DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS, - "DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE, - "DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR, - "DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE, - "DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, - "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, - "DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM, - "DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, - "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, - "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, - "DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE, - "DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG, - "DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, - "DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE, - "DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL, - "DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API, - "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + "CONTENT_EXTRACTION_ENGINE": rag_config.get("CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE), + "PDF_EXTRACT_IMAGES": rag_config.get("PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES), + "DATALAB_MARKER_API_KEY": rag_config.get("DATALAB_MARKER_API_KEY", request.app.state.config.DATALAB_MARKER_API_KEY), + "DATALAB_MARKER_LANGS": rag_config.get("DATALAB_MARKER_LANGS", request.app.state.config.DATALAB_MARKER_LANGS), + "DATALAB_MARKER_SKIP_CACHE": rag_config.get("DATALAB_MARKER_SKIP_CACHE", request.app.state.config.DATALAB_MARKER_SKIP_CACHE), + "DATALAB_MARKER_FORCE_OCR": rag_config.get("DATALAB_MARKER_FORCE_OCR", request.app.state.config.DATALAB_MARKER_FORCE_OCR), + "DATALAB_MARKER_PAGINATE": rag_config.get("DATALAB_MARKER_PAGINATE", request.app.state.config.DATALAB_MARKER_PAGINATE), + "DATALAB_MARKER_STRIP_EXISTING_OCR": rag_config.get("DATALAB_MARKER_STRIP_EXISTING_OCR", request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR), + "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": rag_config.get("DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION), + "DATALAB_MARKER_USE_LLM": rag_config.get("DATALAB_MARKER_USE_LLM", request.app.state.config.DATALAB_MARKER_USE_LLM), + "DATALAB_MARKER_OUTPUT_FORMAT": rag_config.get("DATALAB_MARKER_OUTPUT_FORMAT", request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT), + "EXTERNAL_DOCUMENT_LOADER_URL": rag_config.get("EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL), + "EXTERNAL_DOCUMENT_LOADER_API_KEY": rag_config.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY), + "TIKA_SERVER_URL": rag_config.get("TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL), + "DOCLING_SERVER_URL": rag_config.get("DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL), + "DOCLING_OCR_ENGINE": rag_config.get("DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE), + "DOCLING_OCR_LANG": rag_config.get("DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG), + "DOCLING_DO_PICTURE_DESCRIPTION": rag_config.get("DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION), + "DOCLING_PICTURE_DESCRIPTION_MODE": rag_config.get("DOCLING_PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE), + "DOCLING_PICTURE_DESCRIPTION_LOCAL": rag_config.get("DOCLING_PICTURE_DESCRIPTION_LOCAL", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL), + "DOCLING_PICTURE_DESCRIPTION_API": rag_config.get("DOCLING_PICTURE_DESCRIPTION_API", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API), + "DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT), + "DOCUMENT_INTELLIGENCE_KEY": rag_config.get("DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY), + "MISTRAL_OCR_API_KEY": rag_config.get("MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY), # Reranking settings - "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, - "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, - "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + "RAG_RERANKING_MODEL": rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL), + "RAG_RERANKING_ENGINE": rag_config.get("RAG_RERANKING_ENGINE", request.app.state.config.RAG_RERANKING_ENGINE), + "RAG_EXTERNAL_RERANKER_URL": rag_config.get("RAG_EXTERNAL_RERANKER_URL", request.app.state.config.RAG_EXTERNAL_RERANKER_URL), + "RAG_EXTERNAL_RERANKER_API_KEY": rag_config.get("RAG_EXTERNAL_RERANKER_API_KEY", request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY), # Chunking settings - "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, - "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, - "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, + "TEXT_SPLITTER": rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER), + "CHUNK_SIZE": rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE), + "CHUNK_OVERLAP": rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP), # File upload settings - "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, - "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, - "FILE_IMAGE_COMPRESSION_WIDTH": request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, - "FILE_IMAGE_COMPRESSION_HEIGHT": request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, - "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, + "FILE_MAX_SIZE": rag_config.get("FILE_MAX_SIZE", request.app.state.config.FILE_MAX_SIZE), + "FILE_MAX_COUNT": rag_config.get("FILE_MAX_COUNT", request.app.state.config.FILE_MAX_COUNT), + "FILE_IMAGE_COMPRESSION_WIDTH": rag_config.get("FILE_IMAGE_COMPRESSION_WIDTH", request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH), + "FILE_IMAGE_COMPRESSION_HEIGHT": rag_config.get("FILE_IMAGE_COMPRESSION_HEIGHT", request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT), + "ALLOWED_FILE_EXTENSIONS": rag_config.get("ALLOWED_FILE_EXTENSIONS", request.app.state.config.ALLOWED_FILE_EXTENSIONS), # Integration settings - "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, - "ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + "ENABLE_GOOGLE_DRIVE_INTEGRATION": rag_config.get("ENABLE_GOOGLE_DRIVE_INTEGRATION", request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION), + "ENABLE_ONEDRIVE_INTEGRATION": rag_config.get("enable_onedrive_integration", request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION), # Web search settings "web": { - "ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH, - "WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE, - "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, - "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, - "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, - "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, - "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, - "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, - "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, - "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, - "YACY_USERNAME": request.app.state.config.YACY_USERNAME, - "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, - "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, - "GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID, - "BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY, - "KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY, - "MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY, - "BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY, - "SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY, - "SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS, - "SERPER_API_KEY": request.app.state.config.SERPER_API_KEY, - "SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY, - "TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY, - "SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY, - "SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE, - "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, - "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, - "JINA_API_KEY": request.app.state.config.JINA_API_KEY, - "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, - "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "EXA_API_KEY": request.app.state.config.EXA_API_KEY, - "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, - "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, - "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, - "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, - "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, - "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, - "ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - "PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL, - "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, - "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, - "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, - "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, - "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, - "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, - "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, - "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, - "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, - "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, - "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, + "ENABLE_WEB_SEARCH": web_config.get("ENABLE_WEB_SEARCH", request.app.state.config.ENABLE_WEB_SEARCH), + "WEB_SEARCH_ENGINE": web_config.get("WEB_SEARCH_ENGINE", request.app.state.config.WEB_SEARCH_ENGINE), + "WEB_SEARCH_TRUST_ENV": web_config.get("WEB_SEARCH_TRUST_ENV", request.app.state.config.WEB_SEARCH_TRUST_ENV), + "WEB_SEARCH_RESULT_COUNT": web_config.get("WEB_SEARCH_RESULT_COUNT", request.app.state.config.WEB_SEARCH_RESULT_COUNT), + "WEB_SEARCH_CONCURRENT_REQUESTS": web_config.get("WEB_SEARCH_CONCURRENT_REQUESTS", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS), + "WEB_SEARCH_DOMAIN_FILTER_LIST": web_config.get("WEB_SEARCH_DOMAIN_FILTER_LIST", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST), + "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": web_config.get("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL), + "BYPASS_WEB_SEARCH_WEB_LOADER": web_config.get("BYPASS_WEB_SEARCH_WEB_LOADER", request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER), + "SEARXNG_QUERY_URL": web_config.get("SEARXNG_QUERY_URL", request.app.state.config.SEARXNG_QUERY_URL), + "YACY_QUERY_URL": web_config.get("YACY_QUERY_URL", request.app.state.config.YACY_QUERY_URL), + "YACY_USERNAME": web_config.get("YACY_QUERY_USERNAME",request.app.state.config.YACY_USERNAME), + "YACY_PASSWORD": web_config.get("YACY_QUERY_PASSWORD",request.app.state.config.YACY_PASSWORD), + "GOOGLE_PSE_API_KEY": web_config.get("GOOGLE_PSE_API_KEY", request.app.state.config.GOOGLE_PSE_API_KEY), + "GOOGLE_PSE_ENGINE_ID": web_config.get("GOOGLE_PSE_ENGINE_ID", request.app.state.config.GOOGLE_PSE_ENGINE_ID), + "BRAVE_SEARCH_API_KEY": web_config.get("BRAVE_SEARCH_API_KEY", request.app.state.config.BRAVE_SEARCH_API_KEY), + "KAGI_SEARCH_API_KEY": web_config.get("KAGI_SEARCH_API_KEY", request.app.state.config.KAGI_SEARCH_API_KEY), + "MOJEEK_SEARCH_API_KEY": web_config.get("MOJEEK_SEARCH_API_KEY", request.app.state.config.MOJEEK_SEARCH_API_KEY), + "BOCHA_SEARCH_API_KEY": web_config.get("BOCHA_SEARCH_API_KEY", request.app.state.config.BOCHA_SEARCH_API_KEY), + "SERPSTACK_API_KEY": web_config.get("SERPSTACK_API_KEY", request.app.state.config.SERPSTACK_API_KEY), + "SERPSTACK_HTTPS": web_config.get("SERPSTACK_HTTPS", request.app.state.config.SERPSTACK_HTTPS), + "SERPER_API_KEY": web_config.get("SERPER_API_KEY", request.app.state.config.SERPER_API_KEY), + "SERPLY_API_KEY": web_config.get("SERPLY_API_KEY", request.app.state.config.SERPLY_API_KEY), + "TAVILY_API_KEY": web_config.get("TAVILY_API_KEY", request.app.state.config.TAVILY_API_KEY), + "SEARCHAPI_API_KEY": web_config.get("SEARCHAPI_API_KEY", request.app.state.config.SEARCHAPI_API_KEY), + "SEARCHAPI_ENGINE": web_config.get("SEARCHAPI_ENGINE", request.app.state.config.SEARCHAPI_ENGINE), + "SERPAPI_API_KEY": web_config.get("SERPAPI_API_KEY", request.app.state.config.SERPAPI_API_KEY), + "SERPAPI_ENGINE": web_config.get("SERPAPI_ENGINE", request.app.state.config.SERPAPI_ENGINE), + "JINA_API_KEY": web_config.get("JINA_API_KEY", request.app.state.config.JINA_API_KEY), + "BING_SEARCH_V7_ENDPOINT": web_config.get("BING_SEARCH_V7_ENDPOINT", request.app.state.config.BING_SEARCH_V7_ENDPOINT), + "BING_SEARCH_V7_SUBSCRIPTION_KEY": web_config.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY), + "EXA_API_KEY": web_config.get("EXA_API_KEY", request.app.state.config.EXA_API_KEY), + "PERPLEXITY_API_KEY": web_config.get("PERPLEXITY_API_KEY", request.app.state.config.PERPLEXITY_API_KEY), + "PERPLEXITY_MODEL": web_config.get("PERPLEXITY_MODEL", request.app.state.config.PERPLEXITY_MODEL), + "PERPLEXITY_SEARCH_CONTEXT_USAGE": web_config.get("PERPLEXITY_SEARCH_CONTEXT_USAGE", request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE), + "SOUGOU_API_SID": web_config.get("SOUGOU_API_SID", request.app.state.config.SOUGOU_API_SID), + "SOUGOU_API_SK": web_config.get("SOUGOU_API_SK", request.app.state.config.SOUGOU_API_SK), + "WEB_LOADER_ENGINE": web_config.get("WEB_LOADER_ENGINE", request.app.state.config.WEB_LOADER_ENGINE), + "ENABLE_WEB_LOADER_SSL_VERIFICATION": web_config.get("ENABLE_WEB_LOADER_SSL_VERIFICATION", request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION), + "PLAYWRIGHT_WS_URL": web_config.get("PLAYWRIGHT_WS_URL", request.app.state.config.PLAYWRIGHT_WS_URL), + "PLAYWRIGHT_TIMEOUT": web_config.get("PLAYWRIGHT_TIMEOUT", request.app.state.config.PLAYWRIGHT_TIMEOUT), + "FIRECRAWL_API_KEY": web_config.get("FIRECRAWL_API_KEY", request.app.state.config.FIRECRAWL_API_KEY), + "FIRECRAWL_API_BASE_URL": web_config.get("FIRECRAWL_API_BASE_URL", request.app.state.config.FIRECRAWL_API_BASE_URL), + "TAVILY_EXTRACT_DEPTH": web_config.get("TAVILY_EXTRACT_DEPTH", request.app.state.config.TAVILY_EXTRACT_DEPTH), + "EXTERNAL_WEB_SEARCH_URL": web_config.get("WEB_SEARCH_URL", request.app.state.config.EXTERNAL_WEB_SEARCH_URL), + "EXTERNAL_WEB_SEARCH_API_KEY": web_config.get("WEB_SEARCH_KEY", request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY), + "EXTERNAL_WEB_LOADER_URL": web_config.get("WEB_LOADER_URL", request.app.state.config.EXTERNAL_WEB_LOADER_URL), + "EXTERNAL_WEB_LOADER_API_KEY": web_config.get("WEB_LOADER_KEY", request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY), + "YOUTUBE_LOADER_LANGUAGE": web_config.get("YOUTUBE_LOADER_LANGUAGE", request.app.state.config.YOUTUBE_LOADER_LANGUAGE), + "YOUTUBE_LOADER_PROXY_URL": web_config.get("YOUTUBE_LOADER_PROXY_URL", request.app.state.config.YOUTUBE_LOADER_PROXY_URL), + "YOUTUBE_LOADER_TRANSLATION": web_config.get("YOUTUBE_LOADER_TRANSLATION", request.app.state.YOUTUBE_LOADER_TRANSLATION), }, + "DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS), + "DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS, + "DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS, + "LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS, + "LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS, } @@ -612,481 +803,613 @@ class ConfigForm(BaseModel): # Web search settings web: Optional[WebConfig] = None + # knowledge base ID + knowledge_id: Optional[str] = None @router.post("/config/update") async def update_rag_config( - request: Request, form_data: ConfigForm, user=Depends(get_admin_user) + request: Request, form_data: ConfigForm, user=Depends(get_verified_user) ): - # RAG settings - request.app.state.config.RAG_TEMPLATE = ( - form_data.RAG_TEMPLATE - if form_data.RAG_TEMPLATE is not None - else request.app.state.config.RAG_TEMPLATE - ) - request.app.state.config.TOP_K = ( - form_data.TOP_K - if form_data.TOP_K is not None - else request.app.state.config.TOP_K - ) - request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = ( - form_data.BYPASS_EMBEDDING_AND_RETRIEVAL - if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None - else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL - ) - request.app.state.config.RAG_FULL_CONTEXT = ( - form_data.RAG_FULL_CONTEXT - if form_data.RAG_FULL_CONTEXT is not None - else request.app.state.config.RAG_FULL_CONTEXT - ) + """ + Update the RAG configuration. + If DEFAULT_RAG_SETTINGS is True, update the global configuration. + Otherwise, update the RAG configuration in the database for the user's knowledge base. + """ - # Hybrid search settings - request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( - form_data.ENABLE_RAG_HYBRID_SEARCH - if form_data.ENABLE_RAG_HYBRID_SEARCH is not None - else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH - ) - # Free up memory if hybrid search is disabled - if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: - request.app.state.rf = None + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + # Update the RAG configuration in the database + rag_config = knowledge_base.rag_config - request.app.state.config.TOP_K_RERANKER = ( - form_data.TOP_K_RERANKER - if form_data.TOP_K_RERANKER is not None - else request.app.state.config.TOP_K_RERANKER - ) - request.app.state.config.RELEVANCE_THRESHOLD = ( - form_data.RELEVANCE_THRESHOLD - if form_data.RELEVANCE_THRESHOLD is not None - else request.app.state.config.RELEVANCE_THRESHOLD - ) - request.app.state.config.HYBRID_BM25_WEIGHT = ( - form_data.HYBRID_BM25_WEIGHT - if form_data.HYBRID_BM25_WEIGHT is not None - else request.app.state.config.HYBRID_BM25_WEIGHT - ) + # Free up memory if hybrid search is disabled and model is not in use elswhere + in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get("RAG_RERANKING_MODEL"), model_type="RAG_RERANKING_MODEL", id=form_data.knowledge_id) - # Content extraction settings - request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( - form_data.CONTENT_EXTRACTION_ENGINE - if form_data.CONTENT_EXTRACTION_ENGINE is not None - else request.app.state.config.CONTENT_EXTRACTION_ENGINE - ) - request.app.state.config.PDF_EXTRACT_IMAGES = ( - form_data.PDF_EXTRACT_IMAGES - if form_data.PDF_EXTRACT_IMAGES is not None - else request.app.state.config.PDF_EXTRACT_IMAGES - ) - request.app.state.config.DATALAB_MARKER_API_KEY = ( + if not form_data.ENABLE_RAG_HYBRID_SEARCH and \ + not in_use and \ + request.app.state.rf.get(rag_config["RAG_RERANKING_MODEL"]): + if rag_config.get("RAG_RERANKING_MODEL"): + del request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]] + engine = request.app.state.config.RAG_RERANKING_ENGINE + target_model = rag_config["RAG_RERANKING_MODEL"] + models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine] + + # Find and remove the dictionary that contains the target model + for model_config in models_list[:]: # Create a copy of the list for safe iteration + if model_config["RAG_RERANKING_MODEL"] == target_model: + models_list.remove(model_config) + + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() + + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + + # Update only the provided fields in the rag_config + for field, value in form_data.model_dump(exclude_unset=True).items(): + if field == "web" and value is not None: + rag_config["web"] = {**rag_config.get("web", {}), **value} + else: + rag_config[field] = value + + + log.info( + f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" + ) + try: + try: + if not rag_config["RAG_RERANKING_MODEL"] in request.app.state.rf and not rag_config["RAG_RERANKING_MODEL"] == "": + request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]] = get_rf( + rag_config["RAG_RERANKING_ENGINE"], + rag_config["RAG_RERANKING_MODEL"], + rag_config["RAG_EXTERNAL_RERANKER_URL"], + rag_config["RAG_EXTERNAL_RERANKER_API_KEY"], + True, + ) + + # add model to state for reloading on startup + request.app.state.config.LOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append({ + "RAG_RERANKING_MODEL": rag_config["RAG_RERANKING_MODEL"], + "RAG_EXTERNAL_RERANKER_URL": rag_config["RAG_EXTERNAL_RERANKER_URL"], + "RAG_EXTERNAL_RERANKER_API_KEY": rag_config["RAG_EXTERNAL_RERANKER_API_KEY"]}) + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() + + # add model to state for selectable reranking models + if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]]: + request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append(rag_config["RAG_RERANKING_MODEL"]) + request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save() + + rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS + rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS + + except Exception as e: + log.error(f"Error loading reranking model: {e}") + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + except Exception as e: + log.exception(f"Problem updating reranking model: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + Knowledges.update_rag_config_by_id( + id=knowledge_base.id, rag_config=rag_config + ) + + return rag_config + else: + # Update the global configuration + # RAG settings + request.app.state.config.RAG_TEMPLATE = ( + form_data.RAG_TEMPLATE + if form_data.RAG_TEMPLATE is not None + else request.app.state.config.RAG_TEMPLATE + ) + request.app.state.config.TOP_K = ( + form_data.TOP_K + if form_data.TOP_K is not None + else request.app.state.config.TOP_K + ) + request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = ( + form_data.BYPASS_EMBEDDING_AND_RETRIEVAL + if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None + else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ) + request.app.state.config.RAG_FULL_CONTEXT = ( + form_data.RAG_FULL_CONTEXT + if form_data.RAG_FULL_CONTEXT is not None + else request.app.state.config.RAG_FULL_CONTEXT + ) + + # Hybrid search settings + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( + form_data.ENABLE_RAG_HYBRID_SEARCH + if form_data.ENABLE_RAG_HYBRID_SEARCH is not None + else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH + ) + + # Free up memory if hybrid search is disabled and model is not in use elswhere + in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_RERANKING_MODEL, model_type="RAG_RERANKING_MODEL") + + if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and \ + not in_use and \ + request.app.state.rf.get(request.app.state.config.RAG_RERANKING_MODEL): + del request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] + engine = request.app.state.config.RAG_RERANKING_ENGINE + target_model = request.app.state.config.RAG_RERANKING_MODEL + models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine] + + # Find and remove the dictionary that contains the target model + for model_config in models_list[:]: # Create a copy of the list for safe iteration + if model_config["RAG_RERANKING_MODEL"] == target_model: + models_list.remove(model_config) + + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() + + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + + request.app.state.config.TOP_K_RERANKER = ( + form_data.TOP_K_RERANKER + if form_data.TOP_K_RERANKER is not None + else request.app.state.config.TOP_K_RERANKER + ) + request.app.state.config.RELEVANCE_THRESHOLD = ( + form_data.RELEVANCE_THRESHOLD + if form_data.RELEVANCE_THRESHOLD is not None + else request.app.state.config.RELEVANCE_THRESHOLD + ) + request.app.state.config.HYBRID_BM25_WEIGHT = ( + form_data.HYBRID_BM25_WEIGHT + if form_data.HYBRID_BM25_WEIGHT is not None + else request.app.state.config.HYBRID_BM25_WEIGHT + ) + + # Content extraction settings + request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( + form_data.CONTENT_EXTRACTION_ENGINE + if form_data.CONTENT_EXTRACTION_ENGINE is not None + else request.app.state.config.CONTENT_EXTRACTION_ENGINE + ) + request.app.state.config.PDF_EXTRACT_IMAGES = ( + form_data.PDF_EXTRACT_IMAGES + if form_data.PDF_EXTRACT_IMAGES is not None + else request.app.state.config.PDF_EXTRACT_IMAGES + ) + request.app.state.config.DATALAB_MARKER_API_KEY = ( form_data.DATALAB_MARKER_API_KEY if form_data.DATALAB_MARKER_API_KEY is not None else request.app.state.config.DATALAB_MARKER_API_KEY - ) - request.app.state.config.DATALAB_MARKER_LANGS = ( - form_data.DATALAB_MARKER_LANGS - if form_data.DATALAB_MARKER_LANGS is not None - else request.app.state.config.DATALAB_MARKER_LANGS - ) - request.app.state.config.DATALAB_MARKER_SKIP_CACHE = ( - form_data.DATALAB_MARKER_SKIP_CACHE - if form_data.DATALAB_MARKER_SKIP_CACHE is not None - else request.app.state.config.DATALAB_MARKER_SKIP_CACHE - ) - request.app.state.config.DATALAB_MARKER_FORCE_OCR = ( - form_data.DATALAB_MARKER_FORCE_OCR - if form_data.DATALAB_MARKER_FORCE_OCR is not None - else request.app.state.config.DATALAB_MARKER_FORCE_OCR - ) - request.app.state.config.DATALAB_MARKER_PAGINATE = ( - form_data.DATALAB_MARKER_PAGINATE - if form_data.DATALAB_MARKER_PAGINATE is not None - else request.app.state.config.DATALAB_MARKER_PAGINATE - ) - request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = ( - form_data.DATALAB_MARKER_STRIP_EXISTING_OCR - if form_data.DATALAB_MARKER_STRIP_EXISTING_OCR is not None - else request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR - ) - request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = ( - form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION - if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None - else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION - ) - request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = ( - form_data.DATALAB_MARKER_OUTPUT_FORMAT - if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None - else request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT - ) - request.app.state.config.DATALAB_MARKER_USE_LLM = ( - form_data.DATALAB_MARKER_USE_LLM - if form_data.DATALAB_MARKER_USE_LLM is not None - else request.app.state.config.DATALAB_MARKER_USE_LLM - ) - request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = ( - form_data.EXTERNAL_DOCUMENT_LOADER_URL - if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None - else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL - ) - request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = ( - form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY - if form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY is not None - else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY - ) - request.app.state.config.TIKA_SERVER_URL = ( - form_data.TIKA_SERVER_URL - if form_data.TIKA_SERVER_URL is not None - else request.app.state.config.TIKA_SERVER_URL - ) - request.app.state.config.DOCLING_SERVER_URL = ( - form_data.DOCLING_SERVER_URL - if form_data.DOCLING_SERVER_URL is not None - else request.app.state.config.DOCLING_SERVER_URL - ) - request.app.state.config.DOCLING_OCR_ENGINE = ( - form_data.DOCLING_OCR_ENGINE - if form_data.DOCLING_OCR_ENGINE is not None - else request.app.state.config.DOCLING_OCR_ENGINE - ) - request.app.state.config.DOCLING_OCR_LANG = ( - form_data.DOCLING_OCR_LANG - if form_data.DOCLING_OCR_LANG is not None - else request.app.state.config.DOCLING_OCR_LANG - ) - - request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = ( - form_data.DOCLING_DO_PICTURE_DESCRIPTION - if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None - else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION - ) - - request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = ( - form_data.DOCLING_PICTURE_DESCRIPTION_MODE - if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None - else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE - ) - request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = ( - form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL - if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None - else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL - ) - request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = ( - form_data.DOCLING_PICTURE_DESCRIPTION_API - if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None - else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API - ) - - request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = ( - form_data.DOCUMENT_INTELLIGENCE_ENDPOINT - if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None - else request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT - ) - request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = ( - form_data.DOCUMENT_INTELLIGENCE_KEY - if form_data.DOCUMENT_INTELLIGENCE_KEY is not None - else request.app.state.config.DOCUMENT_INTELLIGENCE_KEY - ) - request.app.state.config.MISTRAL_OCR_API_KEY = ( - form_data.MISTRAL_OCR_API_KEY - if form_data.MISTRAL_OCR_API_KEY is not None - else request.app.state.config.MISTRAL_OCR_API_KEY - ) - - # Reranking settings - request.app.state.config.RAG_RERANKING_ENGINE = ( - form_data.RAG_RERANKING_ENGINE - if form_data.RAG_RERANKING_ENGINE is not None - else request.app.state.config.RAG_RERANKING_ENGINE - ) - - request.app.state.config.RAG_EXTERNAL_RERANKER_URL = ( - form_data.RAG_EXTERNAL_RERANKER_URL - if form_data.RAG_EXTERNAL_RERANKER_URL is not None - else request.app.state.config.RAG_EXTERNAL_RERANKER_URL - ) - - request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = ( - form_data.RAG_EXTERNAL_RERANKER_API_KEY - if form_data.RAG_EXTERNAL_RERANKER_API_KEY is not None - else request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY - ) - - log.info( - f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" - ) - try: - request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL - - try: - request.app.state.rf = get_rf( - request.app.state.config.RAG_RERANKING_ENGINE, - request.app.state.config.RAG_RERANKING_MODEL, - request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - True, - ) - except Exception as e: - log.error(f"Error loading reranking model: {e}") - request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False - except Exception as e: - log.exception(f"Problem updating reranking model: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT(e), + ) + request.app.state.config.DATALAB_MARKER_LANGS = ( + form_data.DATALAB_MARKER_LANGS + if form_data.DATALAB_MARKER_LANGS is not None + else request.app.state.config.DATALAB_MARKER_LANGS + ) + request.app.state.config.DATALAB_MARKER_SKIP_CACHE = ( + form_data.DATALAB_MARKER_SKIP_CACHE + if form_data.DATALAB_MARKER_SKIP_CACHE is not None + else request.app.state.config.DATALAB_MARKER_SKIP_CACHE + ) + request.app.state.config.DATALAB_MARKER_FORCE_OCR = ( + form_data.DATALAB_MARKER_FORCE_OCR + if form_data.DATALAB_MARKER_FORCE_OCR is not None + else request.app.state.config.DATALAB_MARKER_FORCE_OCR + ) + request.app.state.config.DATALAB_MARKER_PAGINATE = ( + form_data.DATALAB_MARKER_PAGINATE + if form_data.DATALAB_MARKER_PAGINATE is not None + else request.app.state.config.DATALAB_MARKER_PAGINATE + ) + request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = ( + form_data.DATALAB_MARKER_STRIP_EXISTING_OCR + if form_data.DATALAB_MARKER_STRIP_EXISTING_OCR is not None + else request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR + ) + request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = ( + form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION + if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None + else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION + ) + request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = ( + form_data.DATALAB_MARKER_OUTPUT_FORMAT + if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None + else request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT + ) + request.app.state.config.DATALAB_MARKER_USE_LLM = ( + form_data.DATALAB_MARKER_USE_LLM + if form_data.DATALAB_MARKER_USE_LLM is not None + else request.app.state.config.DATALAB_MARKER_USE_LLM + ) + request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = ( + form_data.EXTERNAL_DOCUMENT_LOADER_URL + if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None + else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL + ) + request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = ( + form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY + if form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY is not None + else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY + ) + request.app.state.config.TIKA_SERVER_URL = ( + form_data.TIKA_SERVER_URL + if form_data.TIKA_SERVER_URL is not None + else request.app.state.config.TIKA_SERVER_URL + ) + request.app.state.config.DOCLING_SERVER_URL = ( + form_data.DOCLING_SERVER_URL + if form_data.DOCLING_SERVER_URL is not None + else request.app.state.config.DOCLING_SERVER_URL + ) + request.app.state.config.DOCLING_OCR_ENGINE = ( + form_data.DOCLING_OCR_ENGINE + if form_data.DOCLING_OCR_ENGINE is not None + else request.app.state.config.DOCLING_OCR_ENGINE + ) + request.app.state.config.DOCLING_OCR_LANG = ( + form_data.DOCLING_OCR_LANG + if form_data.DOCLING_OCR_LANG is not None + else request.app.state.config.DOCLING_OCR_LANG ) - # Chunking settings - request.app.state.config.TEXT_SPLITTER = ( - form_data.TEXT_SPLITTER - if form_data.TEXT_SPLITTER is not None - else request.app.state.config.TEXT_SPLITTER - ) - request.app.state.config.CHUNK_SIZE = ( - form_data.CHUNK_SIZE - if form_data.CHUNK_SIZE is not None - else request.app.state.config.CHUNK_SIZE - ) - request.app.state.config.CHUNK_OVERLAP = ( - form_data.CHUNK_OVERLAP - if form_data.CHUNK_OVERLAP is not None - else request.app.state.config.CHUNK_OVERLAP - ) - - # File upload settings - request.app.state.config.FILE_MAX_SIZE = form_data.FILE_MAX_SIZE - request.app.state.config.FILE_MAX_COUNT = form_data.FILE_MAX_COUNT - request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH = ( - form_data.FILE_IMAGE_COMPRESSION_WIDTH - ) - request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = ( - form_data.FILE_IMAGE_COMPRESSION_HEIGHT - ) - request.app.state.config.ALLOWED_FILE_EXTENSIONS = ( - form_data.ALLOWED_FILE_EXTENSIONS - if form_data.ALLOWED_FILE_EXTENSIONS is not None - else request.app.state.config.ALLOWED_FILE_EXTENSIONS - ) - - # Integration settings - request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ( - form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION - if form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION is not None - else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION - ) - request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ( - form_data.ENABLE_ONEDRIVE_INTEGRATION - if form_data.ENABLE_ONEDRIVE_INTEGRATION is not None - else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION - ) - - if form_data.web is not None: - # Web search settings - request.app.state.config.ENABLE_WEB_SEARCH = form_data.web.ENABLE_WEB_SEARCH - request.app.state.config.WEB_SEARCH_ENGINE = form_data.web.WEB_SEARCH_ENGINE - request.app.state.config.WEB_SEARCH_TRUST_ENV = ( - form_data.web.WEB_SEARCH_TRUST_ENV - ) - request.app.state.config.WEB_SEARCH_RESULT_COUNT = ( - form_data.web.WEB_SEARCH_RESULT_COUNT - ) - request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = ( - form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS - ) - request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = ( - form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST - ) - request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( - form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL - ) - request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = ( - form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER - ) - request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL - request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL - request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME - request.app.state.config.YACY_PASSWORD = form_data.web.YACY_PASSWORD - request.app.state.config.GOOGLE_PSE_API_KEY = form_data.web.GOOGLE_PSE_API_KEY - request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( - form_data.web.GOOGLE_PSE_ENGINE_ID - ) - request.app.state.config.BRAVE_SEARCH_API_KEY = ( - form_data.web.BRAVE_SEARCH_API_KEY - ) - request.app.state.config.KAGI_SEARCH_API_KEY = form_data.web.KAGI_SEARCH_API_KEY - request.app.state.config.MOJEEK_SEARCH_API_KEY = ( - form_data.web.MOJEEK_SEARCH_API_KEY - ) - request.app.state.config.BOCHA_SEARCH_API_KEY = ( - form_data.web.BOCHA_SEARCH_API_KEY - ) - request.app.state.config.SERPSTACK_API_KEY = form_data.web.SERPSTACK_API_KEY - request.app.state.config.SERPSTACK_HTTPS = form_data.web.SERPSTACK_HTTPS - request.app.state.config.SERPER_API_KEY = form_data.web.SERPER_API_KEY - request.app.state.config.SERPLY_API_KEY = form_data.web.SERPLY_API_KEY - request.app.state.config.TAVILY_API_KEY = form_data.web.TAVILY_API_KEY - request.app.state.config.SEARCHAPI_API_KEY = form_data.web.SEARCHAPI_API_KEY - request.app.state.config.SEARCHAPI_ENGINE = form_data.web.SEARCHAPI_ENGINE - request.app.state.config.SERPAPI_API_KEY = form_data.web.SERPAPI_API_KEY - request.app.state.config.SERPAPI_ENGINE = form_data.web.SERPAPI_ENGINE - request.app.state.config.JINA_API_KEY = form_data.web.JINA_API_KEY - request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( - form_data.web.BING_SEARCH_V7_ENDPOINT - ) - request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( - form_data.web.BING_SEARCH_V7_SUBSCRIPTION_KEY - ) - request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY - request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY - request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL - request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = ( - form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE - ) - request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID - request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK - - # Web loader settings - request.app.state.config.WEB_LOADER_ENGINE = form_data.web.WEB_LOADER_ENGINE - request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ( - form_data.web.ENABLE_WEB_LOADER_SSL_VERIFICATION - ) - request.app.state.config.PLAYWRIGHT_WS_URL = form_data.web.PLAYWRIGHT_WS_URL - request.app.state.config.PLAYWRIGHT_TIMEOUT = form_data.web.PLAYWRIGHT_TIMEOUT - request.app.state.config.FIRECRAWL_API_KEY = form_data.web.FIRECRAWL_API_KEY - request.app.state.config.FIRECRAWL_API_BASE_URL = ( - form_data.web.FIRECRAWL_API_BASE_URL - ) - request.app.state.config.EXTERNAL_WEB_SEARCH_URL = ( - form_data.web.EXTERNAL_WEB_SEARCH_URL - ) - request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = ( - form_data.web.EXTERNAL_WEB_SEARCH_API_KEY - ) - request.app.state.config.EXTERNAL_WEB_LOADER_URL = ( - form_data.web.EXTERNAL_WEB_LOADER_URL - ) - request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = ( - form_data.web.EXTERNAL_WEB_LOADER_API_KEY - ) - request.app.state.config.TAVILY_EXTRACT_DEPTH = ( - form_data.web.TAVILY_EXTRACT_DEPTH - ) - request.app.state.config.YOUTUBE_LOADER_LANGUAGE = ( - form_data.web.YOUTUBE_LOADER_LANGUAGE - ) - request.app.state.config.YOUTUBE_LOADER_PROXY_URL = ( - form_data.web.YOUTUBE_LOADER_PROXY_URL - ) - request.app.state.YOUTUBE_LOADER_TRANSLATION = ( - form_data.web.YOUTUBE_LOADER_TRANSLATION + request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = ( + form_data.DOCLING_DO_PICTURE_DESCRIPTION + if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None + else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION + ) + + request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = ( + form_data.DOCLING_PICTURE_DESCRIPTION_MODE + if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None + else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE + ) + request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = ( + form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL + if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None + else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL + ) + request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = ( + form_data.DOCLING_PICTURE_DESCRIPTION_API + if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None + else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API + ) + + request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = ( + form_data.DOCUMENT_INTELLIGENCE_ENDPOINT + if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None + else request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT + ) + request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = ( + form_data.DOCUMENT_INTELLIGENCE_KEY + if form_data.DOCUMENT_INTELLIGENCE_KEY is not None + else request.app.state.config.DOCUMENT_INTELLIGENCE_KEY + ) + request.app.state.config.MISTRAL_OCR_API_KEY = ( + form_data.MISTRAL_OCR_API_KEY + if form_data.MISTRAL_OCR_API_KEY is not None + else request.app.state.config.MISTRAL_OCR_API_KEY ) - return { - "status": True, - # RAG settings - "RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE, - "TOP_K": request.app.state.config.TOP_K, - "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, - "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, - # Hybrid search settings - "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, - "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, - "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, - # Content extraction settings - "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, - "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, - "DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY, - "DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS, - "DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE, - "DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR, - "DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE, - "DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, - "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, - "DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM, - "DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, - "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, - "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, - "DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE, - "DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG, - "DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, - "DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE, - "DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL, - "DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API, - "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, # Reranking settings - "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, - "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, - "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + request.app.state.config.RAG_RERANKING_ENGINE = ( + form_data.RAG_RERANKING_ENGINE + if form_data.RAG_RERANKING_ENGINE is not None + else request.app.state.config.RAG_RERANKING_ENGINE + ) + + request.app.state.config.RAG_EXTERNAL_RERANKER_URL = ( + form_data.RAG_EXTERNAL_RERANKER_URL + if form_data.RAG_EXTERNAL_RERANKER_URL is not None + else request.app.state.config.RAG_EXTERNAL_RERANKER_URL + ) + + request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = ( + form_data.RAG_EXTERNAL_RERANKER_API_KEY + if form_data.RAG_EXTERNAL_RERANKER_API_KEY is not None + else request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY + ) + + + log.info( + f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" + ) + try: + request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL + + try: + if not request.app.state.config.RAG_RERANKING_MODEL in request.app.state.rf and not request.app.state.config.RAG_RERANKING_MODEL == "": + request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + True, + ) + + # add model to state for reloading on startup + request.app.state.config.LOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append({ + "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, + "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY + }) + request.app.state.config._state["LOADED_RERANKING_MODELS"].save() + + # add model to state for selectable reranking models + if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE]: + request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append(request.app.state.config.RAG_RERANKING_MODEL) + request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save() + + rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS + rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS + + + except Exception as e: + log.error(f"Error loading reranking model: {e}") + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + except Exception as e: + log.exception(f"Problem updating reranking model: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + # Chunking settings - "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, - "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, - "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, + request.app.state.config.TEXT_SPLITTER = ( + form_data.TEXT_SPLITTER + if form_data.TEXT_SPLITTER is not None + else request.app.state.config.TEXT_SPLITTER + ) + request.app.state.config.CHUNK_SIZE = ( + form_data.CHUNK_SIZE + if form_data.CHUNK_SIZE is not None + else request.app.state.config.CHUNK_SIZE + ) + request.app.state.config.CHUNK_OVERLAP = ( + form_data.CHUNK_OVERLAP + if form_data.CHUNK_OVERLAP is not None + else request.app.state.config.CHUNK_OVERLAP + ) + # File upload settings - "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, - "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, - "FILE_IMAGE_COMPRESSION_WIDTH": request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, - "FILE_IMAGE_COMPRESSION_HEIGHT": request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, - "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, + request.app.state.config.FILE_MAX_SIZE = form_data.FILE_MAX_SIZE + request.app.state.config.FILE_MAX_COUNT = form_data.FILE_MAX_COUNT + request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH = ( + form_data.FILE_IMAGE_COMPRESSION_WIDTH + ) + request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = ( + form_data.FILE_IMAGE_COMPRESSION_HEIGHT + ) + request.app.state.config.ALLOWED_FILE_EXTENSIONS = ( + form_data.ALLOWED_FILE_EXTENSIONS + if form_data.ALLOWED_FILE_EXTENSIONS is not None + else request.app.state.config.ALLOWED_FILE_EXTENSIONS + ) + # Integration settings - "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, - "ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, - # Web search settings - "web": { - "ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH, - "WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE, - "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, - "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, - "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, - "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, - "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, - "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, - "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, - "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, - "YACY_USERNAME": request.app.state.config.YACY_USERNAME, - "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, - "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, - "GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID, - "BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY, - "KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY, - "MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY, - "BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY, - "SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY, - "SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS, - "SERPER_API_KEY": request.app.state.config.SERPER_API_KEY, - "SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY, - "TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY, - "SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY, - "SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE, - "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, - "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, - "JINA_API_KEY": request.app.state.config.JINA_API_KEY, - "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, - "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "EXA_API_KEY": request.app.state.config.EXA_API_KEY, - "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, - "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, - "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, - "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, - "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, - "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, - "ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - "PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL, - "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, - "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, - "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, - "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, - "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, - "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, - "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, - "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, - "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, - "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, - "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, - }, - } + request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ( + form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION + if form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION is not None + else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION + ) + request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ( + form_data.ENABLE_ONEDRIVE_INTEGRATION + if form_data.ENABLE_ONEDRIVE_INTEGRATION is not None + else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION + ) + + if form_data.web is not None: + # Web search settings + request.app.state.config.ENABLE_WEB_SEARCH = form_data.web.ENABLE_WEB_SEARCH + request.app.state.config.WEB_SEARCH_ENGINE = form_data.web.WEB_SEARCH_ENGINE + request.app.state.config.WEB_SEARCH_TRUST_ENV = ( + form_data.web.WEB_SEARCH_TRUST_ENV + ) + request.app.state.config.WEB_SEARCH_RESULT_COUNT = ( + form_data.web.WEB_SEARCH_RESULT_COUNT + ) + request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = ( + form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS + ) + request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = ( + form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST + ) + request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( + form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL + ) + request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = ( + form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER + ) + request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL + request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL + request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME + request.app.state.config.YACY_PASSWORD = form_data.web.YACY_PASSWORD + request.app.state.config.GOOGLE_PSE_API_KEY = form_data.web.GOOGLE_PSE_API_KEY + request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( + form_data.web.GOOGLE_PSE_ENGINE_ID + ) + request.app.state.config.BRAVE_SEARCH_API_KEY = ( + form_data.web.BRAVE_SEARCH_API_KEY + ) + request.app.state.config.KAGI_SEARCH_API_KEY = form_data.web.KAGI_SEARCH_API_KEY + request.app.state.config.MOJEEK_SEARCH_API_KEY = ( + form_data.web.MOJEEK_SEARCH_API_KEY + ) + request.app.state.config.BOCHA_SEARCH_API_KEY = ( + form_data.web.BOCHA_SEARCH_API_KEY + ) + request.app.state.config.SERPSTACK_API_KEY = form_data.web.SERPSTACK_API_KEY + request.app.state.config.SERPSTACK_HTTPS = form_data.web.SERPSTACK_HTTPS + request.app.state.config.SERPER_API_KEY = form_data.web.SERPER_API_KEY + request.app.state.config.SERPLY_API_KEY = form_data.web.SERPLY_API_KEY + request.app.state.config.TAVILY_API_KEY = form_data.web.TAVILY_API_KEY + request.app.state.config.SEARCHAPI_API_KEY = form_data.web.SEARCHAPI_API_KEY + request.app.state.config.SEARCHAPI_ENGINE = form_data.web.SEARCHAPI_ENGINE + request.app.state.config.SERPAPI_API_KEY = form_data.web.SERPAPI_API_KEY + request.app.state.config.SERPAPI_ENGINE = form_data.web.SERPAPI_ENGINE + request.app.state.config.JINA_API_KEY = form_data.web.JINA_API_KEY + request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( + form_data.web.BING_SEARCH_V7_ENDPOINT + ) + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( + form_data.web.BING_SEARCH_V7_SUBSCRIPTION_KEY + ) + request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY + request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY + request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL + request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = ( + form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE + ) + request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID + request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK + + # Web loader settings + request.app.state.config.WEB_LOADER_ENGINE = form_data.web.WEB_LOADER_ENGINE + request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ( + form_data.web.ENABLE_WEB_LOADER_SSL_VERIFICATION + ) + request.app.state.config.PLAYWRIGHT_WS_URL = form_data.web.PLAYWRIGHT_WS_URL + request.app.state.config.PLAYWRIGHT_TIMEOUT = form_data.web.PLAYWRIGHT_TIMEOUT + request.app.state.config.FIRECRAWL_API_KEY = form_data.web.FIRECRAWL_API_KEY + request.app.state.config.FIRECRAWL_API_BASE_URL = ( + form_data.web.FIRECRAWL_API_BASE_URL + ) + request.app.state.config.EXTERNAL_WEB_SEARCH_URL = ( + form_data.web.EXTERNAL_WEB_SEARCH_URL + ) + request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = ( + form_data.web.EXTERNAL_WEB_SEARCH_API_KEY + ) + request.app.state.config.EXTERNAL_WEB_LOADER_URL = ( + form_data.web.EXTERNAL_WEB_LOADER_URL + ) + request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = ( + form_data.web.EXTERNAL_WEB_LOADER_API_KEY + ) + request.app.state.config.TAVILY_EXTRACT_DEPTH = ( + form_data.web.TAVILY_EXTRACT_DEPTH + ) + request.app.state.config.YOUTUBE_LOADER_LANGUAGE = ( + form_data.web.YOUTUBE_LOADER_LANGUAGE + ) + request.app.state.config.YOUTUBE_LOADER_PROXY_URL = ( + form_data.web.YOUTUBE_LOADER_PROXY_URL + ) + request.app.state.YOUTUBE_LOADER_TRANSLATION = ( + form_data.web.YOUTUBE_LOADER_TRANSLATION + ) + + return { + "status": True, + # RAG settings + "RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE, + "TOP_K": request.app.state.config.TOP_K, + "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, + "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, + # Hybrid search settings + "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, + "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, + "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, + # Content extraction settings + "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, + "DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY, + "DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS, + "DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE, + "DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR, + "DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE, + "DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, + "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, + "DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM, + "DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, + "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, + "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, + "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, + "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, + "DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE, + "DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG, + "DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, + "DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE, + "DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL, + "DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API, + "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + # Reranking settings + "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, + "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, + "RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + "RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + # Chunking settings + "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, + "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, + "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, + # File upload settings + "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, + "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, + "FILE_IMAGE_COMPRESSION_WIDTH": request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, + "FILE_IMAGE_COMPRESSION_HEIGHT": request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, + "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, + # Integration settings + "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + "ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + # Web search settings + "web": { + "ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH, + "WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE, + "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, + "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, + "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, + "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, + "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, + "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, + "YACY_USERNAME": request.app.state.config.YACY_USERNAME, + "YACY_PASSWORD": request.app.state.config.YACY_PASSWORD, + "GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY, + "GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY, + "KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY, + "MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY, + "SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY, + "SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS, + "SERPER_API_KEY": request.app.state.config.SERPER_API_KEY, + "SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY, + "TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY, + "SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY, + "SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE, + "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, + "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, + "JINA_API_KEY": request.app.state.config.JINA_API_KEY, + "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "EXA_API_KEY": request.app.state.config.EXA_API_KEY, + "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, + "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, + "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, + "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, + "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, + "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, + "ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, + "PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL, + "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, + "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, + "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, + "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, + "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, + "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, + "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, + "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, + "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, + }, + "DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS + } #################################### @@ -1105,6 +1428,7 @@ def save_docs_to_vector_db( split: bool = True, add: bool = False, user=None, + knowledge_id: Optional[str] = None ) -> bool: def _get_docs_info(docs: list[Document]) -> str: docs_info = set() @@ -1125,6 +1449,29 @@ def save_docs_to_vector_db( log.info( f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" ) + + rag_config = {} + # Retrieve the knowledge base using the collection_name + if knowledge_id: + knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id) + # Retrieve the RAG configuration + if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + + # Use knowledge-base-specific or default configurations + text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER) + chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE) + chunk_overlap = rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP) + embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE) + embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL) + embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE) + openai_api_base_url = rag_config.get("openai_config", {}).get("url", request.app.state.config.RAG_OPENAI_API_BASE_URL) + openai_api_key = rag_config.get("openai_config", {}).get("url", request.app.state.config.RAG_OPENAI_API_KEY) + ollama_base_url = rag_config.get("ollama_config", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL) + ollama_api_key = rag_config.get("ollama_config", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY) + azure_openai_url = rag_config.get("azure_openai", {}).get("url", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL) + azure_openai_key = rag_config.get("azure_openai", {}).get("key", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL) + azure_openai_version = rag_config.get("azure_openai", {}).get("version", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL) # Check if entries with the same hash (metadata.hash) already exist if metadata and "hash" in metadata: @@ -1140,13 +1487,13 @@ def save_docs_to_vector_db( raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: - if request.app.state.config.TEXT_SPLITTER in ["", "character"]: + if text_splitter_type in ["", "character"]: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=request.app.state.config.CHUNK_SIZE, - chunk_overlap=request.app.state.config.CHUNK_OVERLAP, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, add_start_index=True, ) - elif request.app.state.config.TEXT_SPLITTER == "token": + elif text_splitter_type == "token": log.info( f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" ) @@ -1154,8 +1501,8 @@ def save_docs_to_vector_db( tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME)) text_splitter = TokenTextSplitter( encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME), - chunk_size=request.app.state.config.CHUNK_SIZE, - chunk_overlap=request.app.state.config.CHUNK_OVERLAP, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, add_start_index=True, ) else: @@ -1173,8 +1520,8 @@ def save_docs_to_vector_db( **(metadata if metadata else {}), "embedding_config": json.dumps( { - "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, - "model": request.app.state.config.RAG_EMBEDDING_MODEL, + "engine": embedding_engine, + "model": embedding_model, } ), } @@ -1207,31 +1554,27 @@ def save_docs_to_vector_db( log.info(f"adding to collection {collection_name}") embedding_function = get_embedding_function( - request.app.state.config.RAG_EMBEDDING_ENGINE, - request.app.state.config.RAG_EMBEDDING_MODEL, - request.app.state.ef, + embedding_engine, + embedding_model, + request.app.state.ef[embedding_model], ( - request.app.state.config.RAG_OPENAI_API_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + openai_api_base_url + if embedding_engine == "openai" else ( - request.app.state.config.RAG_OLLAMA_BASE_URL - if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" - else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL + ollama_base_url + if embedding_engine == "ollama" + else azure_openai_url ) ), ( - request.app.state.config.RAG_OPENAI_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else ( - request.app.state.config.RAG_OLLAMA_API_KEY - if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" - else request.app.state.config.RAG_AZURE_OPENAI_API_KEY - ) + openai_api_key + if embedding_engine == "openai" + else ollama_api_key ), - request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + embedding_batch_size, azure_api_version=( - request.app.state.config.RAG_AZURE_OPENAI_API_VERSION - if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + azure_openai_version + if embedding_engine == "azure_openai" else None ), ) @@ -1267,6 +1610,7 @@ class ProcessFileForm(BaseModel): file_id: str content: Optional[str] = None collection_name: Optional[str] = None + knowledge_id: Optional[str] = None @router.post("/process/file") @@ -1282,6 +1626,97 @@ def process_file( if collection_name is None: collection_name = f"file-{file.id}" + + rag_config = {} + # Retrieve the knowledge base using the collection id - knowledge_id == collection_name (minimal working solution without logic changes) + if form_data.collection_name: + knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name) + + # Retrieve the RAG configuration + if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db + + elif form_data.knowledge_id: + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) + + # Retrieve the RAG configuration + if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + + # Use knowledge-base-specific or default configurations + content_extraction_engine = rag_config.get( + "CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE + ) + datalab_marker_api_key=rag_config.get( + "DATALAB_MARKER_API_KEY", request.app.state.config.DATALAB_MARKER_API_KEY + ) + datalab_marker_langs=rag_config.get( + "DATALAB_MARKER_LANGS", request.app.state.config.DATALAB_MARKER_LANGS + ) + datalab_marker_skip_cache=rag_config.get( + "DATALAB_MARKER_SKIP_CACHE", request.app.state.config.DATALAB_MARKER_SKIP_CACHE + ) + datalab_marker_force_ocr=rag_config.get( + "DATALAB_MARKER_FORCE_OCR", request.app.state.config.DATALAB_MARKER_FORCE_OCR + ) + datalab_marker_paginate=rag_config.get( + "DATALAB_MARKER_PAGINATE", request.app.state.config.DATALAB_MARKER_PAGINATE + ) + datalab_marker_strip_existing_ocr=rag_config.get( + "DATALAB_MARKER_STRIP_EXISTING_OCR", request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR + ) + datalab_marker_disable_image_extraction=rag_config.get( + "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION + ) + datalab_marker_use_llm=rag_config.get( + "DATALAB_MARKER_USE_LLM", request.app.state.config.DATALAB_MARKER_USE_LLM + ) + datalab_marker_output_format=rag_config.get( + "DATALAB_MARKER_OUTPUT_FORMAT", request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT + ) + external_document_loader_url = rag_config.get( + "EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL + ) + external_document_loader_api_key = rag_config.get( + "EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY + ) + tika_server_url = rag_config.get( + "TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL + ) + docling_server_url = rag_config.get( + "DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL + ) + docling_ocr_engine=rag_config.get( + "DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE + ) + docling_ocr_lang=rag_config.get( + "DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG + ) + docling_do_picture_description=rag_config.get( + "DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION + ) + picture_description_mode = rag_config.get( + "PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE + ) + picture_description_local = rag_config.get( + "PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL + ) + picture_description_api = rag_config.get( + "PICTURE_DESCRIPTION_API", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API + ) + pdf_extract_images = rag_config.get( + "PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES + ) + document_intelligence_endpoint = rag_config.get( + "DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT + ) + document_intelligence_key = rag_config.get( + "DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY + ) + mistral_ocr_api_key = rag_config.get( + "MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY + ) if form_data.content: # Update the content in the file @@ -1346,32 +1781,32 @@ def process_file( if file_path: file_path = Storage.get_file(file_path) loader = Loader( - engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, - DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY, - DATALAB_MARKER_LANGS=request.app.state.config.DATALAB_MARKER_LANGS, - DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE, - DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR, - DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE, - DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, - DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, - DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM, - DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, - EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, - DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL, + engine=content_extraction_engine, + DATALAB_MARKER_API_KEY=datalab_marker_api_key, + DATALAB_MARKER_LANGS=datalab_marker_langs, + DATALAB_MARKER_SKIP_CACHE=datalab_marker_skip_cache, + DATALAB_MARKER_FORCE_OCR=datalab_marker_force_ocr, + DATALAB_MARKER_PAGINATE=datalab_marker_paginate, + DATALAB_MARKER_STRIP_EXISTING_OCR=datalab_marker_strip_existing_ocr, + DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=datalab_marker_disable_image_extraction, + DATALAB_MARKER_USE_LLM=datalab_marker_use_llm, + DATALAB_MARKER_OUTPUT_FORMAT=datalab_marker_output_format, + EXTERNAL_DOCUMENT_LOADER_URL=external_document_loader_url, + EXTERNAL_DOCUMENT_LOADER_API_KEY=external_document_loader_api_key, + TIKA_SERVER_URL=tika_server_url, + DOCLING_SERVER_URL=docling_server_url, DOCLING_PARAMS={ - "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE, - "ocr_lang": request.app.state.config.DOCLING_OCR_LANG, - "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, - "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE, - "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL, - "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API, + "ocr_engine": docling_ocr_engine, + "ocr_lang": docling_ocr_lang, + "do_picture_description": docling_do_picture_description, + "picture_description_mode": picture_description_mode, + "picture_description_local": picture_description_local, + "picture_description_api": picture_description_api, }, - PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, - DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY, + PDF_EXTRACT_IMAGES=pdf_extract_images, + DOCUMENT_INTELLIGENCE_ENDPOINT=document_intelligence_endpoint, + DOCUMENT_INTELLIGENCE_KEY=document_intelligence_key, + MISTRAL_OCR_API_KEY=mistral_ocr_api_key, ) docs = loader.load( file.filename, file.meta.get("content_type"), file_path @@ -1414,7 +1849,7 @@ def process_file( hash = calculate_sha256_string(text_content) Files.update_file_hash_by_id(file.id, hash) - if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: + if not rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL): try: result = save_docs_to_vector_db( request, @@ -1427,6 +1862,7 @@ def process_file( }, add=(True if form_data.collection_name else False), user=user, + knowledge_id=form_data.knowledge_id ) if result: @@ -1477,7 +1913,7 @@ class ProcessTextForm(BaseModel): def process_text( request: Request, form_data: ProcessTextForm, - user=Depends(get_verified_user), + user=Depends(get_verified_user) ): collection_name = form_data.collection_name if collection_name is None: @@ -1975,7 +2411,24 @@ def query_doc_handler( user=Depends(get_verified_user), ): try: - if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: + # Try to get individual rag config for this collection + rag_config = {} + knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + + # Use config from rag_config if present, else fallback to global config + enable_hybrid = rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH) + embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL) + reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL) + top_k = form_data.k if form_data.k else rag_config.get("TOP_K", request.app.state.config.TOP_K) + top_k_reranker = form_data.k_reranker if form_data.k_reranker else rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER) + relevance_threshold = form_data.r if form_data.r else rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD) + hybrid_bm25_weight = getattr(form_data, "hybrid_bm25_weight", None) + if hybrid_bm25_weight is None: + hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT) + + if enable_hybrid: collection_results = {} collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( collection_name=form_data.collection_name @@ -1984,32 +2437,23 @@ def query_doc_handler( collection_name=form_data.collection_name, collection_result=collection_results[form_data.collection_name], query=form_data.query, - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model]( query, prefix=prefix, user=user ), - k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, - k_reranker=form_data.k_reranker - or request.app.state.config.TOP_K_RERANKER, - r=( - form_data.r - if form_data.r - else request.app.state.config.RELEVANCE_THRESHOLD - ), - hybrid_bm25_weight=( - form_data.hybrid_bm25_weight - if form_data.hybrid_bm25_weight - else request.app.state.config.HYBRID_BM25_WEIGHT - ), + k=top_k, + reranking_function=request.app.state.rf[reranking_model], + k_reranker=top_k_reranker, + r=relevance_threshold, + hybrid_bm25_weight=hybrid_bm25_weight, user=user, ) else: return query_doc( collection_name=form_data.collection_name, - query_embedding=request.app.state.EMBEDDING_FUNCTION( + query_embedding=request.app.state.EMBEDDING_FUNCTION[embedding_model]( form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user ), - k=form_data.k if form_data.k else request.app.state.config.TOP_K, + k=top_k, user=user, ) except Exception as e: @@ -2041,9 +2485,8 @@ def query_collection_handler( return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( - query, prefix=prefix, user=user - ), + user=user, + ef=request.app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=request.app.state.rf, k_reranker=form_data.k_reranker @@ -2058,14 +2501,16 @@ def query_collection_handler( if form_data.hybrid_bm25_weight else request.app.state.config.HYBRID_BM25_WEIGHT ), + embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL, + reranking_model=request.app.state.config.RAG_RERANKING_MODEL, ) else: return query_collection( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( - query, prefix=prefix, user=user - ), + user=user, + ef=request.app.state.EMBEDDING_FUNCTION, + embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL, k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) diff --git a/backend/open_webui/test/apps/webui/routers/test_individual_rag_config.py b/backend/open_webui/test/apps/webui/routers/test_individual_rag_config.py new file mode 100644 index 000000000..64a20deb5 --- /dev/null +++ b/backend/open_webui/test/apps/webui/routers/test_individual_rag_config.py @@ -0,0 +1,338 @@ +from test.util.abstract_integration_test import AbstractPostgresTest +from test.util.mock_user import mock_webui_user + +class TestRagConfig(AbstractPostgresTest): + BASE_PATH = "/api/v1/config" + + def setup_class(cls): + super().setup_class() + from open_webui.models.knowledge import Knowledges + + cls.knowledges = Knowledges + + def setup_method(self): + super().setup_method() + # Insert a knowledge base with default settings + self.knowledges.insert_new_knowledge( + id="1", + name="Default KB", + rag_config={ + "DEFAULT_RAG_SETTINGS": True, + "TEMPLATE": "default-template", + "TOP_K": 5, + }, + ) + # Insert a knowledge base with custom RAG config + self.knowledges.insert_new_knowledge( + id="2", + name="Custom KB", + rag_config={ + "DEFAULT_RAG_SETTINGS": False, + "TEMPLATE": "custom-template", + "TOP_K": 10, + "web": { + "ENABLE_WEB_SEARCH": True, + "WEB_SEARCH_ENGINE": "custom-engine" + } + }, + ) + + def test_get_rag_config_default(self): + # Should return default config for knowledge base with DEFAULT_RAG_SETTINGS True + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url(""), + json={"knowledge_id": "1"} + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] is True + assert data["RAG_TEMPLATE"] == "default-template" + assert data["TOP_K"] == 5 + assert data["DEFAULT_RAG_SETTINGS"] is True + + def test_get_rag_config_individual(self): + # Should return custom config for knowledge base with DEFAULT_RAG_SETTINGS False + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url(""), + json={"knowledge_id": "2"} + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] is True + assert data["RAG_TEMPLATE"] == "custom-template" + assert data["TOP_K"] == 10 + assert data["DEFAULT_RAG_SETTINGS"] is False + assert data["web"]["ENABLE_WEB_SEARCH"] is True + assert data["web"]["WEB_SEARCH_ENGINE"] == "custom-engine" + + def test_get_rag_config_unauthorized(self): + # Should return 401 if not authenticated + response = self.fast_api_client.post( + self.create_url(""), + json={"knowledge_id": "1"} + ) + assert response.status_code == 401 + + def test_update_rag_config_default(self): + # Should update the global config for knowledge base with DEFAULT_RAG_SETTINGS True + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/update"), + json={ + "knowledge_id": "1", + "RAG_TEMPLATE": "updated-template", + "TOP_K": 42, + "ENABLE_RAG_HYBRID_SEARCH": False, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] is True + assert data["RAG_TEMPLATE"] == "updated-template" + assert data["TOP_K"] == 42 + assert data["ENABLE_RAG_HYBRID_SEARCH"] is False + + def test_update_rag_config_individual(self): + # Should update the config for knowledge base with DEFAULT_RAG_SETTINGS False + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/update"), + json={ + "knowledge_id": "2", + "TEMPLATE": "individual-updated", + "TOP_K": 99, + "web": { + "ENABLE_WEB_SEARCH": False, + "WEB_SEARCH_ENGINE": "updated-engine" + } + } + ) + assert response.status_code == 200 + data = response.json() + assert data["TEMPLATE"] == "individual-updated" + assert data["TOP_K"] == 99 + assert data["web"]["ENABLE_WEB_SEARCH"] is False + assert data["web"]["WEB_SEARCH_ENGINE"] == "updated-engine" + + def test_update_reranking_model_and_states_individual(self): + # Simulate app state for reranking models + app = self.fast_api_client.app + app.state.rf = {} + app.state.config.LOADED_RERANKING_MODELS = {"": [], "external": []} + app.state.config.DOWNLOADED_RERANKING_MODELS = {"": [], "external": []} + + # Update individual config with new reranking model + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/update"), + json={ + "knowledge_id": "2", + "RAG_RERANKING_MODEL": "", + "RAG_RERANKING_ENGINE": "", + "RAG_EXTERNAL_RERANKER_URL": "", + "RAG_EXTERNAL_RERANKER_API_KEY": "", + "ENABLE_RAG_HYBRID_SEARCH": True, + } + ) + assert response.status_code == 200 + data = response.json() + # Model should be in loaded and downloaded models + loaded = app.state.config.LOADED_RERANKING_MODELS[""] + downloaded = app.state.config.DOWNLOADED_RERANKING_MODELS[""] + assert any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in loaded) + assert "BBAI/bge-reranker-v2-m3" in downloaded + assert "BBAI/bge-reranker-v2-m3" in app.state.rf + + def test_update_reranking_model_and_states_default(self): + # Simulate app state for reranking models + app = self.fast_api_client.app + app.state.rf = {} + app.state.config.LOADED_RERANKING_MODELS = {"": [], "external": []} + app.state.config.DOWNLOADED_RERANKING_MODELS = {"": [], "external": []} + + # Update default config with new reranking model + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/update"), + json={ + "knowledge_id": "1", + "RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3", + "RAG_RERANKING_ENGINE": "", + "RAG_EXTERNAL_RERANKER_URL": "", + "RAG_EXTERNAL_RERANKER_API_KEY": "", + "ENABLE_RAG_HYBRID_SEARCH": True, + } + ) + assert response.status_code == 200 + data = response.json() + loaded = app.state.config.LOADED_RERANKING_MODELS[""] + downloaded = app.state.config.DOWNLOADED_RERANKING_MODELS[""] + assert any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in loaded) + assert "BBAI/bge-reranker-v2-m3" in downloaded + assert "BBAI/bge-reranker-v2-m3" in app.state.rf + + def test_update_rag_config_unauthorized(self): + # Should return 401 if not authenticated + response = self.fast_api_client.post( + self.create_url("/update"), + json={"knowledge_id": "1", "RAG_TEMPLATE": "should-not-update"} + ) + assert response.status_code == 401 + + def test_reranking_model_freed_only_if_not_in_use_elsewhere(self): + """ + Test that the reranking model is only deleted from state if no other knowledge base is using it. + """ + app = self.fast_api_client.app + app.state.rf = {"rerank-model-shared": object()} + app.state.config.LOADED_RERANKING_MODELS = {"": [{"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3"}]} + app.state.config.DOWNLOADED_RERANKING_MODELS = {"": ["BBAI/bge-reranker-v2-m3"]} + + # Patch is_model_in_use_elsewhere to simulate model still in use + from unittest.mock import patch + + with patch("open_webui.models.knowledge.Knowledges.is_model_in_use_elsewhere", return_value=True): + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/update"), + json={ + "knowledge_id": "2", + "RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3", + "RAG_RERANKING_ENGINE": "", + "ENABLE_RAG_HYBRID_SEARCH": False, + } + ) + assert response.status_code == 200 + # Model should NOT be deleted from state + assert "rerank-model-shared" in app.state.rf + assert any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in app.state.config.LOADED_RERANKING_MODELS[""]) + + # Now simulate model NOT in use elsewhere + app.state.rf = {"rerank-model-shared": object()} + app.state.config.LOADED_RERANKING_MODELS = {"": [{"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3"}]} + app.state.config.DOWNLOADED_RERANKING_MODELS = {"": ["BBAI/bge-reranker-v2-m3"]} + + with patch("open_webui.models.knowledge.Knowledges.is_model_in_use_elsewhere", return_value=False): + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/update"), + json={ + "knowledge_id": "2", + "RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3", + "RAG_RERANKING_ENGINE": "", + "ENABLE_RAG_HYBRID_SEARCH": False, + } + ) + assert response.status_code == 200 + # Model should be deleted from state + assert "rerank-model-shared" not in app.state.rf + assert not any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in app.state.config.LOADED_RERANKING_MODELS[""]) + + def test_get_embedding_config_default(self): + # Should return default embedding config for knowledge base with DEFAULT_RAG_SETTINGS True + # First, add embedding config to the default KB + self.knowledges.update_rag_config_by_id( + id="1", + rag_config={ + "DEFAULT_RAG_SETTINGS": True, + "embedding_engine": "", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + "embedding_batch_size": 1, + "openai_config": {"url": "https://api.openai.com", "key": "default-key"}, + "ollama_config": {"url": "http://localhost:11434", "key": "ollama-key"}, + } + ) + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/embedding"), + json={"knowledge_id": "1"} + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] is True + assert data["embedding_engine"] == "" + assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2" + assert data["embedding_batch_size"] == 1 + assert data["openai_config"]["url"] == "https://api.openai.com" + assert data["openai_config"]["key"] == "default-key" + assert data["ollama_config"]["url"] == "http://localhost:11434" + assert data["ollama_config"]["key"] == "ollama-key" + + def test_get_embedding_config_individual(self): + # Should return custom embedding config for knowledge base with DEFAULT_RAG_SETTINGS False + self.knowledges.update_rag_config_by_id( + id="2", + rag_config={ + "DEFAULT_RAG_SETTINGS": False, + "embedding_engine": "", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + "embedding_batch_size": 2, + "openai_config": {"url": "https://custom.openai.com", "key": "custom-key"}, + "ollama_config": {"url": "http://custom-ollama:11434", "key": "custom-ollama-key"}, + } + ) + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/embedding"), + json={"knowledge_id": "2"} + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] is True + assert data["embedding_engine"] == "" + assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2" + assert data["embedding_batch_size"] == 2 + assert data["openai_config"]["url"] == "https://custom.openai.com" + assert data["openai_config"]["key"] == "custom-key" + assert data["ollama_config"]["url"] == "http://custom-ollama:11434" + assert data["ollama_config"]["key"] == "custom-ollama-key" + + def test_update_embedding_config_default(self): + # Should update the global embedding config for knowledge base with DEFAULT_RAG_SETTINGS True + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/embedding/update"), + json={ + "knowledge_id": "1", + "embedding_engine": "", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + "embedding_batch_size": 4, + "openai_config": {"url": "https://api.openai.com/v2", "key": "updated-key"}, + "ollama_config": {"url": "http://localhost:11434", "key": "ollama-key"}, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] is True + assert data["embedding_engine"] == "" + assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2" + assert data["embedding_batch_size"] == 4 + assert data["openai_config"]["url"] == "https://api.openai.com/v2" + assert data["openai_config"]["key"] == "updated-key" + + def test_update_embedding_config_individual(self): + # Should update the embedding config for knowledge base with DEFAULT_RAG_SETTINGS False + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/embedding/update"), + json={ + "knowledge_id": "2", + "embedding_engine": "", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + "embedding_batch_size": 8, + "openai_config": {"url": "https://custom.openai.com/v2", "key": "custom-key-2"}, + "ollama_config": {"url": "http://custom-ollama:11434/v2", "key": "custom-ollama-key-2"}, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] is True + assert data["embedding_engine"] == "" + assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2" + assert data["embedding_batch_size"] == 8 + assert data["openai_config"]["url"] == "https://custom.openai.com/v2" + assert data["openai_config"]["key"] == "custom-key-2" + assert data["ollama_config"]["url"] == "http://custom-ollama:11434/v2" + assert data["ollama_config"]["key"] == "custom-ollama-key-2" \ No newline at end of file diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index b1e69db26..d5c7d0b77 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -602,7 +602,7 @@ async def chat_image_generation_handler( async def chat_completion_files_handler( - request: Request, body: dict, user: UserModel + request: Request, body: dict, user: UserModel, model_knowledge ) -> tuple[dict, dict[str, list]]: sources = [] @@ -640,6 +640,21 @@ async def chat_completion_files_handler( queries = [get_last_user_message(body["messages"])] try: + # check if individual rag config is used + rag_config = {} + if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + rag_config = model_knowledge[0].get("rag_config") + + k=rag_config.get("TOP_K", request.app.state.config.TOP_K) + reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL) + reranking_function=request.app.state.rf[reranking_model] if reranking_model else None + k_reranker=rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER) + r=rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD) + hybrid_bm25_weight=rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT), + hybrid_search=rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH) + full_context=rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT) + embedding_model = rag_config.get("RAG_EMBEDDING_MODEL", request.app.state.config.RAG_EMBEDDING_MODEL) + # Offload get_sources_from_files to a separate thread loop = asyncio.get_running_loop() with ThreadPoolExecutor() as executor: @@ -649,16 +664,16 @@ async def chat_completion_files_handler( request=request, files=files, queries=queries, - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( - query, prefix=prefix, user=user - ), - k=request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, - k_reranker=request.app.state.config.TOP_K_RERANKER, - r=request.app.state.config.RELEVANCE_THRESHOLD, - hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, - hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - full_context=request.app.state.config.RAG_FULL_CONTEXT, + user=user, + ef=request.app.state.EMBEDDING_FUNCTION, + k=k, + reranking_function=reranking_function, + k_reranker=k_reranker, + r=r, + hybrid_bm25_weight=hybrid_bm25_weight, + hybrid_search=hybrid_search, + full_context=full_context, + embedding_model=embedding_model, ), ) except Exception as e: @@ -917,7 +932,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.exception(e) try: - form_data, flags = await chat_completion_files_handler(request, form_data, user) + form_data, flags = await chat_completion_files_handler(request, form_data, user, model_knowledge) sources.extend(flags.get("sources", [])) except Exception as e: log.exception(e) @@ -958,20 +973,24 @@ async def process_chat_payload(request, form_data, user, metadata, model): f"With a 0 relevancy threshold for RAG, the context cannot be empty" ) + # Adjusted RAG template step to use knowledge-base-specific configuration + rag_template_config = request.app.state.config.RAG_TEMPLATE + + if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True): + rag_template_config = model_knowledge[0].get("rag_config").get( + "RAG_TEMPLATE", request.app.state.config.RAG_TEMPLATE + ) + # Workaround for Ollama 2.0+ system prompt issue # TODO: replace with add_or_update_system_message if model.get("owned_by") == "ollama": form_data["messages"] = prepend_to_first_user_message_content( - rag_template( - request.app.state.config.RAG_TEMPLATE, context_string, prompt - ), + rag_template(rag_template_config, context_string, prompt), form_data["messages"], ) else: form_data["messages"] = add_or_update_system_message( - rag_template( - request.app.state.config.RAG_TEMPLATE, context_string, prompt - ), + rag_template(rag_template_config, context_string, prompt), form_data["messages"], ) diff --git a/src/lib/apis/files/index.ts b/src/lib/apis/files/index.ts index a58d7cb93..aefdcc949 100644 --- a/src/lib/apis/files/index.ts +++ b/src/lib/apis/files/index.ts @@ -1,12 +1,15 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const uploadFile = async (token: string, file: File, metadata?: object | null) => { +export const uploadFile = async (token: string, file: File, metadata?: object | null, knowledge_id?: string) => { const data = new FormData(); data.append('file', file); if (metadata) { data.append('metadata', JSON.stringify(metadata)); } + if (knowledge_id) { + data.append('knowledge_id', knowledge_id); + } let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/files/`, { diff --git a/src/lib/apis/knowledge/index.ts b/src/lib/apis/knowledge/index.ts index c01c986a2..d2381cb3d 100644 --- a/src/lib/apis/knowledge/index.ts +++ b/src/lib/apis/knowledge/index.ts @@ -4,7 +4,8 @@ export const createNewKnowledge = async ( token: string, name: string, description: string, - accessControl: null | object + accessControl: null | object, + rag_config: null | object ) => { let error = null; @@ -18,7 +19,8 @@ export const createNewKnowledge = async ( body: JSON.stringify({ name: name, description: description, - access_control: accessControl + access_control: accessControl, + rag_config: rag_config }) }) .then(async (res) => { @@ -137,6 +139,7 @@ type KnowledgeUpdateForm = { description?: string; data?: object; access_control?: null | object; + rag_config?: object; }; export const updateKnowledgeById = async (token: string, id: string, form: KnowledgeUpdateForm) => { @@ -153,7 +156,8 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl name: form?.name ? form.name : undefined, description: form?.description ? form.description : undefined, data: form?.data ? form.data : undefined, - access_control: form.access_control + access_control: form.access_control, + rag_config: form?.rag_config ? form.rag_config : undefined }) }) .then(async (res) => { @@ -373,3 +377,31 @@ export const reindexKnowledgeFiles = async (token: string) => { return res; }; + +export const reindexSpecificKnowledgeFiles = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/reindex/${id}`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; \ No newline at end of file diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index 6df927fec..dd0f5a162 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -1,14 +1,17 @@ import { RETRIEVAL_API_BASE_URL } from '$lib/constants'; -export const getRAGConfig = async (token: string) => { +export const getRAGConfig = async (token: string, collectionForm?: CollectionForm) => { let error = null; const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config`, { - method: 'GET', + method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` - } + }, + body: JSON.stringify( + collectionForm ? {collectionForm: collectionForm} : {} + ) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -57,6 +60,7 @@ type RAGConfigForm = { content_extraction?: ContentExtractConfigForm; web_loader_ssl_verification?: boolean; youtube?: YoutubeConfigForm; + knowledge_id?: string; }; export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => { @@ -152,15 +156,18 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings return res; }; -export const getEmbeddingConfig = async (token: string) => { +export const getEmbeddingConfig = async (token: string, collectionForm?: CollectionForm) => { let error = null; const res = await fetch(`${RETRIEVAL_API_BASE_URL}/embedding`, { - method: 'GET', + method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` - } + }, + body: JSON.stringify( + collectionForm ? {collectionForm: collectionForm} : {} + ) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -196,6 +203,7 @@ type EmbeddingModelUpdateForm = { embedding_engine: string; embedding_model: string; embedding_batch_size?: number; + knowledge_id?: string; }; export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { diff --git a/src/lib/components/workspace/Knowledge/CreateKnowledgeBase.svelte b/src/lib/components/workspace/Knowledge/CreateKnowledgeBase.svelte index e7c1248f5..094c5337d 100644 --- a/src/lib/components/workspace/Knowledge/CreateKnowledgeBase.svelte +++ b/src/lib/components/workspace/Knowledge/CreateKnowledgeBase.svelte @@ -1,114 +1,355 @@