mirror of
https://github.com/open-webui/open-webui
synced 2024-11-06 16:59:42 +00:00
add qdrant as vector db
This commit is contained in:
parent
3f892583c3
commit
878a570a2c
@ -4,6 +4,10 @@ if VECTOR_DB == "milvus":
|
||||
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
VECTOR_DB_CLIENT = MilvusClient()
|
||||
elif VECTOR_DB == "qdrant":
|
||||
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
VECTOR_DB_CLIENT = QdrantClient()
|
||||
else:
|
||||
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
|
178
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py
Normal file
178
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py
Normal file
@ -0,0 +1,178 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel("INFO")
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name_with_prefix,
|
||||
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
|
||||
)
|
||||
|
||||
log.info(f"collection {collection_name_with_prefix} successfully created!")
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(
|
||||
collection_name=collection_name
|
||||
):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=dimension
|
||||
)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
|
||||
log.info("start search...")
|
||||
query_response = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query=vectors[0],
|
||||
limit=limit,
|
||||
)
|
||||
get_result = self._result_to_get_result(query_response.points)
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=[[point.score for point in query_response.points]]
|
||||
)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
# Construct the filter string for querying
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
try:
|
||||
|
||||
field_conditions = []
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
|
||||
|
||||
log.info("start search...")
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query_filter=models.Filter(should=field_conditions),
|
||||
limit=limit,
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
limit=10000000 # default is 10
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self.create_points(items)
|
||||
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self.create_points(items)
|
||||
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
field_conditions = []
|
||||
|
||||
if ids:
|
||||
for id_value in ids:
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.id",
|
||||
match=models.MatchValue(value=id_value),
|
||||
),
|
||||
),
|
||||
elif filter:
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
),
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(
|
||||
must=field_conditions
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
collection_names = self.client.get_collections().collections
|
||||
for collection_name in collection_names:
|
||||
if collection_name.name.startswith(self.collection_prefix):
|
||||
self.client.delete_collection(collection_name=collection_name.name)
|
||||
|
||||
def create_points(self, items: list[VectorItem]):
|
||||
vectors = [item["vector"] for item in items]
|
||||
log.info("insert points...")
|
||||
points = []
|
||||
for idx, item in enumerate(items):
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=vectors[idx],
|
||||
payload={"text": item["text"], "metadata": item["metadata"]},
|
||||
)
|
||||
)
|
||||
return points
|
@ -901,6 +901,9 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||
|
||||
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
|
||||
|
||||
# Qdrant
|
||||
QDRANT_URI = os.environ.get("QDRANT_URI", None)
|
||||
|
||||
####################################
|
||||
# Information Retrieval (RAG)
|
||||
####################################
|
||||
|
@ -41,6 +41,7 @@ langchain-chroma==0.1.4
|
||||
fake-useragent==1.5.1
|
||||
chromadb==0.5.9
|
||||
pymilvus==2.4.7
|
||||
qdrant-client~=1.12.0
|
||||
|
||||
sentence-transformers==3.0.1
|
||||
colbert-ai==0.2.21
|
||||
|
Loading…
Reference in New Issue
Block a user