refac: "rag" endpoints renamed to "retrieval"

This commit is contained in:
Timothy J. Baek
2024-09-28 01:27:46 +02:00
parent 6e9db3e3c8
commit e1103305f5
27 changed files with 41 additions and 34 deletions

View File

@@ -0,0 +1,10 @@
from open_webui.apps.rag.vector.dbs.chroma import ChromaClient
from open_webui.apps.rag.vector.dbs.milvus import MilvusClient
from open_webui.config import VECTOR_DB
if VECTOR_DB == "milvus":
VECTOR_DB_CLIENT = MilvusClient()
else:
VECTOR_DB_CLIENT = ChromaClient()

View File

@@ -0,0 +1,122 @@
import chromadb
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
CHROMA_HTTP_PORT,
CHROMA_HTTP_HEADERS,
CHROMA_HTTP_SSL,
CHROMA_TENANT,
CHROMA_DATABASE,
)
class ChromaClient:
def __init__(self):
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
headers=CHROMA_HTTP_HEADERS,
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
else:
self.client = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
collections = self.client.list_collections()
return collection_name in [collection.name for collection in collections]
def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
return self.client.delete_collection(name=collection_name)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=vectors,
n_results=limit,
)
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.get()
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
}
)
return None
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(name=collection_name)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items]
for batch in create_batches(
api=self.client,
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
):
collection.add(*batch)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(name=collection_name)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items]
collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
)
def delete(self, collection_name: str, ids: list[str]):
# Delete the items from the collection based on the ids.
collection = self.client.get_collection(name=collection_name)
if collection:
collection.delete(ids=ids)
def reset(self):
# Resets the database. This will delete all collections and item entries.
return self.client.reset()

View File

@@ -0,0 +1,204 @@
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
import json
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
MILVUS_URI,
)
class MilvusClient:
def __init__(self):
self.collection_prefix = "open_webui"
self.client = Client(uri=MILVUS_URI)
def _result_to_get_result(self, result) -> GetResult:
ids = []
documents = []
metadatas = []
for match in result:
_ids = []
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text"))
_metadatas.append(item.get("metadata"))
ids.append(_ids)
documents.append(_documents)
metadatas.append(_metadatas)
return GetResult(
**{
"ids": ids,
"documents": documents,
"metadatas": metadatas,
}
)
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
documents = []
metadatas = []
for match in result:
_ids = []
_distances = []
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_distances.append(item.get("distance"))
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata"))
ids.append(_ids)
distances.append(_distances)
documents.append(_documents)
metadatas.append(_metadatas)
return SearchResult(
**{
"ids": ids,
"distances": distances,
"documents": documents,
"metadatas": metadatas,
}
)
def _create_collection(self, collection_name: str, dimension: int):
schema = self.client.create_schema(
auto_id=False,
enable_dynamic_field=True,
)
schema.add_field(
field_name="id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=65535,
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=dimension,
description="vector",
)
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
schema.add_field(
field_name="metadata", datatype=DataType.JSON, description="metadata"
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="HNSW",
metric_type="COSINE",
params={"M": 16, "efConstruction": 100},
)
self.client.create_collection(
collection_name=f"{self.collection_prefix}_{collection_name}",
schema=schema,
index_params=index_params,
)
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
return self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
return self.client.drop_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors,
limit=limit,
output_fields=["data", "metadata"],
)
return self._result_to_search_result(result)
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter='id != ""',
)
return self._result_to_get_result([result])
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": item["metadata"],
}
for item in items
],
)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": item["metadata"],
}
for item in items
],
)
def delete(self, collection_name: str, ids: list[str]):
# Delete the items from the collection based on the ids.
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids,
)
def reset(self):
# Resets the database. This will delete all collections and item entries.
collection_names = self.client.list_collections()
for collection_name in collection_names:
if collection_name.startswith(self.collection_prefix):
self.client.drop_collection(collection_name=collection_name)

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from typing import Optional, List, Any
class VectorItem(BaseModel):
id: str
text: str
vector: List[float | int]
metadata: Any
class GetResult(BaseModel):
ids: Optional[List[List[str]]]
documents: Optional[List[List[str]]]
metadatas: Optional[List[List[Any]]]
class SearchResult(GetResult):
distances: Optional[List[List[float | int]]]