From 701f40aedd580e012ffca837bb4de4fec58410de Mon Sep 17 00:00:00 2001 From: Jason Kidd Date: Mon, 4 Nov 2024 12:33:58 -0800 Subject: [PATCH] feat: Initial support for pgvector --- .dockerignore | 3 +- .../apps/retrieval/vector/connector.py | 4 + .../apps/retrieval/vector/dbs/pgvector.py | 339 ++++++++++++++++++ backend/open_webui/config.py | 4 + backend/requirements.txt | 1 + 5 files changed, 350 insertions(+), 1 deletion(-) create mode 100644 backend/open_webui/apps/retrieval/vector/dbs/pgvector.py diff --git a/.dockerignore b/.dockerignore index d7e716758..2b4f7b5fc 100644 --- a/.dockerignore +++ b/.dockerignore @@ -16,4 +16,5 @@ _old uploads .ipynb_checkpoints **/*.db -_test \ No newline at end of file +_test +backend/data/* diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py index b95b55783..528835b56 100644 --- a/backend/open_webui/apps/retrieval/vector/connector.py +++ b/backend/open_webui/apps/retrieval/vector/connector.py @@ -12,6 +12,10 @@ elif VECTOR_DB == "opensearch": from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient VECTOR_DB_CLIENT = OpenSearchClient() +elif VECTOR_DB == "pgvector": + from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient + + VECTOR_DB_CLIENT = PgvectorClient() else: from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient diff --git a/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py b/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py new file mode 100644 index 000000000..f096e244a --- /dev/null +++ b/backend/open_webui/apps/retrieval/vector/dbs/pgvector.py @@ -0,0 +1,339 @@ +from typing import Optional, List, Dict, Any +from sqlalchemy import ( + cast, + column, + Column, + Integer, + select, + text, + Text, + values, +) +from sqlalchemy.sql import true + +from sqlalchemy.orm import declarative_base, Session +from sqlalchemy.dialects.postgresql import JSONB, array +from pgvector.sqlalchemy import Vector +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.ext.declarative import declarative_base + +from open_webui.apps.webui.internal.db import Session +from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult + +VECTOR_LENGTH = 1536 +Base = declarative_base() + + +class DocumentChunk(Base): + __tablename__ = "document_chunk" + + id = Column(Text, primary_key=True) + vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) + collection_name = Column(Text, nullable=False) + text = Column(Text, nullable=True) + vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) + + +class PgvectorClient: + def __init__(self) -> None: + self.session = Session + try: + # Ensure the pgvector extension is available + self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + + # Create the tables if they do not exist + # Base.metadata.create_all requires a bind (engine or connection) + # Get the connection from the session + connection = self.session.connection() + Base.metadata.create_all(bind=connection) + + # Create an index on the vector column if it doesn't exist + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector " + "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);" + ) + ) + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " + "ON document_chunk (collection_name);" + ) + ) + self.session.commit() + print("Initialization complete.") + except Exception as e: + self.session.rollback() + print(f"Error during initialization: {e}") + raise + + def adjust_vector_length(self, vector: List[float]) -> List[float]: + # Adjust vector to have length VECTOR_LENGTH + current_length = len(vector) + if current_length < VECTOR_LENGTH: + # Pad the vector with zeros + vector += [0.0] * (VECTOR_LENGTH - current_length) + elif current_length > VECTOR_LENGTH: + raise Exception( + f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}" + ) + return vector + + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + new_items = [] + for item in items: + vector = self.adjust_vector_length(item["vector"]) + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=item["metadata"], + ) + new_items.append(new_chunk) + self.session.bulk_save_objects(new_items) + self.session.commit() + print( + f"Inserted {len(new_items)} items into collection '{collection_name}'." + ) + except Exception as e: + self.session.rollback() + print(f"Error during insert: {e}") + raise + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + existing = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.id == item["id"]) + .first() + ) + if existing: + existing.vector = vector + existing.text = item["text"] + existing.vmetadata = item["metadata"] + existing.collection_name = ( + collection_name # Update collection_name if necessary + ) + else: + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=item["metadata"], + ) + self.session.add(new_chunk) + self.session.commit() + print(f"Upserted {len(items)} items into collection '{collection_name}'.") + except Exception as e: + self.session.rollback() + print(f"Error during upsert: {e}") + raise + + def search( + self, + collection_name: str, + vectors: List[List[float]], + limit: Optional[int] = None, + ) -> Optional[SearchResult]: + try: + if not vectors: + return None + + # Adjust query vectors to VECTOR_LENGTH + vectors = [self.adjust_vector_length(vector) for vector in vectors] + num_queries = len(vectors) + + def vector_expr(vector): + return cast(array(vector), Vector(VECTOR_LENGTH)) + + # Create the values for query vectors + qid_col = column("qid", Integer) + q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) + query_vectors = ( + values(qid_col, q_vector_col) + .data( + [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)] + ) + .alias("query_vectors") + ) + + # Build the lateral subquery for each query vector + subq = ( + select( + DocumentChunk.id, + DocumentChunk.text, + DocumentChunk.vmetadata, + ( + DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector) + ).label("distance"), + ) + .where(DocumentChunk.collection_name == collection_name) + .order_by( + (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) + ) + ) + if limit is not None: + subq = subq.limit(limit) + subq = subq.lateral("result") + + # Build the main query by joining query_vectors and the lateral subquery + stmt = ( + select( + query_vectors.c.qid, + subq.c.id, + subq.c.text, + subq.c.vmetadata, + subq.c.distance, + ) + .select_from(query_vectors) + .join(subq, true()) + .order_by(query_vectors.c.qid, subq.c.distance) + ) + + result_proxy = self.session.execute(stmt) + results = result_proxy.all() + + ids = [[] for _ in range(num_queries)] + distances = [[] for _ in range(num_queries)] + documents = [[] for _ in range(num_queries)] + metadatas = [[] for _ in range(num_queries)] + + if not results: + return SearchResult( + ids=ids, + distances=distances, + documents=documents, + metadatas=metadatas, + ) + + for row in results: + qid = int(row.qid) + ids[qid].append(row.id) + distances[qid].append(row.distance) + documents[qid].append(row.text) + metadatas[qid].append(row.vmetadata) + + return SearchResult( + ids=ids, distances=distances, documents=documents, metadatas=metadatas + ) + except Exception as e: + print(f"Error during search: {e}") + return None + + def query( + self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + + for key, value in filter.items(): + query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + return GetResult( + ids=ids, + documents=documents, + metadatas=metadatas, + ) + except Exception as e: + print(f"Error during query: {e}") + return None + + def get( + self, collection_name: str, limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + except Exception as e: + print(f"Error during get: {e}") + return None + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, + ) -> None: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if ids: + query = query.filter(DocumentChunk.id.in_(ids)) + if filter: + for key, value in filter.items(): + query = query.filter( + DocumentChunk.vmetadata[key].astext == str(value) + ) + deleted = query.delete(synchronize_session=False) + self.session.commit() + print(f"Deleted {deleted} items from collection '{collection_name}'.") + except Exception as e: + self.session.rollback() + print(f"Error during delete: {e}") + raise + + def reset(self) -> None: + try: + deleted = self.session.query(DocumentChunk).delete() + self.session.commit() + print( + f"Reset complete. Deleted {deleted} items from 'document_chunk' table." + ) + except Exception as e: + self.session.rollback() + print(f"Error during reset: {e}") + raise + + def close(self) -> None: + pass + + def has_collection(self, collection_name: str) -> bool: + try: + exists = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.collection_name == collection_name) + .first() + is not None + ) + return exists + except Exception as e: + print(f"Error checking collection existence: {e}") + return False + + def delete_collection(self, collection_name: str) -> None: + self.delete(collection_name) + print(f"Collection '{collection_name}' deleted.") diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index f7221eeaf..316e866bd 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -20,6 +20,7 @@ from open_webui.env import ( WEBUI_FAVICON_URL, WEBUI_NAME, log, + DATABASE_URL, ) from pydantic import BaseModel from sqlalchemy import JSON, Column, DateTime, Integer, func @@ -931,6 +932,9 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( VECTOR_DB = os.environ.get("VECTOR_DB", "chroma") +if VECTOR_DB == 'pgvector' and not DATABASE_URL.startswith("postgres"): + raise ValueError("Pgvector requires using Postgres with vector extension as the primary database.") + # Chroma CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) diff --git a/backend/requirements.txt b/backend/requirements.txt index 122bc71d2..1a82347dd 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -19,6 +19,7 @@ alembic==1.13.2 peewee==3.17.6 peewee-migrate==1.12.2 psycopg2-binary==2.9.9 +pgvector==0.3.5 PyMySQL==1.1.1 bcrypt==4.2.0