From 63e6f39b3193e2ab6c4583f73110d543b955e4e1 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 14 Feb 2024 23:32:54 -0800 Subject: [PATCH] refac --- Dockerfile | 4 ++-- backend/apps/audio/main.py | 9 +++------ backend/config.py | 4 +++- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index 010dbc869..520c2964d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,8 +31,8 @@ ENV SCARF_NO_ANALYTICS true ENV DO_NOT_TRACK true #Whisper TTS Settings -ENV WHISPER_DIR="/app/backend/data/cache/whisper/models" ENV WHISPER_MODEL="base" +ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" WORKDIR /app/backend @@ -49,7 +49,7 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* # RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" -RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_DIR'])" +RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" # copy embedding weight from build diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 85937de80..86e79c473 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -21,7 +21,7 @@ from utils.utils import ( ) from utils.misc import calculate_sha256 -from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL_NAME +from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR app = FastAPI() app.add_middleware( @@ -54,14 +54,11 @@ def transcribe( f.write(contents) f.close() - model_name = os.getenv('WHISPER_MODEL', WHISPER_MODEL_NAME) - download_root = os.getenv('WHISPER_DIR', f"{CACHE_DIR}/whisper/models") - model = WhisperModel( - model_name, + WHISPER_MODEL, device="cpu", compute_type="int8", - download_root=download_root, + download_root=WHISPER_MODEL_DIR, ) segments, info = model.transcribe(file_path, beam_size=5) diff --git a/backend/config.py b/backend/config.py index 81b900840..f2e25c6ae 100644 --- a/backend/config.py +++ b/backend/config.py @@ -136,4 +136,6 @@ CHUNK_OVERLAP = 100 #################################### # Transcribe #################################### -WHISPER_MODEL_NAME = "base" + +WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") +WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")