This commit is contained in:
Timothy J. Baek
2024-09-10 04:37:06 +01:00
parent d5f13dd9e0
commit 522afbb0a0
7 changed files with 240 additions and 127 deletions

View File

@@ -1,6 +1,10 @@
import chromadb
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
@@ -12,7 +16,7 @@ from open_webui.config import (
)
class Chroma:
class ChromaClient:
def __init__(self):
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(
@@ -32,27 +36,73 @@ class Chroma:
database=CHROMA_DATABASE,
)
def query_collection(self, name, query_embeddings, k):
collection = self.client.get_collection(name=name)
def list_collections(self) -> list[str]:
collections = self.client.list_collections()
return [collection.name for collection in collections]
def create_collection(self, collection_name: str):
return self.client.create_collection(name=collection_name)
def delete_collection(self, collection_name: str):
return self.client.delete_collection(name=collection_name)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
query_embeddings=vectors,
n_results=limit,
)
return result
return {
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
return None
def list_collections(self):
return self.client.list_collections()
def get(self, collection_name: str) -> Optional[QueryResult]:
collection = self.client.get_collection(name=collection_name)
if collection:
return collection.get()
return None
def create_collection(self, name):
return self.client.create_collection(name=name)
def insert(self, collection_name: str, items: list[VectorItem]):
collection = self.client.get_or_create_collection(name=collection_name)
def get_or_create_collection(self, name):
return self.client.get_or_create_collection(name=name)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items]
def delete_collection(self, name):
return self.client.delete_collection(name=name)
for batch in create_batches(
api=self.client,
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
):
collection.add(*batch)
def upsert(self, collection_name: str, items: list[VectorItem]):
collection = self.client.get_or_create_collection(name=collection_name)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadata = [item["metadata"] for item in items]
collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadata=metadata
)
def delete(self, collection_name: str, ids: list[str]):
collection = self.client.get_collection(name=collection_name)
if collection:
collection.delete(ids=ids)
def reset(self):
return self.client.reset()

View File

@@ -0,0 +1,39 @@
from pymilvus import MilvusClient as Milvus
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
class MilvusClient:
def __init__(self):
self.client = Milvus()
def list_collections(self) -> list[str]:
pass
def create_collection(self, collection_name: str):
pass
def delete_collection(self, collection_name: str):
pass
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]:
pass
def get(self, collection_name: str) -> Optional[QueryResult]:
pass
def insert(self, collection_name: str, items: list[VectorItem]):
pass
def upsert(self, collection_name: str, items: list[VectorItem]):
pass
def delete(self, collection_name: str, ids: list[str]):
pass
def reset(self):
pass