From d5c1c2f0a7ff219166efc33b54b3e907ef6b63c6 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 20 Oct 2024 21:34:36 -0700 Subject: [PATCH] enh: faster whisper custom model support --- backend/open_webui/apps/audio/main.py | 61 ++++++++----- .../apps/webui/routers/functions.py | 2 +- .../open_webui/apps/webui/routers/tools.py | 3 - backend/open_webui/apps/webui/utils.py | 1 - backend/open_webui/config.py | 42 +++------ .../components/admin/Settings/Audio.svelte | 91 ++++++++++++++++++- 6 files changed, 140 insertions(+), 60 deletions(-) diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index 6398b9ee1..148430da8 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -63,6 +63,9 @@ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY app.state.config.STT_ENGINE = AUDIO_STT_ENGINE app.state.config.STT_MODEL = AUDIO_STT_MODEL +app.state.config.WHISPER_MODEL = WHISPER_MODEL +app.state.faster_whisper_model = None + app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE @@ -82,6 +85,31 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) +def set_faster_whisper_model(model: str, auto_update: bool = False): + if model and app.state.config.STT_ENGINE == "": + from faster_whisper import WhisperModel + + faster_whisper_kwargs = { + "model_size_or_path": model, + "device": whisper_device_type, + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not auto_update, + } + + try: + app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) + except Exception: + log.warning( + "WhisperModel initialization failed, attempting download with local_files_only=False" + ) + faster_whisper_kwargs["local_files_only"] = False + app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) + + else: + app.state.faster_whisper_model = None + + class TTSConfigForm(BaseModel): OPENAI_API_BASE_URL: str OPENAI_API_KEY: str @@ -99,6 +127,7 @@ class STTConfigForm(BaseModel): OPENAI_API_KEY: str ENGINE: str MODEL: str + WHISPER_MODEL: str class AudioConfigUpdateForm(BaseModel): @@ -152,6 +181,7 @@ async def get_audio_config(user=Depends(get_admin_user)): "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, "ENGINE": app.state.config.STT_ENGINE, "MODEL": app.state.config.STT_MODEL, + "WHISPER_MODEL": app.state.config.WHISPER_MODEL, }, } @@ -176,6 +206,8 @@ async def update_audio_config( app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY app.state.config.STT_ENGINE = form_data.stt.ENGINE app.state.config.STT_MODEL = form_data.stt.MODEL + app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL + set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE) return { "tts": { @@ -194,6 +226,7 @@ async def update_audio_config( "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, "ENGINE": app.state.config.STT_ENGINE, "MODEL": app.state.config.STT_MODEL, + "WHISPER_MODEL": app.state.config.WHISPER_MODEL, }, } @@ -367,27 +400,10 @@ def transcribe(file_path): id = filename.split(".")[0] if app.state.config.STT_ENGINE == "": - from faster_whisper import WhisperModel - - whisper_kwargs = { - "model_size_or_path": WHISPER_MODEL, - "device": whisper_device_type, - "compute_type": "int8", - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not WHISPER_MODEL_AUTO_UPDATE, - } - - log.debug(f"whisper_kwargs: {whisper_kwargs}") - - try: - model = WhisperModel(**whisper_kwargs) - except Exception: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" - ) - whisper_kwargs["local_files_only"] = False - model = WhisperModel(**whisper_kwargs) + if app.state.faster_whisper_model is None: + set_faster_whisper_model(app.state.config.WHISPER_MODEL) + model = app.state.faster_whisper_model segments, info = model.transcribe(file_path, beam_size=5) log.info( "Detected language '%s' with probability %f" @@ -395,7 +411,6 @@ def transcribe(file_path): ) transcript = "".join([segment.text for segment in list(segments)]) - data = {"text": transcript.strip()} # save the transcript to a json file @@ -403,7 +418,7 @@ def transcribe(file_path): with open(transcript_file, "w") as f: json.dump(data, f) - print(data) + log.debug(data) return data elif app.state.config.STT_ENGINE == "openai": if is_mp4_audio(file_path): @@ -417,7 +432,7 @@ def transcribe(file_path): files = {"file": (filename, open(file_path, "rb"))} data = {"model": app.state.config.STT_MODEL} - print(files, data) + log.debug(files, data) r = None try: diff --git a/backend/open_webui/apps/webui/routers/functions.py b/backend/open_webui/apps/webui/routers/functions.py index 130603462..aeaceecfb 100644 --- a/backend/open_webui/apps/webui/routers/functions.py +++ b/backend/open_webui/apps/webui/routers/functions.py @@ -9,7 +9,7 @@ from open_webui.apps.webui.models.functions import ( Functions, ) from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports -from open_webui.config import CACHE_DIR, FUNCTIONS_DIR +from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.utils import get_admin_user, get_verified_user diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index 0db21c895..d1ad89dea 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs from open_webui.utils.utils import get_admin_user, get_verified_user -TOOLS_DIR = f"{DATA_DIR}/tools" -os.makedirs(TOOLS_DIR, exist_ok=True) - router = APIRouter() diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/apps/webui/utils.py index 969d5622c..51d379656 100644 --- a/backend/open_webui/apps/webui/utils.py +++ b/backend/open_webui/apps/webui/utils.py @@ -8,7 +8,6 @@ import tempfile from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.tools import Tools -from open_webui.config import FUNCTIONS_DIR, TOOLS_DIR def extract_frontmatter(content): diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 496f2395f..2fed66798 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -548,26 +548,10 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) CACHE_DIR = f"{DATA_DIR}/cache" Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) -#################################### -# Tools DIR -#################################### - -TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") -Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) - - -#################################### -# Functions DIR -#################################### - -FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions") -Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) - #################################### # OLLAMA_BASE_URL #################################### - ENABLE_OLLAMA_API = PersistentConfig( "ENABLE_OLLAMA_API", "ollama.enable", @@ -1223,17 +1207,6 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( ) -#################################### -# Transcribe -#################################### - -WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") -WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") -WHISPER_MODEL_AUTO_UPDATE = ( - os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" -) - - #################################### # Images #################################### @@ -1449,6 +1422,19 @@ IMAGE_GENERATION_MODEL = PersistentConfig( # Audio #################################### +# Transcription +WHISPER_MODEL = PersistentConfig( + "WHISPER_MODEL", + "audio.stt.whisper_model", + os.getenv("WHISPER_MODEL", "base"), +) + +WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") +WHISPER_MODEL_AUTO_UPDATE = ( + os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" +) + + AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig( "AUDIO_STT_OPENAI_API_BASE_URL", "audio.stt.openai.api_base_url", @@ -1470,7 +1456,7 @@ AUDIO_STT_ENGINE = PersistentConfig( AUDIO_STT_MODEL = PersistentConfig( "AUDIO_STT_MODEL", "audio.stt.model", - os.getenv("AUDIO_STT_MODEL", "whisper-1"), + os.getenv("AUDIO_STT_MODEL", ""), ) AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig( diff --git a/src/lib/components/admin/Settings/Audio.svelte b/src/lib/components/admin/Settings/Audio.svelte index ed9ae7334..41230a08b 100644 --- a/src/lib/components/admin/Settings/Audio.svelte +++ b/src/lib/components/admin/Settings/Audio.svelte @@ -38,6 +38,9 @@ let STT_OPENAI_API_KEY = ''; let STT_ENGINE = ''; let STT_MODEL = ''; + let STT_WHISPER_MODEL = ''; + + let STT_WHISPER_MODEL_LOADING = false; // eslint-disable-next-line no-undef let voices: SpeechSynthesisVoice[] = []; @@ -99,18 +102,23 @@ OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL, OPENAI_API_KEY: STT_OPENAI_API_KEY, ENGINE: STT_ENGINE, - MODEL: STT_MODEL + MODEL: STT_MODEL, + WHISPER_MODEL: STT_WHISPER_MODEL } }); if (res) { saveHandler(); - getBackendConfig() - .then(config.set) - .catch(() => {}); + config.set(await getBackendConfig()); } }; + const sttModelUpdateHandler = async () => { + STT_WHISPER_MODEL_LOADING = true; + await updateConfigHandler(); + STT_WHISPER_MODEL_LOADING = false; + }; + onMount(async () => { const res = await getAudioConfig(localStorage.token); @@ -134,6 +142,7 @@ STT_ENGINE = res.stt.ENGINE; STT_MODEL = res.stt.MODEL; + STT_WHISPER_MODEL = res.stt.WHISPER_MODEL; } await getVoices(); @@ -201,6 +210,80 @@ + {:else if STT_ENGINE === ''} +
+
{$i18n.t('STT Model')}
+ +
+
+ +
+ + +
+ +
+ {$i18n.t( + 'If you want to use a custom model, please enter the model name and click the refresh button.' + )} +
+
{/if}