Merge pull request #4674 from crizCraig/sanitize-11labs-voiceid

sec: Sanitize 11labs voice id to address semgrep security issue: tainted-path-traversal-stdlib-fastapi
This commit is contained in:
Timothy Jaeryang Baek 2024-08-17 15:49:56 +02:00 committed by GitHub
commit bd8df3583d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 51 deletions

View File

@ -1,5 +1,12 @@
import os
import hashlib
import json
import logging
import os
import uuid
from functools import lru_cache
from pathlib import Path
import requests
from fastapi import (
FastAPI,
Request,
@ -8,34 +15,14 @@ from fastapi import (
status,
UploadFile,
File,
Form,
)
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
import uuid
import requests
import hashlib
from pathlib import Path
import json
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
get_current_user,
get_verified_user,
get_admin_user,
)
from utils.misc import calculate_sha256
from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
UPLOAD_DIR,
WHISPER_MODEL,
WHISPER_MODEL_DIR,
WHISPER_MODEL_AUTO_UPDATE,
@ -52,6 +39,12 @@ from config import (
AUDIO_TTS_VOICE,
AppConfig,
)
from constants import ERROR_MESSAGES
from utils.utils import (
get_current_user,
get_verified_user,
get_admin_user,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
@ -261,6 +254,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=400, detail="Invalid JSON payload")
voice_id = payload.get("voice", "")
if voice_id not in get_available_voices():
raise HTTPException(
status_code=400,
detail="Invalid voice id",
)
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
headers = {
@ -466,39 +466,60 @@ async def get_models(user=Depends(get_verified_user)):
return {"models": get_available_models()}
def get_available_voices() -> list[dict]:
def get_available_voices() -> dict:
"""Returns {voice_id: voice_name} dict"""
ret = {}
if app.state.config.TTS_ENGINE == "openai":
return [
{"name": "alloy", "id": "alloy"},
{"name": "echo", "id": "echo"},
{"name": "fable", "id": "fable"},
{"name": "onyx", "id": "onyx"},
{"name": "nova", "id": "nova"},
{"name": "shimmer", "id": "shimmer"},
]
elif app.state.config.TTS_ENGINE == "elevenlabs":
headers = {
"xi-api-key": app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
ret = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
elif app.state.config.TTS_ENGINE == "elevenlabs":
try:
response = requests.get(
"https://api.elevenlabs.io/v1/voices", headers=headers
)
response.raise_for_status()
voices_data = response.json()
ret = get_elevenlabs_voices()
except Exception as e:
# Avoided @lru_cache with exception
pass
voices = []
for voice in voices_data.get("voices", []):
voices.append({"name": voice["name"], "id": voice["voice_id"]})
return voices
except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}")
return ret
return []
@lru_cache
def get_elevenlabs_voices() -> dict:
"""
Note, set the following in your .env file to use Elevenlabs:
AUDIO_TTS_ENGINE=elevenlabs
AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
AUDIO_TTS_MODEL=eleven_multilingual_v2
"""
headers = {
"xi-api-key": app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
}
try:
# TODO: Add retries
response = requests.get(
"https://api.elevenlabs.io/v1/voices", headers=headers
)
response.raise_for_status()
voices_data = response.json()
voices = {}
for voice in voices_data.get("voices", []):
voices[voice["voice_id"]] = voice["name"]
except requests.RequestException as e:
# Avoid @lru_cache with exception
log.error(f"Error fetching voices: {str(e)}")
raise RuntimeError(f"Error fetching voices: {str(e)}")
return voices
@app.get("/voices")
async def get_voices(user=Depends(get_verified_user)):
return {"voices": get_available_voices()}
return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}

View File

@ -1410,13 +1410,13 @@ AUDIO_TTS_ENGINE = PersistentConfig(
AUDIO_TTS_MODEL = PersistentConfig(
"AUDIO_TTS_MODEL",
"audio.tts.model",
os.getenv("AUDIO_TTS_MODEL", "tts-1"),
os.getenv("AUDIO_TTS_MODEL", "tts-1"), # OpenAI default model
)
AUDIO_TTS_VOICE = PersistentConfig(
"AUDIO_TTS_VOICE",
"audio.tts.voice",
os.getenv("AUDIO_TTS_VOICE", "alloy"),
os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice
)