This commit is contained in:
Timothy J. Baek 2024-10-25 21:46:14 -07:00
parent 50dcad0f73
commit 780591e991
3 changed files with 13 additions and 11 deletions

View File

@ -13,7 +13,7 @@ ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL="" ARG USE_RERANKING_MODEL=""
# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken # Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
ARG USE_TIKTOKEN_ENCODING_MODEL_NAME="cl100k_base" ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
ARG BUILD_HASH=dev-build ARG BUILD_HASH=dev-build
# Override at your own risk - non-root configurations are untested # Override at your own risk - non-root configurations are untested
@ -77,7 +77,7 @@ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
## Tiktoken model settings ## ## Tiktoken model settings ##
ENV TIKTOKEN_ENCODING_MODEL_NAME="$USE_TIKTOKEN_ENCODING_MODEL_NAME" \ ENV TIKTOKEN_ENCODING_NAME="$USE_TIKTOKEN_ENCODING_NAME" \
TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken" TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
## Hugging Face download cache ## ## Hugging Face download cache ##
@ -139,13 +139,13 @@ RUN pip3 install uv && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_MODEL_NAME'])"; \ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
else \ else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_MODEL_NAME'])"; \ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
fi; \ fi; \
chown -R $UID:$GID /app/backend/data/ chown -R $UID:$GID /app/backend/data/

View File

@ -14,6 +14,7 @@ from typing import Iterator, Optional, Sequence, Union
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
import tiktoken
from open_webui.storage.provider import Storage from open_webui.storage.provider import Storage
@ -50,7 +51,7 @@ from open_webui.apps.retrieval.utils import (
from open_webui.apps.webui.models.files import Files from open_webui.apps.webui.models.files import Files
from open_webui.config import ( from open_webui.config import (
BRAVE_SEARCH_API_KEY, BRAVE_SEARCH_API_KEY,
TIKTOKEN_ENCODING_MODEL_NAME, TIKTOKEN_ENCODING_NAME,
RAG_TEXT_SPLITTER, RAG_TEXT_SPLITTER,
CHUNK_OVERLAP, CHUNK_OVERLAP,
CHUNK_SIZE, CHUNK_SIZE,
@ -135,7 +136,7 @@ app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_MODEL_NAME = TIKTOKEN_ENCODING_MODEL_NAME app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app.state.config.CHUNK_SIZE = CHUNK_SIZE app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
@ -666,8 +667,9 @@ def save_docs_to_vector_db(
add_start_index=True, add_start_index=True,
) )
elif app.state.config.TEXT_SPLITTER == "token": elif app.state.config.TEXT_SPLITTER == "token":
text_splitter = TokenTextSplitter( text_splitter = TokenTextSplitter(
model_name=app.state.config.TIKTOKEN_ENCODING_MODEL_NAME, encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME,
chunk_size=app.state.config.CHUNK_SIZE, chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP, chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,

View File

@ -1074,10 +1074,10 @@ RAG_TEXT_SPLITTER = PersistentConfig(
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
TIKTOKEN_ENCODING_MODEL_NAME = PersistentConfig( TIKTOKEN_ENCODING_NAME = PersistentConfig(
"TIKTOKEN_ENCODING_MODEL_NAME", "TIKTOKEN_ENCODING_NAME",
"rag.tiktoken_encoding_model_name", "rag.tiktoken_encoding_name",
os.environ.get("TIKTOKEN_ENCODING_MODEL_NAME", "cl100k_base"), os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"),
) )