mirror of
https://github.com/open-webui/open-webui
synced 2024-11-07 09:09:53 +00:00
109 lines
3.6 KiB
Python
109 lines
3.6 KiB
Python
import chromadb
|
|
from chromadb import Settings
|
|
from chromadb.utils.batch_utils import create_batches
|
|
|
|
from typing import Optional
|
|
|
|
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
|
|
from open_webui.config import (
|
|
CHROMA_DATA_PATH,
|
|
CHROMA_HTTP_HOST,
|
|
CHROMA_HTTP_PORT,
|
|
CHROMA_HTTP_HEADERS,
|
|
CHROMA_HTTP_SSL,
|
|
CHROMA_TENANT,
|
|
CHROMA_DATABASE,
|
|
)
|
|
|
|
|
|
class ChromaClient:
|
|
def __init__(self):
|
|
if CHROMA_HTTP_HOST != "":
|
|
self.client = chromadb.HttpClient(
|
|
host=CHROMA_HTTP_HOST,
|
|
port=CHROMA_HTTP_PORT,
|
|
headers=CHROMA_HTTP_HEADERS,
|
|
ssl=CHROMA_HTTP_SSL,
|
|
tenant=CHROMA_TENANT,
|
|
database=CHROMA_DATABASE,
|
|
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
|
)
|
|
else:
|
|
self.client = chromadb.PersistentClient(
|
|
path=CHROMA_DATA_PATH,
|
|
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
|
tenant=CHROMA_TENANT,
|
|
database=CHROMA_DATABASE,
|
|
)
|
|
|
|
def list_collections(self) -> list[str]:
|
|
collections = self.client.list_collections()
|
|
return [collection.name for collection in collections]
|
|
|
|
def create_collection(self, collection_name: str):
|
|
return self.client.create_collection(name=collection_name)
|
|
|
|
def delete_collection(self, collection_name: str):
|
|
return self.client.delete_collection(name=collection_name)
|
|
|
|
def search(
|
|
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
|
) -> Optional[QueryResult]:
|
|
collection = self.client.get_collection(name=collection_name)
|
|
if collection:
|
|
result = collection.query(
|
|
query_embeddings=vectors,
|
|
n_results=limit,
|
|
)
|
|
|
|
return {
|
|
"ids": result["ids"],
|
|
"distances": result["distances"],
|
|
"documents": result["documents"],
|
|
"metadatas": result["metadatas"],
|
|
}
|
|
return None
|
|
|
|
def get(self, collection_name: str) -> Optional[QueryResult]:
|
|
collection = self.client.get_collection(name=collection_name)
|
|
if collection:
|
|
return collection.get()
|
|
return None
|
|
|
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
|
collection = self.client.get_or_create_collection(name=collection_name)
|
|
|
|
ids = [item["id"] for item in items]
|
|
documents = [item["text"] for item in items]
|
|
embeddings = [item["vector"] for item in items]
|
|
metadatas = [item["metadata"] for item in items]
|
|
|
|
for batch in create_batches(
|
|
api=self.client,
|
|
documents=documents,
|
|
embeddings=embeddings,
|
|
ids=ids,
|
|
metadatas=metadatas,
|
|
):
|
|
collection.add(*batch)
|
|
|
|
def upsert(self, collection_name: str, items: list[VectorItem]):
|
|
collection = self.client.get_or_create_collection(name=collection_name)
|
|
|
|
ids = [item["id"] for item in items]
|
|
documents = [item["text"] for item in items]
|
|
embeddings = [item["vector"] for item in items]
|
|
metadata = [item["metadata"] for item in items]
|
|
|
|
collection.upsert(
|
|
ids=ids, documents=documents, embeddings=embeddings, metadata=metadata
|
|
)
|
|
|
|
def delete(self, collection_name: str, ids: list[str]):
|
|
collection = self.client.get_collection(name=collection_name)
|
|
if collection:
|
|
collection.delete(ids=ids)
|
|
|
|
def reset(self):
|
|
return self.client.reset()
|