mirror of
https://github.com/open-webui/open-webui
synced 2024-11-25 13:29:53 +00:00
enh: faster whisper custom model support
This commit is contained in:
parent
145a7bbda5
commit
d5c1c2f0a7
@ -63,6 +63,9 @@ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
|
||||
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
||||
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
||||
|
||||
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
||||
app.state.faster_whisper_model = None
|
||||
|
||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
||||
@ -82,6 +85,31 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
if model and app.state.config.STT_ENGINE == "":
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
faster_whisper_kwargs = {
|
||||
"model_size_or_path": model,
|
||||
"device": whisper_device_type,
|
||||
"compute_type": "int8",
|
||||
"download_root": WHISPER_MODEL_DIR,
|
||||
"local_files_only": not auto_update,
|
||||
}
|
||||
|
||||
try:
|
||||
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||
)
|
||||
faster_whisper_kwargs["local_files_only"] = False
|
||||
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||
|
||||
else:
|
||||
app.state.faster_whisper_model = None
|
||||
|
||||
|
||||
class TTSConfigForm(BaseModel):
|
||||
OPENAI_API_BASE_URL: str
|
||||
OPENAI_API_KEY: str
|
||||
@ -99,6 +127,7 @@ class STTConfigForm(BaseModel):
|
||||
OPENAI_API_KEY: str
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
WHISPER_MODEL: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
@ -152,6 +181,7 @@ async def get_audio_config(user=Depends(get_admin_user)):
|
||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.STT_ENGINE,
|
||||
"MODEL": app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
@ -176,6 +206,8 @@ async def update_audio_config(
|
||||
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
||||
app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||
app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
|
||||
|
||||
return {
|
||||
"tts": {
|
||||
@ -194,6 +226,7 @@ async def update_audio_config(
|
||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.STT_ENGINE,
|
||||
"MODEL": app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
@ -367,27 +400,10 @@ def transcribe(file_path):
|
||||
id = filename.split(".")[0]
|
||||
|
||||
if app.state.config.STT_ENGINE == "":
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
whisper_kwargs = {
|
||||
"model_size_or_path": WHISPER_MODEL,
|
||||
"device": whisper_device_type,
|
||||
"compute_type": "int8",
|
||||
"download_root": WHISPER_MODEL_DIR,
|
||||
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
|
||||
}
|
||||
|
||||
log.debug(f"whisper_kwargs: {whisper_kwargs}")
|
||||
|
||||
try:
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||
)
|
||||
whisper_kwargs["local_files_only"] = False
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
if app.state.faster_whisper_model is None:
|
||||
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
|
||||
|
||||
model = app.state.faster_whisper_model
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
@ -395,7 +411,6 @@ def transcribe(file_path):
|
||||
)
|
||||
|
||||
transcript = "".join([segment.text for segment in list(segments)])
|
||||
|
||||
data = {"text": transcript.strip()}
|
||||
|
||||
# save the transcript to a json file
|
||||
@ -403,7 +418,7 @@ def transcribe(file_path):
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
print(data)
|
||||
log.debug(data)
|
||||
return data
|
||||
elif app.state.config.STT_ENGINE == "openai":
|
||||
if is_mp4_audio(file_path):
|
||||
@ -417,7 +432,7 @@ def transcribe(file_path):
|
||||
files = {"file": (filename, open(file_path, "rb"))}
|
||||
data = {"model": app.state.config.STT_MODEL}
|
||||
|
||||
print(files, data)
|
||||
log.debug(files, data)
|
||||
|
||||
r = None
|
||||
try:
|
||||
|
@ -9,7 +9,7 @@ from open_webui.apps.webui.models.functions import (
|
||||
Functions,
|
||||
)
|
||||
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
|
||||
from open_webui.config import CACHE_DIR, FUNCTIONS_DIR
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.tools import get_tools_specs
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
TOOLS_DIR = f"{DATA_DIR}/tools"
|
||||
os.makedirs(TOOLS_DIR, exist_ok=True)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -8,7 +8,6 @@ import tempfile
|
||||
|
||||
from open_webui.apps.webui.models.functions import Functions
|
||||
from open_webui.apps.webui.models.tools import Tools
|
||||
from open_webui.config import FUNCTIONS_DIR, TOOLS_DIR
|
||||
|
||||
|
||||
def extract_frontmatter(content):
|
||||
|
@ -548,26 +548,10 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
||||
CACHE_DIR = f"{DATA_DIR}/cache"
|
||||
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
####################################
|
||||
# Tools DIR
|
||||
####################################
|
||||
|
||||
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
|
||||
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# Functions DIR
|
||||
####################################
|
||||
|
||||
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
|
||||
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
####################################
|
||||
# OLLAMA_BASE_URL
|
||||
####################################
|
||||
|
||||
|
||||
ENABLE_OLLAMA_API = PersistentConfig(
|
||||
"ENABLE_OLLAMA_API",
|
||||
"ollama.enable",
|
||||
@ -1223,17 +1207,6 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# Transcribe
|
||||
####################################
|
||||
|
||||
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
|
||||
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
||||
WHISPER_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# Images
|
||||
####################################
|
||||
@ -1449,6 +1422,19 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
|
||||
# Audio
|
||||
####################################
|
||||
|
||||
# Transcription
|
||||
WHISPER_MODEL = PersistentConfig(
|
||||
"WHISPER_MODEL",
|
||||
"audio.stt.whisper_model",
|
||||
os.getenv("WHISPER_MODEL", "base"),
|
||||
)
|
||||
|
||||
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
||||
WHISPER_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_STT_OPENAI_API_BASE_URL",
|
||||
"audio.stt.openai.api_base_url",
|
||||
@ -1470,7 +1456,7 @@ AUDIO_STT_ENGINE = PersistentConfig(
|
||||
AUDIO_STT_MODEL = PersistentConfig(
|
||||
"AUDIO_STT_MODEL",
|
||||
"audio.stt.model",
|
||||
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
|
||||
os.getenv("AUDIO_STT_MODEL", ""),
|
||||
)
|
||||
|
||||
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
|
@ -38,6 +38,9 @@
|
||||
let STT_OPENAI_API_KEY = '';
|
||||
let STT_ENGINE = '';
|
||||
let STT_MODEL = '';
|
||||
let STT_WHISPER_MODEL = '';
|
||||
|
||||
let STT_WHISPER_MODEL_LOADING = false;
|
||||
|
||||
// eslint-disable-next-line no-undef
|
||||
let voices: SpeechSynthesisVoice[] = [];
|
||||
@ -99,18 +102,23 @@
|
||||
OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL,
|
||||
OPENAI_API_KEY: STT_OPENAI_API_KEY,
|
||||
ENGINE: STT_ENGINE,
|
||||
MODEL: STT_MODEL
|
||||
MODEL: STT_MODEL,
|
||||
WHISPER_MODEL: STT_WHISPER_MODEL
|
||||
}
|
||||
});
|
||||
|
||||
if (res) {
|
||||
saveHandler();
|
||||
getBackendConfig()
|
||||
.then(config.set)
|
||||
.catch(() => {});
|
||||
config.set(await getBackendConfig());
|
||||
}
|
||||
};
|
||||
|
||||
const sttModelUpdateHandler = async () => {
|
||||
STT_WHISPER_MODEL_LOADING = true;
|
||||
await updateConfigHandler();
|
||||
STT_WHISPER_MODEL_LOADING = false;
|
||||
};
|
||||
|
||||
onMount(async () => {
|
||||
const res = await getAudioConfig(localStorage.token);
|
||||
|
||||
@ -134,6 +142,7 @@
|
||||
|
||||
STT_ENGINE = res.stt.ENGINE;
|
||||
STT_MODEL = res.stt.MODEL;
|
||||
STT_WHISPER_MODEL = res.stt.WHISPER_MODEL;
|
||||
}
|
||||
|
||||
await getVoices();
|
||||
@ -201,6 +210,80 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{:else if STT_ENGINE === ''}
|
||||
<div>
|
||||
<div class=" mb-1.5 text-sm font-medium">{$i18n.t('STT Model')}</div>
|
||||
|
||||
<div class="flex w-full">
|
||||
<div class="flex-1 mr-2">
|
||||
<input
|
||||
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
|
||||
placeholder={$i18n.t('Set whisper model')}
|
||||
bind:value={STT_WHISPER_MODEL}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<button
|
||||
class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
|
||||
on:click={() => {
|
||||
sttModelUpdateHandler();
|
||||
}}
|
||||
disabled={STT_WHISPER_MODEL_LOADING}
|
||||
>
|
||||
{#if STT_WHISPER_MODEL_LOADING}
|
||||
<div class="self-center">
|
||||
<svg
|
||||
class=" w-4 h-4"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<style>
|
||||
.spinner_ajPY {
|
||||
transform-origin: center;
|
||||
animation: spinner_AtaB 0.75s infinite linear;
|
||||
}
|
||||
|
||||
@keyframes spinner_AtaB {
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
<path
|
||||
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
|
||||
opacity=".25"
|
||||
/>
|
||||
<path
|
||||
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
|
||||
class="spinner_ajPY"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
{:else}
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 16 16"
|
||||
fill="currentColor"
|
||||
class="w-4 h-4"
|
||||
>
|
||||
<path
|
||||
d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
|
||||
/>
|
||||
<path
|
||||
d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
|
||||
{$i18n.t(
|
||||
'If you want to use a custom model, please enter the model name and click the refresh button.'
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user