From dff3732fcd721b060c36c10d8f968da25594773c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 13 Oct 2024 02:07:50 -0700 Subject: [PATCH] enh: tiktoken/token splitter support --- Dockerfile | 10 ++++++++ backend/open_webui/apps/retrieval/main.py | 28 ++++++++++++++++++----- backend/open_webui/config.py | 16 +++++++++++++ backend/open_webui/constants.py | 2 +- 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2e898dc89..b7e2ce983 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,10 @@ ARG USE_CUDA_VER=cu121 # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_RERANKING_MODEL="" + +# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken +ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base" + ARG BUILD_HASH=dev-build # Override at your own risk - non-root configurations are untested ARG UID=0 @@ -72,6 +76,10 @@ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \ SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" +## Tiktoken model settings ## +ENV TIKTOKEN_ENCODING_NAME="$USE_TIKTOKEN_ENCODING_NAME" \ + TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken" + ## Hugging Face download cache ## ENV HF_HOME="/app/backend/data/cache/embedding/models" @@ -131,11 +139,13 @@ RUN pip3 install uv && \ uv pip install --system -r requirements.txt --no-cache-dir && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ + python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ else \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ + python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ fi; \ chown -R $UID:$GID /app/backend/data/ diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index c39f8be16..eacb0d7cd 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -47,6 +47,8 @@ from open_webui.apps.retrieval.utils import ( from open_webui.apps.webui.models.files import Files from open_webui.config import ( BRAVE_SEARCH_API_KEY, + TIKTOKEN_ENCODING_NAME, + RAG_TEXT_SPLITTER, CHUNK_OVERLAP, CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, @@ -102,7 +104,7 @@ from open_webui.utils.misc import ( ) from open_webui.utils.utils import get_admin_user, get_verified_user -from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter from langchain_community.document_loaders import ( YoutubeLoader, ) @@ -129,6 +131,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL +app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER +app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME + app.state.config.CHUNK_SIZE = CHUNK_SIZE app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP @@ -648,11 +653,22 @@ def save_docs_to_vector_db( raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) + if app.state.config.TEXT_SPLITTER in ["", "character"]: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + elif app.state.config.TEXT_SPLITTER == "token": + text_splitter = TokenTextSplitter( + encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME, + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + else: + raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) + docs = text_splitter.split_documents(docs) if len(docs) == 0: diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 98d342897..d55619ee0 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1014,6 +1014,22 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) + +RAG_TEXT_SPLITTER = PersistentConfig( + "RAG_TEXT_SPLITTER", + "rag.text_splitter", + os.environ.get("RAG_TEXT_SPLITTER", ""), +) + + +TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") +TIKTOKEN_ENCODING_NAME = PersistentConfig( + "TIKTOKEN_ENCODING_NAME", + "rag.tiktoken_encoding_name", + os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"), +) + + CHUNK_SIZE = PersistentConfig( "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000")) ) diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 37461402b..704cdd074 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -20,7 +20,7 @@ class ERROR_MESSAGES(str, Enum): def __str__(self) -> str: return super().__str__() - DEFAULT = lambda err="": f"Something went wrong :/\n{err if err else ''}" + DEFAULT = lambda err="": f"Something went wrong :/\n[ERROR: {err if err else ''}]" ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance." DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."