mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
feat: Add abstract base class for vector database integration
- Created `VectorDBBase` as an abstract base class to standardize vector database operations. - Added required methods for common vector database operations: `has_collection`, `delete_collection`, `insert`, `upsert`, `search`, `query`, `get`, `delete`, `reset`. - The base class can now be extended by any vector database implementation (e.g., Qdrant, Pinecone) to ensure a consistent API across different database systems.
This commit is contained in:
committed by
Athanasios Oikonomou
parent
913f8a15f9
commit
1e291aff25
@@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
@@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ChromaClient:
|
||||
class ChromaClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
settings_dict = {
|
||||
"allow_reset": True,
|
||||
|
||||
@@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
|
||||
from typing import Optional
|
||||
import ssl
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
ELASTICSEARCH_URL,
|
||||
ELASTICSEARCH_CA_CERTS,
|
||||
@@ -15,7 +20,7 @@ from open_webui.config import (
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchClient:
|
||||
class ElasticsearchClient(VectorDBBase):
|
||||
"""
|
||||
Important:
|
||||
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
|
||||
|
||||
@@ -4,7 +4,12 @@ import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
MILVUS_DB,
|
||||
@@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
class MilvusClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
if MILVUS_TOKEN is None:
|
||||
|
||||
@@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
OPENSEARCH_URI,
|
||||
OPENSEARCH_SSL,
|
||||
@@ -12,7 +17,7 @@ from open_webui.config import (
|
||||
)
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
class OpenSearchClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.index_prefix = "open_webui"
|
||||
self.client = OpenSearch(
|
||||
|
||||
@@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
@@ -44,7 +49,7 @@ class DocumentChunk(Base):
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class PgvectorClient:
|
||||
class PgvectorClient(VectorDBBase):
|
||||
def __init__(self) -> None:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
|
||||
@@ -2,7 +2,12 @@ from typing import Optional, List, Dict, Any, Union
|
||||
import logging
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
PINECONE_API_KEY,
|
||||
PINECONE_ENVIRONMENT,
|
||||
@@ -20,7 +25,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class PineconeClient:
|
||||
class PineconeClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
QDRANT_URI,
|
||||
QDRANT_API_KEY,
|
||||
@@ -22,7 +27,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
class QdrantClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
|
||||
Reference in New Issue
Block a user