feat: milvus support

This commit is contained in:
Timothy J. Baek 2024-09-12 01:52:19 -04:00
parent 0886b3a0a4
commit 4775fe43d8
6 changed files with 158 additions and 20 deletions

View File

@ -1010,7 +1010,6 @@ def store_docs_in_vector_db(
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
VECTOR_DB_CLIENT.create_collection(collection_name=collection_name)
VECTOR_DB_CLIENT.insert(
collection_name=collection_name,
items=[

View File

@ -41,10 +41,6 @@ class ChromaClient:
collections = self.client.list_collections()
return [collection.name for collection in collections]
def create_collection(self, collection_name: str):
# Create a new collection based on the collection name.
return self.client.create_collection(name=collection_name)
def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
return self.client.delete_collection(name=collection_name)
@ -76,7 +72,7 @@ class ChromaClient:
return None
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection.
# Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(name=collection_name)
ids = [item["id"] for item in items]
@ -94,7 +90,7 @@ class ChromaClient:
collection.add(*batch)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them.
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(name=collection_name)
ids = [item["id"] for item in items]

View File

@ -1,39 +1,175 @@
from pymilvus import MilvusClient as Milvus
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
import json
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
from open_webui.config import (
MILVUS_URI,
)
class MilvusClient:
def __init__(self):
self.client = Milvus()
self.collection_prefix = "open_webui"
self.client = Client(uri=MILVUS_URI)
def _result_to_query_result(self, result) -> QueryResult:
print(result)
ids = []
distances = []
documents = []
metadatas = []
for match in result:
_ids = []
_distances = []
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_distances.append(item.get("distance"))
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata"))
ids.append(_ids)
distances.append(_distances)
documents.append(_documents)
metadatas.append(_metadatas)
return {
"ids": ids,
"distances": distances,
"documents": documents,
"metadatas": metadatas,
}
def _create_collection(self, collection_name: str, dimension: int):
schema = self.client.create_schema(
auto_id=False,
enable_dynamic_field=True,
)
schema.add_field(
field_name="id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=65535,
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=dimension,
description="vector",
)
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
schema.add_field(
field_name="metadata", datatype=DataType.JSON, description="metadata"
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
)
self.client.create_collection(
collection_name=f"{self.collection_prefix}_{collection_name}",
schema=schema,
index_params=index_params,
)
def list_collections(self) -> list[str]:
pass
def create_collection(self, collection_name: str):
pass
# List all the collections in the database.
return [
collection[len(self.collection_prefix) :]
for collection in self.client.list_collections()
if collection.startswith(self.collection_prefix)
]
def delete_collection(self, collection_name: str):
pass
# Delete the collection based on the collection name.
return self.client.drop_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]:
pass
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors,
limit=limit,
output_fields=["data", "metadata"],
)
return self._result_to_query_result(result)
def get(self, collection_name: str) -> Optional[QueryResult]:
pass
# Get all the items in the collection.
result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
)
return self._result_to_query_result(result)
def insert(self, collection_name: str, items: list[VectorItem]):
pass
# Insert the items into the collection, if the collection does not exist, it will be created.
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": item["metadata"],
}
for item in items
],
)
def upsert(self, collection_name: str, items: list[VectorItem]):
pass
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": item["metadata"],
}
for item in items
],
)
def delete(self, collection_name: str, ids: list[str]):
pass
# Delete the items from the collection based on the ids.
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids,
)
def reset(self):
pass
# Resets the database. This will delete all collections and item entries.
collection_names = self.client.list_collections()
for collection_name in collection_names:
if collection_name.startswith(self.collection_prefix):
self.client.drop_collection(collection_name=collection_name)

View File

@ -910,6 +910,10 @@ else:
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
# Milvus
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
####################################
# RAG
####################################

View File

@ -40,6 +40,8 @@ langchain-chroma==0.1.2
fake-useragent==1.5.1
chromadb==0.5.5
pymilvus==2.4.6
sentence-transformers==3.0.1
pypdf==4.3.1
docx2txt==0.8

View File

@ -47,6 +47,7 @@ dependencies = [
"fake-useragent==1.5.1",
"chromadb==0.5.5",
"pymilvus==2.4.6",
"sentence-transformers==3.0.1",
"pypdf==4.3.1",
"docx2txt==0.8",