open-webui/backend/open_webui/apps/rag/vector/dbs/chroma.py

123 lines
4.5 KiB
Python
Raw Normal View History

2024-09-10 01:27:50 +00:00
import chromadb
from chromadb import Settings
2024-09-10 03:37:06 +00:00
from chromadb.utils.batch_utils import create_batches
2024-09-10 01:27:50 +00:00
2024-09-10 03:37:06 +00:00
from typing import Optional
2024-09-13 05:18:20 +00:00
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
2024-09-10 01:27:50 +00:00
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
CHROMA_HTTP_PORT,
CHROMA_HTTP_HEADERS,
CHROMA_HTTP_SSL,
CHROMA_TENANT,
CHROMA_DATABASE,
)
2024-09-10 03:37:06 +00:00
class ChromaClient:
2024-09-10 01:27:50 +00:00
def __init__(self):
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
headers=CHROMA_HTTP_HEADERS,
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
else:
self.client = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)
2024-09-12 06:00:31 +00:00
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
2024-09-10 03:37:06 +00:00
collections = self.client.list_collections()
2024-09-12 06:00:31 +00:00
return collection_name in [collection.name for collection in collections]
2024-09-10 03:37:06 +00:00
def delete_collection(self, collection_name: str):
2024-09-10 03:46:40 +00:00
# Delete the collection based on the collection name.
2024-09-10 03:37:06 +00:00
return self.client.delete_collection(name=collection_name)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
2024-09-13 05:18:20 +00:00
) -> Optional[SearchResult]:
2024-09-10 03:46:40 +00:00
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
2024-09-10 03:37:06 +00:00
collection = self.client.get_collection(name=collection_name)
2024-09-10 01:27:50 +00:00
if collection:
result = collection.query(
2024-09-10 03:37:06 +00:00
query_embeddings=vectors,
n_results=limit,
2024-09-10 01:27:50 +00:00
)
2024-09-10 03:37:06 +00:00
2024-09-13 05:18:20 +00:00
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
2024-09-10 03:37:06 +00:00
return None
2024-09-13 05:18:20 +00:00
def get(self, collection_name: str) -> Optional[GetResult]:
2024-09-10 03:46:40 +00:00
# Get all the items in the collection.
2024-09-10 03:37:06 +00:00
collection = self.client.get_collection(name=collection_name)
if collection:
2024-09-13 05:18:20 +00:00
result = collection.get()
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
}
)
2024-09-10 01:27:50 +00:00
return None
2024-09-10 03:37:06 +00:00
def insert(self, collection_name: str, items: list[VectorItem]):
2024-09-12 05:52:19 +00:00
# Insert the items into the collection, if the collection does not exist, it will be created.
2024-09-10 03:37:06 +00:00
collection = self.client.get_or_create_collection(name=collection_name)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items]
2024-09-10 01:27:50 +00:00
2024-09-10 03:37:06 +00:00
for batch in create_batches(
api=self.client,
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
):
collection.add(*batch)
2024-09-10 01:27:50 +00:00
2024-09-10 03:37:06 +00:00
def upsert(self, collection_name: str, items: list[VectorItem]):
2024-09-12 05:52:19 +00:00
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
2024-09-10 03:37:06 +00:00
collection = self.client.get_or_create_collection(name=collection_name)
2024-09-10 01:27:50 +00:00
2024-09-10 03:37:06 +00:00
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
2024-09-12 18:59:29 +00:00
metadatas = [item["metadata"] for item in items]
2024-09-10 03:37:06 +00:00
collection.upsert(
2024-09-12 18:59:29 +00:00
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
2024-09-10 03:37:06 +00:00
)
def delete(self, collection_name: str, ids: list[str]):
2024-09-10 03:46:40 +00:00
# Delete the items from the collection based on the ids.
2024-09-10 03:37:06 +00:00
collection = self.client.get_collection(name=collection_name)
if collection:
collection.delete(ids=ids)
2024-09-10 01:27:50 +00:00
def reset(self):
2024-09-10 03:46:40 +00:00
# Resets the database. This will delete all collections and item entries.
2024-09-10 01:27:50 +00:00
return self.client.reset()