enh: ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER

This commit is contained in:
Timothy Jaeryang Baek
2025-12-30 19:31:59 +04:00
parent 61e25dc2dc
commit d3a682759f
3 changed files with 40 additions and 48 deletions

View File

@@ -2862,6 +2862,12 @@ RAG_TEXT_SPLITTER = PersistentConfig(
os.environ.get("RAG_TEXT_SPLITTER", ""),
)
ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = PersistentConfig(
"ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER",
"rag.enable_markdown_header_text_splitter",
os.environ.get("ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER", "True").lower() == "true",
)
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
TIKTOKEN_ENCODING_NAME = PersistentConfig(

View File

@@ -282,6 +282,7 @@ from open_webui.config import (
MISTRAL_OCR_API_BASE_URL,
MISTRAL_OCR_API_KEY,
RAG_TEXT_SPLITTER,
ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER,
TIKTOKEN_ENCODING_NAME,
PDF_EXTRACT_IMAGES,
YOUTUBE_LOADER_LANGUAGE,
@@ -888,6 +889,10 @@ app.state.config.MINERU_API_TIMEOUT = MINERU_API_TIMEOUT
app.state.config.MINERU_PARAMS = MINERU_PARAMS
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = (
ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER
)
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app.state.config.CHUNK_SIZE = CHUNK_SIZE

View File

@@ -1312,6 +1312,27 @@ def save_docs_to_vector_db(
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
if split:
if request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER:
log.info("Using markdown header text splitter")
# Define headers to split on - covering most common markdown header levels
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=[
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
("#####", "Header 5"),
("######", "Header 6"),
],
strip_headers=False, # Keep headers in content for context
)
split_docs = []
for doc in docs:
split_docs.extend(markdown_splitter.split_text(doc.page_content))
docs = split_docs
if request.app.state.config.TEXT_SPLITTER in ["", "character"]:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=request.app.state.config.CHUNK_SIZE,
@@ -1332,52 +1353,6 @@ def save_docs_to_vector_db(
add_start_index=True,
)
docs = text_splitter.split_documents(docs)
elif request.app.state.config.TEXT_SPLITTER == "markdown_header":
log.info("Using markdown header text splitter")
# Define headers to split on - covering most common markdown header levels
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
("#####", "Header 5"),
("######", "Header 6"),
]
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on,
strip_headers=False, # Keep headers in content for context
)
md_split_docs = []
for doc in docs:
md_header_splits = markdown_splitter.split_text(doc.page_content)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=request.app.state.config.CHUNK_SIZE,
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
md_header_splits = text_splitter.split_documents(md_header_splits)
# Convert back to Document objects, preserving original metadata
for split_chunk in md_header_splits:
headings_list = []
# Extract header values in order based on headers_to_split_on
for _, header_meta_key_name in headers_to_split_on:
if header_meta_key_name in split_chunk.metadata:
headings_list.append(
split_chunk.metadata[header_meta_key_name]
)
md_split_docs.append(
Document(
page_content=split_chunk.page_content,
metadata={**doc.metadata, "headings": headings_list},
)
)
docs = md_split_docs
else:
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
@@ -2424,7 +2399,11 @@ class DeleteForm(BaseModel):
@router.post("/delete")
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user), db: Session = Depends(get_session)):
def delete_entries_from_collection(
form_data: DeleteForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
try:
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
file = Files.get_file_by_id(form_data.file_id, db=db)
@@ -2566,7 +2545,9 @@ async def process_files_batch(
# Update all files with collection name
for file_update, file_result in zip(file_updates, file_results):
Files.update_file_by_id(id=file_result.file_id, form_data=file_update, db=db)
Files.update_file_by_id(
id=file_result.file_id, form_data=file_update, db=db
)
file_result.status = "completed"
except Exception as e: