mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	feat: Add abstract base class for vector database integration
- Created `VectorDBBase` as an abstract base class to standardize vector database operations. - Added required methods for common vector database operations: `has_collection`, `delete_collection`, `insert`, `upsert`, `search`, `query`, `get`, `delete`, `reset`. - The base class can now be extended by any vector database implementation (e.g., Qdrant, Pinecone) to ensure a consistent API across different database systems.
This commit is contained in:
		
							parent
							
								
									913f8a15f9
								
							
						
					
					
						commit
						1e291aff25
					
				@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 | 
			
		||||
from open_webui.retrieval.vector.main import (
 | 
			
		||||
    VectorDBBase,
 | 
			
		||||
    VectorItem,
 | 
			
		||||
    SearchResult,
 | 
			
		||||
    GetResult,
 | 
			
		||||
)
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
    CHROMA_DATA_PATH,
 | 
			
		||||
    CHROMA_HTTP_HOST,
 | 
			
		||||
@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChromaClient:
 | 
			
		||||
class ChromaClient(VectorDBBase):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        settings_dict = {
 | 
			
		||||
            "allow_reset": True,
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
 | 
			
		||||
from typing import Optional
 | 
			
		||||
import ssl
 | 
			
		||||
from elasticsearch.helpers import bulk, scan
 | 
			
		||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 | 
			
		||||
from open_webui.retrieval.vector.main import (
 | 
			
		||||
    VectorDBBase,
 | 
			
		||||
    VectorItem,
 | 
			
		||||
    SearchResult,
 | 
			
		||||
    GetResult,
 | 
			
		||||
)
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
    ELASTICSEARCH_URL,
 | 
			
		||||
    ELASTICSEARCH_CA_CERTS,
 | 
			
		||||
@ -15,7 +20,7 @@ from open_webui.config import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ElasticsearchClient:
 | 
			
		||||
class ElasticsearchClient(VectorDBBase):
 | 
			
