mirror of
https://github.com/matatonic/openedai-speech
synced 2025-06-26 18:16:32 +00:00
0.15.2 Thread safe version 15
This commit is contained in:
parent
964b23a21c
commit
703dec32b1
@ -29,6 +29,10 @@ If you find a better voice match for `tts-1` or `tts-1-hd`, please let me know s
|
||||
|
||||
## Recent Changes
|
||||
|
||||
Version 0.15.2, 2024-06-28
|
||||
|
||||
* Thread safe version, locked generation at the sentence level
|
||||
|
||||
Version 0.15.1, 2024-06-27
|
||||
|
||||
* Remove deepspeed from requirements.txt, it's too complex for typical users. A more detailed deepspeed install document will be required.
|
||||
|
||||
193
speech.py
193
speech.py
@ -1,20 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import os
|
||||
import asyncio
|
||||
import contextlib
|
||||
import gc
|
||||
import io
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import yaml
|
||||
import contextlib
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
from openedai import OpenAIStub, BadRequestError, ServiceUnavailableError
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from openedai import OpenAIStub, BadRequestError, ServiceUnavailableError
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
@ -46,7 +49,7 @@ def unload_model():
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
class xtts_wrapper():
|
||||
check_interval: int = 1
|
||||
check_interval: int = 1 # too aggressive?
|
||||
|
||||
def __init__(self, model_name, device, model_path=None, unload_timer=None):
|
||||
self.model_name = model_name
|
||||
@ -89,18 +92,21 @@ class xtts_wrapper():
|
||||
self.timer.start()
|
||||
|
||||
def tts(self, text, language, speaker_wav, **hf_generate_kwargs):
|
||||
self.not_idle()
|
||||
try:
|
||||
with torch.no_grad():
|
||||
with self.lock: # this doesn't seem threadsafe, but it's quick enough
|
||||
gpt_cond_latent, speaker_embedding = self.xtts.get_conditioning_latents(audio_path=[speaker_wav]) # XXX TODO: allow multiple wav
|
||||
logger.debug(f"waiting lock")
|
||||
with self.lock, torch.no_grad(): # I wish this could be another way, but it seems that inference_stream cannot be access async reliably
|
||||
logger.debug(f"grabbed lock, tts text: {text}")
|
||||
self.last_used = time.time()
|
||||
try:
|
||||
gpt_cond_latent, speaker_embedding = self.xtts.get_conditioning_latents(audio_path=[speaker_wav]) # XXX TODO: allow multiple wav
|
||||
|
||||
for wav in self.xtts.inference_stream(text, language, gpt_cond_latent, speaker_embedding, **hf_generate_kwargs):
|
||||
yield wav.cpu().numpy().tobytes() # assumes wav data is f32le
|
||||
self.not_idle()
|
||||
yield wav.cpu().numpy().tobytes()
|
||||
|
||||
finally:
|
||||
logger.debug(f"held lock for {time.time() - self.last_used:0.1f} sec")
|
||||
self.last_used = time.time()
|
||||
|
||||
|
||||
finally:
|
||||
self.not_idle()
|
||||
|
||||
def default_exists(filename: str):
|
||||
if not os.path.exists(filename):
|
||||
@ -116,7 +122,7 @@ def default_exists(filename: str):
|
||||
|
||||
# Read pre process map on demand so it can be changed without restarting the server
|
||||
def preprocess(raw_input):
|
||||
logger.debug(f"preprocess: before: {[raw_input]}")
|
||||
#logger.debug(f"preprocess: before: {[raw_input]}")
|
||||
default_exists('config/pre_process_map.yaml')
|
||||
with open('config/pre_process_map.yaml', 'r', encoding='utf8') as file:
|
||||
pre_process_map = yaml.safe_load(file)
|
||||
@ -124,7 +130,7 @@ def preprocess(raw_input):
|
||||
raw_input = re.sub(a, b, raw_input)
|
||||
|
||||
raw_input = raw_input.strip()
|
||||
logger.debug(f"preprocess: after: {[raw_input]}")
|
||||
#logger.debug(f"preprocess: after: {[raw_input]}")
|
||||
return raw_input
|
||||
|
||||
# Read voice map on demand so it can be changed without restarting the server
|
||||
@ -200,7 +206,7 @@ async def generate_speech(request: GenerateSpeechRequest):
|
||||
elif model == 'tts-1-hd':
|
||||
media_type = "audio/pcm;rate=24000"
|
||||
else:
|
||||
BadRequestError(f"Invalid response_format: '{response_format}'", param='response_format')
|
||||
raise BadRequestError(f"Invalid response_format: '{response_format}'", param='response_format')
|
||||
|
||||
ffmpeg_args = None
|
||||
tts_io_out = None
|
||||
@ -225,12 +231,12 @@ async def generate_speech(request: GenerateSpeechRequest):
|
||||
tts_proc = subprocess.Popen(tts_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
tts_proc.stdin.write(bytearray(input_text.encode('utf-8')))
|
||||
tts_proc.stdin.close()
|
||||
tts_io_out = tts_proc.stdout
|
||||
|
||||
ffmpeg_args = build_ffmpeg_args(response_format, input_format="s16le", sample_rate="22050")
|
||||
|
||||
# Pipe the output from piper/xtts to the input of ffmpeg
|
||||
ffmpeg_args.extend(["-"])
|
||||
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_io_out, stdout=subprocess.PIPE)
|
||||
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_proc.stdout, stdout=subprocess.PIPE)
|
||||
|
||||
return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type)
|
||||
# Use xtts for tts-1-hd
|
||||
@ -262,6 +268,9 @@ async def generate_speech(request: GenerateSpeechRequest):
|
||||
ffmpeg_args.extend(["-af", f"atempo={speed}"])
|
||||
speed = 1.0
|
||||
|
||||
# Pipe the output from piper/xtts to the input of ffmpeg
|
||||
ffmpeg_args.extend(["-"])
|
||||
|
||||
language = voice_map.pop('language', 'en')
|
||||
|
||||
comment = voice_map.pop('comment', None) # ignored.
|
||||
@ -273,27 +282,150 @@ async def generate_speech(request: GenerateSpeechRequest):
|
||||
|
||||
hf_generate_kwargs['enable_text_splitting'] = hf_generate_kwargs.get('enable_text_splitting', True) # change the default to true
|
||||
|
||||
# Pipe the output from piper/xtts to the input of ffmpeg
|
||||
ffmpeg_args.extend(["-"])
|
||||
if hf_generate_kwargs['enable_text_splitting']:
|
||||
all_text = split_sentence(input_text, language, xtts.xtts.tokenizer.char_limits[language])
|
||||
else:
|
||||
all_text = [input_text]
|
||||
|
||||
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
|
||||
def generator():
|
||||
# before the xtts lock, it was:
|
||||
#def generator():
|
||||
# for chunk in xtts.tts(text=input_text, language=language, speaker_wav=speaker, **hf_generate_kwargs):
|
||||
# ffmpeg_proc.stdin.write(chunk) # <-- but this blocks forever and holds the xtts lock if a client disconnects
|
||||
#worker = threading.Thread(target=generator)
|
||||
#worker.daemon = True
|
||||
#worker.start()
|
||||
#return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type)
|
||||
#
|
||||
# What follows is stupidly overcomplicated, but there is no other way I can find (yet) that detects client disconnects and not get blocked up
|
||||
|
||||
os.set_blocking(ffmpeg_proc.stdout.fileno(), False) # this doesn't work on windows until python 3.12
|
||||
os.set_blocking(ffmpeg_proc.stdin.fileno(), False) # this doesn't work on windows until python 3.12
|
||||
ffmpeg_in = io.FileIO(ffmpeg_proc.stdin.fileno(), 'wb')
|
||||
|
||||
in_q = queue.Queue() # speech pcm
|
||||
out_q = queue.Queue() # ffmpeg audio out
|
||||
ex_q = queue.Queue() # exceptions
|
||||
|
||||
def ffmpeg_io():
|
||||
# in_q -> ffmopeg -> out_q
|
||||
while not (ffmpeg_proc.stdout.closed and ffmpeg_proc.stdin.closed):
|
||||
try:
|
||||
while not ffmpeg_proc.stdout.closed:
|
||||
chunk = ffmpeg_proc.stdout.read()
|
||||
if chunk is None:
|
||||
break
|
||||
|
||||
if len(chunk) == 0: # real end
|
||||
out_q.put(None)
|
||||
ffmpeg_proc.stdout.close()
|
||||
break
|
||||
|
||||
out_q.put(chunk)
|
||||
continue # consume audio without delay
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"ffmpeg stdout read: {repr(e)}")
|
||||
out_q.put(None)
|
||||
ex_q.put(e)
|
||||
return
|
||||
|
||||
try:
|
||||
while not ffmpeg_proc.stdin.closed:
|
||||
chunk = in_q.get_nowait()
|
||||
if chunk is None:
|
||||
ffmpeg_proc.stdin.close()
|
||||
break
|
||||
n = ffmpeg_in.write(chunk) # BrokenPipeError from here on client disconnect
|
||||
if n is None:
|
||||
in_q.queue.appendleft(chunk)
|
||||
break
|
||||
if n != len(chunk):
|
||||
in_q.queue.appendleft(chunk[n:])
|
||||
break
|
||||
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
except BrokenPipeError as e:
|
||||
ex_q.put(e) # we need to get this exception into the generation loop, which holds the lock
|
||||
ffmpeg_proc.kill()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
ex_q.put(e)
|
||||
ffmpeg_proc.kill()
|
||||
return
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def exception_check(exq: queue.Queue):
|
||||
try:
|
||||
for chunk in xtts.tts(text=input_text, language=language, speaker_wav=speaker, **hf_generate_kwargs):
|
||||
ffmpeg_proc.stdin.write(chunk)
|
||||
e = exq.get_nowait()
|
||||
except queue.Empty:
|
||||
return
|
||||
|
||||
raise e
|
||||
|
||||
def generator():
|
||||
# text -> in_q
|
||||
try:
|
||||
for text in all_text:
|
||||
for chunk in xtts.tts(text=text, language=language, speaker_wav=speaker, **hf_generate_kwargs):
|
||||
exception_check(ex_q)
|
||||
in_q.put(chunk)
|
||||
|
||||
in_q.put(None)
|
||||
|
||||
except BrokenPipeError as e: # client disconnect lands here
|
||||
#logger.debug(f"{repr(e)}")
|
||||
logger.info("Client disconnected")
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
logger.debug(f"{repr(e)}")
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {repr(e)}")
|
||||
raise e
|
||||
|
||||
finally:
|
||||
ffmpeg_proc.stdin.close()
|
||||
|
||||
worker = threading.Thread(target=generator)
|
||||
worker.daemon = True
|
||||
worker = threading.Thread(target=generator, daemon = True)
|
||||
worker.start()
|
||||
|
||||
return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type)
|
||||
worker2 = threading.Thread(target=ffmpeg_io, daemon = True)
|
||||
worker2.start()
|
||||
|
||||
async def audio_out():
|
||||
# out_q -> client
|
||||
while True:
|
||||
try:
|
||||
audio = out_q.get_nowait()
|
||||
if audio is None:
|
||||
return
|
||||
yield audio
|
||||
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
logger.debug("{repr(e)}")
|
||||
ex_q.put(e)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("{repr(e)}")
|
||||
ex_q.put(e)
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
def cleanup():
|
||||
ffmpeg_proc.kill()
|
||||
del worker
|
||||
del worker2
|
||||
|
||||
return StreamingResponse(audio_out(), media_type=media_type, background=cleanup)
|
||||
else:
|
||||
raise BadRequestError("No such model, must be tts-1 or tts-1-hd.", param='model')
|
||||
|
||||
@ -314,7 +446,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument('--xtts_device', action='store', default=auto_torch_device(), help="Set the device for the xtts model. The special value of 'none' will use piper for all models.")
|
||||
parser.add_argument('--preload', action='store', default=None, help="Preload a model (Ex. 'xtts' or 'xtts_v2.0.2'). By default it's loaded on first use.")
|
||||
parser.add_argument('--unload-timer', action='store', default=None, type=int, help="Idle unload timer for the XTTS model in seconds")
|
||||
parser.add_argument('--unload-timer', action='store', default=None, type=int, help="Idle unload timer for the XTTS model in seconds, Ex. 900 for 15 minutes")
|
||||
parser.add_argument('--use-deepspeed', action='store_true', default=False, help="Use deepspeed with xtts (this option is unsupported)")
|
||||
parser.add_argument('-P', '--port', action='store', default=8000, type=int, help="Server tcp port")
|
||||
parser.add_argument('-H', '--host', action='store', default='0.0.0.0', help="Host to listen on, Ex. 0.0.0.0")
|
||||
@ -333,6 +465,7 @@ if __name__ == "__main__":
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.tts.layers.xtts.tokenizer import split_sentence
|
||||
|
||||
if args.preload:
|
||||
xtts = xtts_wrapper(args.preload, device=args.xtts_device, unload_timer=args.unload_timer)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user