feat: add voxtral support
This commit is contained in:
@@ -3403,6 +3403,24 @@ AUDIO_STT_AZURE_MAX_SPEAKERS = PersistentConfig(
|
||||
os.getenv("AUDIO_STT_AZURE_MAX_SPEAKERS", ""),
|
||||
)
|
||||
|
||||
AUDIO_STT_MISTRAL_API_KEY = PersistentConfig(
|
||||
"AUDIO_STT_MISTRAL_API_KEY",
|
||||
"audio.stt.mistral.api_key",
|
||||
os.getenv("AUDIO_STT_MISTRAL_API_KEY", ""),
|
||||
)
|
||||
|
||||
AUDIO_STT_MISTRAL_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_STT_MISTRAL_API_BASE_URL",
|
||||
"audio.stt.mistral.api_base_url",
|
||||
os.getenv("AUDIO_STT_MISTRAL_API_BASE_URL", "https://api.mistral.ai/v1"),
|
||||
)
|
||||
|
||||
AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = PersistentConfig(
|
||||
"AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS",
|
||||
"audio.stt.mistral.use_chat_completions",
|
||||
os.getenv("AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS", "false").lower() == "true",
|
||||
)
|
||||
|
||||
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_TTS_OPENAI_API_BASE_URL",
|
||||
"audio.tts.openai.api_base_url",
|
||||
|
||||
@@ -175,6 +175,9 @@ from open_webui.config import (
|
||||
AUDIO_STT_AZURE_LOCALES,
|
||||
AUDIO_STT_AZURE_BASE_URL,
|
||||
AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
AUDIO_STT_MISTRAL_API_KEY,
|
||||
AUDIO_STT_MISTRAL_API_BASE_URL,
|
||||
AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
|
||||
AUDIO_TTS_ENGINE,
|
||||
AUDIO_TTS_MODEL,
|
||||
AUDIO_TTS_VOICE,
|
||||
@@ -1108,6 +1111,10 @@ app.state.config.AUDIO_STT_AZURE_LOCALES = AUDIO_STT_AZURE_LOCALES
|
||||
app.state.config.AUDIO_STT_AZURE_BASE_URL = AUDIO_STT_AZURE_BASE_URL
|
||||
app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = AUDIO_STT_AZURE_MAX_SPEAKERS
|
||||
|
||||
app.state.config.AUDIO_STT_MISTRAL_API_KEY = AUDIO_STT_MISTRAL_API_KEY
|
||||
app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = AUDIO_STT_MISTRAL_API_BASE_URL
|
||||
app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS
|
||||
|
||||
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
||||
|
||||
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
import html
|
||||
import base64
|
||||
from functools import lru_cache
|
||||
from pydub import AudioSegment
|
||||
from pydub.silence import split_on_silence
|
||||
@@ -178,6 +179,9 @@ class STTConfigForm(BaseModel):
|
||||
AZURE_LOCALES: str
|
||||
AZURE_BASE_URL: str
|
||||
AZURE_MAX_SPEAKERS: str
|
||||
MISTRAL_API_KEY: str
|
||||
MISTRAL_API_BASE_URL: str
|
||||
MISTRAL_USE_CHAT_COMPLETIONS: bool
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
@@ -214,6 +218,9 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
"MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY,
|
||||
"MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL,
|
||||
"MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -255,6 +262,13 @@ async def update_audio_config(
|
||||
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
|
||||
form_data.stt.AZURE_MAX_SPEAKERS
|
||||
)
|
||||
request.app.state.config.AUDIO_STT_MISTRAL_API_KEY = form_data.stt.MISTRAL_API_KEY
|
||||
request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = (
|
||||
form_data.stt.MISTRAL_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = (
|
||||
form_data.stt.MISTRAL_USE_CHAT_COMPLETIONS
|
||||
)
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
@@ -290,6 +304,9 @@ async def update_audio_config(
|
||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
"MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY,
|
||||
"MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL,
|
||||
"MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -828,6 +845,184 @@ def transcription_handler(request, file_path, metadata):
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
elif request.app.state.config.STT_ENGINE == "mistral":
|
||||
# Check file exists
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=400, detail="Audio file not found")
|
||||
|
||||
# Check file size
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File size exceeds limit of {MAX_FILE_SIZE_MB}MB",
|
||||
)
|
||||
|
||||
api_key = request.app.state.config.AUDIO_STT_MISTRAL_API_KEY
|
||||
api_base_url = (
|
||||
request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL
|
||||
or "https://api.mistral.ai/v1"
|
||||
)
|
||||
use_chat_completions = (
|
||||
request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Mistral API key is required for Mistral STT",
|
||||
)
|
||||
|
||||
r = None
|
||||
try:
|
||||
# Use voxtral-mini-latest as the default model for transcription
|
||||
model = request.app.state.config.STT_MODEL or "voxtral-mini-latest"
|
||||
|
||||
log.info(
|
||||
f"Mistral STT - model: {model}, "
|
||||
f"method: {'chat_completions' if use_chat_completions else 'transcriptions'}"
|
||||
)
|
||||
|
||||
if use_chat_completions:
|
||||
# Use chat completions API with audio input
|
||||
# This method requires mp3 or wav format
|
||||
audio_file_to_use = file_path
|
||||
|
||||
if is_audio_conversion_required(file_path):
|
||||
log.debug("Converting audio to mp3 for chat completions API")
|
||||
converted_path = convert_audio_to_mp3(file_path)
|
||||
if converted_path:
|
||||
audio_file_to_use = converted_path
|
||||
else:
|
||||
log.error("Audio conversion failed")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Audio conversion failed. Chat completions API requires mp3 or wav format.",
|
||||
)
|
||||
|
||||
# Read and encode audio file as base64
|
||||
with open(audio_file_to_use, "rb") as audio_file:
|
||||
audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
|
||||
|
||||
# Prepare chat completions request
|
||||
url = f"{api_base_url}/chat/completions"
|
||||
|
||||
# Add language instruction if specified
|
||||
language = metadata.get("language", None) if metadata else None
|
||||
if language:
|
||||
text_instruction = f"Transcribe this audio exactly as spoken in {language}. Do not translate it."
|
||||
else:
|
||||
text_instruction = "Transcribe this audio exactly as spoken in its original language. Do not translate it to another language."
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": audio_base64,
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": text_instruction
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
r = requests.post(
|
||||
url=url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
response = r.json()
|
||||
|
||||
# Extract transcript from chat completion response
|
||||
transcript = response.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
|
||||
if not transcript:
|
||||
raise ValueError("Empty transcript in response")
|
||||
|
||||
data = {"text": transcript}
|
||||
|
||||
else:
|
||||
# Use dedicated transcriptions API
|
||||
url = f"{api_base_url}/audio/transcriptions"
|
||||
|
||||
# Determine the MIME type
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
if not mime_type:
|
||||
mime_type = "audio/webm"
|
||||
|
||||
# Use context manager to ensure file is properly closed
|
||||
with open(file_path, "rb") as audio_file:
|
||||
files = {"file": (filename, audio_file, mime_type)}
|
||||
data_form = {"model": model}
|
||||
|
||||
# Add language if specified in metadata
|
||||
language = metadata.get("language", None) if metadata else None
|
||||
if language:
|
||||
data_form["language"] = language
|
||||
|
||||
r = requests.post(
|
||||
url=url,
|
||||
files=files,
|
||||
data=data_form,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
response = r.json()
|
||||
|
||||
# Extract transcript from response
|
||||
transcript = response.get("text", "").strip()
|
||||
if not transcript:
|
||||
raise ValueError("Empty transcript in response")
|
||||
|
||||
data = {"text": transcript}
|
||||
|
||||
# Save transcript to json file (consistent with other providers)
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
log.debug(data)
|
||||
return data
|
||||
|
||||
except ValueError as e:
|
||||
log.exception("Error parsing Mistral response")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to parse Mistral response: {str(e)}",
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
|
||||
try:
|
||||
if r is not None and r.status_code != 200:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
else:
|
||||
detail = f"External: {r.text}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status_code", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
|
||||
log.info(f"transcribe: {file_path} {metadata}")
|
||||
|
||||
Reference in New Issue
Block a user