diff --git a/Dockerfile b/Dockerfile index 10fc3f116..4573de780 100644 --- a/Dockerfile +++ b/Dockerfile @@ -36,7 +36,10 @@ ARG INCLUDE_OLLAMA ## Basis ## ENV ENV=prod \ PORT=8080 \ - INCLUDE_OLLAMA_ENV=${INCLUDE_OLLAMA} + # pass build args to the build + INCLUDE_OLLAMA_DOCKER=${INCLUDE_OLLAMA} \ + USE_MPS_DOCKER=${USE_MPS} \ + USE_CUDA_DOCKER=${USE_CUDA} ## Basis URL Config ## ENV OLLAMA_BASE_URL="/ollama" \ @@ -65,7 +68,7 @@ ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" \ # Important: # If you want to use CUDA you need to install the nvidia-container-toolkit (https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) # you can set this to "cuda" but its recomended to use --build-arg CUDA_ENABLED=true flag when building the image - RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" \ + # RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" \ DEVICE_COMPUTE_TYPE="int8" # device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance #### Preloaded models ########################################################## @@ -75,21 +78,18 @@ WORKDIR /app/backend COPY ./backend/requirements.txt ./requirements.txt RUN if [ "$USE_CUDA" = "true" ]; then \ - export DEVICE_TYPE="cuda" && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 --no-cache-dir && \ pip3 install -r requirements.txt --no-cache-dir; \ elif [ "$USE_MPS" = "true" ]; then \ - export DEVICE_TYPE="mps" && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ pip3 install -r requirements.txt --no-cache-dir && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ - python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['DEVICE_TYPE'])"; \ + python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='mps')"; \ else \ - export DEVICE_TYPE="cpu" && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ pip3 install -r requirements.txt --no-cache-dir && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ - python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['DEVICE_TYPE'])"; \ + python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ fi diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 5a93f9e0c..82d2d28b9 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -71,7 +71,7 @@ from constants import ERROR_MESSAGES # sentence_transformer_ef = SentenceTransformer( # model_name_or_path=RAG_EMBEDDING_MODEL, # cache_folder=RAG_EMBEDDING_MODEL_DIR, -# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, +# device=DEVICE_TYPE, # ) @@ -178,7 +178,6 @@ async def update_embedding_model( device=DEVICE_TYPE, ) ) - return { "status": True, "embedding_model": app.state.RAG_EMBEDDING_MODEL, diff --git a/backend/config.py b/backend/config.py index c0fd56ba9..3bc00323f 100644 --- a/backend/config.py +++ b/backend/config.py @@ -208,7 +208,7 @@ OLLAMA_API_BASE_URL = os.environ.get( ) OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") -INCLUDE_OLLAMA = os.environ.get("INCLUDE_OLLAMA", "false") +INCLUDE_OLLAMA = os.environ.get("INCLUDE_OLLAMA_ENV", "false") if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": @@ -220,7 +220,7 @@ if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": if ENV == "prod": if OLLAMA_BASE_URL == "/ollama": - if INCLUDE_OLLAMA == "true": + if INCLUDE_OLLAMA.lower() == "true": # if you use all-in-one docker container (Open WebUI + Ollama) # with the docker build arg INCLUDE_OLLAMA=true (--build-arg="INCLUDE_OLLAMA=true") this only works with http://localhost:11434 OLLAMA_BASE_URL = "http://localhost:11434" @@ -336,9 +336,20 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") # device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance -DEVICE_TYPE = os.environ.get( - "DEVICE_TYPE", "cpu" -) +USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") +USE_MPS = os.environ.get("USE_MPS_DOCKER", "false") + +if USE_CUDA.lower() == "true" and USE_MPS.lower() == "true": + print("Both USE_CUDA and USE_MPS cannot be set to true. Defaulting to CPU.") + DEVICE_TYPE = "cpu" +elif USE_CUDA.lower() == "true": + DEVICE_TYPE = "cuda" +elif USE_MPS.lower() == "true": + DEVICE_TYPE = "mps" +else: + DEVICE_TYPE = "cpu" + + CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False), diff --git a/backend/start.sh b/backend/start.sh index 73a337e23..ca0e9688c 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -2,7 +2,7 @@ # Get the INCLUDE_OLLAMA_ENV environment variable which is set in the Dockerfile # This includes the ollama in the image -INCLUDE_OLLAMA=${INCLUDE_OLLAMA_ENV:-false} +INCLUDE_OLLAMA=${INCLUDE_OLLAMA_DOCKER} SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) cd "$SCRIPT_DIR" || exit