diff --git a/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py b/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py index f096e244a..fa496e7d9 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py @@ -2,6 +2,7 @@ from typing import Optional, List, Dict, Any from sqlalchemy import ( cast, column, + create_engine, Column, Integer, select, @@ -10,15 +11,15 @@ from sqlalchemy import ( values, ) from sqlalchemy.sql import true +from sqlalchemy.pool import NullPool -from sqlalchemy.orm import declarative_base, Session +from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.dialects.postgresql import JSONB, array from pgvector.sqlalchemy import Vector from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.ext.declarative import declarative_base -from open_webui.apps.webui.internal.db import Session from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import PGVECTOR_DB_URL VECTOR_LENGTH = 1536 Base = declarative_base() @@ -36,7 +37,19 @@ class DocumentChunk(Base): class PgvectorClient: def __init__(self) -> None: - self.session = Session + + # if no pgvector uri, use the existing database connection + if not PGVECTOR_DB_URL: + from open_webui.apps.webui.internal.db import Session + + self.session = Session + else: + engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool) + SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine, expire_on_commit=False + ) + self.session = scoped_session(SessionLocal) + try: # Ensure the pgvector extension is available self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 316e866bd..7680c75a3 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -932,9 +932,6 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( VECTOR_DB = os.environ.get("VECTOR_DB", "chroma") -if VECTOR_DB == 'pgvector' and not DATABASE_URL.startswith("postgres"): - raise ValueError("Pgvector requires using Postgres with vector extension as the primary database.") - # Chroma CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) @@ -968,6 +965,11 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False) OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None) OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None) +# Pgvector +PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", None) +if VECTOR_DB == 'pgvector' and not (DATABASE_URL.startswith("postgres") or PGVECTOR_DB_URL): + raise ValueError("Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database.") + #################################### # Information Retrieval (RAG) ####################################