Merge pull request #13098 from athoik/dev

feat: Add abstract base class for vector database integration
This commit is contained in:
Tim Jaeryang Baek 2025-04-22 23:24:08 -07:00 committed by GitHub
commit d3e516934c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 117 additions and 15 deletions

View File

@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient:
class ChromaClient(VectorDBBase):
def __init__(self):
settings_dict = {
"allow_reset": True,

View File

@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
ELASTICSEARCH_URL,
ELASTICSEARCH_CA_CERTS,
@ -15,7 +20,7 @@ from open_webui.config import (
)
class ElasticsearchClient:
class ElasticsearchClient(VectorDBBase):
"""
Important:
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating

View File

@ -4,7 +4,12 @@ import json
import logging
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
MILVUS_URI,
MILVUS_DB,
@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient:
class MilvusClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open_webui"
if MILVUS_TOKEN is None:

View File

@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
OPENSEARCH_URI,
OPENSEARCH_SSL,
@ -12,7 +17,7 @@ from open_webui.config import (
)
class OpenSearchClient:
class OpenSearchClient(VectorDBBase):
def __init__(self):
self.index_prefix = "open_webui"
self.client = OpenSearch(

View File

@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.env import SRC_LOG_LEVELS
@ -44,7 +49,7 @@ class DocumentChunk(Base):
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient:
class PgvectorClient(VectorDBBase):
def __init__(self) -> None:
# if no pgvector uri, use the existing database connection

View File

@ -2,7 +2,12 @@ from typing import Optional, List, Dict, Any, Union
import logging
from pinecone import Pinecone, ServerlessSpec
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
PINECONE_API_KEY,
PINECONE_ENVIRONMENT,
@ -20,7 +25,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class PineconeClient:
class PineconeClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"

View File

@ -6,7 +6,12 @@ from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
QDRANT_URI,
QDRANT_API_KEY,
@ -22,7 +27,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient:
class QdrantClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI

View File

@ -1,5 +1,6 @@
from pydantic import BaseModel
from typing import Optional, List, Any
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
class VectorItem(BaseModel):
@ -17,3 +18,69 @@ class GetResult(BaseModel):
class SearchResult(GetResult):
distances: Optional[List[List[float | int]]]
class VectorDBBase(ABC):
"""
Abstract base class for all vector database backends.
Implementations of this class provide methods for collection management,
vector insertion, deletion, similarity search, and metadata filtering.
Any custom vector database integration must inherit from this class and
implement all abstract methods.
"""
@abstractmethod
def has_collection(self, collection_name: str) -> bool:
"""Check if the collection exists in the vector DB."""
pass
@abstractmethod
def delete_collection(self, collection_name: str) -> None:
"""Delete a collection from the vector DB."""
pass
@abstractmethod
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert a list of vector items into a collection."""
pass
@abstractmethod
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert or update vector items in a collection."""
pass
@abstractmethod
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
pass
@abstractmethod
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""Query vectors from a collection using metadata filter."""
pass
@abstractmethod
def get(self, collection_name: str) -> Optional[GetResult]:
"""Retrieve all vectors from a collection."""
pass
@abstractmethod
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""Delete vectors by ID or filter from a collection."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the vector database by removing all collections or those matching a condition."""
pass