feat: Add ability to set URI for pgvector

This commit is contained in:
Jason Kidd 2024-11-04 13:34:05 -08:00
parent 701f40aedd
commit 319ea8cb7f
No known key found for this signature in database
GPG Key ID: 72BF942827539044
2 changed files with 22 additions and 7 deletions

View File

@ -2,6 +2,7 @@ from typing import Optional, List, Dict, Any
from sqlalchemy import (
cast,
column,
create_engine,
Column,
Integer,
select,
@ -10,15 +11,15 @@ from sqlalchemy import (
values,
)
from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool
from sqlalchemy.orm import declarative_base, Session
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.ext.declarative import declarative_base
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import PGVECTOR_DB_URL
VECTOR_LENGTH = 1536
Base = declarative_base()
@ -36,7 +37,19 @@ class DocumentChunk(Base):
class PgvectorClient:
def __init__(self) -> None:
self.session = Session
# if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL:
from open_webui.apps.webui.internal.db import Session
self.session = Session
else:
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
self.session = scoped_session(SessionLocal)
try:
# Ensure the pgvector extension is available
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))

View File

@ -932,9 +932,6 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
if VECTOR_DB == 'pgvector' and not DATABASE_URL.startswith("postgres"):
raise ValueError("Pgvector requires using Postgres with vector extension as the primary database.")
# Chroma
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
@ -968,6 +965,11 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
# Pgvector
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", None)
if VECTOR_DB == 'pgvector' and not (DATABASE_URL.startswith("postgres") or PGVECTOR_DB_URL):
raise ValueError("Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database.")
####################################
# Information Retrieval (RAG)
####################################