diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 497a5685d..f2f4733c5 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -628,40 +628,26 @@ async def update_query_settings( #################################### -def store_data_in_vector_db( - data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False +def save_docs_to_vector_db( + docs, + collection_name, + metadata: Optional[dict] = None, + overwrite: bool = False, + split: bool = True, ) -> bool: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - docs = text_splitter.split_documents(data) + log.info(f"save_docs_to_vector_db {docs} {collection_name}") - if len(docs) > 0: - log.info(f"store_data_in_vector_db {docs}") - return store_docs_in_vector_db(docs, collection_name, metadata, overwrite) - else: + if split: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + docs = text_splitter.split_documents(docs) + + if len(docs) == 0: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) - -def store_text_in_vector_db( - text, metadata, collection_name, overwrite: bool = False -) -> bool: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - docs = text_splitter.create_documents([text], metadatas=[metadata]) - return store_docs_in_vector_db(docs, collection_name, overwrite=overwrite) - - -def store_docs_in_vector_db( - docs, collection_name, metadata: Optional[dict] = None, overwrite: bool = False -) -> bool: - log.info(f"store_docs_in_vector_db {docs} {collection_name}") - texts = [doc.page_content for doc in docs] metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs] @@ -728,21 +714,24 @@ def process_file( file = Files.get_file_by_id(form_data.file_id) file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}") - loader = Loader( - engine=app.state.config.CONTENT_EXTRACTION_ENGINE, - TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, - PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, - ) - data = loader.load(file.filename, file.meta.get("content_type"), file_path) - collection_name = form_data.collection_name if collection_name is None: with open(file_path, "rb") as f: collection_name = calculate_sha256(f)[:63] + loader = Loader( + engine=app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + ) + docs = loader.load(file.filename, file.meta.get("content_type"), file_path) + + raw_content = " ".join([doc.page_content for doc in docs]) + print(raw_content) + try: - result = store_data_in_vector_db( - data, + result = save_docs_to_vector_db( + docs, collection_name, { "file_id": form_data.file_id, @@ -790,11 +779,13 @@ def process_text( if collection_name is None: collection_name = calculate_sha256_string(form_data.content) - result = store_text_in_vector_db( - form_data.content, - metadata={"name": form_data.name, "created_by": user.id}, - collection_name=collection_name, - ) + docs = [ + Document( + page_content=form_data.content, + metadata={"name": form_data.name, "created_by": user.id}, + ) + ] + result = save_docs_to_vector_db(docs, collection_name) if result: return {"status": True, "collection_name": collection_name} @@ -822,10 +813,10 @@ def process_docs_dir(user=Depends(get_admin_user)): TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, ) - data = loader.load(filename, file_content_type[0], str(path)) + docs = loader.load(filename, file_content_type[0], str(path)) try: - result = store_data_in_vector_db(data, collection_name) + result = save_docs_to_vector_db(docs, collection_name) if result: sanitized_filename = sanitize_filename(filename) @@ -870,19 +861,19 @@ def process_docs_dir(user=Depends(get_admin_user)): @app.post("/process/youtube") def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)): try: + collection_name = form_data.collection_name + if not collection_name: + collection_name = calculate_sha256_string(form_data.url)[:63] + loader = YoutubeLoader.from_youtube_url( form_data.url, add_video_info=True, language=app.state.config.YOUTUBE_LOADER_LANGUAGE, translation=app.state.YOUTUBE_LOADER_TRANSLATION, ) - data = loader.load() + docs = loader.load() - collection_name = form_data.collection_name - if not collection_name: - collection_name = calculate_sha256_string(form_data.url)[:63] - - store_data_in_vector_db(data, collection_name, overwrite=True) + save_docs_to_vector_db(docs, collection_name, overwrite=True) return { "status": True, @@ -900,18 +891,17 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u @app.post("/process/web") def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): try: + collection_name = form_data.collection_name + if not collection_name: + collection_name = calculate_sha256_string(form_data.url)[:63] + loader = get_web_loader( form_data.url, verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) - data = loader.load() - - collection_name = form_data.collection_name - if not collection_name: - collection_name = calculate_sha256_string(form_data.url)[:63] - - store_data_in_vector_db(data, collection_name, overwrite=True) + docs = loader.load() + save_docs_to_vector_db(docs, collection_name, overwrite=True) return { "status": True, @@ -1060,15 +1050,16 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): ) try: - urls = [result.link for result in web_results] - loader = get_web_loader(urls) - data = loader.load() - collection_name = form_data.collection_name if collection_name == "": collection_name = calculate_sha256_string(form_data.query)[:63] - store_data_in_vector_db(data, collection_name, overwrite=True) + urls = [result.link for result in web_results] + + loader = get_web_loader(urls) + docs = loader.load() + save_docs_to_vector_db(docs, collection_name, overwrite=True) + return { "status": True, "collection_name": collection_name,