improvements

This commit is contained in:
Robin Bially 2024-10-09 13:10:23 +02:00
parent b185524a8c
commit b56f77ed47

View File

@ -1,4 +1,3 @@
import logging
from typing import Optional from typing import Optional
from qdrant_client import QdrantClient as Qclient from qdrant_client import QdrantClient as Qclient
@ -8,10 +7,6 @@ from qdrant_client.models import models
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI from open_webui.config import QDRANT_URI
log = logging.getLogger(__name__)
log.setLevel("INFO")
class QdrantClient: class QdrantClient:
def __init__(self): def __init__(self):
self.collection_prefix = "open-webui" self.collection_prefix = "open-webui"
@ -44,7 +39,7 @@ class QdrantClient:
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE), vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
) )
log.info(f"collection {collection_name_with_prefix} successfully created!") print(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension): def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection( if not self.has_collection(
@ -65,7 +60,6 @@ class QdrantClient:
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
log.info("start search...")
query_response = self.client.query_points( query_response = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
query=vectors[0], query=vectors[0],
@ -90,7 +84,6 @@ class QdrantClient:
field_conditions.append( field_conditions.append(
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value))) models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
log.info("start search...")
points = self.client.query_points( points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
query_filter=models.Filter(should=field_conditions), query_filter=models.Filter(should=field_conditions),
@ -164,15 +157,16 @@ class QdrantClient:
self.client.delete_collection(collection_name=collection_name.name) self.client.delete_collection(collection_name=collection_name.name)
def create_points(self, items: list[VectorItem]): def create_points(self, items: list[VectorItem]):
vectors = [item["vector"] for item in items]
log.info("insert points...")
points = [] points = []
for idx, item in enumerate(items): for idx, item in enumerate(items):
points.append( points.append(
PointStruct( PointStruct(
id=item["id"], id=item["id"],
vector=vectors[idx], vector=item["vector"],
payload={"text": item["text"], "metadata": item["metadata"]}, payload={
"text": item["text"],
"metadata": item["metadata"]
},
) )
) )
return points return points