		||||
    """
 | 
			
		||||
    Important:
 | 
			
		||||
    in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,12 @@ import json
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 | 
			
		||||
from open_webui.retrieval.vector.main import (
 | 
			
		||||
    VectorDBBase,
 | 
			
		||||
    VectorItem,
 | 
			
		||||
    SearchResult,
 | 
			
		||||
    GetResult,
 | 
			
		||||
)
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
    MILVUS_URI,
 | 
			
		||||
    MILVUS_DB,
 | 
			
		||||
@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MilvusClient:
 | 
			
		||||
class MilvusClient(VectorDBBase):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.collection_prefix = "open_webui"
 | 
			
		||||
        if MILVUS_TOKEN is None:
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
 | 
			
		||||
from opensearchpy.helpers import bulk
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 | 
			
		||||
from open_webui.retrieval.vector.main import (
 | 
			
		||||
    VectorDBBase,
 | 
			
		||||
    VectorItem,
 | 
			
		||||
    SearchResult,
 | 
			
		||||
    GetResult,
 | 
			
		||||
)
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
    OPENSEARCH_URI,
 | 
			
		||||
    OPENSEARCH_SSL,
 | 
			
		||||
@ -12,7 +17,7 @@ from open_webui.config import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OpenSearchClient:
 | 
			
		||||
class OpenSearchClient(VectorDBBase):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.index_prefix = "open_webui"
 | 
			
		||||
        self.client = OpenSearch(
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
 | 
			
		||||
from sqlalchemy.ext.mutable import MutableDict
 | 
			
		||||
from sqlalchemy.exc import NoSuchTableError
 | 
			
		||||
 | 
			
		||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 | 
			
		||||
from open_webui.retrieval.vector.main import (
 | 
			
		||||
    VectorDBBase,
 | 
			
		||||
    VectorItem,
 | 
			
		||||
    SearchResult,
 | 
			
		||||
    GetResult,
 | 
			
		||||
)
 | 
			
		||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
 | 
			
		||||
 | 
			
		||||
from open_webui.env import SRC_LOG_LEVELS
 | 
			
		||||
@ -44,7 +49,7 @@ class DocumentChunk(Base):
 | 
			
		||||
    vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PgvectorClient:
 | 
			
		||||
class PgvectorClient(VectorDBBase):
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
 | 
			
		||||
        # if no pgvector uri, use the existing database connection
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,12 @@ from typing import Optional, List, Dict, Any, Union
 | 
			
		||||
import logging
 | 
			
		||||
from pinecone import Pinecone, ServerlessSpec
 | 
			
		||||
 | 
			
		||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 | 
			
		||||
from open_webui.retrieval.vector.main import (
 | 
			
		||||
    VectorDBBase,
 | 
			
		||||
    VectorItem,
 | 
			
		||||
    SearchResult,
 | 
			
		||||
    GetResult,
 | 
			
		||||
)
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
    PINECONE_API_KEY,
 | 
			
		||||
    PINECONE_ENVIRONMENT,
 | 
			
		||||
@ -20,7 +25,7 @@ log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PineconeClient:
 | 
			
		||||
class PineconeClient(VectorDBBase):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.collection_prefix = "open-webui"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,12 @@ from qdrant_client import QdrantClient as Qclient
 | 
			
		||||
from qdrant_client.http.models import PointStruct
 | 
			
		||||
from qdrant_client.models import models
 | 
			
		||||
 | 
			
		||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 | 
			
		||||
from open_webui.retrieval.vector.main import (
 | 
			
		||||
    VectorDBBase,
 | 
			
		||||
    VectorItem,
 | 
			
		||||
    SearchResult,
 | 
			
		||||
    GetResult,
 | 
			
		||||
)
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
    QDRANT_URI,
 | 
			
		||||
    QDRANT_API_KEY,
 | 
			
		||||
@ -22,7 +27,7 @@ log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QdrantClient:
 | 
			
		||||
class QdrantClient(VectorDBBase):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.collection_prefix = "open-webui"
 | 
			
		||||
        self.QDRANT_URI = QDRANT_URI
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from typing import Optional, List, Any
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import Any, Dict, List, Optional, Union
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VectorItem(BaseModel):
 | 
			
		||||
@ -17,3 +18,69 @@ class GetResult(BaseModel):
 | 
			
		||||
 | 
			
		||||
class SearchResult(GetResult):
 | 
			
		||||
    distances: Optional[List[List[float | int]]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VectorDBBase(ABC):
 | 
			
		||||
    """
 | 
			
		||||
    Abstract base class for all vector database backends.
 | 
			
		||||
 | 
			
		||||
    Implementations of this class provide methods for collection management,
 | 
			
		||||
    vector insertion, deletion, similarity search, and metadata filtering.
 | 
			
		||||
 | 
			
		||||
    Any custom vector database integration must inherit from this class and
 | 
			
		||||
    implement all abstract methods.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def has_collection(self, collection_name: str) -> bool:
 | 
			
		||||
        """Check if the collection exists in the vector DB."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def delete_collection(self, collection_name: str) -> None:
 | 
			
		||||
        """Delete a collection from the vector DB."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def insert(self, collection_name: str, items: List[VectorItem]) -> None:
 | 
			
		||||
        """Insert a list of vector items into a collection."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
 | 
			
		||||
        """Insert or update vector items in a collection."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def search(
 | 
			
		||||
        self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
 | 
			
		||||
    ) -> Optional[SearchResult]:
 | 
			
		||||
        """Search for similar vectors in a collection."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def query(
 | 
			
		||||
        self, collection_name: str, filter: Dict, limit: Optional[int] = None
 | 
			
		||||
    ) -> Optional[GetResult]:
 | 
			
		||||
        """Query vectors from a collection using metadata filter."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get(self, collection_name: str) -> Optional[GetResult]:
 | 
			
		||||
        """Retrieve all vectors from a collection."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def delete(
 | 
			
		||||
        self,
 | 
			
		||||
        collection_name: str,
 | 
			
		||||
        ids: Optional[List[str]] = None,
 | 
			
		||||
        filter: Optional[Dict] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Delete vectors by ID or filter from a collection."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def reset(self) -> None:
 | 
			
		||||
        """Reset the vector database by removing all collections or those matching a condition."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user