mirror of
https://github.com/open-webui/open-webui
synced 2025-03-03 10:52:09 +00:00
feat: milvus support
This commit is contained in:
parent
0886b3a0a4
commit
4775fe43d8
@ -1010,7 +1010,6 @@ def store_docs_in_vector_db(
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
VECTOR_DB_CLIENT.create_collection(collection_name=collection_name)
|
||||
VECTOR_DB_CLIENT.insert(
|
||||
collection_name=collection_name,
|
||||
items=[
|
||||
|
@ -41,10 +41,6 @@ class ChromaClient:
|
||||
collections = self.client.list_collections()
|
||||
return [collection.name for collection in collections]
|
||||
|
||||
def create_collection(self, collection_name: str):
|
||||
# Create a new collection based on the collection name.
|
||||
return self.client.create_collection(name=collection_name)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
# Delete the collection based on the collection name.
|
||||
return self.client.delete_collection(name=collection_name)
|
||||
@ -76,7 +72,7 @@ class ChromaClient:
|
||||
return None
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection.
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(name=collection_name)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
@ -94,7 +90,7 @@ class ChromaClient:
|
||||
collection.add(*batch)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them.
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(name=collection_name)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
|
@ -1,39 +1,175 @@
|
||||
from pymilvus import MilvusClient as Milvus
|
||||
from pymilvus import MilvusClient as Client
|
||||
from pymilvus import FieldSchema, DataType
|
||||
import json
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
)
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
def __init__(self):
|
||||
self.client = Milvus()
|
||||
self.collection_prefix = "open_webui"
|
||||
self.client = Client(uri=MILVUS_URI)
|
||||
|
||||
def _result_to_query_result(self, result) -> QueryResult:
|
||||
print(result)
|
||||
|
||||
ids = []
|
||||
distances = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in result:
|
||||
_ids = []
|
||||
_distances = []
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_distances.append(item.get("distance"))
|
||||
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("entity", {}).get("metadata"))
|
||||
|
||||
ids.append(_ids)
|
||||
distances.append(_distances)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return {
|
||||
"ids": ids,
|
||||
"distances": distances,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
schema = self.client.create_schema(
|
||||
auto_id=False,
|
||||
enable_dynamic_field=True,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=65535,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
description="vector",
|
||||
)
|
||||
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
|
||||
schema.add_field(
|
||||
field_name="metadata", datatype=DataType.JSON, description="metadata"
|
||||
)
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
)
|
||||
|
||||
def list_collections(self) -> list[str]:
|
||||
pass
|
||||
|
||||
def create_collection(self, collection_name: str):
|
||||
pass
|
||||
# List all the collections in the database.
|
||||
return [
|
||||
collection[len(self.collection_prefix) :]
|
||||
for collection in self.client.list_collections()
|
||||
if collection.startswith(self.collection_prefix)
|
||||
]
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
pass
|
||||
# Delete the collection based on the collection name.
|
||||
return self.client.drop_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[QueryResult]:
|
||||
pass
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
result = self.client.search(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=vectors,
|
||||
limit=limit,
|
||||
output_fields=["data", "metadata"],
|
||||
)
|
||||
|
||||
return self._result_to_query_result(result)
|
||||
|
||||
def get(self, collection_name: str) -> Optional[QueryResult]:
|
||||
pass
|
||||
# Get all the items in the collection.
|
||||
result = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
)
|
||||
return self._result_to_query_result(result)
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
pass
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
return self.client.insert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": item["metadata"],
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
pass
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
return self.client.upsert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": item["metadata"],
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
)
|
||||
|
||||
def delete(self, collection_name: str, ids: list[str]):
|
||||
pass
|
||||
# Delete the items from the collection based on the ids.
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
ids=ids,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
|
||||
collection_names = self.client.list_collections()
|
||||
for collection_name in collection_names:
|
||||
if collection_name.startswith(self.collection_prefix):
|
||||
self.client.drop_collection(collection_name=collection_name)
|
||||
|
@ -910,6 +910,10 @@ else:
|
||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
||||
|
||||
# Milvus
|
||||
|
||||
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
|
||||
|
||||
####################################
|
||||
# RAG
|
||||
####################################
|
||||
|
@ -40,6 +40,8 @@ langchain-chroma==0.1.2
|
||||
|
||||
fake-useragent==1.5.1
|
||||
chromadb==0.5.5
|
||||
pymilvus==2.4.6
|
||||
|
||||
sentence-transformers==3.0.1
|
||||
pypdf==4.3.1
|
||||
docx2txt==0.8
|
||||
|
@ -47,6 +47,7 @@ dependencies = [
|
||||
|
||||
"fake-useragent==1.5.1",
|
||||
"chromadb==0.5.5",
|
||||
"pymilvus==2.4.6",
|
||||
"sentence-transformers==3.0.1",
|
||||
"pypdf==4.3.1",
|
||||
"docx2txt==0.8",
|
||||
|
Loading…
Reference in New Issue
Block a user