From 94d9d3d59088bd45664d89f7ec9ec033e2bdbc17 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Tue, 25 Mar 2025 16:46:14 +0100 Subject: [PATCH] Fix: Normalze all database distances to score in [0, 1] --- backend/open_webui/retrieval/utils.py | 25 ++++--------------- .../open_webui/retrieval/vector/dbs/chroma.py | 8 +++++- .../open_webui/retrieval/vector/dbs/milvus.py | 5 +++- .../retrieval/vector/dbs/opensearch.py | 2 +- .../retrieval/vector/dbs/pgvector.py | 4 ++- .../open_webui/retrieval/vector/dbs/qdrant.py | 3 ++- 6 files changed, 22 insertions(+), 25 deletions(-) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index b05057b28..a3974f15b 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -175,7 +175,7 @@ def merge_get_results(get_results: list[dict]) -> dict: def merge_and_sort_query_results( - query_results: list[dict], k: int, reverse: bool = False + query_results: list[dict], k: int ) -> dict: # Initialize lists to store combined data combined = dict() # To store documents with unique document hashes @@ -196,28 +196,18 @@ def merge_and_sort_query_results( continue # if doc is new, no further comparison is needed # if doc is alredy in, but new distance is better, update - if not reverse and distance < combined[doc_hash][0]: - # Chroma uses unconventional cosine similarity, so we don't need to reverse the results - # https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections - combined[doc_hash] = (distance, document, metadata) - if reverse and distance > combined[doc_hash][0]: + if distance > combined[doc_hash][0]: combined[doc_hash] = (distance, document, metadata) combined = list(combined.values()) # Sort the list based on distances - combined.sort(key=lambda x: x[0], reverse=reverse) + combined.sort(key=lambda x: x[0], reverse=True) # Slice to keep only the top k elements sorted_distances, sorted_documents, sorted_metadatas = ( zip(*combined[:k]) if combined else ([], [], []) ) - # if chromaDB, the distance is 0 (best) to 2 (worse) - # re-order to -1 (worst) to 1 (best) for relevance score - if not reverse: - sorted_distances = tuple(-dist for dist in sorted_distances) - sorted_distances = tuple(dist + 1 for dist in sorted_distances) - # Create and return the output dictionary return { "distances": [list(sorted_distances)], @@ -267,12 +257,7 @@ def query_collection( else: pass - if VECTOR_DB == "chroma": - # Chroma uses unconventional cosine similarity, so we don't need to reverse the results - # https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections - return merge_and_sort_query_results(results, k=k, reverse=False) - else: - return merge_and_sort_query_results(results, k=k, reverse=True) + return merge_and_sort_query_results(results, k=k) def query_collection_with_hybrid_search( @@ -308,7 +293,7 @@ def query_collection_with_hybrid_search( "Hybrid search failed for all collections. Using Non hybrid search as fallback." ) - return merge_and_sort_query_results(results, k=k, reverse=True) + return merge_and_sort_query_results(results, k=k) def get_embedding_function( diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index 006ee2076..3543cd545 100755 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -75,10 +75,16 @@ class ChromaClient: n_results=limit, ) + # chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1 + # https://docs.trychroma.com/docs/collections/configure cosine equation + distances: list = result["distances"][0] + distances = [2 - dist for dist in distances] + distances = [[dist/2 for dist in distances]] + return SearchResult( **{ "ids": result["ids"], - "distances": result["distances"], + "distances": distances, "documents": result["documents"], "metadatas": result["metadatas"], } diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index ad05f9422..4d0da57ac 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -64,7 +64,10 @@ class MilvusClient: for item in match: _ids.append(item.get("id")) - _distances.append(item.get("distance")) + # normalize milvus score from [-1, 1] to [0, 1] range + # https://milvus.io/docs/de/metric.md + _dist = (item.get("distance") + 1.0)/2.0 + _distances.append(_dist) _documents.append(item.get("entity", {}).get("data", {}).get("text")) _metadatas.append(item.get("entity", {}).get("metadata")) diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 99567c84e..432bcef41 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -120,7 +120,7 @@ class OpenSearchClient: "script_score": { "query": {"match_all": {}}, "script": { - "source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0", + "source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0", "params": { "field": "vector", "query_value": vectors[0], diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index eab02232f..0ddf48d1e 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -278,7 +278,9 @@ class PgvectorClient: for row in results: qid = int(row.qid) ids[qid].append(row.id) - distances[qid].append(row.distance) + # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range + # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying + distances[qid].append((2.0 - row.distance)/2.0) documents[qid].append(row.text) metadatas[qid].append(row.vmetadata) diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index 28f0b3779..070bf3de5 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -99,7 +99,8 @@ class QdrantClient: ids=get_result.ids, documents=get_result.documents, metadatas=get_result.metadatas, - distances=[[point.score for point in query_response.points]], + # qdrant distance is [-1, 1], normalize to [0, 1] + distances=[[(point.score + 1.0)/2.0 for point in query_response.points]], ) def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):