mirror of
https://github.com/open-webui/open-webui
synced 2025-06-14 18:33:15 +00:00
feat: Add abstract base class for vector database integration
- Created `VectorDBBase` as an abstract base class to standardize vector database operations. - Added required methods for common vector database operations: `has_collection`, `delete_collection`, `insert`, `upsert`, `search`, `query`, `get`, `delete`, `reset`. - The base class can now be extended by any vector database implementation (e.g., Qdrant, Pinecone) to ensure a consistent API across different database systems.
This commit is contained in:
parent
913f8a15f9
commit
1e291aff25
@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ChromaClient:
|
||||
class ChromaClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
settings_dict = {
|
||||
"allow_reset": True,
|
||||
|
@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
|
||||
from typing import Optional
|
||||
import ssl
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
ELASTICSEARCH_URL,
|
||||
ELASTICSEARCH_CA_CERTS,
|
||||
@ -15,7 +20,7 @@ from open_webui.config import (
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchClient:
|
||||
class ElasticsearchClient(VectorDBBase):
|
||||
"""
|
||||
Important:
|
||||
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
|
||||
|
@ -4,7 +4,12 @@ import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
MILVUS_DB,
|
||||
@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
class MilvusClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
if MILVUS_TOKEN is None:
|
||||
|
@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
OPENSEARCH_URI,
|
||||
OPENSEARCH_SSL,
|
||||
@ -12,7 +17,7 @@ from open_webui.config import (
|
||||
)
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
class OpenSearchClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.index_prefix = "open_webui"
|
||||
self.client = OpenSearch(
|
||||
|
@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
@ -44,7 +49,7 @@ class DocumentChunk(Base):
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class PgvectorClient:
|
||||
class PgvectorClient(VectorDBBase):
|
||||
def __init__(self) -> None:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
|
@ -2,7 +2,12 @@ from typing import Optional, List, Dict, Any, Union
|
||||
import logging
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
PINECONE_API_KEY,
|
||||
PINECONE_ENVIRONMENT,
|
||||
@ -20,7 +25,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class PineconeClient:
|
||||
class PineconeClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
|
||||
|
@ -6,7 +6,12 @@ from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
QDRANT_URI,
|
||||
QDRANT_API_KEY,
|
||||
@ -22,7 +27,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
class QdrantClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
|
@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class VectorItem(BaseModel):
|
||||
@ -17,3 +18,69 @@ class GetResult(BaseModel):
|
||||
|
||||
class SearchResult(GetResult):
|
||||
distances: Optional[List[List[float | int]]]
|
||||
|
||||
|
||||
class VectorDBBase(ABC):
|
||||
"""
|
||||
Abstract base class for all vector database backends.
|
||||
|
||||
Implementations of this class provide methods for collection management,
|
||||
vector insertion, deletion, similarity search, and metadata filtering.
|
||||
|
||||
Any custom vector database integration must inherit from this class and
|
||||
implement all abstract methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""Check if the collection exists in the vector DB."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""Delete a collection from the vector DB."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Insert a list of vector items into a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Insert or update vector items in a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
"""Search for similar vectors in a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
"""Query vectors from a collection using metadata filter."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""Retrieve all vectors from a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Delete vectors by ID or filter from a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by removing all collections or those matching a condition."""
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user