mirror of
https://github.com/open-webui/open-webui
synced 2025-03-21 04:18:56 +00:00
Merge pull request #5312 from open-webui/multiple-vector-dbs
feat: various vector db support
This commit is contained in:
commit
c7fc17da69
@ -96,7 +96,6 @@ from open_webui.utils.misc import (
|
|||||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||||
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
|
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
|
||||||
|
|
||||||
from chromadb.utils.batch_utils import create_batches
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_community.document_loaders import (
|
from langchain_community.document_loaders import (
|
||||||
BSHTMLLoader,
|
BSHTMLLoader,
|
||||||
@ -998,14 +997,11 @@ def store_docs_in_vector_db(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if overwrite:
|
if overwrite:
|
||||||
for collection in VECTOR_DB_CLIENT.list_collections():
|
if collection_name in VECTOR_DB_CLIENT.list_collections():
|
||||||
if collection_name == collection.name:
|
|
||||||
log.info(f"deleting existing collection {collection_name}")
|
log.info(f"deleting existing collection {collection_name}")
|
||||||
VECTOR_DB_CLIENT.delete_collection(name=collection_name)
|
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
|
||||||
|
|
||||||
collection = VECTOR_DB_CLIENT.create_collection(name=collection_name)
|
embedding_function = get_embedding_function(
|
||||||
|
|
||||||
embedding_func = get_embedding_function(
|
|
||||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
app.state.config.RAG_EMBEDDING_MODEL,
|
app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
@ -1014,17 +1010,18 @@ def store_docs_in_vector_db(
|
|||||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|
VECTOR_DB_CLIENT.insert(
|
||||||
embeddings = embedding_func(embedding_texts)
|
collection_name=collection_name,
|
||||||
|
items=[
|
||||||
for batch in create_batches(
|
{
|
||||||
api=VECTOR_DB_CLIENT,
|
"id": str(uuid.uuid4()),
|
||||||
ids=[str(uuid.uuid4()) for _ in texts],
|
"text": text,
|
||||||
metadatas=metadatas,
|
"vector": embedding_function(text.replace("\n", " ")),
|
||||||
embeddings=embeddings,
|
"metadata": metadatas[idx],
|
||||||
documents=texts,
|
}
|
||||||
):
|
for idx, text in enumerate(texts)
|
||||||
collection.add(*batch)
|
],
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -24,6 +24,44 @@ log = logging.getLogger(__name__)
|
|||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchRetriever(BaseRetriever):
|
||||||
|
collection_name: Any
|
||||||
|
embedding_function: Any
|
||||||
|
top_k: int
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
|
) -> list[Document]:
|
||||||
|
result = VECTOR_DB_CLIENT.search(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
vectors=[self.embedding_function(query)],
|
||||||
|
limit=self.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
ids = result["ids"][0]
|
||||||
|
metadatas = result["metadatas"][0]
|
||||||
|
documents = result["documents"][0]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for idx in range(len(ids)):
|
||||||
|
results.append(
|
||||||
|
Document(
|
||||||
|
metadata=metadatas[idx],
|
||||||
|
page_content=documents[idx],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def query_doc(
|
def query_doc(
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query: str,
|
query: str,
|
||||||
@ -31,15 +69,18 @@ def query_doc(
|
|||||||
k: int,
|
k: int,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
result = VECTOR_DB_CLIENT.query_collection(
|
result = VECTOR_DB_CLIENT.search(
|
||||||
name=collection_name,
|
collection_name=collection_name,
|
||||||
query_embeddings=embedding_function(query),
|
vectors=[embedding_function(query)],
|
||||||
k=k,
|
limit=k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("result", result)
|
||||||
|
|
||||||
log.info(f"query_doc:result {result}")
|
log.info(f"query_doc:result {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@ -52,25 +93,23 @@ def query_doc_with_hybrid_search(
|
|||||||
r: float,
|
r: float,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
collection = VECTOR_DB_CLIENT.get_collection(name=collection_name)
|
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||||
documents = collection.get() # get all documents
|
|
||||||
|
|
||||||
bm25_retriever = BM25Retriever.from_texts(
|
bm25_retriever = BM25Retriever.from_texts(
|
||||||
texts=documents.get("documents"),
|
texts=result.documents,
|
||||||
metadatas=documents.get("metadatas"),
|
metadatas=result.metadatas,
|
||||||
)
|
)
|
||||||
bm25_retriever.k = k
|
bm25_retriever.k = k
|
||||||
|
|
||||||
chroma_retriever = ChromaRetriever(
|
vector_search_retriever = VectorSearchRetriever(
|
||||||
collection=collection,
|
collection_name=collection_name,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
top_n=k,
|
top_k=k,
|
||||||
)
|
)
|
||||||
|
|
||||||
ensemble_retriever = EnsembleRetriever(
|
ensemble_retriever = EnsembleRetriever(
|
||||||
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
|
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
|
||||||
)
|
)
|
||||||
|
|
||||||
compressor = RerankCompressor(
|
compressor = RerankCompressor(
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
top_n=k,
|
top_n=k,
|
||||||
@ -394,45 +433,6 @@ def generate_openai_batch_embeddings(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
||||||
from langchain_core.retrievers import BaseRetriever
|
|
||||||
|
|
||||||
|
|
||||||
class ChromaRetriever(BaseRetriever):
|
|
||||||
collection: Any
|
|
||||||
embedding_function: Any
|
|
||||||
top_n: int
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
) -> list[Document]:
|
|
||||||
query_embeddings = self.embedding_function(query)
|
|
||||||
|
|
||||||
results = self.collection.query(
|
|
||||||
query_embeddings=[query_embeddings],
|
|
||||||
n_results=self.top_n,
|
|
||||||
)
|
|
||||||
|
|
||||||
ids = results["ids"][0]
|
|
||||||
metadatas = results["metadatas"][0]
|
|
||||||
documents = results["documents"][0]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for idx in range(len(ids)):
|
|
||||||
results.append(
|
|
||||||
Document(
|
|
||||||
metadata=metadatas[idx],
|
|
||||||
page_content=documents[idx],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
from open_webui.apps.rag.vector.dbs.chroma import Chroma
|
from open_webui.apps.rag.vector.dbs.chroma import ChromaClient
|
||||||
|
from open_webui.apps.rag.vector.dbs.milvus import MilvusClient
|
||||||
|
|
||||||
|
|
||||||
from open_webui.config import VECTOR_DB
|
from open_webui.config import VECTOR_DB
|
||||||
|
|
||||||
VECTOR_DB_CLIENT = Chroma()
|
if VECTOR_DB == "milvus":
|
||||||
|
VECTOR_DB_CLIENT = MilvusClient()
|
||||||
|
else:
|
||||||
|
VECTOR_DB_CLIENT = ChromaClient()
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import chromadb
|
import chromadb
|
||||||
from chromadb import Settings
|
from chromadb import Settings
|
||||||
|
from chromadb.utils.batch_utils import create_batches
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
CHROMA_DATA_PATH,
|
CHROMA_DATA_PATH,
|
||||||
CHROMA_HTTP_HOST,
|
CHROMA_HTTP_HOST,
|
||||||
@ -12,7 +16,7 @@ from open_webui.config import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Chroma:
|
class ChromaClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if CHROMA_HTTP_HOST != "":
|
if CHROMA_HTTP_HOST != "":
|
||||||
self.client = chromadb.HttpClient(
|
self.client = chromadb.HttpClient(
|
||||||
@ -32,27 +36,78 @@ class Chroma:
|
|||||||
database=CHROMA_DATABASE,
|
database=CHROMA_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def query_collection(self, name, query_embeddings, k):
|
def list_collections(self) -> list[str]:
|
||||||
collection = self.client.get_collection(name=name)
|
# List all the collections in the database.
|
||||||
|
collections = self.client.list_collections()
|
||||||
|
return [collection.name for collection in collections]
|
||||||
|
|
||||||
|
def delete_collection(self, collection_name: str):
|
||||||
|
# Delete the collection based on the collection name.
|
||||||
|
return self.client.delete_collection(name=collection_name)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||||
|
) -> Optional[QueryResult]:
|
||||||
|
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||||
|
collection = self.client.get_collection(name=collection_name)
|
||||||
if collection:
|
if collection:
|
||||||
result = collection.query(
|
result = collection.query(
|
||||||
query_embeddings=[query_embeddings],
|
query_embeddings=vectors,
|
||||||
n_results=k,
|
n_results=limit,
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
|
return {
|
||||||
|
"ids": result["ids"],
|
||||||
|
"distances": result["distances"],
|
||||||
|
"documents": result["documents"],
|
||||||
|
"metadatas": result["metadatas"],
|
||||||
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_collections(self):
|
def get(self, collection_name: str) -> Optional[QueryResult]:
|
||||||
return self.client.list_collections()
|
# Get all the items in the collection.
|
||||||
|
collection = self.client.get_collection(name=collection_name)
|
||||||
|
if collection:
|
||||||
|
return collection.get()
|
||||||
|
return None
|
||||||
|
|
||||||
def create_collection(self, name):
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
return self.client.create_collection(name=name)
|
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||||
|
collection = self.client.get_or_create_collection(name=collection_name)
|
||||||
|
|
||||||
def get_or_create_collection(self, name):
|
ids = [item["id"] for item in items]
|
||||||
return self.client.get_or_create_collection(name=name)
|
documents = [item["text"] for item in items]
|
||||||
|
embeddings = [item["vector"] for item in items]
|
||||||
|
metadatas = [item["metadata"] for item in items]
|
||||||
|
|
||||||
def delete_collection(self, name):
|
for batch in create_batches(
|
||||||
return self.client.delete_collection(name=name)
|
api=self.client,
|
||||||
|
documents=documents,
|
||||||
|
embeddings=embeddings,
|
||||||
|
ids=ids,
|
||||||
|
metadatas=metadatas,
|
||||||
|
):
|
||||||
|
collection.add(*batch)
|
||||||
|
|
||||||
|
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||||
|
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||||
|
collection = self.client.get_or_create_collection(name=collection_name)
|
||||||
|
|
||||||
|
ids = [item["id"] for item in items]
|
||||||
|
documents = [item["text"] for item in items]
|
||||||
|
embeddings = [item["vector"] for item in items]
|
||||||
|
metadata = [item["metadata"] for item in items]
|
||||||
|
|
||||||
|
collection.upsert(
|
||||||
|
ids=ids, documents=documents, embeddings=embeddings, metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, collection_name: str, ids: list[str]):
|
||||||
|
# Delete the items from the collection based on the ids.
|
||||||
|
collection = self.client.get_collection(name=collection_name)
|
||||||
|
if collection:
|
||||||
|
collection.delete(ids=ids)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
# Resets the database. This will delete all collections and item entries.
|
||||||
return self.client.reset()
|
return self.client.reset()
|
||||||
|
@ -0,0 +1,175 @@
|
|||||||
|
from pymilvus import MilvusClient as Client
|
||||||
|
from pymilvus import FieldSchema, DataType
|
||||||
|
import json
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
|
||||||
|
from open_webui.config import (
|
||||||
|
MILVUS_URI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.collection_prefix = "open_webui"
|
||||||
|
self.client = Client(uri=MILVUS_URI)
|
||||||
|
|
||||||
|
def _result_to_query_result(self, result) -> QueryResult:
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
ids = []
|
||||||
|
distances = []
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for match in result:
|
||||||
|
_ids = []
|
||||||
|
_distances = []
|
||||||
|
_documents = []
|
||||||
|
_metadatas = []
|
||||||
|
|
||||||
|
for item in match:
|
||||||
|
_ids.append(item.get("id"))
|
||||||
|
_distances.append(item.get("distance"))
|
||||||
|
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
||||||
|
_metadatas.append(item.get("entity", {}).get("metadata"))
|
||||||
|
|
||||||
|
ids.append(_ids)
|
||||||
|
distances.append(_distances)
|
||||||
|
documents.append(_documents)
|
||||||
|
metadatas.append(_metadatas)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ids": ids,
|
||||||
|
"distances": distances,
|
||||||
|
"documents": documents,
|
||||||
|
"metadatas": metadatas,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_collection(self, collection_name: str, dimension: int):
|
||||||
|
schema = self.client.create_schema(
|
||||||
|
auto_id=False,
|
||||||
|
enable_dynamic_field=True,
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="id",
|
||||||
|
datatype=DataType.VARCHAR,
|
||||||
|
is_primary=True,
|
||||||
|
max_length=65535,
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="vector",
|
||||||
|
datatype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=dimension,
|
||||||
|
description="vector",
|
||||||
|
)
|
||||||
|
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
|
||||||
|
schema.add_field(
|
||||||
|
field_name="metadata", datatype=DataType.JSON, description="metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
index_params = self.client.prepare_index_params()
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client.create_collection(
|
||||||
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
|
schema=schema,
|
||||||
|
index_params=index_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_collections(self) -> list[str]:
|
||||||
|
# List all the collections in the database.
|
||||||
|
return [
|
||||||
|
collection[len(self.collection_prefix) :]
|
||||||
|
for collection in self.client.list_collections()
|
||||||
|
if collection.startswith(self.collection_prefix)
|
||||||
|
]
|
||||||
|
|
||||||
|
def delete_collection(self, collection_name: str):
|
||||||
|
# Delete the collection based on the collection name.
|
||||||
|
return self.client.drop_collection(
|
||||||
|
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||||
|
) -> Optional[QueryResult]:
|
||||||
|
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||||
|
result = self.client.search(
|
||||||
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
|
data=vectors,
|
||||||
|
limit=limit,
|
||||||
|
output_fields=["data", "metadata"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._result_to_query_result(result)
|
||||||
|
|
||||||
|
def get(self, collection_name: str) -> Optional[QueryResult]:
|
||||||
|
# Get all the items in the collection.
|
||||||
|
result = self.client.query(
|
||||||
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
|
)
|
||||||
|
return self._result_to_query_result(result)
|
||||||
|
|
||||||
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
|
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||||
|
if not self.client.has_collection(
|
||||||
|
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||||
|
):
|
||||||
|
self._create_collection(
|
||||||
|
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.client.insert(
|
||||||
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
|
data=[
|
||||||
|
{
|
||||||
|
"id": item["id"],
|
||||||
|
"vector": item["vector"],
|
||||||
|
"data": {"text": item["text"]},
|
||||||
|
"metadata": item["metadata"],
|
||||||
|
}
|
||||||
|
for item in items
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||||
|
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||||
|
if not self.client.has_collection(
|
||||||
|
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||||
|
):
|
||||||
|
self._create_collection(
|
||||||
|
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.client.upsert(
|
||||||
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
|
data=[
|
||||||
|
{
|
||||||
|
"id": item["id"],
|
||||||
|
"vector": item["vector"],
|
||||||
|
"data": {"text": item["text"]},
|
||||||
|
"metadata": item["metadata"],
|
||||||
|
}
|
||||||
|
for item in items
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, collection_name: str, ids: list[str]):
|
||||||
|
# Delete the items from the collection based on the ids.
|
||||||
|
|
||||||
|
return self.client.delete(
|
||||||
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||||
|
ids=ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
# Resets the database. This will delete all collections and item entries.
|
||||||
|
|
||||||
|
collection_names = self.client.list_collections()
|
||||||
|
for collection_name in collection_names:
|
||||||
|
if collection_name.startswith(self.collection_prefix):
|
||||||
|
self.client.drop_collection(collection_name=collection_name)
|
16
backend/open_webui/apps/rag/vector/main.py
Normal file
16
backend/open_webui/apps/rag/vector/main.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, List, Any
|
||||||
|
|
||||||
|
|
||||||
|
class VectorItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
text: str
|
||||||
|
vector: List[float | int]
|
||||||
|
metadata: Any
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResult(BaseModel):
|
||||||
|
ids: Optional[List[List[str]]]
|
||||||
|
distances: Optional[List[List[float | int]]]
|
||||||
|
documents: Optional[List[List[str]]]
|
||||||
|
metadatas: Optional[List[List[Any]]]
|
@ -50,16 +50,17 @@ async def add_memory(
|
|||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
memory = Memories.insert_new_memory(user.id, form_data.content)
|
memory = Memories.insert_new_memory(user.id, form_data.content)
|
||||||
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
|
|
||||||
|
|
||||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
VECTOR_DB_CLIENT.upsert(
|
||||||
name=f"user-memory-{user.id}"
|
collection_name=f"user-memory-{user.id}",
|
||||||
)
|
items=[
|
||||||
collection.upsert(
|
{
|
||||||
documents=[memory.content],
|
"id": memory.id,
|
||||||
ids=[memory.id],
|
"text": memory.content,
|
||||||
embeddings=[memory_embedding],
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||||
metadatas=[{"created_at": memory.created_at}],
|
"metadata": {"created_at": memory.created_at},
|
||||||
|
}
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return memory
|
return memory
|
||||||
@ -79,14 +80,10 @@ class QueryMemoryForm(BaseModel):
|
|||||||
async def query_memory(
|
async def query_memory(
|
||||||
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
|
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
|
results = VECTOR_DB_CLIENT.search(
|
||||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
name=f"user-memory-{user.id}",
|
||||||
name=f"user-memory-{user.id}"
|
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
|
||||||
)
|
limit=form_data.k,
|
||||||
|
|
||||||
results = collection.query(
|
|
||||||
query_embeddings=[query_embedding],
|
|
||||||
n_results=form_data.k, # how many results to return
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -100,18 +97,24 @@ async def reset_memory_from_vector_db(
|
|||||||
request: Request, user=Depends(get_verified_user)
|
request: Request, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
|
||||||
name=f"user-memory-{user.id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
memories = Memories.get_memories_by_user_id(user.id)
|
memories = Memories.get_memories_by_user_id(user.id)
|
||||||
for memory in memories:
|
VECTOR_DB_CLIENT.upsert(
|
||||||
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
|
collection_name=f"user-memory-{user.id}",
|
||||||
collection.upsert(
|
items=[
|
||||||
documents=[memory.content],
|
{
|
||||||
ids=[memory.id],
|
"id": memory.id,
|
||||||
embeddings=[memory_embedding],
|
"text": memory.content,
|
||||||
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||||
|
"metadata": {
|
||||||
|
"created_at": memory.created_at,
|
||||||
|
"updated_at": memory.updated_at,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for memory in memories
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -151,16 +154,18 @@ async def update_memory_by_id(
|
|||||||
raise HTTPException(status_code=404, detail="Memory not found")
|
raise HTTPException(status_code=404, detail="Memory not found")
|
||||||
|
|
||||||
if form_data.content is not None:
|
if form_data.content is not None:
|
||||||
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
|
VECTOR_DB_CLIENT.upsert(
|
||||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
collection_name=f"user-memory-{user.id}",
|
||||||
name=f"user-memory-{user.id}"
|
items=[
|
||||||
)
|
{
|
||||||
collection.upsert(
|
"id": memory.id,
|
||||||
documents=[form_data.content],
|
"text": memory.content,
|
||||||
ids=[memory.id],
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||||
embeddings=[memory_embedding],
|
"metadata": {
|
||||||
metadatas=[
|
"created_at": memory.created_at,
|
||||||
{"created_at": memory.created_at, "updated_at": memory.updated_at}
|
"updated_at": memory.updated_at,
|
||||||
|
},
|
||||||
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -177,10 +182,9 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
|
|||||||
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
|
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
VECTOR_DB_CLIENT.delete(
|
||||||
name=f"user-memory-{user.id}"
|
collection_name=f"user-memory-{user.id}", ids=[memory_id]
|
||||||
)
|
)
|
||||||
collection.delete(ids=[memory_id])
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
@ -910,6 +910,10 @@ else:
|
|||||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
||||||
|
|
||||||
|
# Milvus
|
||||||
|
|
||||||
|
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# RAG
|
# RAG
|
||||||
####################################
|
####################################
|
||||||
|
@ -40,6 +40,8 @@ langchain-chroma==0.1.2
|
|||||||
|
|
||||||
fake-useragent==1.5.1
|
fake-useragent==1.5.1
|
||||||
chromadb==0.5.5
|
chromadb==0.5.5
|
||||||
|
pymilvus==2.4.6
|
||||||
|
|
||||||
sentence-transformers==3.0.1
|
sentence-transformers==3.0.1
|
||||||
pypdf==4.3.1
|
pypdf==4.3.1
|
||||||
docx2txt==0.8
|
docx2txt==0.8
|
||||||
|
@ -47,6 +47,7 @@ dependencies = [
|
|||||||
|
|
||||||
"fake-useragent==1.5.1",
|
"fake-useragent==1.5.1",
|
||||||
"chromadb==0.5.5",
|
"chromadb==0.5.5",
|
||||||
|
"pymilvus==2.4.6",
|
||||||
"sentence-transformers==3.0.1",
|
"sentence-transformers==3.0.1",
|
||||||
"pypdf==4.3.1",
|
"pypdf==4.3.1",
|
||||||
"docx2txt==0.8",
|
"docx2txt==0.8",
|
||||||
|
Loading…
Reference in New Issue
Block a user