import chromadb from chromadb import Settings from chromadb.utils.batch_utils import create_batches from typing import Optional from open_webui.apps.rag.vector.main import VectorItem, QueryResult from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, CHROMA_HTTP_PORT, CHROMA_HTTP_HEADERS, CHROMA_HTTP_SSL, CHROMA_TENANT, CHROMA_DATABASE, ) class ChromaClient: def __init__(self): if CHROMA_HTTP_HOST != "": self.client = chromadb.HttpClient( host=CHROMA_HTTP_HOST, port=CHROMA_HTTP_PORT, headers=CHROMA_HTTP_HEADERS, ssl=CHROMA_HTTP_SSL, tenant=CHROMA_TENANT, database=CHROMA_DATABASE, settings=Settings(allow_reset=True, anonymized_telemetry=False), ) else: self.client = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False), tenant=CHROMA_TENANT, database=CHROMA_DATABASE, ) def list_collections(self) -> list[str]: collections = self.client.list_collections() return [collection.name for collection in collections] def create_collection(self, collection_name: str): return self.client.create_collection(name=collection_name) def delete_collection(self, collection_name: str): return self.client.delete_collection(name=collection_name) def search( self, collection_name: str, vectors: list[list[float | int]], limit: int ) -> Optional[QueryResult]: collection = self.client.get_collection(name=collection_name) if collection: result = collection.query( query_embeddings=vectors, n_results=limit, ) return { "ids": result["ids"], "distances": result["distances"], "documents": result["documents"], "metadatas": result["metadatas"], } return None def get(self, collection_name: str) -> Optional[QueryResult]: collection = self.client.get_collection(name=collection_name) if collection: return collection.get() return None def insert(self, collection_name: str, items: list[VectorItem]): collection = self.client.get_or_create_collection(name=collection_name) ids = [item["id"] for item in items] documents = [item["text"] for item in items] embeddings = [item["vector"] for item in items] metadatas = [item["metadata"] for item in items] for batch in create_batches( api=self.client, documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas, ): collection.add(*batch) def upsert(self, collection_name: str, items: list[VectorItem]): collection = self.client.get_or_create_collection(name=collection_name) ids = [item["id"] for item in items] documents = [item["text"] for item in items] embeddings = [item["vector"] for item in items] metadata = [item["metadata"] for item in items] collection.upsert( ids=ids, documents=documents, embeddings=embeddings, metadata=metadata ) def delete(self, collection_name: str, ids: list[str]): collection = self.client.get_collection(name=collection_name) if collection: collection.delete(ids=ids) def reset(self): return self.client.reset()