diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 683f42819..a48bbbfc4 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -269,13 +269,15 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict: def query_collection( collection_names: list[str], queries: list[str], - embedding_function, + user, + ef, + embedding_model, k: int, ) -> dict: results = [] error = False - def process_query_collection(collection_name, query_embedding): + def process_query_collection(collection_name, query_embedding, k): try: if collection_name: result = query_doc( @@ -290,18 +292,30 @@ def query_collection( log.exception(f"Error when querying the collection: {e}") return None, e - # Generate all query embeddings (in one call) - query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) log.debug( f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" ) + from open_webui.models.knowledge import Knowledges with ThreadPoolExecutor() as executor: future_results = [] - for query_embedding in query_embeddings: - for collection_name in collection_names: + for collection_name in collection_names: + rag_config = {} + knowledge_base = Knowledges.get_knowledge_by_id(collection_name) + + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + embedding_model = rag_config.get("embedding_model", embedding_model) + k = rag_config.get("TOP_K", k) + + embedding_function=lambda query, prefix: ef[embedding_model]( + query, prefix=prefix, user=user + ) + # Generate embeddings for each query using the collection's embedding function + query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) + for query_embedding in query_embeddings: result = executor.submit( - process_query_collection, collection_name, query_embedding + process_query_collection, collection_name, query_embedding, k ) future_results.append(result) task_results = [future.result() for future in future_results] @@ -321,12 +335,14 @@ def query_collection( def query_collection_with_hybrid_search( collection_names: list[str], queries: list[str], - embedding_function, + user, + ef, k: int, reranking_function, k_reranker: int, r: float, hybrid_bm25_weight: float, + embedding_model: str, ) -> dict: results = [] error = False @@ -351,13 +367,32 @@ def query_collection_with_hybrid_search( def process_query(collection_name, query): try: + from open_webui.models.knowledge import Knowledges + + # Use Knowledges to get per-collection RAG config + knowledge_base = Knowledges.get_knowledge_by_id(collection_name) + + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): + rag_config = knowledge_base.rag_config + # Use config from rag_config if present, else fallback to global config + embedding_model = rag_config.get("embedding_model", embedding_model) + reranking_model = rag_config.get("reranking_function", reranking_model) + k = rag_config.get("TOP_K", k) + k_reranker = rag_config.get("TOP_K_RERANKER", k_reranker) + r = rag_config.get("RELEVANCE_THRESHOLD", r) + hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", hybrid_bm25_weight) + + embedding_function=lambda query, prefix: ef[embedding_model]( + query, prefix=prefix, user=user + ), + result = query_doc_with_hybrid_search( collection_name=collection_name, collection_result=collection_results[collection_name], query=query, embedding_function=embedding_function, k=k, - reranking_function=reranking_function, + reranking_function=reranking_function[reranking_model], k_reranker=k_reranker, r=r, hybrid_bm25_weight=hybrid_bm25_weight, @@ -445,7 +480,8 @@ def get_sources_from_files( request, files, queries, - embedding_function, + user, + ef, k, reranking_function, k_reranker, @@ -453,9 +489,10 @@ def get_sources_from_files( hybrid_bm25_weight, hybrid_search, full_context=False, + embedding_model=None ): log.debug( - f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}" + f"files: {files} {queries} {ef[embedding_model]} {reranking_function} {full_context}" ) extracted_collections = [] @@ -563,12 +600,14 @@ def get_sources_from_files( context = query_collection_with_hybrid_search( collection_names=collection_names, queries=queries, - embedding_function=embedding_function, + user=user, + ef=ef, k=k, reranking_function=reranking_function, k_reranker=k_reranker, r=r, hybrid_bm25_weight=hybrid_bm25_weight, + embedding_model=embedding_model, ) except Exception as e: log.debug( @@ -580,8 +619,10 @@ def get_sources_from_files( context = query_collection( collection_names=collection_names, queries=queries, - embedding_function=embedding_function, + user=user, + ef=ef, k=k, + embedding_model=embedding_model ) except Exception as e: log.exception(e) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 775a40f36..5f482ce96 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -644,7 +644,8 @@ async def chat_completion_files_handler( reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL) reranking_function=request.app.state.rf[reranking_model] if reranking_model else None k_reranker=rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER) - r=rag_config.get("RELEVANCE THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD) + r=rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD) + hybrid_bm25_weight=rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT), hybrid_search=rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH) full_context=rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT) embedding_model = rag_config.get("RAG_EMBEDDING_MODEL", request.app.state.config.RAG_EMBEDDING_MODEL) @@ -658,16 +659,16 @@ async def chat_completion_files_handler( request=request, files=files, queries=queries, - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model]( - query, prefix=prefix, user=user - ), + user=user, + ef=request.app.state.EMBEDDING_FUNCTION, k=k, reranking_function=reranking_function, k_reranker=k_reranker, r=r, - hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, + hybrid_bm25_weight=hybrid_bm25_weight, hybrid_search=hybrid_search, full_context=full_context, + embedding_model=embedding_model, ), ) except Exception as e: