chore: format

This commit is contained in:
Timothy J. Baek
2024-10-20 18:38:06 -07:00
parent 768b7e139c
commit 9936583477
6 changed files with 85 additions and 72 deletions

View File

@@ -110,9 +110,8 @@ class ChromaClient:
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
@@ -131,9 +130,8 @@ class ChromaClient:
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]

View File

@@ -9,6 +9,7 @@ from open_webui.config import QDRANT_URI
NO_LIMIT = 999999999
class QdrantClient:
def __init__(self):
self.collection_prefix = "open-webui"
@@ -38,15 +39,15 @@ class QdrantClient:
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
vectors_config=models.VectorParams(
size=dimension, distance=models.Distance.COSINE
),
)
print(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(
collection_name=collection_name
):
if not self.has_collection(collection_name=collection_name):
self._create_collection(
collection_name=collection_name, dimension=dimension
)
@@ -56,22 +57,23 @@ class QdrantClient:
PointStruct(
id=item["id"],
vector=item["vector"],
payload={
"text": item["text"],
"metadata": item["metadata"]
},
payload={"text": item["text"], "metadata": item["metadata"]},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
return self.client.collection_exists(
f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str):
return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
return self.client.delete_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
if limit is None:
@@ -87,7 +89,7 @@ class QdrantClient:
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]]
distances=[[point.score for point in query_response.points]],
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
@@ -101,7 +103,10 @@ class QdrantClient:
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
@@ -117,7 +122,7 @@ class QdrantClient:
# Get all the items in the collection.
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT # otherwise qdrant would set limit to 10!
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
)
return self._result_to_get_result(points.points)
@@ -134,10 +139,10 @@ class QdrantClient:
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
field_conditions = []
@@ -162,9 +167,7 @@ class QdrantClient:
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
points_selector=models.FilterSelector(
filter=models.Filter(
must=field_conditions
)
filter=models.Filter(must=field_conditions)
),
)