mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
enh: retrieval query generation
This commit is contained in:
@@ -177,35 +177,34 @@ def merge_and_sort_query_results(
|
||||
|
||||
def query_collection(
|
||||
collection_names: list[str],
|
||||
query: str,
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
k: int,
|
||||
) -> dict:
|
||||
|
||||
results = []
|
||||
query_embedding = embedding_function(query)
|
||||
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = query_doc(
|
||||
collection_name=collection_name,
|
||||
k=k,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
for query in queries:
|
||||
query_embedding = embedding_function(query)
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = query_doc(
|
||||
collection_name=collection_name,
|
||||
k=k,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
|
||||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
|
||||
def query_collection_with_hybrid_search(
|
||||
collection_names: list[str],
|
||||
query: str,
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
k: int,
|
||||
reranking_function,
|
||||
@@ -215,15 +214,16 @@ def query_collection_with_hybrid_search(
|
||||
error = False
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
result = query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
results.append(result)
|
||||
for query in queries:
|
||||
result = query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
"Error when querying the collection with " f"hybrid_search: {e}"
|
||||
@@ -309,15 +309,14 @@ def get_embedding_function(
|
||||
|
||||
def get_rag_context(
|
||||
files,
|
||||
messages,
|
||||
queries,
|
||||
embedding_function,
|
||||
k,
|
||||
reranking_function,
|
||||
r,
|
||||
hybrid_search,
|
||||
):
|
||||
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
|
||||
query = get_last_user_message(messages)
|
||||
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
@@ -359,7 +358,7 @@ def get_rag_context(
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
query=query,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
@@ -374,7 +373,7 @@ def get_rag_context(
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
query=query,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user