diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index a6b97df3e..f9adc9c95 100755 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches from typing import Optional -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, @@ -23,7 +28,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class ChromaClient: +class ChromaClient(VectorDBBase): def __init__(self): settings_dict = { "allow_reset": True, diff --git a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py index c89628494..18a915e38 100644 --- a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py +++ b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py @@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError from typing import Optional import ssl from elasticsearch.helpers import bulk, scan -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( ELASTICSEARCH_URL, ELASTICSEARCH_CA_CERTS, @@ -15,7 +20,7 @@ from open_webui.config import ( ) -class ElasticsearchClient: +class ElasticsearchClient(VectorDBBase): """ Important: in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 26b4dd5ed..f116c57f7 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -4,7 +4,12 @@ import json import logging from typing import Optional -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( MILVUS_URI, MILVUS_DB, @@ -16,7 +21,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class MilvusClient: +class MilvusClient(VectorDBBase): def __init__(self): self.collection_prefix = "open_webui" if MILVUS_TOKEN is None: diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 432bcef41..60ef2d906 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -2,7 +2,12 @@ from opensearchpy import OpenSearch from opensearchpy.helpers import bulk from typing import Optional -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( OPENSEARCH_URI, OPENSEARCH_SSL, @@ -12,7 +17,7 @@ from open_webui.config import ( ) -class OpenSearchClient: +class OpenSearchClient(VectorDBBase): def __init__(self): self.index_prefix = "open_webui" self.client = OpenSearch( diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index c38dbb036..cd875b406 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.exc import NoSuchTableError -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH from open_webui.env import SRC_LOG_LEVELS @@ -44,7 +49,7 @@ class DocumentChunk(Base): vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) -class PgvectorClient: +class PgvectorClient(VectorDBBase): def __init__(self) -> None: # if no pgvector uri, use the existing database connection diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index c72c06471..bc9bd8bc3 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -2,7 +2,12 @@ from typing import Optional, List, Dict, Any, Union import logging from pinecone import Pinecone, ServerlessSpec -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( PINECONE_API_KEY, PINECONE_ENVIRONMENT, @@ -20,7 +25,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class PineconeClient: +class PineconeClient(VectorDBBase): def __init__(self): self.collection_prefix = "open-webui" diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index a0d602610..dfe297907 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -6,7 +6,12 @@ from qdrant_client import QdrantClient as Qclient from qdrant_client.http.models import PointStruct from qdrant_client.models import models -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( QDRANT_URI, QDRANT_API_KEY, @@ -22,7 +27,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class QdrantClient: +class QdrantClient(VectorDBBase): def __init__(self): self.collection_prefix = "open-webui" self.QDRANT_URI = QDRANT_URI diff --git a/backend/open_webui/retrieval/vector/main.py b/backend/open_webui/retrieval/vector/main.py index f0cf0c038..53f752f57 100644 --- a/backend/open_webui/retrieval/vector/main.py +++ b/backend/open_webui/retrieval/vector/main.py @@ -1,5 +1,6 @@ from pydantic import BaseModel -from typing import Optional, List, Any +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Union class VectorItem(BaseModel): @@ -17,3 +18,69 @@ class GetResult(BaseModel): class SearchResult(GetResult): distances: Optional[List[List[float | int]]] + + +class VectorDBBase(ABC): + """ + Abstract base class for all vector database backends. + + Implementations of this class provide methods for collection management, + vector insertion, deletion, similarity search, and metadata filtering. + + Any custom vector database integration must inherit from this class and + implement all abstract methods. + """ + + @abstractmethod + def has_collection(self, collection_name: str) -> bool: + """Check if the collection exists in the vector DB.""" + pass + + @abstractmethod + def delete_collection(self, collection_name: str) -> None: + """Delete a collection from the vector DB.""" + pass + + @abstractmethod + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + """Insert a list of vector items into a collection.""" + pass + + @abstractmethod + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + """Insert or update vector items in a collection.""" + pass + + @abstractmethod + def search( + self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + ) -> Optional[SearchResult]: + """Search for similar vectors in a collection.""" + pass + + @abstractmethod + def query( + self, collection_name: str, filter: Dict, limit: Optional[int] = None + ) -> Optional[GetResult]: + """Query vectors from a collection using metadata filter.""" + pass + + @abstractmethod + def get(self, collection_name: str) -> Optional[GetResult]: + """Retrieve all vectors from a collection.""" + pass + + @abstractmethod + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict] = None, + ) -> None: + """Delete vectors by ID or filter from a collection.""" + pass + + @abstractmethod + def reset(self) -> None: + """Reset the vector database by removing all collections or those matching a condition.""" + pass