This commit is contained in:
Timothy J. Baek 2024-09-10 04:37:06 +01:00
parent d5f13dd9e0
commit 522afbb0a0
7 changed files with 240 additions and 127 deletions

View File

@ -96,7 +96,6 @@ from open_webui.utils.misc import (
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from chromadb.utils.batch_utils import create_batches
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
BSHTMLLoader,
@ -998,14 +997,11 @@ def store_docs_in_vector_db(
try:
if overwrite:
for collection in VECTOR_DB_CLIENT.list_collections():
if collection_name == collection.name:
log.info(f"deleting existing collection {collection_name}")
VECTOR_DB_CLIENT.delete_collection(name=collection_name)
if collection_name in VECTOR_DB_CLIENT.list_collections():
log.info(f"deleting existing collection {collection_name}")
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
collection = VECTOR_DB_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function(
embedding_function = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
@ -1014,17 +1010,19 @@ def store_docs_in_vector_db(
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
embeddings = embedding_func(embedding_texts)
for batch in create_batches(
api=VECTOR_DB_CLIENT,
ids=[str(uuid.uuid4()) for _ in texts],
metadatas=metadatas,
embeddings=embeddings,
documents=texts,
):
collection.add(*batch)
VECTOR_DB_CLIENT.create_collection(collection_name=collection_name)
VECTOR_DB_CLIENT.insert(
collection_name=collection_name,
items=[
{
"id": str(uuid.uuid4()),
"text": text,
"vector": embedding_function(text.replace("\n", " ")),
"metadata": metadatas[idx],
}
for idx, text in enumerate(texts)
],
)
return True
except Exception as e:

View File

@ -24,6 +24,44 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
vectors=[self.embedding_function(query)],
limit=self.top_k,
)
ids = result["ids"][0]
metadatas = result["metadatas"][0]
documents = result["documents"][0]
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
def query_doc(
collection_name: str,
query: str,
@ -31,15 +69,18 @@ def query_doc(
k: int,
):
try:
result = VECTOR_DB_CLIENT.query_collection(
name=collection_name,
query_embeddings=embedding_function(query),
k=k,
result = VECTOR_DB_CLIENT.search(
collection_name=collection_name,
vectors=[embedding_function(query)],
limit=k,
)
print("result", result)
log.info(f"query_doc:result {result}")
return result
except Exception as e:
print(e)
raise e
@ -52,25 +93,23 @@ def query_doc_with_hybrid_search(
r: float,
):
try:
collection = VECTOR_DB_CLIENT.get_collection(name=collection_name)
documents = collection.get() # get all documents
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
bm25_retriever = BM25Retriever.from_texts(
texts=documents.get("documents"),
metadatas=documents.get("metadatas"),
texts=result.documents,
metadatas=result.metadatas,
)
bm25_retriever.k = k
chroma_retriever = ChromaRetriever(
collection=collection,
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
embedding_function=embedding_function,
top_n=k,
top_k=k,
)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
)
compressor = RerankCompressor(
embedding_function=embedding_function,
top_n=k,
@ -394,45 +433,6 @@ def generate_openai_batch_embeddings(
return None
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
class ChromaRetriever(BaseRetriever):
collection: Any
embedding_function: Any
top_n: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
query_embeddings = self.embedding_function(query)
results = self.collection.query(
query_embeddings=[query_embeddings],
n_results=self.top_n,
)
ids = results["ids"][0]
metadatas = results["metadatas"][0]
documents = results["documents"][0]
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
import operator
from typing import Optional, Sequence

View File

@ -1,4 +1,10 @@
from open_webui.apps.rag.vector.dbs.chroma import Chroma
from open_webui.apps.rag.vector.dbs.chroma import ChromaClient
from open_webui.apps.rag.vector.dbs.milvus import MilvusClient
from open_webui.config import VECTOR_DB
VECTOR_DB_CLIENT = Chroma()
if VECTOR_DB == "milvus":
VECTOR_DB_CLIENT = MilvusClient()
else:
VECTOR_DB_CLIENT = ChromaClient()

View File

@ -1,6 +1,10 @@
import chromadb
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
@ -12,7 +16,7 @@ from open_webui.config import (
)
class Chroma:
class ChromaClient:
def __init__(self):
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(
@ -32,27 +36,73 @@ class Chroma:
database=CHROMA_DATABASE,
)
def query_collection(self, name, query_embeddings, k):
collection = self.client.get_collection(name=name)
def list_collections(self) -> list[str]:
collections = self.client.list_collections()
return [collection.name for collection in collections]
def create_collection(self, collection_name: str):
return self.client.create_collection(name=collection_name)
def delete_collection(self, collection_name: str):
return self.client.delete_collection(name=collection_name)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
query_embeddings=vectors,
n_results=limit,
)
return result
return {
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
return None
def list_collections(self):
return self.client.list_collections()
def get(self, collection_name: str) -> Optional[QueryResult]:
collection = self.client.get_collection(name=collection_name)
if collection:
return collection.get()
return None
def create_collection(self, name):
return self.client.create_collection(name=name)
def insert(self, collection_name: str, items: list[VectorItem]):
collection = self.client.get_or_create_collection(name=collection_name)
def get_or_create_collection(self, name):
return self.client.get_or_create_collection(name=name)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items]
def delete_collection(self, name):
return self.client.delete_collection(name=name)
for batch in create_batches(
api=self.client,
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
):
collection.add(*batch)
def upsert(self, collection_name: str, items: list[VectorItem]):
collection = self.client.get_or_create_collection(name=collection_name)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadata = [item["metadata"] for item in items]
collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadata=metadata
)
def delete(self, collection_name: str, ids: list[str]):
collection = self.client.get_collection(name=collection_name)
if collection:
collection.delete(ids=ids)
def reset(self):
return self.client.reset()

View File

@ -0,0 +1,39 @@
from pymilvus import MilvusClient as Milvus
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
class MilvusClient:
def __init__(self):
self.client = Milvus()
def list_collections(self) -> list[str]:
pass
def create_collection(self, collection_name: str):
pass
def delete_collection(self, collection_name: str):
pass
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]:
pass
def get(self, collection_name: str) -> Optional[QueryResult]:
pass
def insert(self, collection_name: str, items: list[VectorItem]):
pass
def upsert(self, collection_name: str, items: list[VectorItem]):
pass
def delete(self, collection_name: str, ids: list[str]):
pass
def reset(self):
pass

View File

@ -0,0 +1,16 @@
from pydantic import BaseModel
from typing import Optional, List, Any
class VectorItem(BaseModel):
id: str
text: str
vector: List[float | int]
metadata: Any
class QueryResult(BaseModel):
ids: Optional[List[List[str]]]
distances: Optional[List[List[float | int]]]
documents: Optional[List[List[str]]]
metadatas: Optional[List[List[Any]]]

View File

@ -50,16 +50,17 @@ async def add_memory(
user=Depends(get_verified_user),
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[{"created_at": memory.created_at}],
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"metadata": {"created_at": memory.created_at},
}
],
)
return memory
@ -79,14 +80,10 @@ class QueryMemoryForm(BaseModel):
async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
results = collection.query(
query_embeddings=[query_embedding],
n_results=form_data.k, # how many results to return
results = VECTOR_DB_CLIENT.search(
name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
limit=form_data.k,
)
return results
@ -100,18 +97,24 @@ async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
):
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
memories = Memories.get_memories_by_user_id(user.id)
for memory in memories:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
)
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,
},
}
for memory in memories
],
)
return True
@ -151,16 +154,18 @@ async def update_memory_by_id(
raise HTTPException(status_code=404, detail="Memory not found")
if form_data.content is not None:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
documents=[form_data.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[
{"created_at": memory.created_at, "updated_at": memory.updated_at}
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,
},
}
],
)
@ -177,10 +182,9 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
VECTOR_DB_CLIENT.delete(
collection_name=f"user-memory-{user.id}", ids=[memory_id]
)
collection.delete(ids=[memory_id])
return True
return False