mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	refac
This commit is contained in:
		
							parent
							
								
									b943b7d337
								
							
						
					
					
						commit
						939bfd153e
					
				@ -48,9 +48,9 @@ class VectorSearchRetriever(BaseRetriever):
 | 
			
		||||
            limit=self.top_k,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ids = result["ids"][0]
 | 
			
		||||
        metadatas = result["metadatas"][0]
 | 
			
		||||
        documents = result["documents"][0]
 | 
			
		||||
        ids = result.ids[0]
 | 
			
		||||
        metadatas = result.metadatas[0]
 | 
			
		||||
        documents = result.documents[0]
 | 
			
		||||
 | 
			
		||||
        results = []
 | 
			
		||||
        for idx in range(len(ids)):
 | 
			
		||||
@ -194,7 +194,7 @@ def query_collection(
 | 
			
		||||
                    k=k,
 | 
			
		||||
                    embedding_function=embedding_function,
 | 
			
		||||
                )
 | 
			
		||||
                results.append(result)
 | 
			
		||||
                results.append(result.model_dump())
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                log.exception(f"Error when querying the collection: {e}")
 | 
			
		||||
        else:
 | 
			
		||||
@ -212,7 +212,7 @@ def query_collection_with_hybrid_search(
 | 
			
		||||
    r: float,
 | 
			
		||||
) -> dict:
 | 
			
		||||
    results = []
 | 
			
		||||
    failed = 0
 | 
			
		||||
    error = False
 | 
			
		||||
    for collection_name in collection_names:
 | 
			
		||||
        try:
 | 
			
		||||
            result = query_doc_with_hybrid_search(
 | 
			
		||||
@ -228,12 +228,14 @@ def query_collection_with_hybrid_search(
 | 
			
		||||
            log.exception(
 | 
			
		||||
                "Error when querying the collection with " f"hybrid_search: {e}"
 | 
			
		||||
            )
 | 
			
		||||
            failed += 1
 | 
			
		||||
    if failed == len(collection_names):
 | 
			
		||||
            error = True
 | 
			
		||||
 | 
			
		||||
    if error:
 | 
			
		||||
        raise Exception(
 | 
			
		||||
            "Hybrid search failed for all collections. Using "
 | 
			
		||||
            "Non hybrid search as fallback."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return merge_and_sort_query_results(results, k=k, reverse=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
 | 
			
		||||
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
    CHROMA_DATA_PATH,
 | 
			
		||||
    CHROMA_HTTP_HOST,
 | 
			
		||||
@ -47,7 +47,7 @@ class ChromaClient:
 | 
			
		||||
 | 
			
		||||
    def search(
 | 
			
		||||
        self, collection_name: str, vectors: list[list[float | int]], limit: int
 | 
			
		||||
    ) -> Optional[QueryResult]:
 | 
			
		||||
    ) -> Optional[SearchResult]:
 | 
			
		||||
        # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
 | 
			
		||||
        collection = self.client.get_collection(name=collection_name)
 | 
			
		||||
        if collection:
 | 
			
		||||
@ -56,19 +56,31 @@ class ChromaClient:
 | 
			
		||||
                n_results=limit,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return {
 | 
			
		||||
                "ids": result["ids"],
 | 
			
		||||
                "distances": result["distances"],
 | 
			
		||||
                "documents": result["documents"],
 | 
			
		||||
                "metadatas": result["metadatas"],
 | 
			
		||||
            }
 | 
			
		||||
            return SearchResult(
 | 
			
		||||
                **{
 | 
			
		||||
                    "ids": result["ids"],
 | 
			
		||||
                    "distances": result["distances"],
 | 
			
		||||
                    "documents": result["documents"],
 | 
			
		||||
                    "metadatas": result["metadatas"],
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def get(self, collection_name: str) -> Optional[QueryResult]:
 | 
			
		||||
    def get(self, collection_name: str) -> Optional[GetResult]:
 | 
			
		||||
        # Get all the items in the collection.
 | 
			
		||||
        collection = self.client.get_collection(name=collection_name)
 | 
			
		||||
        if collection:
 | 
			
		||||
            return collection.get()
 | 
			
		||||
 | 
			
		||||
            result = collection.get()
 | 
			
		||||
 | 
			
		||||
            return GetResult(
 | 
			
		||||
                **{
 | 
			
		||||
                    "ids": [result["ids"]],
 | 
			
		||||
                    "distances": [result["distances"]],
 | 
			
		||||
                    "documents": [result["documents"]],
 | 
			
		||||
                    "metadatas": [result["metadatas"]],
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def insert(self, collection_name: str, items: list[VectorItem]):
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ import json
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
 | 
			
		||||
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
    MILVUS_URI,
 | 
			
		||||
)
 | 
			
		||||
@ -15,7 +15,7 @@ class MilvusClient:
 | 
			
		||||
        self.collection_prefix = "open_webui"
 | 
			
		||||
        self.client = Client(uri=MILVUS_URI)
 | 
			
		||||
 | 
			
		||||
    def _result_to_query_result(self, result) -> QueryResult:
 | 
			
		||||
    def _result_to_query_result(self, result) -> SearchResult:
 | 
			
		||||
        print(result)
 | 
			
		||||
 | 
			
		||||
        ids = []
 | 
			
		||||
@ -40,12 +40,14 @@ class MilvusClient:
 | 
			
		||||
            documents.append(_documents)
 | 
			
		||||
            metadatas.append(_metadatas)
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "ids": ids,
 | 
			
		||||
            "distances": distances,
 | 
			
		||||
            "documents": documents,
 | 
			
		||||
            "metadatas": metadatas,
 | 
			
		||||
        }
 | 
			
		||||
        return SearchResult(
 | 
			
		||||
            **{
 | 
			
		||||
                "ids": ids,
 | 
			
		||||
                "distances": distances,
 | 
			
		||||
                "documents": documents,
 | 
			
		||||
                "metadatas": metadatas,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _create_collection(self, collection_name: str, dimension: int):
 | 
			
		||||
        schema = self.client.create_schema(
 | 
			
		||||
@ -94,7 +96,7 @@ class MilvusClient:
 | 
			
		||||
 | 
			
		||||
    def search(
 | 
			
		||||
        self, collection_name: str, vectors: list[list[float | int]], limit: int
 | 
			
		||||
    ) -> Optional[QueryResult]:
 | 
			
		||||
    ) -> Optional[SearchResult]:
 | 
			
		||||
        # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
 | 
			
		||||
        result = self.client.search(
 | 
			
		||||
            collection_name=f"{self.collection_prefix}_{collection_name}",
 | 
			
		||||
@ -105,10 +107,11 @@ class MilvusClient:
 | 
			
		||||
 | 
			
		||||
        return self._result_to_query_result(result)
 | 
			
		||||
 | 
			
		||||
    def get(self, collection_name: str) -> Optional[QueryResult]:
 | 
			
		||||
    def get(self, collection_name: str) -> Optional[GetResult]:
 | 
			
		||||
        # Get all the items in the collection.
 | 
			
		||||
        result = self.client.query(
 | 
			
		||||
            collection_name=f"{self.collection_prefix}_{collection_name}",
 | 
			
		||||
            filter='id != ""',
 | 
			
		||||
        )
 | 
			
		||||
        return self._result_to_query_result(result)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,8 +9,11 @@ class VectorItem(BaseModel):
 | 
			
		||||
    metadata: Any
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QueryResult(BaseModel):
 | 
			
		||||
class GetResult(BaseModel):
 | 
			
		||||
    ids: Optional[List[List[str]]]
 | 
			
		||||
    distances: Optional[List[List[float | int]]]
 | 
			
		||||
    documents: Optional[List[List[str]]]
 | 
			
		||||
    metadatas: Optional[List[List[Any]]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SearchResult(GetResult):
 | 
			
		||||
    distances: Optional[List[List[float | int]]]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user