mirror of
https://github.com/open-webui/open-webui
synced 2025-06-16 11:23:56 +00:00
feat: Add abstract base class for vector database integration
- Created `VectorDBBase` as an abstract base class to standardize vector database operations. - Added required methods for common vector database operations: `has_collection`, `delete_collection`, `insert`, `upsert`, `search`, `query`, `get`, `delete`, `reset`. - The base class can now be extended by any vector database implementation (e.g., Qdrant, Pinecone) to ensure a consistent API across different database systems.
This commit is contained in:
parent
913f8a15f9
commit
1e291aff25
@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
SearchResult,
|
||||||
|
GetResult,
|
||||||
|
)
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
CHROMA_DATA_PATH,
|
CHROMA_DATA_PATH,
|
||||||
CHROMA_HTTP_HOST,
|
CHROMA_HTTP_HOST,
|
||||||
@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
|
|||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class ChromaClient:
|
class ChromaClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
settings_dict = {
|
settings_dict = {
|
||||||
"allow_reset": True,
|
"allow_reset": True,
|
||||||
|
@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import ssl
|
import ssl
|
||||||
from elasticsearch.helpers import bulk, scan
|
from elasticsearch.helpers import bulk, scan
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
SearchResult,
|
||||||
|
GetResult,
|
||||||
|
)
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
ELASTICSEARCH_URL,
|
ELASTICSEARCH_URL,
|
||||||
ELASTICSEARCH_CA_CERTS,
|
ELASTICSEARCH_CA_CERTS,
|
||||||
@ -15,7 +20,7 @@ from open_webui.config import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ElasticsearchClient:
|
class ElasticsearchClient(VectorDBBase):
|
||||||
"""
|
"""
|
||||||
Important:
|
Important:
|
||||||
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
|
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
|
||||||
|
@ -4,7 +4,12 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
SearchResult,
|
||||||
|
GetResult,
|
||||||
|
)
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
MILVUS_URI,
|
MILVUS_URI,
|
||||||
MILVUS_DB,
|
MILVUS_DB,
|
||||||
@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
|
|||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class MilvusClient:
|
class MilvusClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.collection_prefix = "open_webui"
|
self.collection_prefix = "open_webui"
|
||||||
if MILVUS_TOKEN is None:
|
if MILVUS_TOKEN is None:
|
||||||
|
@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
|
|||||||
from opensearchpy.helpers import bulk
|
from opensearchpy.helpers import bulk
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
SearchResult,
|
||||||
|
GetResult,
|
||||||
|
)
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
OPENSEARCH_URI,
|
OPENSEARCH_URI,
|
||||||
OPENSEARCH_SSL,
|
OPENSEARCH_SSL,
|
||||||
@ -12,7 +17,7 @@ from open_webui.config import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenSearchClient:
|
class OpenSearchClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.index_prefix = "open_webui"
|
self.index_prefix = "open_webui"
|
||||||
self.client = OpenSearch(
|
self.client = OpenSearch(
|
||||||
|
@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
|
|||||||
from sqlalchemy.ext.mutable import MutableDict
|
from sqlalchemy.ext.mutable import MutableDict
|
||||||
from sqlalchemy.exc import NoSuchTableError
|
from sqlalchemy.exc import NoSuchTableError
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
SearchResult,
|
||||||
|
GetResult,
|
||||||
|
)
|
||||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
@ -44,7 +49,7 @@ class DocumentChunk(Base):
|
|||||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class PgvectorClient:
|
class PgvectorClient(VectorDBBase):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
||||||
# if no pgvector uri, use the existing database connection
|
# if no pgvector uri, use the existing database connection
|
||||||
|
@ -2,7 +2,12 @@ from typing import Optional, List, Dict, Any, Union
|
|||||||
import logging
|
import logging
|
||||||
from pinecone import Pinecone, ServerlessSpec
|
from pinecone import Pinecone, ServerlessSpec
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
SearchResult,
|
||||||
|
GetResult,
|
||||||
|
)
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
PINECONE_API_KEY,
|
PINECONE_API_KEY,
|
||||||
PINECONE_ENVIRONMENT,
|
PINECONE_ENVIRONMENT,
|
||||||
@ -20,7 +25,7 @@ log = logging.getLogger(__name__)
|
|||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class PineconeClient:
|
class PineconeClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.collection_prefix = "open-webui"
|
self.collection_prefix = "open-webui"
|
||||||
|
|
||||||
|
@ -6,7 +6,12 @@ from qdrant_client import QdrantClient as Qclient
|
|||||||
from qdrant_client.http.models import PointStruct
|
from qdrant_client.http.models import PointStruct
|
||||||
from qdrant_client.models import models
|
from qdrant_client.models import models
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
SearchResult,
|
||||||
|
GetResult,
|
||||||
|
)
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
QDRANT_URI,
|
QDRANT_URI,
|
||||||
QDRANT_API_KEY,
|
QDRANT_API_KEY,
|
||||||
@ -22,7 +27,7 @@ log = logging.getLogger(__name__)
|
|||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class QdrantClient:
|
class QdrantClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.collection_prefix = "open-webui"
|
self.collection_prefix = "open-webui"
|
||||||
self.QDRANT_URI = QDRANT_URI
|
self.QDRANT_URI = QDRANT_URI
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List, Any
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
class VectorItem(BaseModel):
|
class VectorItem(BaseModel):
|
||||||
@ -17,3 +18,69 @@ class GetResult(BaseModel):
|
|||||||
|
|
||||||
class SearchResult(GetResult):
|
class SearchResult(GetResult):
|
||||||
distances: Optional[List[List[float | int]]]
|
distances: Optional[List[List[float | int]]]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBBase(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all vector database backends.
|
||||||
|
|
||||||
|
Implementations of this class provide methods for collection management,
|
||||||
|
vector insertion, deletion, similarity search, and metadata filtering.
|
||||||
|
|
||||||
|
Any custom vector database integration must inherit from this class and
|
||||||
|
implement all abstract methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
|
"""Check if the collection exists in the vector DB."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
|
"""Delete a collection from the vector DB."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
|
"""Insert a list of vector items into a collection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
|
"""Insert or update vector items in a collection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(
|
||||||
|
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||||
|
) -> Optional[SearchResult]:
|
||||||
|
"""Search for similar vectors in a collection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(
|
||||||
|
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||||
|
) -> Optional[GetResult]:
|
||||||
|
"""Query vectors from a collection using metadata filter."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
|
"""Retrieve all vectors from a collection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
filter: Optional[Dict] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Delete vectors by ID or filter from a collection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the vector database by removing all collections or those matching a condition."""
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user