From 878a570a2c4ec0bdfebc5cd0a6d6632eac13e66d Mon Sep 17 00:00:00 2001 From: Robin Bially Date: Wed, 9 Oct 2024 12:51:43 +0200 Subject: [PATCH] add qdrant as vector db --- .../apps/retrieval/vector/connector.py | 4 + .../apps/retrieval/vector/dbs/qdrant.py | 178 ++++++++++++++++++ backend/open_webui/config.py | 3 + backend/requirements.txt | 1 + 4 files changed, 186 insertions(+) create mode 100644 backend/open_webui/apps/retrieval/vector/dbs/qdrant.py diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py index 1f33b1721..c7f00f5fd 100644 --- a/backend/open_webui/apps/retrieval/vector/connector.py +++ b/backend/open_webui/apps/retrieval/vector/connector.py @@ -4,6 +4,10 @@ if VECTOR_DB == "milvus": from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient VECTOR_DB_CLIENT = MilvusClient() +elif VECTOR_DB == "qdrant": + from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient + + VECTOR_DB_CLIENT = QdrantClient() else: from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py new file mode 100644 index 000000000..6ef00fff0 --- /dev/null +++ b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py @@ -0,0 +1,178 @@ +import logging +from typing import Optional + +from qdrant_client import QdrantClient as Qclient +from qdrant_client.http.models import PointStruct +from qdrant_client.models import models + +from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import QDRANT_URI + +log = logging.getLogger(__name__) +log.setLevel("INFO") + + +class QdrantClient: + def __init__(self): + self.collection_prefix = "open-webui" + self.QDRANT_URI = QDRANT_URI + self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None + + def _result_to_get_result(self, points) -> GetResult: + ids = [] + documents = [] + metadatas = [] + + for point in points: + payload = point.payload + ids.append(point.id) + documents.append(payload["text"]) + metadatas.append(payload["metadata"]) + + return GetResult( + **{ + "ids": [ids], + "documents": [documents], + "metadatas": [metadatas], + } + ) + + def _create_collection(self, collection_name: str, dimension: int): + collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}" + self.client.create_collection( + collection_name=collection_name_with_prefix, + vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE), + ) + + log.info(f"collection {collection_name_with_prefix} successfully created!") + + def _create_collection_if_not_exists(self, collection_name, dimension): + if not self.has_collection( + collection_name=collection_name + ): + self._create_collection( + collection_name=collection_name, dimension=dimension + ) + + def has_collection(self, collection_name: str) -> bool: + return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}") + + def delete_collection(self, collection_name: str): + return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}") + + def search( + self, collection_name: str, vectors: list[list[float | int]], limit: int + ) -> Optional[SearchResult]: + # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. + + log.info("start search...") + query_response = self.client.query_points( + collection_name=f"{self.collection_prefix}_{collection_name}", + query=vectors[0], + limit=limit, + ) + get_result = self._result_to_get_result(query_response.points) + return SearchResult( + ids=get_result.ids, + documents=get_result.documents, + metadatas=get_result.metadatas, + distances=[[point.score for point in query_response.points]] + ) + + def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): + # Construct the filter string for querying + if not self.has_collection(collection_name): + return None + try: + + field_conditions = [] + for key, value in filter.items(): + field_conditions.append( + models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value))) + + log.info("start search...") + points = self.client.query_points( + collection_name=f"{self.collection_prefix}_{collection_name}", + query_filter=models.Filter(should=field_conditions), + limit=limit, + ) + return self._result_to_get_result(points.points) + except Exception as e: + print(e) + return None + + def get(self, collection_name: str) -> Optional[GetResult]: + # Get all the items in the collection. + points = self.client.query_points( + collection_name=f"{self.collection_prefix}_{collection_name}", + limit=10000000 # default is 10 + ) + return self._result_to_get_result(points.points) + + def insert(self, collection_name: str, items: list[VectorItem]): + # Insert the items into the collection, if the collection does not exist, it will be created. + self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) + points = self.create_points(items) + self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points) + + def upsert(self, collection_name: str, items: list[VectorItem]): + # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. + self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) + points = self.create_points(items) + return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points) + + def delete( + self, + collection_name: str, + ids: Optional[list[str]] = None, + filter: Optional[dict] = None, + ): + # Delete the items from the collection based on the ids. + field_conditions = [] + + if ids: + for id_value in ids: + field_conditions.append( + models.FieldCondition( + key="metadata.id", + match=models.MatchValue(value=id_value), + ), + ), + elif filter: + for key, value in filter.items(): + field_conditions.append( + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ), + + return self.client.delete( + collection_name=f"{self.collection_prefix}_{collection_name}", + points_selector=models.FilterSelector( + filter=models.Filter( + must=field_conditions + ) + ), + ) + + def reset(self): + # Resets the database. This will delete all collections and item entries. + collection_names = self.client.get_collections().collections + for collection_name in collection_names: + if collection_name.name.startswith(self.collection_prefix): + self.client.delete_collection(collection_name=collection_name.name) + + def create_points(self, items: list[VectorItem]): + vectors = [item["vector"] for item in items] + log.info("insert points...") + points = [] + for idx, item in enumerate(items): + points.append( + PointStruct( + id=item["id"], + vector=vectors[idx], + payload={"text": item["text"], "metadata": item["metadata"]}, + ) + ) + return points diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index bfc9a4ded..67ab39566 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -901,6 +901,9 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") +# Qdrant +QDRANT_URI = os.environ.get("QDRANT_URI", None) + #################################### # Information Retrieval (RAG) #################################### diff --git a/backend/requirements.txt b/backend/requirements.txt index 80b4d541f..7d601f86a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -41,6 +41,7 @@ langchain-chroma==0.1.4 fake-useragent==1.5.1 chromadb==0.5.9 pymilvus==2.4.7 +qdrant-client~=1.12.0 sentence-transformers==3.0.1 colbert-ai==0.2.21