mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
enh: faster whisper custom model support
This commit is contained in:
@@ -63,6 +63,9 @@ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
|
||||
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
||||
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
||||
|
||||
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
||||
app.state.faster_whisper_model = None
|
||||
|
||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
||||
@@ -82,6 +85,31 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
if model and app.state.config.STT_ENGINE == "":
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
faster_whisper_kwargs = {
|
||||
"model_size_or_path": model,
|
||||
"device": whisper_device_type,
|
||||
"compute_type": "int8",
|
||||
"download_root": WHISPER_MODEL_DIR,
|
||||
"local_files_only": not auto_update,
|
||||
}
|
||||
|
||||
try:
|
||||
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||
)
|
||||
faster_whisper_kwargs["local_files_only"] = False
|
||||
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||
|
||||
else:
|
||||
app.state.faster_whisper_model = None
|
||||
|
||||
|
||||
class TTSConfigForm(BaseModel):
|
||||
OPENAI_API_BASE_URL: str
|
||||
OPENAI_API_KEY: str
|
||||
@@ -99,6 +127,7 @@ class STTConfigForm(BaseModel):
|
||||
OPENAI_API_KEY: str
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
WHISPER_MODEL: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
@@ -152,6 +181,7 @@ async def get_audio_config(user=Depends(get_admin_user)):
|
||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.STT_ENGINE,
|
||||
"MODEL": app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -176,6 +206,8 @@ async def update_audio_config(
|
||||
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
||||
app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||
app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
|
||||
|
||||
return {
|
||||
"tts": {
|
||||
@@ -194,6 +226,7 @@ async def update_audio_config(
|
||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.STT_ENGINE,
|
||||
"MODEL": app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -367,27 +400,10 @@ def transcribe(file_path):
|
||||
id = filename.split(".")[0]
|
||||
|
||||
if app.state.config.STT_ENGINE == "":
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
whisper_kwargs = {
|
||||
"model_size_or_path": WHISPER_MODEL,
|
||||
"device": whisper_device_type,
|
||||
"compute_type": "int8",
|
||||
"download_root": WHISPER_MODEL_DIR,
|
||||
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
|
||||
}
|
||||
|
||||
log.debug(f"whisper_kwargs: {whisper_kwargs}")
|
||||
|
||||
try:
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||
)
|
||||
whisper_kwargs["local_files_only"] = False
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
if app.state.faster_whisper_model is None:
|
||||
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
|
||||
|
||||
model = app.state.faster_whisper_model
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
@@ -395,7 +411,6 @@ def transcribe(file_path):
|
||||
)
|
||||
|
||||
transcript = "".join([segment.text for segment in list(segments)])
|
||||
|
||||
data = {"text": transcript.strip()}
|
||||
|
||||
# save the transcript to a json file
|
||||
@@ -403,7 +418,7 @@ def transcribe(file_path):
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
print(data)
|
||||
log.debug(data)
|
||||
return data
|
||||
elif app.state.config.STT_ENGINE == "openai":
|
||||
if is_mp4_audio(file_path):
|
||||
@@ -417,7 +432,7 @@ def transcribe(file_path):
|
||||
files = {"file": (filename, open(file_path, "rb"))}
|
||||
data = {"model": app.state.config.STT_MODEL}
|
||||
|
||||
print(files, data)
|
||||
log.debug(files, data)
|
||||
|
||||
r = None
|
||||
try:
|
||||
|
||||
@@ -9,7 +9,7 @@ from open_webui.apps.webui.models.functions import (
|
||||
Functions,
|
||||
)
|
||||
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
|
||||
from open_webui.config import CACHE_DIR, FUNCTIONS_DIR
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
@@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.tools import get_tools_specs
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
TOOLS_DIR = f"{DATA_DIR}/tools"
|
||||
os.makedirs(TOOLS_DIR, exist_ok=True)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import tempfile
|
||||
|
||||
from open_webui.apps.webui.models.functions import Functions
|
||||
from open_webui.apps.webui.models.tools import Tools
|
||||
from open_webui.config import FUNCTIONS_DIR, TOOLS_DIR
|
||||
|
||||
|
||||
def extract_frontmatter(content):
|
||||
|
||||
@@ -548,26 +548,10 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
||||
CACHE_DIR = f"{DATA_DIR}/cache"
|
||||
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
####################################
|
||||
# Tools DIR
|
||||
####################################
|
||||
|
||||
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
|
||||
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# Functions DIR
|
||||
####################################
|
||||
|
||||
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
|
||||
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
####################################
|
||||
# OLLAMA_BASE_URL
|
||||
####################################
|
||||
|
||||
|
||||
ENABLE_OLLAMA_API = PersistentConfig(
|
||||
"ENABLE_OLLAMA_API",
|
||||
"ollama.enable",
|
||||
@@ -1223,17 +1207,6 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# Transcribe
|
||||
####################################
|
||||
|
||||
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
|
||||
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
||||
WHISPER_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# Images
|
||||
####################################
|
||||
@@ -1449,6 +1422,19 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
|
||||
# Audio
|
||||
####################################
|
||||
|
||||
# Transcription
|
||||
WHISPER_MODEL = PersistentConfig(
|
||||
"WHISPER_MODEL",
|
||||
"audio.stt.whisper_model",
|
||||
os.getenv("WHISPER_MODEL", "base"),
|
||||
)
|
||||
|
||||
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
||||
WHISPER_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_STT_OPENAI_API_BASE_URL",
|
||||
"audio.stt.openai.api_base_url",
|
||||
@@ -1470,7 +1456,7 @@ AUDIO_STT_ENGINE = PersistentConfig(
|
||||
AUDIO_STT_MODEL = PersistentConfig(
|
||||
"AUDIO_STT_MODEL",
|
||||
"audio.stt.model",
|
||||
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
|
||||
os.getenv("AUDIO_STT_MODEL", ""),
|
||||
)
|
||||
|
||||
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
|
||||
Reference in New Issue
Block a user