diff --git a/Dockerfile b/Dockerfile index ec879d732..eb5a693bc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,7 @@ ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_RERANKING_MODEL="" # Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken -ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base" +ARG USE_TIKTOKEN_ENCODING_MODEL_NAME="cl100k_base" ARG BUILD_HASH=dev-build # Override at your own risk - non-root configurations are untested @@ -77,7 +77,7 @@ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" ## Tiktoken model settings ## -ENV TIKTOKEN_ENCODING_NAME="$USE_TIKTOKEN_ENCODING_NAME" \ +ENV TIKTOKEN_ENCODING_MODEL_NAME="$USE_TIKTOKEN_ENCODING_MODEL_NAME" \ TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken" ## Hugging Face download cache ## @@ -139,13 +139,13 @@ RUN pip3 install uv && \ uv pip install --system -r requirements.txt --no-cache-dir && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ - python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ + python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_MODEL_NAME'])"; \ else \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ - python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ + python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_MODEL_NAME'])"; \ fi; \ chown -R $UID:$GID /app/backend/data/ diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 04eece38c..4a09b51a7 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -50,7 +50,7 @@ from open_webui.apps.retrieval.utils import ( from open_webui.apps.webui.models.files import Files from open_webui.config import ( BRAVE_SEARCH_API_KEY, - TIKTOKEN_ENCODING_NAME, + TIKTOKEN_ENCODING_MODEL_NAME, RAG_TEXT_SPLITTER, CHUNK_OVERLAP, CHUNK_SIZE, @@ -135,7 +135,7 @@ app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER -app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME +app.state.config.TIKTOKEN_ENCODING_MODEL_NAME = TIKTOKEN_ENCODING_MODEL_NAME app.state.config.CHUNK_SIZE = CHUNK_SIZE app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP @@ -667,7 +667,7 @@ def save_docs_to_vector_db( ) elif app.state.config.TEXT_SPLITTER == "token": text_splitter = TokenTextSplitter( - encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME, + model_name=app.state.config.TIKTOKEN_ENCODING_MODEL_NAME, chunk_size=app.state.config.CHUNK_SIZE, chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 9d1bd72d8..b2a7e9d42 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1074,10 +1074,10 @@ RAG_TEXT_SPLITTER = PersistentConfig( TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") -TIKTOKEN_ENCODING_NAME = PersistentConfig( - "TIKTOKEN_ENCODING_NAME", - "rag.tiktoken_encoding_name", - os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"), +TIKTOKEN_ENCODING_MODEL_NAME = PersistentConfig( + "TIKTOKEN_ENCODING_MODEL_NAME", + "rag.tiktoken_encoding_model_name", + os.environ.get("TIKTOKEN_ENCODING_MODEL_NAME", "cl100k_base"), )