diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 20519b59b..21ad2886a 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -1,5 +1,12 @@ -import os +import hashlib +import json import logging +import os +import uuid +from functools import lru_cache +from pathlib import Path + +import requests from fastapi import ( FastAPI, Request, @@ -8,34 +15,14 @@ from fastapi import ( status, UploadFile, File, - Form, ) -from fastapi.responses import StreamingResponse, JSONResponse, FileResponse - from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse from pydantic import BaseModel - -import uuid -import requests -import hashlib -from pathlib import Path -import json - -from constants import ERROR_MESSAGES -from utils.utils import ( - decode_token, - get_current_user, - get_verified_user, - get_admin_user, -) -from utils.misc import calculate_sha256 - - from config import ( SRC_LOG_LEVELS, CACHE_DIR, - UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, WHISPER_MODEL_AUTO_UPDATE, @@ -52,6 +39,12 @@ from config import ( AUDIO_TTS_VOICE, AppConfig, ) +from constants import ERROR_MESSAGES +from utils.utils import ( + get_current_user, + get_verified_user, + get_admin_user, +) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["AUDIO"]) @@ -261,6 +254,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=400, detail="Invalid JSON payload") voice_id = payload.get("voice", "") + + if voice_id not in get_available_voices(): + raise HTTPException( + status_code=400, + detail="Invalid voice id", + ) + url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" headers = { @@ -466,39 +466,60 @@ async def get_models(user=Depends(get_verified_user)): return {"models": get_available_models()} -def get_available_voices() -> list[dict]: +def get_available_voices() -> dict: + """Returns {voice_id: voice_name} dict""" + ret = {} if app.state.config.TTS_ENGINE == "openai": - return [ - {"name": "alloy", "id": "alloy"}, - {"name": "echo", "id": "echo"}, - {"name": "fable", "id": "fable"}, - {"name": "onyx", "id": "onyx"}, - {"name": "nova", "id": "nova"}, - {"name": "shimmer", "id": "shimmer"}, - ] - elif app.state.config.TTS_ENGINE == "elevenlabs": - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", + ret = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", } - + elif app.state.config.TTS_ENGINE == "elevenlabs": try: - response = requests.get( - "https://api.elevenlabs.io/v1/voices", headers=headers - ) - response.raise_for_status() - voices_data = response.json() + ret = get_elevenlabs_voices() + except Exception as e: + # Avoided @lru_cache with exception + pass - voices = [] - for voice in voices_data.get("voices", []): - voices.append({"name": voice["name"], "id": voice["voice_id"]}) - return voices - except requests.RequestException as e: - log.error(f"Error fetching voices: {str(e)}") + return ret - return [] + +@lru_cache +def get_elevenlabs_voices() -> dict: + """ + Note, set the following in your .env file to use Elevenlabs: + AUDIO_TTS_ENGINE=elevenlabs + AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key + AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices + AUDIO_TTS_MODEL=eleven_multilingual_v2 + """ + headers = { + "xi-api-key": app.state.config.TTS_API_KEY, + "Content-Type": "application/json", + } + try: + # TODO: Add retries + response = requests.get( + "https://api.elevenlabs.io/v1/voices", headers=headers + ) + response.raise_for_status() + voices_data = response.json() + + voices = {} + for voice in voices_data.get("voices", []): + voices[voice["voice_id"]] = voice["name"] + except requests.RequestException as e: + # Avoid @lru_cache with exception + log.error(f"Error fetching voices: {str(e)}") + raise RuntimeError(f"Error fetching voices: {str(e)}") + + return voices @app.get("/voices") async def get_voices(user=Depends(get_verified_user)): - return {"voices": get_available_voices()} + return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]} diff --git a/backend/config.py b/backend/config.py index 07ee06a58..ef2feb8c9 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1410,13 +1410,13 @@ AUDIO_TTS_ENGINE = PersistentConfig( AUDIO_TTS_MODEL = PersistentConfig( "AUDIO_TTS_MODEL", "audio.tts.model", - os.getenv("AUDIO_TTS_MODEL", "tts-1"), + os.getenv("AUDIO_TTS_MODEL", "tts-1"), # OpenAI default model ) AUDIO_TTS_VOICE = PersistentConfig( "AUDIO_TTS_VOICE", "audio.tts.voice", - os.getenv("AUDIO_TTS_VOICE", "alloy"), + os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice )