diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index 00dfe33c9..32ca6b7e9 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -1010,7 +1010,6 @@ def store_docs_in_vector_db( app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, ) - VECTOR_DB_CLIENT.create_collection(collection_name=collection_name) VECTOR_DB_CLIENT.insert( collection_name=collection_name, items=[ diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py index 7ce713d0f..b04dbd6bc 100644 --- a/backend/open_webui/apps/rag/vector/dbs/chroma.py +++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py @@ -41,10 +41,6 @@ class ChromaClient: collections = self.client.list_collections() return [collection.name for collection in collections] - def create_collection(self, collection_name: str): - # Create a new collection based on the collection name. - return self.client.create_collection(name=collection_name) - def delete_collection(self, collection_name: str): # Delete the collection based on the collection name. return self.client.delete_collection(name=collection_name) @@ -76,7 +72,7 @@ class ChromaClient: return None def insert(self, collection_name: str, items: list[VectorItem]): - # Insert the items into the collection. + # Insert the items into the collection, if the collection does not exist, it will be created. collection = self.client.get_or_create_collection(name=collection_name) ids = [item["id"] for item in items] @@ -94,7 +90,7 @@ class ChromaClient: collection.add(*batch) def upsert(self, collection_name: str, items: list[VectorItem]): - # Update the items in the collection, if the items are not present, insert them. + # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. collection = self.client.get_or_create_collection(name=collection_name) ids = [item["id"] for item in items] diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/rag/vector/dbs/milvus.py index 228a45aea..f679b4504 100644 --- a/backend/open_webui/apps/rag/vector/dbs/milvus.py +++ b/backend/open_webui/apps/rag/vector/dbs/milvus.py @@ -1,39 +1,175 @@ -from pymilvus import MilvusClient as Milvus +from pymilvus import MilvusClient as Client +from pymilvus import FieldSchema, DataType +import json from typing import Optional from open_webui.apps.rag.vector.main import VectorItem, QueryResult +from open_webui.config import ( + MILVUS_URI, +) class MilvusClient: def __init__(self): - self.client = Milvus() + self.collection_prefix = "open_webui" + self.client = Client(uri=MILVUS_URI) + + def _result_to_query_result(self, result) -> QueryResult: + print(result) + + ids = [] + distances = [] + documents = [] + metadatas = [] + + for match in result: + _ids = [] + _distances = [] + _documents = [] + _metadatas = [] + + for item in match: + _ids.append(item.get("id")) + _distances.append(item.get("distance")) + _documents.append(item.get("entity", {}).get("data", {}).get("text")) + _metadatas.append(item.get("entity", {}).get("metadata")) + + ids.append(_ids) + distances.append(_distances) + documents.append(_documents) + metadatas.append(_metadatas) + + return { + "ids": ids, + "distances": distances, + "documents": documents, + "metadatas": metadatas, + } + + def _create_collection(self, collection_name: str, dimension: int): + schema = self.client.create_schema( + auto_id=False, + enable_dynamic_field=True, + ) + schema.add_field( + field_name="id", + datatype=DataType.VARCHAR, + is_primary=True, + max_length=65535, + ) + schema.add_field( + field_name="vector", + datatype=DataType.FLOAT_VECTOR, + dim=dimension, + description="vector", + ) + schema.add_field(field_name="data", datatype=DataType.JSON, description="data") + schema.add_field( + field_name="metadata", datatype=DataType.JSON, description="metadata" + ) + + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", index_type="HNSW", metric_type="COSINE", params={} + ) + + self.client.create_collection( + collection_name=f"{self.collection_prefix}_{collection_name}", + schema=schema, + index_params=index_params, + ) def list_collections(self) -> list[str]: - pass - - def create_collection(self, collection_name: str): - pass + # List all the collections in the database. + return [ + collection[len(self.collection_prefix) :] + for collection in self.client.list_collections() + if collection.startswith(self.collection_prefix) + ] def delete_collection(self, collection_name: str): - pass + # Delete the collection based on the collection name. + return self.client.drop_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ) def search( self, collection_name: str, vectors: list[list[float | int]], limit: int ) -> Optional[QueryResult]: - pass + # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. + result = self.client.search( + collection_name=f"{self.collection_prefix}_{collection_name}", + data=vectors, + limit=limit, + output_fields=["data", "metadata"], + ) + + return self._result_to_query_result(result) def get(self, collection_name: str) -> Optional[QueryResult]: - pass + # Get all the items in the collection. + result = self.client.query( + collection_name=f"{self.collection_prefix}_{collection_name}", + ) + return self._result_to_query_result(result) def insert(self, collection_name: str, items: list[VectorItem]): - pass + # Insert the items into the collection, if the collection does not exist, it will be created. + if not self.client.has_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ): + self._create_collection( + collection_name=collection_name, dimension=len(items[0]["vector"]) + ) + + return self.client.insert( + collection_name=f"{self.collection_prefix}_{collection_name}", + data=[ + { + "id": item["id"], + "vector": item["vector"], + "data": {"text": item["text"]}, + "metadata": item["metadata"], + } + for item in items + ], + ) def upsert(self, collection_name: str, items: list[VectorItem]): - pass + # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. + if not self.client.has_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ): + self._create_collection( + collection_name=collection_name, dimension=len(items[0]["vector"]) + ) + + return self.client.upsert( + collection_name=f"{self.collection_prefix}_{collection_name}", + data=[ + { + "id": item["id"], + "vector": item["vector"], + "data": {"text": item["text"]}, + "metadata": item["metadata"], + } + for item in items + ], + ) def delete(self, collection_name: str, ids: list[str]): - pass + # Delete the items from the collection based on the ids. + + return self.client.delete( + collection_name=f"{self.collection_prefix}_{collection_name}", + ids=ids, + ) def reset(self): - pass + # Resets the database. This will delete all collections and item entries. + + collection_names = self.client.list_collections() + for collection_name in collection_names: + if collection_name.startswith(self.collection_prefix): + self.client.drop_collection(collection_name=collection_name) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index e4fe1a546..019cc8847 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -910,6 +910,10 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) +# Milvus + +MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") + #################################### # RAG #################################### diff --git a/backend/requirements.txt b/backend/requirements.txt index 93720cc84..11a742d05 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -40,6 +40,8 @@ langchain-chroma==0.1.2 fake-useragent==1.5.1 chromadb==0.5.5 +pymilvus==2.4.6 + sentence-transformers==3.0.1 pypdf==4.3.1 docx2txt==0.8 diff --git a/pyproject.toml b/pyproject.toml index 057ef1475..b035723f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "fake-useragent==1.5.1", "chromadb==0.5.5", + "pymilvus==2.4.6", "sentence-transformers==3.0.1", "pypdf==4.3.1", "docx2txt==0.8",