diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index 30e5a1106..00dfe33c9 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -96,7 +96,6 @@ from open_webui.utils.misc import ( from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT -from chromadb.utils.batch_utils import create_batches from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import ( BSHTMLLoader, @@ -998,14 +997,11 @@ def store_docs_in_vector_db( try: if overwrite: - for collection in VECTOR_DB_CLIENT.list_collections(): - if collection_name == collection.name: - log.info(f"deleting existing collection {collection_name}") - VECTOR_DB_CLIENT.delete_collection(name=collection_name) + if collection_name in VECTOR_DB_CLIENT.list_collections(): + log.info(f"deleting existing collection {collection_name}") + VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) - collection = VECTOR_DB_CLIENT.create_collection(name=collection_name) - - embedding_func = get_embedding_function( + embedding_function = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, @@ -1014,17 +1010,19 @@ def store_docs_in_vector_db( app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, ) - embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) - embeddings = embedding_func(embedding_texts) - - for batch in create_batches( - api=VECTOR_DB_CLIENT, - ids=[str(uuid.uuid4()) for _ in texts], - metadatas=metadatas, - embeddings=embeddings, - documents=texts, - ): - collection.add(*batch) + VECTOR_DB_CLIENT.create_collection(collection_name=collection_name) + VECTOR_DB_CLIENT.insert( + collection_name=collection_name, + items=[ + { + "id": str(uuid.uuid4()), + "text": text, + "vector": embedding_function(text.replace("\n", " ")), + "metadata": metadatas[idx], + } + for idx, text in enumerate(texts) + ], + ) return True except Exception as e: diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py index 035fefc60..43fbd1596 100644 --- a/backend/open_webui/apps/rag/utils.py +++ b/backend/open_webui/apps/rag/utils.py @@ -24,6 +24,44 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +from typing import Any + +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.retrievers import BaseRetriever + + +class VectorSearchRetriever(BaseRetriever): + collection_name: Any + embedding_function: Any + top_k: int + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + ) -> list[Document]: + result = VECTOR_DB_CLIENT.search( + collection_name=self.collection_name, + vectors=[self.embedding_function(query)], + limit=self.top_k, + ) + + ids = result["ids"][0] + metadatas = result["metadatas"][0] + documents = result["documents"][0] + + results = [] + for idx in range(len(ids)): + results.append( + Document( + metadata=metadatas[idx], + page_content=documents[idx], + ) + ) + return results + + def query_doc( collection_name: str, query: str, @@ -31,15 +69,18 @@ def query_doc( k: int, ): try: - result = VECTOR_DB_CLIENT.query_collection( - name=collection_name, - query_embeddings=embedding_function(query), - k=k, + result = VECTOR_DB_CLIENT.search( + collection_name=collection_name, + vectors=[embedding_function(query)], + limit=k, ) + print("result", result) + log.info(f"query_doc:result {result}") return result except Exception as e: + print(e) raise e @@ -52,25 +93,23 @@ def query_doc_with_hybrid_search( r: float, ): try: - collection = VECTOR_DB_CLIENT.get_collection(name=collection_name) - documents = collection.get() # get all documents + result = VECTOR_DB_CLIENT.get(collection_name=collection_name) bm25_retriever = BM25Retriever.from_texts( - texts=documents.get("documents"), - metadatas=documents.get("metadatas"), + texts=result.documents, + metadatas=result.metadatas, ) bm25_retriever.k = k - chroma_retriever = ChromaRetriever( - collection=collection, + vector_search_retriever = VectorSearchRetriever( + collection_name=collection_name, embedding_function=embedding_function, - top_n=k, + top_k=k, ) ensemble_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] + retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5] ) - compressor = RerankCompressor( embedding_function=embedding_function, top_n=k, @@ -394,45 +433,6 @@ def generate_openai_batch_embeddings( return None -from typing import Any - -from langchain_core.callbacks import CallbackManagerForRetrieverRun -from langchain_core.retrievers import BaseRetriever - - -class ChromaRetriever(BaseRetriever): - collection: Any - embedding_function: Any - top_n: int - - def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - ) -> list[Document]: - query_embeddings = self.embedding_function(query) - - results = self.collection.query( - query_embeddings=[query_embeddings], - n_results=self.top_n, - ) - - ids = results["ids"][0] - metadatas = results["metadatas"][0] - documents = results["documents"][0] - - results = [] - for idx in range(len(ids)): - results.append( - Document( - metadata=metadatas[idx], - page_content=documents[idx], - ) - ) - return results - - import operator from typing import Optional, Sequence diff --git a/backend/open_webui/apps/rag/vector/connector.py b/backend/open_webui/apps/rag/vector/connector.py index d7ca615bf..073becdbe 100644 --- a/backend/open_webui/apps/rag/vector/connector.py +++ b/backend/open_webui/apps/rag/vector/connector.py @@ -1,4 +1,10 @@ -from open_webui.apps.rag.vector.dbs.chroma import Chroma +from open_webui.apps.rag.vector.dbs.chroma import ChromaClient +from open_webui.apps.rag.vector.dbs.milvus import MilvusClient + + from open_webui.config import VECTOR_DB -VECTOR_DB_CLIENT = Chroma() +if VECTOR_DB == "milvus": + VECTOR_DB_CLIENT = MilvusClient() +else: + VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py index 1fd560642..ea82ccdb6 100644 --- a/backend/open_webui/apps/rag/vector/dbs/chroma.py +++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py @@ -1,6 +1,10 @@ import chromadb from chromadb import Settings +from chromadb.utils.batch_utils import create_batches +from typing import Optional + +from open_webui.apps.rag.vector.main import VectorItem, QueryResult from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, @@ -12,7 +16,7 @@ from open_webui.config import ( ) -class Chroma: +class ChromaClient: def __init__(self): if CHROMA_HTTP_HOST != "": self.client = chromadb.HttpClient( @@ -32,27 +36,73 @@ class Chroma: database=CHROMA_DATABASE, ) - def query_collection(self, name, query_embeddings, k): - collection = self.client.get_collection(name=name) + def list_collections(self) -> list[str]: + collections = self.client.list_collections() + return [collection.name for collection in collections] + + def create_collection(self, collection_name: str): + return self.client.create_collection(name=collection_name) + + def delete_collection(self, collection_name: str): + return self.client.delete_collection(name=collection_name) + + def search( + self, collection_name: str, vectors: list[list[float | int]], limit: int + ) -> Optional[QueryResult]: + collection = self.client.get_collection(name=collection_name) if collection: result = collection.query( - query_embeddings=[query_embeddings], - n_results=k, + query_embeddings=vectors, + n_results=limit, ) - return result + + return { + "ids": result["ids"], + "distances": result["distances"], + "documents": result["documents"], + "metadatas": result["metadatas"], + } return None - def list_collections(self): - return self.client.list_collections() + def get(self, collection_name: str) -> Optional[QueryResult]: + collection = self.client.get_collection(name=collection_name) + if collection: + return collection.get() + return None - def create_collection(self, name): - return self.client.create_collection(name=name) + def insert(self, collection_name: str, items: list[VectorItem]): + collection = self.client.get_or_create_collection(name=collection_name) - def get_or_create_collection(self, name): - return self.client.get_or_create_collection(name=name) + ids = [item["id"] for item in items] + documents = [item["text"] for item in items] + embeddings = [item["vector"] for item in items] + metadatas = [item["metadata"] for item in items] - def delete_collection(self, name): - return self.client.delete_collection(name=name) + for batch in create_batches( + api=self.client, + documents=documents, + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + ): + collection.add(*batch) + + def upsert(self, collection_name: str, items: list[VectorItem]): + collection = self.client.get_or_create_collection(name=collection_name) + + ids = [item["id"] for item in items] + documents = [item["text"] for item in items] + embeddings = [item["vector"] for item in items] + metadata = [item["metadata"] for item in items] + + collection.upsert( + ids=ids, documents=documents, embeddings=embeddings, metadata=metadata + ) + + def delete(self, collection_name: str, ids: list[str]): + collection = self.client.get_collection(name=collection_name) + if collection: + collection.delete(ids=ids) def reset(self): return self.client.reset() diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/rag/vector/dbs/milvus.py index e69de29bb..228a45aea 100644 --- a/backend/open_webui/apps/rag/vector/dbs/milvus.py +++ b/backend/open_webui/apps/rag/vector/dbs/milvus.py @@ -0,0 +1,39 @@ +from pymilvus import MilvusClient as Milvus + +from typing import Optional + +from open_webui.apps.rag.vector.main import VectorItem, QueryResult + + +class MilvusClient: + def __init__(self): + self.client = Milvus() + + def list_collections(self) -> list[str]: + pass + + def create_collection(self, collection_name: str): + pass + + def delete_collection(self, collection_name: str): + pass + + def search( + self, collection_name: str, vectors: list[list[float | int]], limit: int + ) -> Optional[QueryResult]: + pass + + def get(self, collection_name: str) -> Optional[QueryResult]: + pass + + def insert(self, collection_name: str, items: list[VectorItem]): + pass + + def upsert(self, collection_name: str, items: list[VectorItem]): + pass + + def delete(self, collection_name: str, ids: list[str]): + pass + + def reset(self): + pass diff --git a/backend/open_webui/apps/rag/vector/main.py b/backend/open_webui/apps/rag/vector/main.py new file mode 100644 index 000000000..5b5a8ea38 --- /dev/null +++ b/backend/open_webui/apps/rag/vector/main.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel +from typing import Optional, List, Any + + +class VectorItem(BaseModel): + id: str + text: str + vector: List[float | int] + metadata: Any + + +class QueryResult(BaseModel): + ids: Optional[List[List[str]]] + distances: Optional[List[List[float | int]]] + documents: Optional[List[List[str]]] + metadatas: Optional[List[List[Any]]] diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/apps/webui/routers/memories.py index 1b44063e7..4680f27ab 100644 --- a/backend/open_webui/apps/webui/routers/memories.py +++ b/backend/open_webui/apps/webui/routers/memories.py @@ -50,16 +50,17 @@ async def add_memory( user=Depends(get_verified_user), ): memory = Memories.insert_new_memory(user.id, form_data.content) - memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" - ) - collection.upsert( - documents=[memory.content], - ids=[memory.id], - embeddings=[memory_embedding], - metadatas=[{"created_at": memory.created_at}], + VECTOR_DB_CLIENT.upsert( + collection_name=f"user-memory-{user.id}", + items=[ + { + "id": memory.id, + "text": memory.content, + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "metadata": {"created_at": memory.created_at}, + } + ], ) return memory @@ -79,14 +80,10 @@ class QueryMemoryForm(BaseModel): async def query_memory( request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) ): - query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" - ) - - results = collection.query( - query_embeddings=[query_embedding], - n_results=form_data.k, # how many results to return + results = VECTOR_DB_CLIENT.search( + name=f"user-memory-{user.id}", + vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)], + limit=form_data.k, ) return results @@ -100,18 +97,24 @@ async def reset_memory_from_vector_db( request: Request, user=Depends(get_verified_user) ): VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" - ) memories = Memories.get_memories_by_user_id(user.id) - for memory in memories: - memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) - collection.upsert( - documents=[memory.content], - ids=[memory.id], - embeddings=[memory_embedding], - ) + VECTOR_DB_CLIENT.upsert( + collection_name=f"user-memory-{user.id}", + items=[ + { + "id": memory.id, + "text": memory.content, + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "metadata": { + "created_at": memory.created_at, + "updated_at": memory.updated_at, + }, + } + for memory in memories + ], + ) + return True @@ -151,16 +154,18 @@ async def update_memory_by_id( raise HTTPException(status_code=404, detail="Memory not found") if form_data.content is not None: - memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" - ) - collection.upsert( - documents=[form_data.content], - ids=[memory.id], - embeddings=[memory_embedding], - metadatas=[ - {"created_at": memory.created_at, "updated_at": memory.updated_at} + VECTOR_DB_CLIENT.upsert( + collection_name=f"user-memory-{user.id}", + items=[ + { + "id": memory.id, + "text": memory.content, + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "metadata": { + "created_at": memory.created_at, + "updated_at": memory.updated_at, + }, + } ], ) @@ -177,10 +182,9 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) if result: - collection = VECTOR_DB_CLIENT.get_or_create_collection( - name=f"user-memory-{user.id}" + VECTOR_DB_CLIENT.delete( + collection_name=f"user-memory-{user.id}", ids=[memory_id] ) - collection.delete(ids=[memory_id]) return True return False