enh: add to vector db support

This commit is contained in:
Timothy J. Baek 2024-10-03 06:44:17 -07:00
parent 325ca98773
commit d394f8b7be

View File

@ -637,6 +637,7 @@ def save_docs_to_vector_db(
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
overwrite: bool = False, overwrite: bool = False,
split: bool = True, split: bool = True,
add: bool = False,
) -> bool: ) -> bool:
log.info(f"save_docs_to_vector_db {docs} {collection_name}") log.info(f"save_docs_to_vector_db {docs} {collection_name}")
@ -662,42 +663,44 @@ def save_docs_to_vector_db(
metadata[key] = str(value) metadata[key] = str(value)
try: try:
if overwrite:
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
log.info(f"deleting existing collection {collection_name}")
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
log.info(f"collection {collection_name} already exists") log.info(f"collection {collection_name} already exists")
return True
else:
embedding_function = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
embeddings = embedding_function( if overwrite:
list(map(lambda x: x.replace("\n", " "), texts)) VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
) log.info(f"deleting existing collection {collection_name}")
VECTOR_DB_CLIENT.insert( if add is False:
collection_name=collection_name, return True
items=[
{
"id": str(uuid.uuid4()),
"text": text,
"vector": embeddings[idx],
"metadata": metadatas[idx],
}
for idx, text in enumerate(texts)
],
)
return True log.info(f"adding to collection {collection_name}")
embedding_function = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
embeddings = embedding_function(
list(map(lambda x: x.replace("\n", " "), texts))
)
VECTOR_DB_CLIENT.insert(
collection_name=collection_name,
items=[
{
"id": str(uuid.uuid4()),
"text": text,
"vector": embeddings[idx],
"metadata": metadatas[idx],
}
for idx, text in enumerate(texts)
],
)
return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return False return False
@ -715,37 +718,53 @@ def process_file(
): ):
try: try:
file = Files.get_file_by_id(form_data.file_id) file = Files.get_file_by_id(form_data.file_id)
file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name is None: if collection_name is None:
with open(file_path, "rb") as f: collection_name = file.id
collection_name = calculate_sha256(f)[:63]
loader = Loader( loader = Loader(
engine=app.state.config.CONTENT_EXTRACTION_ENGINE, engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
) )
docs = loader.load(file.filename, file.meta.get("content_type"), file_path)
file_path = file.meta.get("path", None)
if file_path:
docs = loader.load(file.filename, file.meta.get("content_type"), file_path)
else:
docs = [
Document(
page_content=file.data.get("content", ""),
metadata={
"name": file.filename,
"created_by": file.user_id,
**file.meta,
},
)
]
text_content = " ".join([doc.page_content for doc in docs]) text_content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {text_content}") log.debug(f"text_content: {text_content}")
hash = calculate_sha256_string(text_content) hash = calculate_sha256_string(text_content)
Files.update_file_data_by_id( res = Files.update_file_data_by_id(
form_data.file_id, file.id,
{"content": text_content}, {"content": text_content},
) )
print(res)
Files.update_file_hash_by_id(form_data.file_id, hash) Files.update_file_hash_by_id(form_data.file_id, hash)
try: try:
result = save_docs_to_vector_db( result = save_docs_to_vector_db(
docs, docs=docs,
collection_name, collection_name=collection_name,
{ metadata={
"file_id": form_data.file_id, "file_id": form_data.file_id,
"name": file.meta.get("name", file.filename), "name": file.meta.get("name", file.filename),
"hash": hash,
}, },
add=(True if form_data.collection_name else False),
) )
if result: if result:
@ -1184,6 +1203,30 @@ def query_collection_handler(
#################################### ####################################
class DeleteForm(BaseModel):
collection_name: str
file_id: str
@app.post("/delete")
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
try:
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
file = Files.get_file_by_id(form_data.file_id)
hash = file.hash
VECTOR_DB_CLIENT.delete(
collection_name=form_data.collection_name,
metadata={"hash": hash},
)
return {"status": True}
else:
return {"status": False}
except Exception as e:
log.exception(e)
return {"status": False}
@app.post("/reset/db") @app.post("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)): def reset_vector_db(user=Depends(get_admin_user)):
VECTOR_DB_CLIENT.reset() VECTOR_DB_CLIENT.reset()