0.15.2 Thread safe version 15

This commit is contained in:
matatonic 2024-06-28 16:09:48 -04:00
parent 964b23a21c
commit 703dec32b1
2 changed files with 167 additions and 30 deletions

View File

@ -29,6 +29,10 @@ If you find a better voice match for `tts-1` or `tts-1-hd`, please let me know s
## Recent Changes
Version 0.15.2, 2024-06-28
* Thread safe version, locked generation at the sentence level
Version 0.15.1, 2024-06-27
* Remove deepspeed from requirements.txt, it's too complex for typical users. A more detailed deepspeed install document will be required.

193
speech.py
View File

@ -1,20 +1,23 @@
#!/usr/bin/env python3
import argparse
import os
import asyncio
import contextlib
import gc
import io
import os
import queue
import re
import subprocess
import sys
import threading
import time
import yaml
import contextlib
from fastapi.responses import StreamingResponse
from loguru import logger
from openedai import OpenAIStub, BadRequestError, ServiceUnavailableError
from pydantic import BaseModel
import uvicorn
from openedai import OpenAIStub, BadRequestError, ServiceUnavailableError
@contextlib.asynccontextmanager
@ -46,7 +49,7 @@ def unload_model():
torch.cuda.ipc_collect()
class xtts_wrapper():
check_interval: int = 1
check_interval: int = 1 # too aggressive?
def __init__(self, model_name, device, model_path=None, unload_timer=None):
self.model_name = model_name
@ -89,18 +92,21 @@ class xtts_wrapper():
self.timer.start()
def tts(self, text, language, speaker_wav, **hf_generate_kwargs):
self.not_idle()
try:
with torch.no_grad():
with self.lock: # this doesn't seem threadsafe, but it's quick enough
gpt_cond_latent, speaker_embedding = self.xtts.get_conditioning_latents(audio_path=[speaker_wav]) # XXX TODO: allow multiple wav
logger.debug(f"waiting lock")
with self.lock, torch.no_grad(): # I wish this could be another way, but it seems that inference_stream cannot be access async reliably
logger.debug(f"grabbed lock, tts text: {text}")
self.last_used = time.time()
try:
gpt_cond_latent, speaker_embedding = self.xtts.get_conditioning_latents(audio_path=[speaker_wav]) # XXX TODO: allow multiple wav
for wav in self.xtts.inference_stream(text, language, gpt_cond_latent, speaker_embedding, **hf_generate_kwargs):
yield wav.cpu().numpy().tobytes() # assumes wav data is f32le
self.not_idle()
yield wav.cpu().numpy().tobytes()
finally:
logger.debug(f"held lock for {time.time() - self.last_used:0.1f} sec")
self.last_used = time.time()
finally:
self.not_idle()
def default_exists(filename: str):
if not os.path.exists(filename):
@ -116,7 +122,7 @@ def default_exists(filename: str):
# Read pre process map on demand so it can be changed without restarting the server
def preprocess(raw_input):
logger.debug(f"preprocess: before: {[raw_input]}")
#logger.debug(f"preprocess: before: {[raw_input]}")
default_exists('config/pre_process_map.yaml')
with open('config/pre_process_map.yaml', 'r', encoding='utf8') as file:
pre_process_map = yaml.safe_load(file)
@ -124,7 +130,7 @@ def preprocess(raw_input):
raw_input = re.sub(a, b, raw_input)
raw_input = raw_input.strip()
logger.debug(f"preprocess: after: {[raw_input]}")
#logger.debug(f"preprocess: after: {[raw_input]}")
return raw_input
# Read voice map on demand so it can be changed without restarting the server
@ -200,7 +206,7 @@ async def generate_speech(request: GenerateSpeechRequest):
elif model == 'tts-1-hd':
media_type = "audio/pcm;rate=24000"
else:
BadRequestError(f"Invalid response_format: '{response_format}'", param='response_format')
raise BadRequestError(f"Invalid response_format: '{response_format}'", param='response_format')
ffmpeg_args = None
tts_io_out = None
@ -225,12 +231,12 @@ async def generate_speech(request: GenerateSpeechRequest):
tts_proc = subprocess.Popen(tts_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
tts_proc.stdin.write(bytearray(input_text.encode('utf-8')))
tts_proc.stdin.close()
tts_io_out = tts_proc.stdout
ffmpeg_args = build_ffmpeg_args(response_format, input_format="s16le", sample_rate="22050")
# Pipe the output from piper/xtts to the input of ffmpeg
ffmpeg_args.extend(["-"])
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_io_out, stdout=subprocess.PIPE)
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_proc.stdout, stdout=subprocess.PIPE)
return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type)
# Use xtts for tts-1-hd
@ -262,6 +268,9 @@ async def generate_speech(request: GenerateSpeechRequest):
ffmpeg_args.extend(["-af", f"atempo={speed}"])
speed = 1.0
# Pipe the output from piper/xtts to the input of ffmpeg
ffmpeg_args.extend(["-"])
language = voice_map.pop('language', 'en')
comment = voice_map.pop('comment', None) # ignored.
@ -273,27 +282,150 @@ async def generate_speech(request: GenerateSpeechRequest):
hf_generate_kwargs['enable_text_splitting'] = hf_generate_kwargs.get('enable_text_splitting', True) # change the default to true
# Pipe the output from piper/xtts to the input of ffmpeg
ffmpeg_args.extend(["-"])
if hf_generate_kwargs['enable_text_splitting']:
all_text = split_sentence(input_text, language, xtts.xtts.tokenizer.char_limits[language])
else:
all_text = [input_text]
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
def generator():
# before the xtts lock, it was:
#def generator():
# for chunk in xtts.tts(text=input_text, language=language, speaker_wav=speaker, **hf_generate_kwargs):
# ffmpeg_proc.stdin.write(chunk) # <-- but this blocks forever and holds the xtts lock if a client disconnects
#worker = threading.Thread(target=generator)
#worker.daemon = True
#worker.start()
#return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type)
#
# What follows is stupidly overcomplicated, but there is no other way I can find (yet) that detects client disconnects and not get blocked up
os.set_blocking(ffmpeg_proc.stdout.fileno(), False) # this doesn't work on windows until python 3.12
os.set_blocking(ffmpeg_proc.stdin.fileno(), False) # this doesn't work on windows until python 3.12
ffmpeg_in = io.FileIO(ffmpeg_proc.stdin.fileno(), 'wb')
in_q = queue.Queue() # speech pcm
out_q = queue.Queue() # ffmpeg audio out
ex_q = queue.Queue() # exceptions
def ffmpeg_io():
# in_q -> ffmopeg -> out_q
while not (ffmpeg_proc.stdout.closed and ffmpeg_proc.stdin.closed):
try:
while not ffmpeg_proc.stdout.closed:
chunk = ffmpeg_proc.stdout.read()
if chunk is None:
break
if len(chunk) == 0: # real end
out_q.put(None)
ffmpeg_proc.stdout.close()
break
out_q.put(chunk)
continue # consume audio without delay
except Exception as e:
logger.debug(f"ffmpeg stdout read: {repr(e)}")
out_q.put(None)
ex_q.put(e)
return
try:
while not ffmpeg_proc.stdin.closed:
chunk = in_q.get_nowait()
if chunk is None:
ffmpeg_proc.stdin.close()
break
n = ffmpeg_in.write(chunk) # BrokenPipeError from here on client disconnect
if n is None:
in_q.queue.appendleft(chunk)
break
if n != len(chunk):
in_q.queue.appendleft(chunk[n:])
break
except queue.Empty:
pass
except BrokenPipeError as e:
ex_q.put(e) # we need to get this exception into the generation loop, which holds the lock
ffmpeg_proc.kill()
return
except Exception as e:
ex_q.put(e)
ffmpeg_proc.kill()
return
time.sleep(0.01)
def exception_check(exq: queue.Queue):
try:
for chunk in xtts.tts(text=input_text, language=language, speaker_wav=speaker, **hf_generate_kwargs):
ffmpeg_proc.stdin.write(chunk)
e = exq.get_nowait()
except queue.Empty:
return
raise e
def generator():
# text -> in_q
try:
for text in all_text:
for chunk in xtts.tts(text=text, language=language, speaker_wav=speaker, **hf_generate_kwargs):
exception_check(ex_q)
in_q.put(chunk)
in_q.put(None)
except BrokenPipeError as e: # client disconnect lands here
#logger.debug(f"{repr(e)}")
logger.info("Client disconnected")
except asyncio.CancelledError as e:
logger.debug(f"{repr(e)}")
pass
except Exception as e:
logger.error(f"Exception: {repr(e)}")
raise e
finally:
ffmpeg_proc.stdin.close()
worker = threading.Thread(target=generator)
worker.daemon = True
worker = threading.Thread(target=generator, daemon = True)
worker.start()
return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type)
worker2 = threading.Thread(target=ffmpeg_io, daemon = True)
worker2.start()
async def audio_out():
# out_q -> client
while True:
try:
audio = out_q.get_nowait()
if audio is None:
return
yield audio
except queue.Empty:
pass
except asyncio.CancelledError as e:
logger.debug("{repr(e)}")
ex_q.put(e)
return
except Exception as e:
logger.debug("{repr(e)}")
ex_q.put(e)
return
await asyncio.sleep(0.01)
def cleanup():
ffmpeg_proc.kill()
del worker
del worker2
return StreamingResponse(audio_out(), media_type=media_type, background=cleanup)
else:
raise BadRequestError("No such model, must be tts-1 or tts-1-hd.", param='model')
@ -314,7 +446,7 @@ if __name__ == "__main__":
parser.add_argument('--xtts_device', action='store', default=auto_torch_device(), help="Set the device for the xtts model. The special value of 'none' will use piper for all models.")
parser.add_argument('--preload', action='store', default=None, help="Preload a model (Ex. 'xtts' or 'xtts_v2.0.2'). By default it's loaded on first use.")
parser.add_argument('--unload-timer', action='store', default=None, type=int, help="Idle unload timer for the XTTS model in seconds")
parser.add_argument('--unload-timer', action='store', default=None, type=int, help="Idle unload timer for the XTTS model in seconds, Ex. 900 for 15 minutes")
parser.add_argument('--use-deepspeed', action='store_true', default=False, help="Use deepspeed with xtts (this option is unsupported)")
parser.add_argument('-P', '--port', action='store', default=8000, type=int, help="Server tcp port")
parser.add_argument('-H', '--host', action='store', default='0.0.0.0', help="Host to listen on, Ex. 0.0.0.0")
@ -333,6 +465,7 @@ if __name__ == "__main__":
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.manage import ModelManager
from TTS.tts.layers.xtts.tokenizer import split_sentence
if args.preload:
xtts = xtts_wrapper(args.preload, device=args.xtts_device, unload_timer=args.unload_timer)