diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 2f052ca6e..08a30815b 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2862,6 +2862,12 @@ RAG_TEXT_SPLITTER = PersistentConfig( os.environ.get("RAG_TEXT_SPLITTER", ""), ) +ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = PersistentConfig( + "ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER", + "rag.enable_markdown_header_text_splitter", + os.environ.get("ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER", "True").lower() == "true", +) + TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") TIKTOKEN_ENCODING_NAME = PersistentConfig( diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7e5a19498..68c90c7ee 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -282,6 +282,7 @@ from open_webui.config import ( MISTRAL_OCR_API_BASE_URL, MISTRAL_OCR_API_KEY, RAG_TEXT_SPLITTER, + ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, TIKTOKEN_ENCODING_NAME, PDF_EXTRACT_IMAGES, YOUTUBE_LOADER_LANGUAGE, @@ -888,6 +889,10 @@ app.state.config.MINERU_API_TIMEOUT = MINERU_API_TIMEOUT app.state.config.MINERU_PARAMS = MINERU_PARAMS app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER +app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = ( + ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER +) + app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME app.state.config.CHUNK_SIZE = CHUNK_SIZE diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index c24ebf8d5..6f6b97fb9 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1312,6 +1312,27 @@ def save_docs_to_vector_db( raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: + if request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER: + log.info("Using markdown header text splitter") + + # Define headers to split on - covering most common markdown header levels + markdown_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=[ + ("#", "Header 1"), + ("##", "Header 2"), + ("###", "Header 3"), + ("####", "Header 4"), + ("#####", "Header 5"), + ("######", "Header 6"), + ], + strip_headers=False, # Keep headers in content for context + ) + + split_docs = [] + for doc in docs: + split_docs.extend(markdown_splitter.split_text(doc.page_content)) + docs = split_docs + if request.app.state.config.TEXT_SPLITTER in ["", "character"]: text_splitter = RecursiveCharacterTextSplitter( chunk_size=request.app.state.config.CHUNK_SIZE, @@ -1332,52 +1353,6 @@ def save_docs_to_vector_db( add_start_index=True, ) docs = text_splitter.split_documents(docs) - elif request.app.state.config.TEXT_SPLITTER == "markdown_header": - log.info("Using markdown header text splitter") - - # Define headers to split on - covering most common markdown header levels - headers_to_split_on = [ - ("#", "Header 1"), - ("##", "Header 2"), - ("###", "Header 3"), - ("####", "Header 4"), - ("#####", "Header 5"), - ("######", "Header 6"), - ] - - markdown_splitter = MarkdownHeaderTextSplitter( - headers_to_split_on=headers_to_split_on, - strip_headers=False, # Keep headers in content for context - ) - - md_split_docs = [] - for doc in docs: - md_header_splits = markdown_splitter.split_text(doc.page_content) - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=request.app.state.config.CHUNK_SIZE, - chunk_overlap=request.app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - md_header_splits = text_splitter.split_documents(md_header_splits) - - # Convert back to Document objects, preserving original metadata - for split_chunk in md_header_splits: - headings_list = [] - # Extract header values in order based on headers_to_split_on - for _, header_meta_key_name in headers_to_split_on: - if header_meta_key_name in split_chunk.metadata: - headings_list.append( - split_chunk.metadata[header_meta_key_name] - ) - - md_split_docs.append( - Document( - page_content=split_chunk.page_content, - metadata={**doc.metadata, "headings": headings_list}, - ) - ) - - docs = md_split_docs else: raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) @@ -2424,7 +2399,11 @@ class DeleteForm(BaseModel): @router.post("/delete") -def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user), db: Session = Depends(get_session)): +def delete_entries_from_collection( + form_data: DeleteForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): try: if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): file = Files.get_file_by_id(form_data.file_id, db=db) @@ -2566,7 +2545,9 @@ async def process_files_batch( # Update all files with collection name for file_update, file_result in zip(file_updates, file_results): - Files.update_file_by_id(id=file_result.file_id, form_data=file_update, db=db) + Files.update_file_by_id( + id=file_result.file_id, form_data=file_update, db=db + ) file_result.status = "completed" except Exception as e: