mirror of
https://github.com/open-webui/open-webui
synced 2024-11-21 23:57:51 +00:00
Merge pull request #6645 from jk-f5/feat/pgvector
feat: Intial support for pgvector as backing vector database
This commit is contained in:
commit
1ee7730fac
@ -16,4 +16,5 @@ _old
|
||||
uploads
|
||||
.ipynb_checkpoints
|
||||
**/*.db
|
||||
_test
|
||||
_test
|
||||
backend/data/*
|
||||
|
@ -12,6 +12,10 @@ elif VECTOR_DB == "opensearch":
|
||||
from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = OpenSearchClient()
|
||||
elif VECTOR_DB == "pgvector":
|
||||
from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
VECTOR_DB_CLIENT = PgvectorClient()
|
||||
else:
|
||||
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
|
352
backend/open_webui/apps/retrieval/vector/dbs/pgvector.py
Normal file
352
backend/open_webui/apps/retrieval/vector/dbs/pgvector.py
Normal file
@ -0,0 +1,352 @@
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy import (
|
||||
cast,
|
||||
column,
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
select,
|
||||
text,
|
||||
Text,
|
||||
values,
|
||||
)
|
||||
from sqlalchemy.sql import true
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import PGVECTOR_DB_URL
|
||||
|
||||
VECTOR_LENGTH = 1536
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||
collection_name = Column(Text, nullable=False)
|
||||
text = Column(Text, nullable=True)
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class PgvectorClient:
|
||||
def __init__(self) -> None:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
if not PGVECTOR_DB_URL:
|
||||
from open_webui.apps.webui.internal.db import Session
|
||||
|
||||
self.session = Session
|
||||
else:
|
||||
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
self.session = scoped_session(SessionLocal)
|
||||
|
||||
try:
|
||||
# Ensure the pgvector extension is available
|
||||
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||
|
||||
# Create the tables if they do not exist
|
||||
# Base.metadata.create_all requires a bind (engine or connection)
|
||||
# Get the connection from the session
|
||||
connection = self.session.connection()
|
||||
Base.metadata.create_all(bind=connection)
|
||||
|
||||
# Create an index on the vector column if it doesn't exist
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
|
||||
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
|
||||
)
|
||||
)
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||
"ON document_chunk (collection_name);"
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
print("Initialization complete.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during initialization: {e}")
|
||||
raise
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
# Adjust vector to have length VECTOR_LENGTH
|
||||
current_length = len(vector)
|
||||
if current_length < VECTOR_LENGTH:
|
||||
# Pad the vector with zeros
|
||||
vector += [0.0] * (VECTOR_LENGTH - current_length)
|
||||
elif current_length > VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
|
||||
)
|
||||
return vector
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
print(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = item["metadata"]
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
)
|
||||
else:
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: List[List[float]],
|
||||
limit: Optional[int] = None,
|
||||
) -> Optional[SearchResult]:
|
||||
try:
|
||||
if not vectors:
|
||||
return None
|
||||
|
||||
# Adjust query vectors to VECTOR_LENGTH
|
||||
vectors = [self.adjust_vector_length(vector) for vector in vectors]
|
||||
num_queries = len(vectors)
|
||||
|
||||
def vector_expr(vector):
|
||||
return cast(array(vector), Vector(VECTOR_LENGTH))
|
||||
|
||||
# Create the values for query vectors
|
||||
qid_col = column("qid", Integer)
|
||||
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
||||
query_vectors = (
|
||||
values(qid_col, q_vector_col)
|
||||
.data(
|
||||
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
|
||||
)
|
||||
.alias("query_vectors")
|
||||
)
|
||||
|
||||
# Build the lateral subquery for each query vector
|
||||
subq = (
|
||||
select(
|
||||
DocumentChunk.id,
|
||||
DocumentChunk.text,
|
||||
DocumentChunk.vmetadata,
|
||||
(
|
||||
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
||||
).label("distance"),
|
||||
)
|
||||
.where(DocumentChunk.collection_name == collection_name)
|
||||
.order_by(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
)
|
||||
)
|
||||
if limit is not None:
|
||||
subq = subq.limit(limit)
|
||||
subq = subq.lateral("result")
|
||||
|
||||
# Build the main query by joining query_vectors and the lateral subquery
|
||||
stmt = (
|
||||
select(
|
||||
query_vectors.c.qid,
|
||||
subq.c.id,
|
||||
subq.c.text,
|
||||
subq.c.vmetadata,
|
||||
subq.c.distance,
|
||||
)
|
||||
.select_from(query_vectors)
|
||||
.join(subq, true())
|
||||
.order_by(query_vectors.c.qid, subq.c.distance)
|
||||
)
|
||||
|
||||
result_proxy = self.session.execute(stmt)
|
||||
results = result_proxy.all()
|
||||
|
||||
ids = [[] for _ in range(num_queries)]
|
||||
distances = [[] for _ in range(num_queries)]
|
||||
documents = [[] for _ in range(num_queries)]
|
||||
metadatas = [[] for _ in range(num_queries)]
|
||||
|
||||
if not results:
|
||||
return SearchResult(
|
||||
ids=ids,
|
||||
distances=distances,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
for row in results:
|
||||
qid = int(row.qid)
|
||||
ids[qid].append(row.id)
|
||||
distances[qid].append(row.distance)
|
||||
documents[qid].append(row.text)
|
||||
metadatas[qid].append(row.vmetadata)
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during search: {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
return GetResult(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during query: {e}")
|
||||
return None
|
||||
|
||||
def get(
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
print(f"Error during get: {e}")
|
||||
return None
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
print(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during delete: {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
deleted = self.session.query(DocumentChunk).delete()
|
||||
self.session.commit()
|
||||
print(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during reset: {e}")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
try:
|
||||
exists = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.collection_name == collection_name)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
return exists
|
||||
except Exception as e:
|
||||
print(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
self.delete(collection_name)
|
||||
print(f"Collection '{collection_name}' deleted.")
|
@ -20,6 +20,7 @@ from open_webui.env import (
|
||||
WEBUI_FAVICON_URL,
|
||||
WEBUI_NAME,
|
||||
log,
|
||||
DATABASE_URL,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
||||
@ -964,6 +965,11 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
|
||||
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
|
||||
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
|
||||
|
||||
# Pgvector
|
||||
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", None)
|
||||
if VECTOR_DB == 'pgvector' and not (DATABASE_URL.startswith("postgres") or PGVECTOR_DB_URL):
|
||||
raise ValueError("Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database.")
|
||||
|
||||
####################################
|
||||
# Information Retrieval (RAG)
|
||||
####################################
|
||||
|
@ -19,6 +19,7 @@ alembic==1.13.2
|
||||
peewee==3.17.6
|
||||
peewee-migrate==1.12.2
|
||||
psycopg2-binary==2.9.9
|
||||
pgvector==0.3.5
|
||||
PyMySQL==1.1.1
|
||||
bcrypt==4.2.0
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user