diff --git a/Dockerfile b/Dockerfile index dc1fba7b3..64fa80175 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,7 @@ # syntax=docker/dockerfile:1 +# Initialize device type args +ARG USE_CUDA=false +ARG USE_MPS=false ######## WebUI frontend ######## FROM node:21-alpine3.19 as build @@ -23,6 +26,10 @@ RUN npm run build ######## WebUI backend ######## FROM python:3.11-slim-bookworm as base +# Use args +ARG USE_CUDA +ARG USE_MPS + ## Basis ## ENV ENV=prod \ PORT=8080 @@ -54,7 +61,8 @@ ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" \ # Important: # If you want to use CUDA you need to install the nvidia-container-toolkit (https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) # you can set this to "cuda" but its recomended to use --build-arg CUDA_ENABLED=true flag when building the image - RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" + RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" \ + DEVICE_COMPUTE_TYPE="int8" # device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance #### Preloaded models ########################################################## @@ -62,19 +70,24 @@ WORKDIR /app/backend # install python dependencies COPY ./backend/requirements.txt ./requirements.txt -RUN pip3 install -r requirements.txt --no-cache-dir - -RUN if [ "$RAG_EMBEDDING_MODEL_DEVICE_TYPE" = "cuda" ]; then \ - echo "CUDA enabled" && \ - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 --no-cache-dir; \ - else \ +RUN if [ "$USE_CUDA" = "true" ]; then \ + export DEVICE_TYPE="cuda" && \ + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 --no-cache-dir && \ + pip3 install -r requirements.txt --no-cache-dir; \ + elif [ "$USE_MPS" = "true" ]; then \ + export DEVICE_TYPE="mps" && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ - python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"; \ + pip3 install -r requirements.txt --no-cache-dir && \ + python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ + python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['DEVICE_TYPE'])"; \ + else \ + export DEVICE_TYPE="cpu" && \ + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ + pip3 install -r requirements.txt --no-cache-dir && \ + python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ + python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['DEVICE_TYPE'])"; \ fi -# preload tts model -RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" - # install required packages RUN apt-get update \ # Install pandoc and netcat @@ -100,4 +113,4 @@ COPY ./backend . EXPOSE 8080 -CMD [ "bash", "start.sh"] +CMD [ "bash", "start.sh"] \ No newline at end of file diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index d8cb415fc..2faf07aa4 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -21,7 +21,11 @@ from utils.utils import ( ) from utils.misc import calculate_sha256 -from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR +from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, DEVICE_TYPE + +if DEVICE_TYPE != "cuda": + whisper_device_type = "cpu" + app = FastAPI() app.add_middleware( @@ -56,7 +60,7 @@ def transcribe( model = WhisperModel( WHISPER_MODEL, - device="auto", + device=whisper_device_type, compute_type="int8", download_root=WHISPER_MODEL_DIR, ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index b21724cc9..5a93f9e0c 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -57,7 +57,7 @@ from config import ( UPLOAD_DIR, DOCS_DIR, RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_DEVICE_TYPE, + DEVICE_TYPE, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP, @@ -87,7 +87,7 @@ app.state.TOP_K = 4 app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( model_name=app.state.RAG_EMBEDDING_MODEL, - device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + device=DEVICE_TYPE, ) ) @@ -175,7 +175,7 @@ async def update_embedding_model( app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( model_name=app.state.RAG_EMBEDDING_MODEL, - device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + device=DEVICE_TYPE, ) ) diff --git a/backend/config.py b/backend/config.py index 831371bb7..7200b9126 100644 --- a/backend/config.py +++ b/backend/config.py @@ -330,8 +330,8 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") # device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance -RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get( - "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu" +DEVICE_TYPE = os.environ.get( + "DEVICE_TYPE", "cpu" ) CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH,