mirror of
https://github.com/matatonic/openedai-speech
synced 2025-06-26 18:16:32 +00:00
0.14.0 +streaming, +pcm, +wav, +temp, top_p, etc.
This commit is contained in:
parent
65c03e3448
commit
ae6a384e75
30
README.md
30
README.md
@ -10,7 +10,7 @@ An OpenAI API compatible text to speech server.
|
||||
Full Compatibility:
|
||||
* `tts-1`: `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer` (configurable)
|
||||
* `tts-1-hd`: `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer` (configurable, uses OpenAI samples by default)
|
||||
* response_format: `mp3`, `opus`, `aac`, or `flac`
|
||||
* response_format: `mp3`, `opus`, `aac`, `flac`, `wav` and `pcm`
|
||||
* speed 0.25-4.0 (and more)
|
||||
|
||||
Details:
|
||||
@ -20,6 +20,8 @@ Details:
|
||||
* Custom cloned voices can be used for tts-1-hd, See: [Custom Voices Howto](#custom-voices-howto)
|
||||
* 🌐 [Multilingual](#multilingual) support with XTTS voices
|
||||
* [Custom fine-tuned XTTS model support](#custom-fine-tuned-model-support)
|
||||
* Configurable [generation parameters](#generation-parameters)
|
||||
* Streamed output while generating
|
||||
* Occasionally, certain words or symbols may sound incorrect, you can fix them with regex via `pre_process_map.yaml`
|
||||
|
||||
|
||||
@ -27,6 +29,14 @@ If you find a better voice match for `tts-1` or `tts-1-hd`, please let me know s
|
||||
|
||||
## Recent Changes
|
||||
|
||||
Version 0.14.0, 2024-06-26
|
||||
|
||||
* Added `response_format`: `wav` and `pcm` support
|
||||
* Output streaming (while generating) for `tts-1` and `tts-1-hd`
|
||||
* Enhanced [generation parameters](#generation-parameters) for xtts models (temperature, top_p, etc.)
|
||||
* Idle unload timer (optional) - doesn't work perfectly yet
|
||||
* Improved error handling
|
||||
|
||||
Version 0.13.0, 2024-06-25
|
||||
|
||||
* Added [Custom fine-tuned XTTS model support](#custom-fine-tuned-model-support)
|
||||
@ -313,3 +323,21 @@ tts-1-hd:
|
||||
model_path: voices/halo
|
||||
```
|
||||
3) The model will be loaded when you access the voice for the first time (`--preload` doesn't work with custom models yet)
|
||||
|
||||
## Generation Parameters
|
||||
|
||||
The generation of XTTSv2 voices can be fine tuned with the following options (defaults included below):
|
||||
|
||||
```yaml
|
||||
tts-1-hd:
|
||||
alloy:
|
||||
model: xtts
|
||||
speaker: voices/alloy.wav
|
||||
enable_text_splitting: True
|
||||
length_penalty: 1.0
|
||||
repetition_penalty: 10
|
||||
speed: 1.0
|
||||
temperature: 0.75
|
||||
top_k: 50
|
||||
top_p: 0.85
|
||||
```
|
@ -2,10 +2,7 @@
|
||||
set COQUI_TOS_AGREED=1
|
||||
set TTS_HOME=voices
|
||||
|
||||
set MODELS=%*
|
||||
if "%MODELS%" == "" set MODELS=xtts
|
||||
|
||||
for %%i in (%MODELS%) do (
|
||||
for %%i in (%*) do (
|
||||
python -c "from TTS.utils.manage import ModelManager; ModelManager().download_model('%%i')"
|
||||
)
|
||||
call download_samples.bat
|
||||
|
@ -2,8 +2,7 @@
|
||||
export COQUI_TOS_AGREED=1
|
||||
export TTS_HOME=voices
|
||||
|
||||
MODELS=${*:-xtts}
|
||||
for model in $MODELS; do
|
||||
for model in $*; do
|
||||
python -c "from TTS.utils.manage import ModelManager; ModelManager().download_model('$model')"
|
||||
done
|
||||
./download_samples.sh
|
@ -4,7 +4,9 @@ loguru
|
||||
# piper-tts
|
||||
piper-tts==1.2.0
|
||||
# xtts
|
||||
TTS
|
||||
TTS==0.22.0
|
||||
# https://github.com/huggingface/transformers/issues/31040
|
||||
transformers<4.41.0
|
||||
# XXX, 3.8+ has some issue for now
|
||||
spacy==3.7.4
|
||||
|
||||
|
@ -4,7 +4,9 @@ loguru
|
||||
# piper-tts
|
||||
piper-tts==1.2.0
|
||||
# xtts
|
||||
TTS
|
||||
TTS==0.22.0
|
||||
# https://github.com/huggingface/transformers/issues/31040
|
||||
transformers<4.41.0
|
||||
# XXX, 3.8+ has some issue for now
|
||||
spacy==3.7.4
|
||||
|
||||
|
@ -2,5 +2,5 @@ TTS_HOME=voices
|
||||
HF_HOME=voices
|
||||
#PRELOAD_MODEL=xtts
|
||||
#PRELOAD_MODEL=xtts_v2.0.2
|
||||
#EXTRA_ARGS=--log-level DEBUG
|
||||
#EXTRA_ARGS=--log-level DEBUG --unload-timer 300
|
||||
#USE_ROCM=1
|
219
speech.py
219
speech.py
@ -1,51 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import gc
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import yaml
|
||||
from fastapi.responses import StreamingResponse
|
||||
import uvicorn
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
import contextlib
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from openedai import OpenAIStub, BadRequestError, ServiceUnavailableError
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app):
|
||||
yield
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
except:
|
||||
pass
|
||||
|
||||
app = OpenAIStub(lifespan=lifespan)
|
||||
xtts = None
|
||||
args = None
|
||||
app = OpenAIStub()
|
||||
|
||||
def unload_model():
|
||||
import torch, gc
|
||||
global xtts
|
||||
if xtts:
|
||||
logger.info("Unloading model")
|
||||
xtts.xtts.to('cpu') # this was required to free up GPU memory...
|
||||
del xtts
|
||||
xtts = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
class xtts_wrapper():
|
||||
def __init__(self, model_name, device, model_path=None):
|
||||
check_interval: int = 1
|
||||
|
||||
def __init__(self, model_name, device, model_path=None, unload_timer=None):
|
||||
self.model_name = model_name
|
||||
self.unload_timer = unload_timer
|
||||
self.last_used = time.time()
|
||||
self.timer = None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
logger.info(f"Loading model {self.model_name} to {device}")
|
||||
|
||||
if model_path: # custom model # and config_path
|
||||
config_path=os.path.join(model_path, 'config.json')
|
||||
self.xtts = TTS(model_path=model_path, config_path=config_path).to(device)
|
||||
else:
|
||||
self.xtts = TTS(model_name=model_name).to(device)
|
||||
if model_path is None:
|
||||
model_path = ModelManager().download_model(model_name)[0]
|
||||
|
||||
def tts(self, text, speaker_wav, speed, language):
|
||||
tf, file_path = tempfile.mkstemp(suffix='.wav', prefix='openedai-speech-')
|
||||
config_path = os.path.join(model_path, 'config.json')
|
||||
config = XttsConfig()
|
||||
config.load_json(config_path)
|
||||
self.xtts = Xtts.init_from_config(config)
|
||||
self.xtts.load_checkpoint(config, checkpoint_dir=model_path, use_deepspeed=False) # XXX there are no prebuilt deepspeed wheels??
|
||||
self.xtts = self.xtts.to(device=device)
|
||||
self.xtts.eval()
|
||||
|
||||
if self.unload_timer:
|
||||
logger.info(f"Setting unload timer to {self.unload_timer} seconds")
|
||||
self.not_idle()
|
||||
self.check_idle()
|
||||
|
||||
def not_idle(self):
|
||||
with self.lock:
|
||||
self.last_used = time.time()
|
||||
|
||||
def check_idle(self):
|
||||
with self.lock:
|
||||
if time.time() - self.last_used >= self.unload_timer:
|
||||
print("Unloading TTS model due to inactivity")
|
||||
unload_model()
|
||||
else:
|
||||
# Reschedule the check
|
||||
self.timer = threading.Timer(self.check_interval, self.check_idle)
|
||||
self.timer.daemon = True
|
||||
self.timer.start()
|
||||
|
||||
def tts(self, text, language, speaker_wav, **hf_generate_kwargs):
|
||||
self.not_idle()
|
||||
try:
|
||||
# TODO: support speaker= as voice id instead of just wav
|
||||
file_path = self.xtts.tts_to_file(
|
||||
text=text,
|
||||
language=language,
|
||||
speaker_wav=speaker_wav,
|
||||
speed=speed,
|
||||
file_path=file_path,
|
||||
)
|
||||
with torch.no_grad():
|
||||
gpt_cond_latent, speaker_embedding = self.xtts.get_conditioning_latents(audio_path=[speaker_wav]) # XXX TODO: allow multiple wav
|
||||
|
||||
for wav in self.xtts.inference_stream(text, language, gpt_cond_latent, speaker_embedding, **hf_generate_kwargs):
|
||||
yield wav.cpu().numpy().tobytes() # assumes wav data is f32le
|
||||
self.not_idle()
|
||||
|
||||
finally:
|
||||
os.unlink(file_path)
|
||||
|
||||
return tf
|
||||
self.not_idle()
|
||||
|
||||
def default_exists(filename: str):
|
||||
if not os.path.exists(filename):
|
||||
@ -92,10 +146,10 @@ class GenerateSpeechRequest(BaseModel):
|
||||
|
||||
def build_ffmpeg_args(response_format, input_format, sample_rate):
|
||||
# Convert the output to the desired format using ffmpeg
|
||||
if input_format == 'raw':
|
||||
ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", "s16le", "-ar", sample_rate, "-ac", "1", "-i", "-"]
|
||||
else:
|
||||
if input_format == 'WAV':
|
||||
ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", "WAV", "-i", "-"]
|
||||
else:
|
||||
ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", input_format, "-ar", sample_rate, "-ac", "1", "-i", "-"]
|
||||
|
||||
if response_format == "mp3":
|
||||
ffmpeg_args.extend(["-f", "mp3", "-c:a", "libmp3lame", "-ab", "64k"])
|
||||
@ -105,6 +159,10 @@ def build_ffmpeg_args(response_format, input_format, sample_rate):
|
||||
ffmpeg_args.extend(["-f", "adts", "-c:a", "aac", "-ab", "64k"])
|
||||
elif response_format == "flac":
|
||||
ffmpeg_args.extend(["-f", "flac", "-c:a", "flac"])
|
||||
elif response_format == "wav":
|
||||
ffmpeg_args.extend(["-f", "wav", "-c:a", "pcm_s16le"])
|
||||
elif response_format == "pcm": # even though pcm is technically 'raw', we still use ffmpeg to adjust the speed
|
||||
ffmpeg_args.extend(["-f", "s16le", "-c:a", "pcm_s16le"])
|
||||
|
||||
return ffmpeg_args
|
||||
|
||||
@ -121,18 +179,27 @@ async def generate_speech(request: GenerateSpeechRequest):
|
||||
|
||||
model = request.model
|
||||
voice = request.voice
|
||||
response_format = request.response_format
|
||||
response_format = request.response_format.lower()
|
||||
speed = request.speed
|
||||
|
||||
# Set the Content-Type header based on the requested format
|
||||
if response_format == "mp3":
|
||||
media_type = "audio/mpeg"
|
||||
elif response_format == "opus":
|
||||
media_type = "audio/ogg;codecs=opus"
|
||||
media_type = "audio/ogg;codec=opus" # codecs?
|
||||
elif response_format == "aac":
|
||||
media_type = "audio/aac"
|
||||
elif response_format == "flac":
|
||||
media_type = "audio/x-flac"
|
||||
elif response_format == "wav":
|
||||
media_type = "audio/wav"
|
||||
elif response_format == "pcm":
|
||||
if model == 'tts-1': # piper
|
||||
media_type = "audio/pcm;rate=22050"
|
||||
elif model == 'tts-1-hd':
|
||||
media_type = "audio/pcm;rate=24000"
|
||||
else:
|
||||
BadRequestError(f"Invalid response_format: '{response_format}'", param='response_format')
|
||||
|
||||
ffmpeg_args = None
|
||||
tts_io_out = None
|
||||
@ -158,51 +225,77 @@ async def generate_speech(request: GenerateSpeechRequest):
|
||||
tts_proc.stdin.write(bytearray(input_text.encode('utf-8')))
|
||||
tts_proc.stdin.close()
|
||||
tts_io_out = tts_proc.stdout
|
||||
ffmpeg_args = build_ffmpeg_args(response_format, input_format="raw", sample_rate="22050")
|
||||
ffmpeg_args = build_ffmpeg_args(response_format, input_format="s16le", sample_rate="22050")
|
||||
|
||||
# Pipe the output from piper/xtts to the input of ffmpeg
|
||||
ffmpeg_args.extend(["-"])
|
||||
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_io_out, stdout=subprocess.PIPE)
|
||||
|
||||
return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type)
|
||||
# Use xtts for tts-1-hd
|
||||
elif model == 'tts-1-hd':
|
||||
voice_map = map_voice_to_speaker(voice, 'tts-1-hd')
|
||||
try:
|
||||
tts_model = voice_map['model']
|
||||
speaker = voice_map['speaker']
|
||||
tts_model = voice_map.pop('model')
|
||||
speaker = voice_map.pop('speaker')
|
||||
|
||||
except KeyError as e:
|
||||
raise ServiceUnavailableError(f"Configuration error: tts-1-hd voice '{voice}' is missing setting. KeyError: {e}")
|
||||
|
||||
language = voice_map.get('language', 'en')
|
||||
tts_model_path = voice_map.get('model_path', None)
|
||||
if xtts and xtts.model_name != tts_model:
|
||||
unload_model()
|
||||
|
||||
if xtts is not None and xtts.model_name != tts_model:
|
||||
import torch, gc
|
||||
del xtts
|
||||
xtts = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
tts_model_path = voice_map.pop('model_path', None) # XXX changing this on the fly is ignored if you keep the same name
|
||||
|
||||
else:
|
||||
if xtts is None:
|
||||
xtts = xtts_wrapper(tts_model, device=args.xtts_device, model_path=tts_model_path)
|
||||
if xtts is None:
|
||||
xtts = xtts_wrapper(tts_model, device=args.xtts_device, model_path=tts_model_path, unload_timer=args.unload_timer)
|
||||
|
||||
ffmpeg_args = build_ffmpeg_args(response_format, input_format="WAV", sample_rate="24000")
|
||||
ffmpeg_args = build_ffmpeg_args(response_format, input_format="f32le", sample_rate="24000")
|
||||
|
||||
# tts speed doesn't seem to work well
|
||||
if speed < 0.5:
|
||||
speed = speed / 0.5
|
||||
ffmpeg_args.extend(["-af", "atempo=0.5"])
|
||||
if speed > 1.0:
|
||||
ffmpeg_args.extend(["-af", f"atempo={speed}"])
|
||||
speed = 1.0
|
||||
# tts speed doesn't seem to work well
|
||||
speed = voice_map.pop('speed', speed)
|
||||
if speed < 0.5:
|
||||
speed = speed / 0.5
|
||||
ffmpeg_args.extend(["-af", "atempo=0.5"])
|
||||
if speed > 1.0:
|
||||
ffmpeg_args.extend(["-af", f"atempo={speed}"])
|
||||
speed = 1.0
|
||||
|
||||
tts_io_out = xtts.tts(text=input_text, speaker_wav=speaker, speed=speed, language=language)
|
||||
language = voice_map.pop('language', 'en')
|
||||
|
||||
comment = voice_map.pop('comment', None) # ignored.
|
||||
|
||||
hf_generate_kwargs = dict(
|
||||
speed=speed,
|
||||
**voice_map,
|
||||
)
|
||||
|
||||
hf_generate_kwargs['enable_text_splitting'] = hf_generate_kwargs.get('enable_text_splitting', True) # change the default to true
|
||||
|
||||
# Pipe the output from piper/xtts to the input of ffmpeg
|
||||
ffmpeg_args.extend(["-"])
|
||||
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
|
||||
def generator():
|
||||
try:
|
||||
for chunk in xtts.tts(text=input_text, language=language, speaker_wav=speaker, **hf_generate_kwargs):
|
||||
ffmpeg_proc.stdin.write(chunk)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {repr(e)}")
|
||||
raise e
|
||||
|
||||
finally:
|
||||
ffmpeg_proc.stdin.close()
|
||||
|
||||
worker = threading.Thread(target=generator)
|
||||
worker.daemon = True
|
||||
worker.start()
|
||||
|
||||
return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type)
|
||||
else:
|
||||
raise BadRequestError("No such model, must be tts-1 or tts-1-hd.", param='model')
|
||||
|
||||
# Pipe the output from piper/xtts to the input of ffmpeg
|
||||
ffmpeg_args.extend(["-"])
|
||||
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_io_out, stdout=subprocess.PIPE)
|
||||
|
||||
return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type)
|
||||
|
||||
# We return 'mps' but currently XTTS will not work with mps devices as the cuda support is incomplete
|
||||
def auto_torch_device():
|
||||
@ -220,6 +313,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument('--xtts_device', action='store', default=auto_torch_device(), help="Set the device for the xtts model. The special value of 'none' will use piper for all models.")
|
||||
parser.add_argument('--preload', action='store', default=None, help="Preload a model (Ex. 'xtts' or 'xtts_v2.0.2'). By default it's loaded on first use.")
|
||||
parser.add_argument('--unload-timer', action='store', default=None, type=int, help="Idle unload timer for the XTTS model in seconds")
|
||||
parser.add_argument('-P', '--port', action='store', default=8000, type=int, help="Server tcp port")
|
||||
parser.add_argument('-H', '--host', action='store', default='0.0.0.0', help="Host to listen on, Ex. 0.0.0.0")
|
||||
parser.add_argument('-L', '--log-level', default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the log level")
|
||||
@ -233,10 +327,13 @@ if __name__ == "__main__":
|
||||
logger.add(sink=sys.stderr, level=args.log_level)
|
||||
|
||||
if args.xtts_device != "none":
|
||||
from TTS.api import TTS
|
||||
import torch
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
if args.preload:
|
||||
xtts = xtts_wrapper(args.preload, device=args.xtts_device)
|
||||
xtts = xtts_wrapper(args.preload, device=args.xtts_device, unload_timer=args.unload_timer)
|
||||
|
||||
app.register_model('tts-1')
|
||||
app.register_model('tts-1-hd')
|
||||
|
@ -48,3 +48,11 @@ tts-1-hd:
|
||||
me:
|
||||
model: xtts_v2.0.2 # you can specify different xtts version
|
||||
speaker: voices/me.wav # this could be you
|
||||
enable_text_splitting: True
|
||||
length_penalty: 1.0
|
||||
repetition_penalty: 10
|
||||
speed: 1.0
|
||||
temperature: 0.75
|
||||
top_k: 50
|
||||
top_p: 0.85
|
||||
comment: You can add a comment here also, which will be persistent and otherwise ignored.
|
Loading…
Reference in New Issue
Block a user