diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index dee70fbaa..374cbb8a2 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -637,6 +637,7 @@ def save_docs_to_vector_db( metadata: Optional[dict] = None, overwrite: bool = False, split: bool = True, + add: bool = False, ) -> bool: log.info(f"save_docs_to_vector_db {docs} {collection_name}") @@ -662,42 +663,44 @@ def save_docs_to_vector_db( metadata[key] = str(value) try: - if overwrite: - if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): - log.info(f"deleting existing collection {collection_name}") - VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) - if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): log.info(f"collection {collection_name} already exists") - return True - else: - embedding_function = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, - ) - embeddings = embedding_function( - list(map(lambda x: x.replace("\n", " "), texts)) - ) + if overwrite: + VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) + log.info(f"deleting existing collection {collection_name}") - VECTOR_DB_CLIENT.insert( - collection_name=collection_name, - items=[ - { - "id": str(uuid.uuid4()), - "text": text, - "vector": embeddings[idx], - "metadata": metadatas[idx], - } - for idx, text in enumerate(texts) - ], - ) + if add is False: + return True - return True + log.info(f"adding to collection {collection_name}") + embedding_function = get_embedding_function( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + app.state.sentence_transformer_ef, + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, + app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + ) + + embeddings = embedding_function( + list(map(lambda x: x.replace("\n", " "), texts)) + ) + + VECTOR_DB_CLIENT.insert( + collection_name=collection_name, + items=[ + { + "id": str(uuid.uuid4()), + "text": text, + "vector": embeddings[idx], + "metadata": metadatas[idx], + } + for idx, text in enumerate(texts) + ], + ) + + return True except Exception as e: log.exception(e) return False @@ -715,37 +718,53 @@ def process_file( ): try: file = Files.get_file_by_id(form_data.file_id) - file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}") collection_name = form_data.collection_name if collection_name is None: - with open(file_path, "rb") as f: - collection_name = calculate_sha256(f)[:63] + collection_name = file.id loader = Loader( engine=app.state.config.CONTENT_EXTRACTION_ENGINE, TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, ) - docs = loader.load(file.filename, file.meta.get("content_type"), file_path) + + file_path = file.meta.get("path", None) + if file_path: + docs = loader.load(file.filename, file.meta.get("content_type"), file_path) + else: + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + "name": file.filename, + "created_by": file.user_id, + **file.meta, + }, + ) + ] + text_content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {text_content}") hash = calculate_sha256_string(text_content) - Files.update_file_data_by_id( - form_data.file_id, + res = Files.update_file_data_by_id( + file.id, {"content": text_content}, ) + print(res) Files.update_file_hash_by_id(form_data.file_id, hash) try: result = save_docs_to_vector_db( - docs, - collection_name, - { + docs=docs, + collection_name=collection_name, + metadata={ "file_id": form_data.file_id, "name": file.meta.get("name", file.filename), + "hash": hash, }, + add=(True if form_data.collection_name else False), ) if result: @@ -1184,6 +1203,30 @@ def query_collection_handler( #################################### +class DeleteForm(BaseModel): + collection_name: str + file_id: str + + +@app.post("/delete") +def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): + try: + if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): + file = Files.get_file_by_id(form_data.file_id) + hash = file.hash + + VECTOR_DB_CLIENT.delete( + collection_name=form_data.collection_name, + metadata={"hash": hash}, + ) + return {"status": True} + else: + return {"status": False} + except Exception as e: + log.exception(e) + return {"status": False} + + @app.post("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): VECTOR_DB_CLIENT.reset()