mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
enh: faster whisper custom model support
This commit is contained in:
parent
145a7bbda5
commit
d5c1c2f0a7
@ -63,6 +63,9 @@ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
|
|||||||
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
||||||
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
||||||
|
|
||||||
|
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
||||||
|
app.state.faster_whisper_model = None
|
||||||
|
|
||||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||||
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
||||||
@ -82,6 +85,31 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
|||||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||||
|
if model and app.state.config.STT_ENGINE == "":
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
faster_whisper_kwargs = {
|
||||||
|
"model_size_or_path": model,
|
||||||
|
"device": whisper_device_type,
|
||||||
|
"compute_type": "int8",
|
||||||
|
"download_root": WHISPER_MODEL_DIR,
|
||||||
|
"local_files_only": not auto_update,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||||
|
except Exception:
|
||||||
|
log.warning(
|
||||||
|
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||||
|
)
|
||||||
|
faster_whisper_kwargs["local_files_only"] = False
|
||||||
|
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
app.state.faster_whisper_model = None
|
||||||
|
|
||||||
|
|
||||||
class TTSConfigForm(BaseModel):
|
class TTSConfigForm(BaseModel):
|
||||||
OPENAI_API_BASE_URL: str
|
OPENAI_API_BASE_URL: str
|
||||||
OPENAI_API_KEY: str
|
OPENAI_API_KEY: str
|
||||||
@ -99,6 +127,7 @@ class STTConfigForm(BaseModel):
|
|||||||
OPENAI_API_KEY: str
|
OPENAI_API_KEY: str
|
||||||
ENGINE: str
|
ENGINE: str
|
||||||
MODEL: str
|
MODEL: str
|
||||||
|
WHISPER_MODEL: str
|
||||||
|
|
||||||
|
|
||||||
class AudioConfigUpdateForm(BaseModel):
|
class AudioConfigUpdateForm(BaseModel):
|
||||||
@ -152,6 +181,7 @@ async def get_audio_config(user=Depends(get_admin_user)):
|
|||||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||||
"ENGINE": app.state.config.STT_ENGINE,
|
"ENGINE": app.state.config.STT_ENGINE,
|
||||||
"MODEL": app.state.config.STT_MODEL,
|
"MODEL": app.state.config.STT_MODEL,
|
||||||
|
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,6 +206,8 @@ async def update_audio_config(
|
|||||||
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
||||||
app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||||
app.state.config.STT_MODEL = form_data.stt.MODEL
|
app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||||
|
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||||
|
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"tts": {
|
"tts": {
|
||||||
@ -194,6 +226,7 @@ async def update_audio_config(
|
|||||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||||
"ENGINE": app.state.config.STT_ENGINE,
|
"ENGINE": app.state.config.STT_ENGINE,
|
||||||
"MODEL": app.state.config.STT_MODEL,
|
"MODEL": app.state.config.STT_MODEL,
|
||||||
|
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -367,27 +400,10 @@ def transcribe(file_path):
|
|||||||
id = filename.split(".")[0]
|
id = filename.split(".")[0]
|
||||||
|
|
||||||
if app.state.config.STT_ENGINE == "":
|
if app.state.config.STT_ENGINE == "":
|
||||||
from faster_whisper import WhisperModel
|
if app.state.faster_whisper_model is None:
|
||||||
|
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
|
||||||
whisper_kwargs = {
|
|
||||||
"model_size_or_path": WHISPER_MODEL,
|
|
||||||
"device": whisper_device_type,
|
|
||||||
"compute_type": "int8",
|
|
||||||
"download_root": WHISPER_MODEL_DIR,
|
|
||||||
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
|
|
||||||
}
|
|
||||||
|
|
||||||
log.debug(f"whisper_kwargs: {whisper_kwargs}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
model = WhisperModel(**whisper_kwargs)
|
|
||||||
except Exception:
|
|
||||||
log.warning(
|
|
||||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
|
||||||
)
|
|
||||||
whisper_kwargs["local_files_only"] = False
|
|
||||||
model = WhisperModel(**whisper_kwargs)
|
|
||||||
|
|
||||||
|
model = app.state.faster_whisper_model
|
||||||
segments, info = model.transcribe(file_path, beam_size=5)
|
segments, info = model.transcribe(file_path, beam_size=5)
|
||||||
log.info(
|
log.info(
|
||||||
"Detected language '%s' with probability %f"
|
"Detected language '%s' with probability %f"
|
||||||
@ -395,7 +411,6 @@ def transcribe(file_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
transcript = "".join([segment.text for segment in list(segments)])
|
transcript = "".join([segment.text for segment in list(segments)])
|
||||||
|
|
||||||
data = {"text": transcript.strip()}
|
data = {"text": transcript.strip()}
|
||||||
|
|
||||||
# save the transcript to a json file
|
# save the transcript to a json file
|
||||||
@ -403,7 +418,7 @@ def transcribe(file_path):
|
|||||||
with open(transcript_file, "w") as f:
|
with open(transcript_file, "w") as f:
|
||||||
json.dump(data, f)
|
json.dump(data, f)
|
||||||
|
|
||||||
print(data)
|
log.debug(data)
|
||||||
return data
|
return data
|
||||||
elif app.state.config.STT_ENGINE == "openai":
|
elif app.state.config.STT_ENGINE == "openai":
|
||||||
if is_mp4_audio(file_path):
|
if is_mp4_audio(file_path):
|
||||||
@ -417,7 +432,7 @@ def transcribe(file_path):
|
|||||||
files = {"file": (filename, open(file_path, "rb"))}
|
files = {"file": (filename, open(file_path, "rb"))}
|
||||||
data = {"model": app.state.config.STT_MODEL}
|
data = {"model": app.state.config.STT_MODEL}
|
||||||
|
|
||||||
print(files, data)
|
log.debug(files, data)
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
try:
|
try:
|
||||||
|
@ -9,7 +9,7 @@ from open_webui.apps.webui.models.functions import (
|
|||||||
Functions,
|
Functions,
|
||||||
)
|
)
|
||||||
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
|
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
|
||||||
from open_webui.config import CACHE_DIR, FUNCTIONS_DIR
|
from open_webui.config import CACHE_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||||
|
@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|||||||
from open_webui.utils.tools import get_tools_specs
|
from open_webui.utils.tools import get_tools_specs
|
||||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
TOOLS_DIR = f"{DATA_DIR}/tools"
|
|
||||||
os.makedirs(TOOLS_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@ import tempfile
|
|||||||
|
|
||||||
from open_webui.apps.webui.models.functions import Functions
|
from open_webui.apps.webui.models.functions import Functions
|
||||||
from open_webui.apps.webui.models.tools import Tools
|
from open_webui.apps.webui.models.tools import Tools
|
||||||
from open_webui.config import FUNCTIONS_DIR, TOOLS_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def extract_frontmatter(content):
|
def extract_frontmatter(content):
|
||||||
|
@ -548,26 +548,10 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
|||||||
CACHE_DIR = f"{DATA_DIR}/cache"
|
CACHE_DIR = f"{DATA_DIR}/cache"
|
||||||
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
####################################
|
|
||||||
# Tools DIR
|
|
||||||
####################################
|
|
||||||
|
|
||||||
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
|
|
||||||
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
####################################
|
|
||||||
# Functions DIR
|
|
||||||
####################################
|
|
||||||
|
|
||||||
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
|
|
||||||
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# OLLAMA_BASE_URL
|
# OLLAMA_BASE_URL
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
|
|
||||||
ENABLE_OLLAMA_API = PersistentConfig(
|
ENABLE_OLLAMA_API = PersistentConfig(
|
||||||
"ENABLE_OLLAMA_API",
|
"ENABLE_OLLAMA_API",
|
||||||
"ollama.enable",
|
"ollama.enable",
|
||||||
@ -1223,17 +1207,6 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
|
||||||
# Transcribe
|
|
||||||
####################################
|
|
||||||
|
|
||||||
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
|
|
||||||
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
|
||||||
WHISPER_MODEL_AUTO_UPDATE = (
|
|
||||||
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Images
|
# Images
|
||||||
####################################
|
####################################
|
||||||
@ -1449,6 +1422,19 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
|
|||||||
# Audio
|
# Audio
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
|
# Transcription
|
||||||
|
WHISPER_MODEL = PersistentConfig(
|
||||||
|
"WHISPER_MODEL",
|
||||||
|
"audio.stt.whisper_model",
|
||||||
|
os.getenv("WHISPER_MODEL", "base"),
|
||||||
|
)
|
||||||
|
|
||||||
|
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
||||||
|
WHISPER_MODEL_AUTO_UPDATE = (
|
||||||
|
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
||||||
"AUDIO_STT_OPENAI_API_BASE_URL",
|
"AUDIO_STT_OPENAI_API_BASE_URL",
|
||||||
"audio.stt.openai.api_base_url",
|
"audio.stt.openai.api_base_url",
|
||||||
@ -1470,7 +1456,7 @@ AUDIO_STT_ENGINE = PersistentConfig(
|
|||||||
AUDIO_STT_MODEL = PersistentConfig(
|
AUDIO_STT_MODEL = PersistentConfig(
|
||||||
"AUDIO_STT_MODEL",
|
"AUDIO_STT_MODEL",
|
||||||
"audio.stt.model",
|
"audio.stt.model",
|
||||||
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
|
os.getenv("AUDIO_STT_MODEL", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
|
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
|
||||||
|
@ -38,6 +38,9 @@
|
|||||||
let STT_OPENAI_API_KEY = '';
|
let STT_OPENAI_API_KEY = '';
|
||||||
let STT_ENGINE = '';
|
let STT_ENGINE = '';
|
||||||
let STT_MODEL = '';
|
let STT_MODEL = '';
|
||||||
|
let STT_WHISPER_MODEL = '';
|
||||||
|
|
||||||
|
let STT_WHISPER_MODEL_LOADING = false;
|
||||||
|
|
||||||
// eslint-disable-next-line no-undef
|
// eslint-disable-next-line no-undef
|
||||||
let voices: SpeechSynthesisVoice[] = [];
|
let voices: SpeechSynthesisVoice[] = [];
|
||||||
@ -99,18 +102,23 @@
|
|||||||
OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL,
|
OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL,
|
||||||
OPENAI_API_KEY: STT_OPENAI_API_KEY,
|
OPENAI_API_KEY: STT_OPENAI_API_KEY,
|
||||||
ENGINE: STT_ENGINE,
|
ENGINE: STT_ENGINE,
|
||||||
MODEL: STT_MODEL
|
MODEL: STT_MODEL,
|
||||||
|
WHISPER_MODEL: STT_WHISPER_MODEL
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if (res) {
|
if (res) {
|
||||||
saveHandler();
|
saveHandler();
|
||||||
getBackendConfig()
|
config.set(await getBackendConfig());
|
||||||
.then(config.set)
|
|
||||||
.catch(() => {});
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const sttModelUpdateHandler = async () => {
|
||||||
|
STT_WHISPER_MODEL_LOADING = true;
|
||||||
|
await updateConfigHandler();
|
||||||
|
STT_WHISPER_MODEL_LOADING = false;
|
||||||
|
};
|
||||||
|
|
||||||
onMount(async () => {
|
onMount(async () => {
|
||||||
const res = await getAudioConfig(localStorage.token);
|
const res = await getAudioConfig(localStorage.token);
|
||||||
|
|
||||||
@ -134,6 +142,7 @@
|
|||||||
|
|
||||||
STT_ENGINE = res.stt.ENGINE;
|
STT_ENGINE = res.stt.ENGINE;
|
||||||
STT_MODEL = res.stt.MODEL;
|
STT_MODEL = res.stt.MODEL;
|
||||||
|
STT_WHISPER_MODEL = res.stt.WHISPER_MODEL;
|
||||||
}
|
}
|
||||||
|
|
||||||
await getVoices();
|
await getVoices();
|
||||||
@ -201,6 +210,80 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
{:else if STT_ENGINE === ''}
|
||||||
|
<div>
|
||||||
|
<div class=" mb-1.5 text-sm font-medium">{$i18n.t('STT Model')}</div>
|
||||||
|
|
||||||
|
<div class="flex w-full">
|
||||||
|
<div class="flex-1 mr-2">
|
||||||
|
<input
|
||||||
|
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
|
||||||
|
placeholder={$i18n.t('Set whisper model')}
|
||||||
|
bind:value={STT_WHISPER_MODEL}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button
|
||||||
|
class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
|
||||||
|
on:click={() => {
|
||||||
|
sttModelUpdateHandler();
|
||||||
|
}}
|
||||||
|
disabled={STT_WHISPER_MODEL_LOADING}
|
||||||
|
>
|
||||||
|
{#if STT_WHISPER_MODEL_LOADING}
|
||||||
|
<div class="self-center">
|
||||||
|
<svg
|
||||||
|
class=" w-4 h-4"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
fill="currentColor"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
>
|
||||||
|
<style>
|
||||||
|
.spinner_ajPY {
|
||||||
|
transform-origin: center;
|
||||||
|
animation: spinner_AtaB 0.75s infinite linear;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes spinner_AtaB {
|
||||||
|
100% {
|
||||||
|
transform: rotate(360deg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
<path
|
||||||
|
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
|
||||||
|
opacity=".25"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
|
||||||
|
class="spinner_ajPY"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
{:else}
|
||||||
|
<svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
viewBox="0 0 16 16"
|
||||||
|
fill="currentColor"
|
||||||
|
class="w-4 h-4"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
|
||||||
|
{$i18n.t(
|
||||||
|
'If you want to use a custom model, please enter the model name and click the refresh button.'
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user