mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
Fix: Fixed issue that handles embedding functions of individual rag config accordingly in query doc related functions
This commit is contained in:
parent
4c19aaaa64
commit
5f43d42cfa
@ -269,13 +269,15 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||
def query_collection(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
user,
|
||||
ef,
|
||||
embedding_model,
|
||||
k: int,
|
||||
) -> dict:
|
||||
results = []
|
||||
error = False
|
||||
|
||||
def process_query_collection(collection_name, query_embedding):
|
||||
def process_query_collection(collection_name, query_embedding, k):
|
||||
try:
|
||||
if collection_name:
|
||||
result = query_doc(
|
||||
@ -290,18 +292,30 @@ def query_collection(
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
return None, e
|
||||
|
||||
# Generate all query embeddings (in one call)
|
||||
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||
log.debug(
|
||||
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
||||
)
|
||||
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_results = []
|
||||
for query_embedding in query_embeddings:
|
||||
for collection_name in collection_names:
|
||||
for collection_name in collection_names:
|
||||
rag_config = {}
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(collection_name)
|
||||
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
embedding_model = rag_config.get("embedding_model", embedding_model)
|
||||
k = rag_config.get("TOP_K", k)
|
||||
|
||||
embedding_function=lambda query, prefix: ef[embedding_model](
|
||||
query, prefix=prefix, user=user
|
||||
)
|
||||
# Generate embeddings for each query using the collection's embedding function
|
||||
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||
for query_embedding in query_embeddings:
|
||||
result = executor.submit(
|
||||
process_query_collection, collection_name, query_embedding
|
||||
process_query_collection, collection_name, query_embedding, k
|
||||
)
|
||||
future_results.append(result)
|
||||
task_results = [future.result() for future in future_results]
|
||||
@ -321,12 +335,14 @@ def query_collection(
|
||||
def query_collection_with_hybrid_search(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
user,
|
||||
ef,
|
||||
k: int,
|
||||
reranking_function,
|
||||
k_reranker: int,
|
||||
r: float,
|
||||
hybrid_bm25_weight: float,
|
||||
embedding_model: str,
|
||||
) -> dict:
|
||||
results = []
|
||||
error = False
|
||||
@ -351,13 +367,32 @@ def query_collection_with_hybrid_search(
|
||||
|
||||
def process_query(collection_name, query):
|
||||
try:
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
|
||||
# Use Knowledges to get per-collection RAG config
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(collection_name)
|
||||
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
# Use config from rag_config if present, else fallback to global config
|
||||
embedding_model = rag_config.get("embedding_model", embedding_model)
|
||||
reranking_model = rag_config.get("reranking_function", reranking_model)
|
||||
k = rag_config.get("TOP_K", k)
|
||||
k_reranker = rag_config.get("TOP_K_RERANKER", k_reranker)
|
||||
r = rag_config.get("RELEVANCE_THRESHOLD", r)
|
||||
hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", hybrid_bm25_weight)
|
||||
|
||||
embedding_function=lambda query, prefix: ef[embedding_model](
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
|
||||
result = query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
collection_result=collection_results[collection_name],
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
reranking_function=reranking_function[reranking_model],
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
@ -445,7 +480,8 @@ def get_sources_from_files(
|
||||
request,
|
||||
files,
|
||||
queries,
|
||||
embedding_function,
|
||||
user,
|
||||
ef,
|
||||
k,
|
||||
reranking_function,
|
||||
k_reranker,
|
||||
@ -453,9 +489,10 @@ def get_sources_from_files(
|
||||
hybrid_bm25_weight,
|
||||
hybrid_search,
|
||||
full_context=False,
|
||||
embedding_model=None
|
||||
):
|
||||
log.debug(
|
||||
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
|
||||
f"files: {files} {queries} {ef[embedding_model]} {reranking_function} {full_context}"
|
||||
)
|
||||
|
||||
extracted_collections = []
|
||||
@ -563,12 +600,14 @@ def get_sources_from_files(
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
user=user,
|
||||
ef=ef,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
@ -580,8 +619,10 @@ def get_sources_from_files(
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
user=user,
|
||||
ef=ef,
|
||||
k=k,
|
||||
embedding_model=embedding_model
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
@ -644,7 +644,8 @@ async def chat_completion_files_handler(
|
||||
reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
|
||||
reranking_function=request.app.state.rf[reranking_model] if reranking_model else None
|
||||
k_reranker=rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
|
||||
r=rag_config.get("RELEVANCE THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
|
||||
r=rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
|
||||
hybrid_bm25_weight=rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT),
|
||||
hybrid_search=rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
|
||||
full_context=rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT)
|
||||
embedding_model = rag_config.get("RAG_EMBEDDING_MODEL", request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
@ -658,16 +659,16 @@ async def chat_completion_files_handler(
|
||||
request=request,
|
||||
files=files,
|
||||
queries=queries,
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model](
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
user=user,
|
||||
ef=request.app.state.EMBEDDING_FUNCTION,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
hybrid_search=hybrid_search,
|
||||
full_context=full_context,
|
||||
embedding_model=embedding_model,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
|
Loading…
Reference in New Issue
Block a user