Fix: Normalze all database distances to score in [0, 1]

This commit is contained in:
Marko Henning 2025-03-25 16:46:14 +01:00
parent 8aa6dade41
commit 94d9d3d590
6 changed files with 22 additions and 25 deletions

View File

@ -175,7 +175,7 @@ def merge_get_results(get_results: list[dict]) -> dict:
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
query_results: list[dict], k: int
) -> dict:
# Initialize lists to store combined data
combined = dict() # To store documents with unique document hashes
@ -196,28 +196,18 @@ def merge_and_sort_query_results(
continue # if doc is new, no further comparison is needed
# if doc is alredy in, but new distance is better, update
if not reverse and distance < combined[doc_hash][0]:
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
combined[doc_hash] = (distance, document, metadata)
if reverse and distance > combined[doc_hash][0]:
if distance > combined[doc_hash][0]:
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
combined.sort(key=lambda x: x[0], reverse=True)
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
# if chromaDB, the distance is 0 (best) to 2 (worse)
# re-order to -1 (worst) to 1 (best) for relevance score
if not reverse:
sorted_distances = tuple(-dist for dist in sorted_distances)
sorted_distances = tuple(dist + 1 for dist in sorted_distances)
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
@ -267,12 +257,7 @@ def query_collection(
else:
pass
if VECTOR_DB == "chroma":
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
return merge_and_sort_query_results(results, k=k, reverse=False)
else:
return merge_and_sort_query_results(results, k=k, reverse=True)
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
@ -308,7 +293,7 @@ def query_collection_with_hybrid_search(
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
)
return merge_and_sort_query_results(results, k=k, reverse=True)
return merge_and_sort_query_results(results, k=k)
def get_embedding_function(

View File

@ -75,10 +75,16 @@ class ChromaClient:
n_results=limit,
)
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
# https://docs.trychroma.com/docs/collections/configure cosine equation
distances: list = result["distances"][0]
distances = [2 - dist for dist in distances]
distances = [[dist/2 for dist in distances]]
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"distances": distances,
"documents": result["documents"],
"metadatas": result["metadatas"],
}

View File

@ -64,7 +64,10 @@ class MilvusClient:
for item in match:
_ids.append(item.get("id"))
_distances.append(item.get("distance"))
# normalize milvus score from [-1, 1] to [0, 1] range
# https://milvus.io/docs/de/metric.md
_dist = (item.get("distance") + 1.0)/2.0
_distances.append(_dist)
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata"))

View File

@ -120,7 +120,7 @@ class OpenSearchClient:
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
"source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
"params": {
"field": "vector",
"query_value": vectors[0],

View File

@ -278,7 +278,9 @@ class PgvectorClient:
for row in results:
qid = int(row.qid)
ids[qid].append(row.id)
distances[qid].append(row.distance)
# normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
# https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
distances[qid].append((2.0 - row.distance)/2.0)
documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata)

View File

@ -99,7 +99,8 @@ class QdrantClient:
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]],
# qdrant distance is [-1, 1], normalize to [0, 1]
distances=[[(point.score + 1.0)/2.0 for point in query_response.points]],
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):