mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	Fix: Normalze all database distances to score in [0, 1]
This commit is contained in:
		
							parent
							
								
									8aa6dade41
								
							
						
					
					
						commit
						94d9d3d590
					
				@ -175,7 +175,7 @@ def merge_get_results(get_results: list[dict]) -> dict:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_and_sort_query_results(
 | 
			
		||||
    query_results: list[dict], k: int, reverse: bool = False
 | 
			
		||||
    query_results: list[dict], k: int
 | 
			
		||||
) -> dict:
 | 
			
		||||
    # Initialize lists to store combined data
 | 
			
		||||
    combined = dict()  # To store documents with unique document hashes
 | 
			
		||||
@ -196,28 +196,18 @@ def merge_and_sort_query_results(
 | 
			
		||||
                    continue  # if doc is new, no further comparison is needed
 | 
			
		||||
 | 
			
		||||
                # if doc is alredy in, but new distance is better, update
 | 
			
		||||
                if not reverse and distance < combined[doc_hash][0]:
 | 
			
		||||
                    # Chroma uses unconventional cosine similarity, so we don't need to reverse the results
 | 
			
		||||
                    # https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
 | 
			
		||||
                    combined[doc_hash] = (distance, document, metadata)
 | 
			
		||||
                if reverse and distance > combined[doc_hash][0]:
 | 
			
		||||
                if distance > combined[doc_hash][0]:
 | 
			
		||||
                    combined[doc_hash] = (distance, document, metadata)
 | 
			
		||||
 | 
			
		||||
    combined = list(combined.values())
 | 
			
		||||
    # Sort the list based on distances
 | 
			
		||||
    combined.sort(key=lambda x: x[0], reverse=reverse)
 | 
			
		||||
    combined.sort(key=lambda x: x[0], reverse=True)
 | 
			
		||||
 | 
			
		||||
    # Slice to keep only the top k elements
 | 
			
		||||
    sorted_distances, sorted_documents, sorted_metadatas = (
 | 
			
		||||
        zip(*combined[:k]) if combined else ([], [], [])
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # if chromaDB, the distance is 0 (best) to 2 (worse)
 | 
			
		||||
    # re-order to -1 (worst) to 1 (best) for relevance score
 | 
			
		||||
    if not reverse:
 | 
			
		||||
        sorted_distances = tuple(-dist for dist in sorted_distances)
 | 
			
		||||
        sorted_distances = tuple(dist + 1 for dist in sorted_distances)
 | 
			
		||||
 | 
			
		||||
    # Create and return the output dictionary
 | 
			
		||||
    return {
 | 
			
		||||
        "distances": [list(sorted_distances)],
 | 
			
		||||
@ -267,12 +257,7 @@ def query_collection(
 | 
			
		||||
            else:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
    if VECTOR_DB == "chroma":
 | 
			
		||||
        # Chroma uses unconventional cosine similarity, so we don't need to reverse the results
 | 
			
		||||
        # https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
 | 
			
		||||
        return merge_and_sort_query_results(results, k=k, reverse=False)
 | 
			
		||||
    else:
 | 
			
		||||
        return merge_and_sort_query_results(results, k=k, reverse=True)
 | 
			
		||||
    return merge_and_sort_query_results(results, k=k)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def query_collection_with_hybrid_search(
 | 
			
		||||
@ -308,7 +293,7 @@ def query_collection_with_hybrid_search(
 | 
			
		||||
            "Hybrid search failed for all collections. Using Non hybrid search as fallback."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return merge_and_sort_query_results(results, k=k, reverse=True)
 | 
			
		||||
    return merge_and_sort_query_results(results, k=k)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_embedding_function(
 | 
			
		||||
 | 
			
		||||
@ -75,10 +75,16 @@ class ChromaClient:
 | 
			
		||||
                    n_results=limit,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
 | 
			
		||||
                # https://docs.trychroma.com/docs/collections/configure cosine equation
 | 
			
		||||
                distances: list = result["distances"][0]
 | 
			
		||||
                distances = [2 - dist for dist in distances]
 | 
			
		||||
                distances = [[dist/2 for dist in distances]]
 | 
			
		||||
 | 
			
		||||
                return SearchResult(
 | 
			
		||||
                    **{
 | 
			
		||||
                        "ids": result["ids"],
 | 
			
		||||
                        "distances": result["distances"],
 | 
			
		||||
                        "distances": distances,
 | 
			
		||||
                        "documents": result["documents"],
 | 
			
		||||
                        "metadatas": result["metadatas"],
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
@ -64,7 +64,10 @@ class MilvusClient:
 | 
			
		||||
 | 
			
		||||
            for item in match:
 | 
			
		||||
                _ids.append(item.get("id"))
 | 
			
		||||
                _distances.append(item.get("distance"))
 | 
			
		||||
                # normalize milvus score from [-1, 1] to [0, 1] range
 | 
			
		||||
                # https://milvus.io/docs/de/metric.md
 | 
			
		||||
                _dist = (item.get("distance") + 1.0)/2.0
 | 
			
		||||
                _distances.append(_dist)
 | 
			
		||||
                _documents.append(item.get("entity", {}).get("data", {}).get("text"))
 | 
			
		||||
                _metadatas.append(item.get("entity", {}).get("metadata"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -120,7 +120,7 @@ class OpenSearchClient:
 | 
			
		||||
                    "script_score": {
 | 
			
		||||
                        "query": {"match_all": {}},
 | 
			
		||||
                        "script": {
 | 
			
		||||
                            "source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
 | 
			
		||||
                            "source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
 | 
			
		||||
                            "params": {
 | 
			
		||||
                                "field": "vector",
 | 
			
		||||
                                "query_value": vectors[0],
 | 
			
		||||
 | 
			
		||||
@ -278,7 +278,9 @@ class PgvectorClient:
 | 
			
		||||
            for row in results:
 | 
			
		||||
                qid = int(row.qid)
 | 
			
		||||
                ids[qid].append(row.id)
 | 
			
		||||
                distances[qid].append(row.distance)
 | 
			
		||||
                # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
 | 
			
		||||
                # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
 | 
			
		||||
                distances[qid].append((2.0 - row.distance)/2.0)
 | 
			
		||||
                documents[qid].append(row.text)
 | 
			
		||||
                metadatas[qid].append(row.vmetadata)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -99,7 +99,8 @@ class QdrantClient:
 | 
			
		||||
            ids=get_result.ids,
 | 
			
		||||
            documents=get_result.documents,
 | 
			
		||||
            metadatas=get_result.metadatas,
 | 
			
		||||
            distances=[[point.score for point in query_response.points]],
 | 
			
		||||
            # qdrant distance is [-1, 1], normalize to [0, 1]
 | 
			
		||||
            distances=[[(point.score + 1.0)/2.0 for point in query_response.points]],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user