mirror of
https://github.com/open-webui/open-webui
synced 2025-04-07 22:25:05 +00:00
Fix: Normalze all database distances to score in [0, 1]
This commit is contained in:
parent
8aa6dade41
commit
94d9d3d590
@ -175,7 +175,7 @@ def merge_get_results(get_results: list[dict]) -> dict:
|
||||
|
||||
|
||||
def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
query_results: list[dict], k: int
|
||||
) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined = dict() # To store documents with unique document hashes
|
||||
@ -196,28 +196,18 @@ def merge_and_sort_query_results(
|
||||
continue # if doc is new, no further comparison is needed
|
||||
|
||||
# if doc is alredy in, but new distance is better, update
|
||||
if not reverse and distance < combined[doc_hash][0]:
|
||||
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
|
||||
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
|
||||
combined[doc_hash] = (distance, document, metadata)
|
||||
if reverse and distance > combined[doc_hash][0]:
|
||||
if distance > combined[doc_hash][0]:
|
||||
combined[doc_hash] = (distance, document, metadata)
|
||||
|
||||
combined = list(combined.values())
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||
combined.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Slice to keep only the top k elements
|
||||
sorted_distances, sorted_documents, sorted_metadatas = (
|
||||
zip(*combined[:k]) if combined else ([], [], [])
|
||||
)
|
||||
|
||||
# if chromaDB, the distance is 0 (best) to 2 (worse)
|
||||
# re-order to -1 (worst) to 1 (best) for relevance score
|
||||
if not reverse:
|
||||
sorted_distances = tuple(-dist for dist in sorted_distances)
|
||||
sorted_distances = tuple(dist + 1 for dist in sorted_distances)
|
||||
|
||||
# Create and return the output dictionary
|
||||
return {
|
||||
"distances": [list(sorted_distances)],
|
||||
@ -267,12 +257,7 @@ def query_collection(
|
||||
else:
|
||||
pass
|
||||
|
||||
if VECTOR_DB == "chroma":
|
||||
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
|
||||
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
|
||||
return merge_and_sort_query_results(results, k=k, reverse=False)
|
||||
else:
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
|
||||
def query_collection_with_hybrid_search(
|
||||
@ -308,7 +293,7 @@ def query_collection_with_hybrid_search(
|
||||
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
|
||||
)
|
||||
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
|
@ -75,10 +75,16 @@ class ChromaClient:
|
||||
n_results=limit,
|
||||
)
|
||||
|
||||
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
|
||||
# https://docs.trychroma.com/docs/collections/configure cosine equation
|
||||
distances: list = result["distances"][0]
|
||||
distances = [2 - dist for dist in distances]
|
||||
distances = [[dist/2 for dist in distances]]
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result["ids"],
|
||||
"distances": result["distances"],
|
||||
"distances": distances,
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
}
|
||||
|
@ -64,7 +64,10 @@ class MilvusClient:
|
||||
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_distances.append(item.get("distance"))
|
||||
# normalize milvus score from [-1, 1] to [0, 1] range
|
||||
# https://milvus.io/docs/de/metric.md
|
||||
_dist = (item.get("distance") + 1.0)/2.0
|
||||
_distances.append(_dist)
|
||||
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("entity", {}).get("metadata"))
|
||||
|
||||
|
@ -120,7 +120,7 @@ class OpenSearchClient:
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
|
||||
"source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
|
||||
"params": {
|
||||
"field": "vector",
|
||||
"query_value": vectors[0],
|
||||
|
@ -278,7 +278,9 @@ class PgvectorClient:
|
||||
for row in results:
|
||||
qid = int(row.qid)
|
||||
ids[qid].append(row.id)
|
||||
distances[qid].append(row.distance)
|
||||
# normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
|
||||
# https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
|
||||
distances[qid].append((2.0 - row.distance)/2.0)
|
||||
documents[qid].append(row.text)
|
||||
metadatas[qid].append(row.vmetadata)
|
||||
|
||||
|
@ -99,7 +99,8 @@ class QdrantClient:
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=[[point.score for point in query_response.points]],
|
||||
# qdrant distance is [-1, 1], normalize to [0, 1]
|
||||
distances=[[(point.score + 1.0)/2.0 for point in query_response.points]],
|
||||
)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
|
Loading…
Reference in New Issue
Block a user