diff --git a/Dockerfile b/Dockerfile index 38f2a53fe..a7692fdb5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,15 +30,21 @@ ENV WEBUI_SECRET_KEY "" ENV SCARF_NO_ANALYTICS true ENV DO_NOT_TRACK true +######## Preloaded models ######## # whisper TTS Settings ENV WHISPER_MODEL="base" ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" +# RAG Embedding Model Settings # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard -# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" +# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" +ENV SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" +# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance +ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" +######## Preloaded models ######## WORKDIR /app/backend @@ -55,9 +61,9 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* # preload embedding model -RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'])" +RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])" # preload tts model -RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" +RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" # copy embedding weight from build diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 86e79c473..d8cb415fc 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -56,7 +56,7 @@ def transcribe( model = WhisperModel( WHISPER_MODEL, - device="cpu", + device="auto", compute_type="int8", download_root=WHISPER_MODEL_DIR, ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 5ab3b843c..656ba4e59 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -13,6 +13,7 @@ import os, shutil from pathlib import Path from typing import List +from sentence_transformers import SentenceTransformer from chromadb.utils import embedding_functions from langchain_community.document_loaders import ( @@ -52,6 +53,7 @@ from config import ( UPLOAD_DIR, DOCS_DIR, RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_DEVICE_TYPE, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP, @@ -60,10 +62,18 @@ from config import ( from constants import ERROR_MESSAGES +# +#if RAG_EMBEDDING_MODEL: +# sentence_transformer_ef = SentenceTransformer( +# model_name_or_path=RAG_EMBEDDING_MODEL, +# cache_folder=RAG_EMBEDDING_MODEL_DIR, +# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, +# ) if RAG_EMBEDDING_MODEL: sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=RAG_EMBEDDING_MODEL + model_name=RAG_EMBEDDING_MODEL, + device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, ) app = FastAPI() diff --git a/backend/config.py b/backend/config.py index 2cc6c2a5e..175b228e9 100644 --- a/backend/config.py +++ b/backend/config.py @@ -138,6 +138,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "") + +# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance +RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get("RAG_EMBEDDING_MODEL_DEVICE_TYPE", "") CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False),