enh: faster whisper custom model support

This commit is contained in:
Timothy J. Baek
2024-10-20 21:34:36 -07:00
parent 145a7bbda5
commit d5c1c2f0a7
6 changed files with 140 additions and 60 deletions

View File

@@ -63,6 +63,9 @@ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.faster_whisper_model = None
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
@@ -82,6 +85,31 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
def set_faster_whisper_model(model: str, auto_update: bool = False):
if model and app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
faster_whisper_kwargs = {
"model_size_or_path": model,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not auto_update,
}
try:
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
faster_whisper_kwargs["local_files_only"] = False
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
else:
app.state.faster_whisper_model = None
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
@@ -99,6 +127,7 @@ class STTConfigForm(BaseModel):
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
WHISPER_MODEL: str
class AudioConfigUpdateForm(BaseModel):
@@ -152,6 +181,7 @@ async def get_audio_config(user=Depends(get_admin_user)):
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
},
}
@@ -176,6 +206,8 @@ async def update_audio_config(
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.STT_MODEL = form_data.stt.MODEL
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
return {
"tts": {
@@ -194,6 +226,7 @@ async def update_audio_config(
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
},
}
@@ -367,27 +400,10 @@ def transcribe(file_path):
id = filename.split(".")[0]
if app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
}
log.debug(f"whisper_kwargs: {whisper_kwargs}")
try:
model = WhisperModel(**whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
if app.state.faster_whisper_model is None:
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
model = app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
@@ -395,7 +411,6 @@ def transcribe(file_path):
)
transcript = "".join([segment.text for segment in list(segments)])
data = {"text": transcript.strip()}
# save the transcript to a json file
@@ -403,7 +418,7 @@ def transcribe(file_path):
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
log.debug(data)
return data
elif app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
@@ -417,7 +432,7 @@ def transcribe(file_path):
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": app.state.config.STT_MODEL}
print(files, data)
log.debug(files, data)
r = None
try:

View File

@@ -9,7 +9,7 @@ from open_webui.apps.webui.models.functions import (
Functions,
)
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
from open_webui.config import CACHE_DIR, FUNCTIONS_DIR
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user

View File

@@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tools_specs
from open_webui.utils.utils import get_admin_user, get_verified_user
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()

View File

@@ -8,7 +8,6 @@ import tempfile
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.tools import Tools
from open_webui.config import FUNCTIONS_DIR, TOOLS_DIR
def extract_frontmatter(content):

View File

@@ -548,26 +548,10 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Functions DIR
####################################
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# OLLAMA_BASE_URL
####################################
ENABLE_OLLAMA_API = PersistentConfig(
"ENABLE_OLLAMA_API",
"ollama.enable",
@@ -1223,17 +1207,6 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
)
####################################
# Transcribe
####################################
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
####################################
# Images
####################################
@@ -1449,6 +1422,19 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
# Audio
####################################
# Transcription
WHISPER_MODEL = PersistentConfig(
"WHISPER_MODEL",
"audio.stt.whisper_model",
os.getenv("WHISPER_MODEL", "base"),
)
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
"audio.stt.openai.api_base_url",
@@ -1470,7 +1456,7 @@ AUDIO_STT_ENGINE = PersistentConfig(
AUDIO_STT_MODEL = PersistentConfig(
"AUDIO_STT_MODEL",
"audio.stt.model",
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
os.getenv("AUDIO_STT_MODEL", ""),
)
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(