mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	Fix: Fixed issue that handles embedding functions of individual rag config accordingly in query doc related functions
This commit is contained in:
		
							parent
							
								
									4c19aaaa64
								
							
						
					
					
						commit
						5f43d42cfa
					
				@ -269,13 +269,15 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
 | 
			
		||||
def query_collection(
 | 
			
		||||
    collection_names: list[str],
 | 
			
		||||
    queries: list[str],
 | 
			
		||||
    embedding_function,
 | 
			
		||||
    user,
 | 
			
		||||
    ef,
 | 
			
		||||
    embedding_model,
 | 
			
		||||
    k: int,
 | 
			
		||||
) -> dict:
 | 
			
		||||
    results = []
 | 
			
		||||
    error = False
 | 
			
		||||
 | 
			
		||||
    def process_query_collection(collection_name, query_embedding):
 | 
			
		||||
    def process_query_collection(collection_name, query_embedding, k):
 | 
			
		||||
        try:
 | 
			
		||||
            if collection_name:
 | 
			
		||||
                result = query_doc(
 | 
			
		||||
@ -290,18 +292,30 @@ def query_collection(
 | 
			
		||||
            log.exception(f"Error when querying the collection: {e}")
 | 
			
		||||
            return None, e
 | 
			
		||||
 | 
			
		||||
    # Generate all query embeddings (in one call)
 | 
			
		||||
    query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
 | 
			
		||||
    log.debug(
 | 
			
		||||
        f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    from open_webui.models.knowledge import Knowledges
 | 
			
		||||
    with ThreadPoolExecutor() as executor:
 | 
			
		||||
        future_results = []
 | 
			
		||||
        for query_embedding in query_embeddings:
 | 
			
		||||
            for collection_name in collection_names:
 | 
			
		||||
        for collection_name in collection_names:
 | 
			
		||||
            rag_config = {}
 | 
			
		||||
            knowledge_base = Knowledges.get_knowledge_by_id(collection_name)
 | 
			
		||||
 | 
			
		||||
            if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
 | 
			
		||||
                rag_config = knowledge_base.rag_config
 | 
			
		||||
                embedding_model = rag_config.get("embedding_model", embedding_model)
 | 
			
		||||
                k = rag_config.get("TOP_K", k)
 | 
			
		||||
 | 
			
		||||
            embedding_function=lambda query, prefix: ef[embedding_model](
 | 
			
		||||
                query, prefix=prefix, user=user
 | 
			
		||||
            )
 | 
			
		||||
            # Generate embeddings for each query using the collection's embedding function
 | 
			
		||||
            query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
 | 
			
		||||
            for query_embedding in query_embeddings:
 | 
			
		||||
                result = executor.submit(
 | 
			
		||||
                    process_query_collection, collection_name, query_embedding
 | 
			
		||||
                    process_query_collection, collection_name, query_embedding, k
 | 
			
		||||
                )
 | 
			
		||||
                future_results.append(result)
 | 
			
		||||
        task_results = [future.result() for future in future_results]
 | 
			
		||||
@ -321,12 +335,14 @@ def query_collection(
 | 
			
		||||
def query_collection_with_hybrid_search(
 | 
			
		||||
    collection_names: list[str],
 | 
			
		||||
    queries: list[str],
 | 
			
		||||
    embedding_function,
 | 
			
		||||
    user,
 | 
			
		||||
    ef,
 | 
			
		||||
    k: int,
 | 
			
		||||
    reranking_function,
 | 
			
		||||
    k_reranker: int,
 | 
			
		||||
    r: float,
 | 
			
		||||
    hybrid_bm25_weight: float,
 | 
			
		||||
    embedding_model: str,
 | 
			
		||||
) -> dict:
 | 
			
		||||
    results = []
 | 
			
		||||
    error = False
 | 
			
		||||
@ -351,13 +367,32 @@ def query_collection_with_hybrid_search(
 | 
			
		||||
 | 
			
		||||
    def process_query(collection_name, query):
 | 
			
		||||
        try:
 | 
			
		||||
            from open_webui.models.knowledge import Knowledges
 | 
			
		||||
 | 
			
		||||
            # Use Knowledges to get per-collection RAG config
 | 
			
		||||
            knowledge_base = Knowledges.get_knowledge_by_id(collection_name)
 | 
			
		||||
 | 
			
		||||
            if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
 | 
			
		||||
                rag_config = knowledge_base.rag_config
 | 
			
		||||
                # Use config from rag_config if present, else fallback to global config
 | 
			
		||||
                embedding_model = rag_config.get("embedding_model", embedding_model)
 | 
			
		||||
                reranking_model = rag_config.get("reranking_function", reranking_model)
 | 
			
		||||
                k = rag_config.get("TOP_K", k)
 | 
			
		||||
                k_reranker = rag_config.get("TOP_K_RERANKER", k_reranker)
 | 
			
		||||
                r = rag_config.get("RELEVANCE_THRESHOLD", r)
 | 
			
		||||
                hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", hybrid_bm25_weight)
 | 
			
		||||
            
 | 
			
		||||
            embedding_function=lambda query, prefix: ef[embedding_model](
 | 
			
		||||
                query, prefix=prefix, user=user
 | 
			
		||||
            ),
 | 
			
		||||
            
 | 
			
		||||
            result = query_doc_with_hybrid_search(
 | 
			
		||||
                collection_name=collection_name,
 | 
			
		||||
                collection_result=collection_results[collection_name],
 | 
			
		||||
                query=query,
 | 
			
		||||
                embedding_function=embedding_function,
 | 
			
		||||
                k=k,
 | 
			
		||||
                reranking_function=reranking_function,
 | 
			
		||||
                reranking_function=reranking_function[reranking_model],
 | 
			
		||||
                k_reranker=k_reranker,
 | 
			
		||||
                r=r,
 | 
			
		||||
                hybrid_bm25_weight=hybrid_bm25_weight,
 | 
			
		||||
@ -445,7 +480,8 @@ def get_sources_from_files(
 | 
			
		||||
    request,
 | 
			
		||||
    files,
 | 
			
		||||
    queries,
 | 
			
		||||
    embedding_function,
 | 
			
		||||
    user,
 | 
			
		||||
    ef,
 | 
			
		||||
    k,
 | 
			
		||||
    reranking_function,
 | 
			
		||||
    k_reranker,
 | 
			
		||||
@ -453,9 +489,10 @@ def get_sources_from_files(
 | 
			
		||||
    hybrid_bm25_weight,
 | 
			
		||||
    hybrid_search,
 | 
			
		||||
    full_context=False,
 | 
			
		||||
    embedding_model=None
 | 
			
		||||
):
 | 
			
		||||
    log.debug(
 | 
			
		||||
        f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
 | 
			
		||||
        f"files: {files} {queries} {ef[embedding_model]} {reranking_function} {full_context}"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    extracted_collections = []
 | 
			
		||||
@ -563,12 +600,14 @@ def get_sources_from_files(
 | 
			
		||||
                                context = query_collection_with_hybrid_search(
 | 
			
		||||
                                    collection_names=collection_names,
 | 
			
		||||
                                    queries=queries,
 | 
			
		||||
                                    embedding_function=embedding_function,
 | 
			
		||||
                                    user=user,
 | 
			
		||||
                                    ef=ef,
 | 
			
		||||
                                    k=k,
 | 
			
		||||
                                    reranking_function=reranking_function,
 | 
			
		||||
                                    k_reranker=k_reranker,
 | 
			
		||||
                                    r=r,
 | 
			
		||||
                                    hybrid_bm25_weight=hybrid_bm25_weight,
 | 
			
		||||
                                    embedding_model=embedding_model,
 | 
			
		||||
                                )
 | 
			
		||||
                            except Exception as e:
 | 
			
		||||
                                log.debug(
 | 
			
		||||
@ -580,8 +619,10 @@ def get_sources_from_files(
 | 
			
		||||
                            context = query_collection(
 | 
			
		||||
                                collection_names=collection_names,
 | 
			
		||||
                                queries=queries,
 | 
			
		||||
                                embedding_function=embedding_function,
 | 
			
		||||
                                user=user,
 | 
			
		||||
                                ef=ef,
 | 
			
		||||
                                k=k,
 | 
			
		||||
                                embedding_model=embedding_model
 | 
			
		||||
                            )
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    log.exception(e)
 | 
			
		||||
 | 
			
		||||
@ -644,7 +644,8 @@ async def chat_completion_files_handler(
 | 
			
		||||
            reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
 | 
			
		||||
            reranking_function=request.app.state.rf[reranking_model] if reranking_model else None
 | 
			
		||||
            k_reranker=rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
 | 
			
		||||
            r=rag_config.get("RELEVANCE THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
 | 
			
		||||
            r=rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
 | 
			
		||||
            hybrid_bm25_weight=rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT),
 | 
			
		||||
            hybrid_search=rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
 | 
			
		||||
            full_context=rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT)
 | 
			
		||||
            embedding_model = rag_config.get("RAG_EMBEDDING_MODEL", request.app.state.config.RAG_EMBEDDING_MODEL)
 | 
			
		||||
@ -658,16 +659,16 @@ async def chat_completion_files_handler(
 | 
			
		||||
                        request=request,
 | 
			
		||||
                        files=files,
 | 
			
		||||
                        queries=queries,
 | 
			
		||||
                        embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model](
 | 
			
		||||
                            query, prefix=prefix, user=user
 | 
			
		||||
                        ),
 | 
			
		||||
                        user=user,
 | 
			
		||||
                        ef=request.app.state.EMBEDDING_FUNCTION,
 | 
			
		||||
                        k=k,
 | 
			
		||||
                        reranking_function=reranking_function,
 | 
			
		||||
                        k_reranker=k_reranker,
 | 
			
		||||
                        r=r,
 | 
			
		||||
                        hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
 | 
			
		||||
                        hybrid_bm25_weight=hybrid_bm25_weight,
 | 
			
		||||
                        hybrid_search=hybrid_search,
 | 
			
		||||
                        full_context=full_context,
 | 
			
		||||
                        embedding_model=embedding_model,
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user