mirror of
https://github.com/open-webui/open-webui
synced 2025-05-15 19:16:35 +00:00
refac
This commit is contained in:
parent
b943b7d337
commit
939bfd153e
@ -48,9 +48,9 @@ class VectorSearchRetriever(BaseRetriever):
|
|||||||
limit=self.top_k,
|
limit=self.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
ids = result["ids"][0]
|
ids = result.ids[0]
|
||||||
metadatas = result["metadatas"][0]
|
metadatas = result.metadatas[0]
|
||||||
documents = result["documents"][0]
|
documents = result.documents[0]
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for idx in range(len(ids)):
|
for idx in range(len(ids)):
|
||||||
@ -194,7 +194,7 @@ def query_collection(
|
|||||||
k=k,
|
k=k,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result.model_dump())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error when querying the collection: {e}")
|
log.exception(f"Error when querying the collection: {e}")
|
||||||
else:
|
else:
|
||||||
@ -212,7 +212,7 @@ def query_collection_with_hybrid_search(
|
|||||||
r: float,
|
r: float,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
results = []
|
results = []
|
||||||
failed = 0
|
error = False
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
try:
|
try:
|
||||||
result = query_doc_with_hybrid_search(
|
result = query_doc_with_hybrid_search(
|
||||||
@ -228,12 +228,14 @@ def query_collection_with_hybrid_search(
|
|||||||
log.exception(
|
log.exception(
|
||||||
"Error when querying the collection with " f"hybrid_search: {e}"
|
"Error when querying the collection with " f"hybrid_search: {e}"
|
||||||
)
|
)
|
||||||
failed += 1
|
error = True
|
||||||
if failed == len(collection_names):
|
|
||||||
|
if error:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Hybrid search failed for all collections. Using "
|
"Hybrid search failed for all collections. Using "
|
||||||
"Non hybrid search as fallback."
|
"Non hybrid search as fallback."
|
||||||
)
|
)
|
||||||
|
|
||||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
|
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
CHROMA_DATA_PATH,
|
CHROMA_DATA_PATH,
|
||||||
CHROMA_HTTP_HOST,
|
CHROMA_HTTP_HOST,
|
||||||
@ -47,7 +47,7 @@ class ChromaClient:
|
|||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||||
) -> Optional[QueryResult]:
|
) -> Optional[SearchResult]:
|
||||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||||
collection = self.client.get_collection(name=collection_name)
|
collection = self.client.get_collection(name=collection_name)
|
||||||
if collection:
|
if collection:
|
||||||
@ -56,19 +56,31 @@ class ChromaClient:
|
|||||||
n_results=limit,
|
n_results=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return SearchResult(
|
||||||
|
**{
|
||||||
"ids": result["ids"],
|
"ids": result["ids"],
|
||||||
"distances": result["distances"],
|
"distances": result["distances"],
|
||||||
"documents": result["documents"],
|
"documents": result["documents"],
|
||||||
"metadatas": result["metadatas"],
|
"metadatas": result["metadatas"],
|
||||||
}
|
}
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[QueryResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
# Get all the items in the collection.
|
# Get all the items in the collection.
|
||||||
collection = self.client.get_collection(name=collection_name)
|
collection = self.client.get_collection(name=collection_name)
|
||||||
if collection:
|
if collection:
|
||||||
return collection.get()
|
|
||||||
|
result = collection.get()
|
||||||
|
|
||||||
|
return GetResult(
|
||||||
|
**{
|
||||||
|
"ids": [result["ids"]],
|
||||||
|
"distances": [result["distances"]],
|
||||||
|
"documents": [result["documents"]],
|
||||||
|
"metadatas": [result["metadatas"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
|
@ -4,7 +4,7 @@ import json
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
|
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
MILVUS_URI,
|
MILVUS_URI,
|
||||||
)
|
)
|
||||||
@ -15,7 +15,7 @@ class MilvusClient:
|
|||||||
self.collection_prefix = "open_webui"
|
self.collection_prefix = "open_webui"
|
||||||
self.client = Client(uri=MILVUS_URI)
|
self.client = Client(uri=MILVUS_URI)
|
||||||
|
|
||||||
def _result_to_query_result(self, result) -> QueryResult:
|
def _result_to_query_result(self, result) -> SearchResult:
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
ids = []
|
ids = []
|
||||||
@ -40,12 +40,14 @@ class MilvusClient:
|
|||||||
documents.append(_documents)
|
documents.append(_documents)
|
||||||
metadatas.append(_metadatas)
|
metadatas.append(_metadatas)
|
||||||
|
|
||||||
return {
|
return SearchResult(
|
||||||
|
**{
|
||||||
"ids": ids,
|
"ids": ids,
|
||||||
"distances": distances,
|
"distances": distances,
|
||||||
"documents": documents,
|
"documents": documents,
|
||||||
"metadatas": metadatas,
|
"metadatas": metadatas,
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def _create_collection(self, collection_name: str, dimension: int):
|
def _create_collection(self, collection_name: str, dimension: int):
|
||||||
schema = self.client.create_schema(
|
schema = self.client.create_schema(
|
||||||
@ -94,7 +96,7 @@ class MilvusClient:
|
|||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||||
) -> Optional[QueryResult]:
|
) -> Optional[SearchResult]:
|
||||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||||
result = self.client.search(
|
result = self.client.search(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
@ -105,10 +107,11 @@ class MilvusClient:
|
|||||||
|
|
||||||
return self._result_to_query_result(result)
|
return self._result_to_query_result(result)
|
||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[QueryResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
# Get all the items in the collection.
|
# Get all the items in the collection.
|
||||||
result = self.client.query(
|
result = self.client.query(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
|
filter='id != ""',
|
||||||
)
|
)
|
||||||
return self._result_to_query_result(result)
|
return self._result_to_query_result(result)
|
||||||
|
|
||||||
|
@ -9,8 +9,11 @@ class VectorItem(BaseModel):
|
|||||||
metadata: Any
|
metadata: Any
|
||||||
|
|
||||||
|
|
||||||
class QueryResult(BaseModel):
|
class GetResult(BaseModel):
|
||||||
ids: Optional[List[List[str]]]
|
ids: Optional[List[List[str]]]
|
||||||
distances: Optional[List[List[float | int]]]
|
|
||||||
documents: Optional[List[List[str]]]
|
documents: Optional[List[List[str]]]
|
||||||
metadatas: Optional[List[List[Any]]]
|
metadatas: Optional[List[List[Any]]]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(GetResult):
|
||||||
|
distances: Optional[List[List[float | int]]]
|
||||||
|
Loading…
Reference in New Issue
Block a user