feat: Add abstract base class for vector database integration

- Created `VectorDBBase` as an abstract base class to standardize vector database operations.
- Added required methods for common vector database operations: `has_collection`, `delete_collection`, `insert`, `upsert`, `search`, `query`, `get`, `delete`, `reset`.
- The base class can now be extended by any vector database implementation (e.g., Qdrant, Pinecone) to ensure a consistent API across different database systems.
This commit is contained in:
Athanasios Oikonomou
2025-04-21 08:26:08 +03:00
committed by Athanasios Oikonomou
parent 913f8a15f9
commit 1e291aff25
8 changed files with 117 additions and 15 deletions

View File

@@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
@@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient:
class ChromaClient(VectorDBBase):
def __init__(self):
settings_dict = {
"allow_reset": True,

View File

@@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
ELASTICSEARCH_URL,
ELASTICSEARCH_CA_CERTS,
@@ -15,7 +20,7 @@ from open_webui.config import (
)
class ElasticsearchClient:
class ElasticsearchClient(VectorDBBase):
"""
Important:
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating

View File

@@ -4,7 +4,12 @@ import json
import logging
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
MILVUS_URI,
MILVUS_DB,
@@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient:
class MilvusClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open_webui"
if MILVUS_TOKEN is None:

View File

@@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
OPENSEARCH_URI,
OPENSEARCH_SSL,
@@ -12,7 +17,7 @@ from open_webui.config import (
)
class OpenSearchClient:
class OpenSearchClient(VectorDBBase):
def __init__(self):
self.index_prefix = "open_webui"
self.client = OpenSearch(

View File

@@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.env import SRC_LOG_LEVELS
@@ -44,7 +49,7 @@ class DocumentChunk(Base):
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient:
class PgvectorClient(VectorDBBase):
def __init__(self) -> None:
# if no pgvector uri, use the existing database connection

View File

@@ -2,7 +2,12 @@ from typing import Optional, List, Dict, Any, Union
import logging
from pinecone import Pinecone, ServerlessSpec
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
PINECONE_API_KEY,
PINECONE_ENVIRONMENT,
@@ -20,7 +25,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class PineconeClient:
class PineconeClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"

View File

@@ -6,7 +6,12 @@ from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
QDRANT_URI,
QDRANT_API_KEY,
@@ -22,7 +27,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient:
class QdrantClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI