mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac
This commit is contained in:
@@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
|
||||
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
@@ -47,7 +47,7 @@ class ChromaClient:
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[QueryResult]:
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
@@ -56,19 +56,31 @@ class ChromaClient:
|
||||
n_results=limit,
|
||||
)
|
||||
|
||||
return {
|
||||
"ids": result["ids"],
|
||||
"distances": result["distances"],
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
}
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result["ids"],
|
||||
"distances": result["distances"],
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
}
|
||||
)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[QueryResult]:
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
return collection.get()
|
||||
|
||||
result = collection.get()
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"distances": [result["distances"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
}
|
||||
)
|
||||
return None
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
|
||||
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
)
|
||||
@@ -15,7 +15,7 @@ class MilvusClient:
|
||||
self.collection_prefix = "open_webui"
|
||||
self.client = Client(uri=MILVUS_URI)
|
||||
|
||||
def _result_to_query_result(self, result) -> QueryResult:
|
||||
def _result_to_query_result(self, result) -> SearchResult:
|
||||
print(result)
|
||||
|
||||
ids = []
|
||||
@@ -40,12 +40,14 @@ class MilvusClient:
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return {
|
||||
"ids": ids,
|
||||
"distances": distances,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
"distances": distances,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
)
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
schema = self.client.create_schema(
|
||||
@@ -94,7 +96,7 @@ class MilvusClient:
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[QueryResult]:
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
result = self.client.search(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
@@ -105,10 +107,11 @@ class MilvusClient:
|
||||
|
||||
return self._result_to_query_result(result)
|
||||
|
||||
def get(self, collection_name: str) -> Optional[QueryResult]:
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
result = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter='id != ""',
|
||||
)
|
||||
return self._result_to_query_result(result)
|
||||
|
||||
|
||||
@@ -9,8 +9,11 @@ class VectorItem(BaseModel):
|
||||
metadata: Any
|
||||
|
||||
|
||||
class QueryResult(BaseModel):
|
||||
class GetResult(BaseModel):
|
||||
ids: Optional[List[List[str]]]
|
||||
distances: Optional[List[List[float | int]]]
|
||||
documents: Optional[List[List[str]]]
|
||||
metadatas: Optional[List[List[Any]]]
|
||||
|
||||
|
||||
class SearchResult(GetResult):
|
||||
distances: Optional[List[List[float | int]]]
|
||||
|
||||
Reference in New Issue
Block a user