From 522afbb0a00aa3c688fe5560a48b53cff8f110c3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 10 Sep 2024 04:37:06 +0100 Subject: [PATCH 1/3] refac --- backend/open_webui/apps/rag/main.py | 36 +++--- backend/open_webui/apps/rag/utils.py | 104 +++++++++--------- .../open_webui/apps/rag/vector/connector.py | 10 +- .../open_webui/apps/rag/vector/dbs/chroma.py | 78 ++++++++++--- .../open_webui/apps/rag/vector/dbs/milvus.py | 39 +++++++ backend/open_webui/apps/rag/vector/main.py | 16 +++ .../open_webui/apps/webui/routers/memories.py | 84 +++++++------- 7 files changed, 240 insertions(+), 127 deletions(-) create mode 100644 backend/open_webui/apps/rag/vector/main.py diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index 30e5a1106..00dfe33c9 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -96,7 +96,6 @@ from open_webui.utils.misc import ( from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT -from chromadb.utils.batch_utils import create_batches from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import ( BSHTMLLoader, @@ -998,14 +997,11 @@ def store_docs_in_vector_db( try: if overwrite: - for collection in VECTOR_DB_CLIENT.list_collections(): - if collection_name == collection.name: - log.info(f"deleting existing collection {collection_name}") - VECTOR_DB_CLIENT.delete_collection(name=collection_name) + if collection_name in VECTOR_DB_CLIENT.list_collections(): + log.info(f"deleting existing collection {collection_name}") + VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) - collection = VECTOR_DB_CLIENT.create_collection(name=collection_name) - - embedding_func = get_embedding_function( + embedding_function = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, @@ -1014,17 +1010,19 @@ def store_docs_in_vector_db( app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, ) - embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) - embeddings = embedding_func(embedding_texts) - - for batch in create_batches( - api=VECTOR_DB_CLIENT, - ids=[str(uuid.uuid4()) for _ in texts], - metadatas=metadatas, - embeddings=embeddings, - documents=texts, - ): - collection.add(*batch) + VECTOR_DB_CLIENT.create_collection(collection_name=collection_name) + VECTOR_DB_CLIENT.insert( + collection_name=collection_name, + items=[ + { + "id": str(uuid.uuid4()), + "text": text, + "vector": embedding_function(text.replace("\n", " ")), + "metadata": metadatas[idx], + } + for idx, text in enumerate(texts) + ], + ) return True except Exception as e: diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py index 035fefc60..43fbd1596 100644 --- a/backend/open_webui/apps/rag/utils.py +++ b/backend/open_webui/apps/rag/utils.py @@ -24,6 +24,44 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +from typing import Any + +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.retrievers import BaseRetriever + + +class VectorSearchRetriever(BaseRetriever): + collection_name: Any + embedding_function: Any + top_k: int + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + ) -> list[Document]: + result = VECTOR_DB_CLIENT.search( + collection_name=self.collection_name, + vectors=[self.embedding_function(query)], + limit=self.top_k, + ) + + ids = result["ids"][0] + metadatas = result["metadatas"][0] + documents = result["documents"][0] + + results = [] + for idx in range(len(ids)): + results.append( + Document( + metadata=metadatas[idx], + page_content=documents[idx], + ) + ) + return results + + def query_doc( collection_name: str, query: str, @@ -31,15 +69,18 @@ def query_doc( k: int, ): try: - result = VECTOR_DB_CLIENT.query_collection( - name=collection_name, - query_embeddings=embedding_function(query), - k=k, + result = VECTOR_DB_CLIENT.search( + collection_name=collection_name, + vectors=[embedding_function(query)], + limit=k, ) + print("result", result) + log.info(f"query_doc:result {result}") return result except Exception as e: + print(e) raise e @@ -52,25 +93,23 @@ def query_doc_with_hybrid_search( r: float, ): try: - collection = VECTOR_DB_CLIENT.get_collection(name=collection_name) - documents = collection.get() # get all documents + result = VECTOR_DB_CLIENT.get(collection_name=collection_name) bm25_retriever = BM25Retriever.from_texts( - texts=documents.get("documents"), - metadatas=documents.get("metadatas"), + texts=result.documents, + metadatas=result.metadatas, ) bm25_retriever.k = k - chroma_retriever = ChromaRetriever( - collection=collection, + vector_search_retriever = VectorSearchRetriever( + collection_name=collection_name, embedding_function=embedding_function, - top_n=k, + top_k=k, ) ensemble_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] + retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5] ) - compressor = RerankCompressor( embedding_function=embedding_function, top_n=k, @@ -394,45 +433,6 @@ def generate_openai_batch_embeddings( return None -from typing import Any - -from langchain_core.callbacks import CallbackManagerForRetrieverRun -from langchain_core.retrievers import BaseRetriever - - -class ChromaRetriever(BaseRetriever): - collection: Any - embedding_function: Any - top_n: int - - def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - ) -> list[Document]: - query_embeddings = self.embedding_function(query) - - results = self.collection.query( - query_embeddings=[query_embeddings], - n_results=self.top_n, - ) - - ids = results["ids"][0] - metadatas = results["metadatas"][0] - documents = results["documents"][0] - - results = [] - for idx in range(len(ids)): - results.append( - Document( - metadata=metadatas[idx], - page_content=documents[idx], - ) - ) - return results - - import operator from typing import Optional, Sequence diff --git a/backend/open_webui/apps/rag/vector/connector.py b/backend/open_webui/apps/rag/vector/connector.py index d7ca615bf..073becdbe 100644 --- a/backend/open_webui/apps/rag/vector/connector.py +++ b/backend/open_webui/apps/rag/vector/connector.py @@ -1,4 +1,10 @@ -from open_webui.apps.rag.vector.dbs.chroma import Chroma +from open_webui.apps.rag.vector.dbs.chroma import ChromaClient +from open_webui.apps.rag.vector.dbs.milvus import MilvusClient + + from open_webui.config import VECTOR_DB -VECTOR_DB_CLIENT = Chroma() +if VECTOR_DB == "milvus": + VECTOR_DB_CLIENT = MilvusClient() +else: + VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py index 1fd560642..ea82ccdb6 100644 --- a/backend/open_webui/apps/rag/vector/dbs/chroma.py +++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py @@ -1,6 +1,10 @@ import chromadb from chromadb import Settings +from chromadb.utils.batch_utils import create_batches +from typing import Optional + +from open_webui.apps.rag.vector.main import VectorItem, QueryResult from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, @@ -12,7 +16,7 @@ from open_webui.config import ( ) -class Chroma: +class ChromaClient: def __init__(self): if CHROMA_HTTP_HOST != "": self.client = chromadb.HttpClient( @@ -32,27 +36,73 @@ class Chroma: database=CHROMA_DATABASE, ) - def query_collection(self, name, query_embeddings, k): - collection = self.client.get_collection(name=name) + def list_collections(self) -> list[str]: + collections = self.client.list_collections() + return [collection.name for collection in collections] + + def create_collection(self, collection_name: str): + return self.client.create_collection(name=collection_name) + + def delete_collection(self, collection_name: str): + return self.client.delete_collection(name=collection_name) + + def search( + self, collection_name: str, vectors: list[list[float | int]], limit: int + ) -> Optional[QueryResult]: + collection = self.client.get_collection(name=collection_name) if collection: result = collection.query( - query_embeddings=[query_embeddings], - n_results=k, + query_embeddings=vectors, + n_results=limit, ) - return result + + return { + "ids": result["ids"], + "distances": result["distances"], + "documents": result["documents"], + "metadatas": result["metadatas"], + } return None - def list_collections(self): - return self.client.list_collections() + def get(self, collection_name: str) -> Optional[QueryResult]: + collection = self.client.get_collection(name=collection_name) + if collection: + return collection.get() + return None - def create_collection(self, name): - return self.client.create_collection(name=name) + def insert(self, collection_name: str, items: list[VectorItem]): + collection = self.client.get_or_create_collection(name=collection_name) - def get_or_create_collection(self, name): - return self.client.get_or_create_collection(name=name) + ids = [item["id"] for item in items] + documents = [item["text"] for item in items] + embeddings = [item["vector"] for item in items] + metadatas = [item["metadata"] for item in items] - def delete_collection(self, name): - return self.client.delete_collection(name=name) + for batch in create_batches( + api=self.client, + documents=documents, + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + ): + collection.add(*batch) + + def upsert(self, collection_name: str, items: list[VectorItem]): + collection = self.client.get_or_create_collection(name=collection_name) + + ids = [item["id"] for item in items] + documents = [item["text"] for item in items] + embeddings = [item["vector"] for item in items] + metadata = [item["metadata"] for item in items] + + collection.upsert( + ids=ids, documents=documents, embeddings=embeddings, metadata=metadata + ) + + def delete(self, collection_name: str, ids: list[str]): + collection = self.client.get_collection(name=collection_name) + if collection: + collection.delete(ids=ids) def reset(self): return self.client.reset() diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/rag/vector/dbs/milvus.py index e69de29bb..228a45aea 100644 --- a/backend/open_webui/apps/rag/vector/dbs/milvus.py +++ b/backend/open_webui/apps/rag/vector/dbs/milvus.py @@ -0,0 +1,39 @@ +from pymilvus import MilvusClient as Milvus + +from typing import Optional + +from open_webui.apps.rag.vector.main import VectorItem, QueryResult + + +class MilvusClient: + def __init__(self): + self.client = Milvus() + + def list_collections(self) -> list[str]: + pass + + def create_collection(self, collection_name: str): + pass + + def delete_collection(self, collection_name: str): + pass + + def search( + self, collection_name: str, vectors: list[list[float | int]], limit: int + ) -> Optional[QueryResult]: + pass + + def get(self, collection_name: str) -> Optional[QueryResult]: + pass + + def insert(self, collection_name: str, items: list[VectorItem]): + pass + + def upsert(self, collection_name: str, items: list[VectorItem]): + pass + + def delete(self, collection_name: str, ids: list[str]): + pass + + def reset(self): + pass diff --git a/backend/open_webui/apps/rag/vector/main.py b/backend/open_webui/apps/rag/vector/main.py new file mode 100644 index 000000000..5b5a8ea38 --- /dev/null +++ b/backend/open_webui/apps/rag/vector/main.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel +from typing import Optional, List, Any + + +class VectorItem(BaseModel): + id: str + text: str + vector: List[float | int] + metadata: Any + + +class QueryResult(BaseModel): + ids: Optional[List[List[str]]] + distances: Optional[List[List[float | int]]] + documents: Optional[List[List[str]]] + metadatas: Optional[List[List[Any]]] diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/apps/webui/routers/memories.py index 1b44063e7..4680f27ab 100644 --- a/backend/open_webui/apps/webui/routers/memories.py +++ b/backend/open_webui/apps/webui/routers/memories.py @@ -50,16 +50,17 @@ async def add_memory( user=Depends(get_verified_user), ): memory = Memories.insert_new_memory(user.id, form_data.content) - memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" - ) - collection.upsert( - documents=[memory.content], - ids=[memory.id], - embeddings=[memory_embedding], - metadatas=[{"created_at": memory.created_at}], + VECTOR_DB_CLIENT.upsert( + collection_name=f"user-memory-{user.id}", + items=[ + { + "id": memory.id, + "text": memory.content, + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "metadata": {"created_at": memory.created_at}, + } + ], ) return memory @@ -79,14 +80,10 @@ class QueryMemoryForm(BaseModel): async def query_memory( request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) ): - query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" - ) - - results = collection.query( - query_embeddings=[query_embedding], - n_results=form_data.k, # how many results to return + results = VECTOR_DB_CLIENT.search( + name=f"user-memory-{user.id}", + vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)], + limit=form_data.k, ) return results @@ -100,18 +97,24 @@ async def reset_memory_from_vector_db( request: Request, user=Depends(get_verified_user) ): VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" - ) memories = Memories.get_memories_by_user_id(user.id) - for memory in memories: - memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) - collection.upsert( - documents=[memory.content], - ids=[memory.id], - embeddings=[memory_embedding], - ) + VECTOR_DB_CLIENT.upsert( + collection_name=f"user-memory-{user.id}", + items=[ + { + "id": memory.id, + "text": memory.content, + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "metadata": { + "created_at": memory.created_at, + "updated_at": memory.updated_at, + }, + } + for memory in memories + ], + ) + return True @@ -151,16 +154,18 @@ async def update_memory_by_id( raise HTTPException(status_code=404, detail="Memory not found") if form_data.content is not None: - memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" - ) - collection.upsert( - documents=[form_data.content], - ids=[memory.id], - embeddings=[memory_embedding], - metadatas=[ - {"created_at": memory.created_at, "updated_at": memory.updated_at} + VECTOR_DB_CLIENT.upsert( + collection_name=f"user-memory-{user.id}", + items=[ + { + "id": memory.id, + "text": memory.content, + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "metadata": { + "created_at": memory.created_at, + "updated_at": memory.updated_at, + }, + } ], ) @@ -177,10 +182,9 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) if result: - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" + VECTOR_DB_CLIENT.delete( + collection_name=f"user-memory-{user.id}", ids=[memory_id] ) - collection.delete(ids=[memory_id]) return True return False From 0886b3a0a467517c462abae73cbf1dc6ee94d179 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 10 Sep 2024 04:46:40 +0100 Subject: [PATCH 2/3] refac: comments --- backend/open_webui/apps/rag/vector/dbs/chroma.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py index ea82ccdb6..7ce713d0f 100644 --- a/backend/open_webui/apps/rag/vector/dbs/chroma.py +++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py @@ -37,18 +37,22 @@ class ChromaClient: ) def list_collections(self) -> list[str]: + # List all the collections in the database. collections = self.client.list_collections() return [collection.name for collection in collections] def create_collection(self, collection_name: str): + # Create a new collection based on the collection name. return self.client.create_collection(name=collection_name) def delete_collection(self, collection_name: str): + # Delete the collection based on the collection name. return self.client.delete_collection(name=collection_name) def search( self, collection_name: str, vectors: list[list[float | int]], limit: int ) -> Optional[QueryResult]: + # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. collection = self.client.get_collection(name=collection_name) if collection: result = collection.query( @@ -65,12 +69,14 @@ class ChromaClient: return None def get(self, collection_name: str) -> Optional[QueryResult]: + # Get all the items in the collection. collection = self.client.get_collection(name=collection_name) if collection: return collection.get() return None def insert(self, collection_name: str, items: list[VectorItem]): + # Insert the items into the collection. collection = self.client.get_or_create_collection(name=collection_name) ids = [item["id"] for item in items] @@ -88,6 +94,7 @@ class ChromaClient: collection.add(*batch) def upsert(self, collection_name: str, items: list[VectorItem]): + # Update the items in the collection, if the items are not present, insert them. collection = self.client.get_or_create_collection(name=collection_name) ids = [item["id"] for item in items] @@ -100,9 +107,11 @@ class ChromaClient: ) def delete(self, collection_name: str, ids: list[str]): + # Delete the items from the collection based on the ids. collection = self.client.get_collection(name=collection_name) if collection: collection.delete(ids=ids) def reset(self): + # Resets the database. This will delete all collections and item entries. return self.client.reset() From 4775fe43d8d7e1088e9e3730a82e2caf911622be Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 12 Sep 2024 01:52:19 -0400 Subject: [PATCH 3/3] feat: milvus support --- backend/open_webui/apps/rag/main.py | 1 - .../open_webui/apps/rag/vector/dbs/chroma.py | 8 +- .../open_webui/apps/rag/vector/dbs/milvus.py | 162 ++++++++++++++++-- backend/open_webui/config.py | 4 + backend/requirements.txt | 2 + pyproject.toml | 1 + 6 files changed, 158 insertions(+), 20 deletions(-) diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index 00dfe33c9..32ca6b7e9 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -1010,7 +1010,6 @@ def store_docs_in_vector_db( app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, ) - VECTOR_DB_CLIENT.create_collection(collection_name=collection_name) VECTOR_DB_CLIENT.insert( collection_name=collection_name, items=[ diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py index 7ce713d0f..b04dbd6bc 100644 --- a/backend/open_webui/apps/rag/vector/dbs/chroma.py +++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py @@ -41,10 +41,6 @@ class ChromaClient: collections = self.client.list_collections() return [collection.name for collection in collections] - def create_collection(self, collection_name: str): - # Create a new collection based on the collection name. - return self.client.create_collection(name=collection_name) - def delete_collection(self, collection_name: str): # Delete the collection based on the collection name. return self.client.delete_collection(name=collection_name) @@ -76,7 +72,7 @@ class ChromaClient: return None def insert(self, collection_name: str, items: list[VectorItem]): - # Insert the items into the collection. + # Insert the items into the collection, if the collection does not exist, it will be created. collection = self.client.get_or_create_collection(name=collection_name) ids = [item["id"] for item in items] @@ -94,7 +90,7 @@ class ChromaClient: collection.add(*batch) def upsert(self, collection_name: str, items: list[VectorItem]): - # Update the items in the collection, if the items are not present, insert them. + # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. collection = self.client.get_or_create_collection(name=collection_name) ids = [item["id"] for item in items] diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/rag/vector/dbs/milvus.py index 228a45aea..f679b4504 100644 --- a/backend/open_webui/apps/rag/vector/dbs/milvus.py +++ b/backend/open_webui/apps/rag/vector/dbs/milvus.py @@ -1,39 +1,175 @@ -from pymilvus import MilvusClient as Milvus +from pymilvus import MilvusClient as Client +from pymilvus import FieldSchema, DataType +import json from typing import Optional from open_webui.apps.rag.vector.main import VectorItem, QueryResult +from open_webui.config import ( + MILVUS_URI, +) class MilvusClient: def __init__(self): - self.client = Milvus() + self.collection_prefix = "open_webui" + self.client = Client(uri=MILVUS_URI) + + def _result_to_query_result(self, result) -> QueryResult: + print(result) + + ids = [] + distances = [] + documents = [] + metadatas = [] + + for match in result: + _ids = [] + _distances = [] + _documents = [] + _metadatas = [] + + for item in match: + _ids.append(item.get("id")) + _distances.append(item.get("distance")) + _documents.append(item.get("entity", {}).get("data", {}).get("text")) + _metadatas.append(item.get("entity", {}).get("metadata")) + + ids.append(_ids) + distances.append(_distances) + documents.append(_documents) + metadatas.append(_metadatas) + + return { + "ids": ids, + "distances": distances, + "documents": documents, + "metadatas": metadatas, + } + + def _create_collection(self, collection_name: str, dimension: int): + schema = self.client.create_schema( + auto_id=False, + enable_dynamic_field=True, + ) + schema.add_field( + field_name="id", + datatype=DataType.VARCHAR, + is_primary=True, + max_length=65535, + ) + schema.add_field( + field_name="vector", + datatype=DataType.FLOAT_VECTOR, + dim=dimension, + description="vector", + ) + schema.add_field(field_name="data", datatype=DataType.JSON, description="data") + schema.add_field( + field_name="metadata", datatype=DataType.JSON, description="metadata" + ) + + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", index_type="HNSW", metric_type="COSINE", params={} + ) + + self.client.create_collection( + collection_name=f"{self.collection_prefix}_{collection_name}", + schema=schema, + index_params=index_params, + ) def list_collections(self) -> list[str]: - pass - - def create_collection(self, collection_name: str): - pass + # List all the collections in the database. + return [ + collection[len(self.collection_prefix) :] + for collection in self.client.list_collections() + if collection.startswith(self.collection_prefix) + ] def delete_collection(self, collection_name: str): - pass + # Delete the collection based on the collection name. + return self.client.drop_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ) def search( self, collection_name: str, vectors: list[list[float | int]], limit: int ) -> Optional[QueryResult]: - pass + # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. + result = self.client.search( + collection_name=f"{self.collection_prefix}_{collection_name}", + data=vectors, + limit=limit, + output_fields=["data", "metadata"], + ) + + return self._result_to_query_result(result) def get(self, collection_name: str) -> Optional[QueryResult]: - pass + # Get all the items in the collection. + result = self.client.query( + collection_name=f"{self.collection_prefix}_{collection_name}", + ) + return self._result_to_query_result(result) def insert(self, collection_name: str, items: list[VectorItem]): - pass + # Insert the items into the collection, if the collection does not exist, it will be created. + if not self.client.has_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ): + self._create_collection( + collection_name=collection_name, dimension=len(items[0]["vector"]) + ) + + return self.client.insert( + collection_name=f"{self.collection_prefix}_{collection_name}", + data=[ + { + "id": item["id"], + "vector": item["vector"], + "data": {"text": item["text"]}, + "metadata": item["metadata"], + } + for item in items + ], + ) def upsert(self, collection_name: str, items: list[VectorItem]): - pass + # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. + if not self.client.has_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ): + self._create_collection( + collection_name=collection_name, dimension=len(items[0]["vector"]) + ) + + return self.client.upsert( + collection_name=f"{self.collection_prefix}_{collection_name}", + data=[ + { + "id": item["id"], + "vector": item["vector"], + "data": {"text": item["text"]}, + "metadata": item["metadata"], + } + for item in items + ], + ) def delete(self, collection_name: str, ids: list[str]): - pass + # Delete the items from the collection based on the ids. + + return self.client.delete( + collection_name=f"{self.collection_prefix}_{collection_name}", + ids=ids, + ) def reset(self): - pass + # Resets the database. This will delete all collections and item entries. + + collection_names = self.client.list_collections() + for collection_name in collection_names: + if collection_name.startswith(self.collection_prefix): + self.client.drop_collection(collection_name=collection_name) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index e4fe1a546..019cc8847 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -910,6 +910,10 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) +# Milvus + +MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") + #################################### # RAG #################################### diff --git a/backend/requirements.txt b/backend/requirements.txt index 93720cc84..11a742d05 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -40,6 +40,8 @@ langchain-chroma==0.1.2 fake-useragent==1.5.1 chromadb==0.5.5 +pymilvus==2.4.6 + sentence-transformers==3.0.1 pypdf==4.3.1 docx2txt==0.8 diff --git a/pyproject.toml b/pyproject.toml index 057ef1475..b035723f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "fake-useragent==1.5.1", "chromadb==0.5.5", + "pymilvus==2.4.6", "sentence-transformers==3.0.1", "pypdf==4.3.1", "docx2txt==0.8",