mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
feat: native speecht5 support
This commit is contained in:
@@ -74,6 +74,10 @@ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
|
||||
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
|
||||
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
|
||||
|
||||
|
||||
app.state.speech_synthesiser = None
|
||||
app.state.speech_speaker_embeddings_dataset = None
|
||||
|
||||
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
|
||||
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
||||
|
||||
@@ -231,6 +235,21 @@ async def update_audio_config(
|
||||
}
|
||||
|
||||
|
||||
def load_speech_pipeline():
|
||||
from transformers import pipeline
|
||||
from datasets import load_dataset
|
||||
|
||||
if app.state.speech_synthesiser is None:
|
||||
app.state.speech_synthesiser = pipeline(
|
||||
"text-to-speech", "microsoft/speecht5_tts"
|
||||
)
|
||||
|
||||
if app.state.speech_speaker_embeddings_dataset is None:
|
||||
app.state.speech_speaker_embeddings_dataset = load_dataset(
|
||||
"Matthijs/cmu-arctic-xvectors", split="validation"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/speech")
|
||||
async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
body = await request.body()
|
||||
@@ -397,6 +416,43 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error synthesizing speech - {response.reason}"
|
||||
)
|
||||
elif app.state.config.TTS_ENGINE == "transformers":
|
||||
payload = None
|
||||
try:
|
||||
payload = json.loads(body.decode("utf-8"))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
|
||||
load_speech_pipeline()
|
||||
|
||||
embeddings_dataset = app.state.speech_speaker_embeddings_dataset
|
||||
|
||||
speaker_index = 6799
|
||||
try:
|
||||
speaker_index = embeddings_dataset["filename"].index(
|
||||
app.state.config.TTS_MODEL
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
speaker_embedding = torch.tensor(
|
||||
embeddings_dataset[speaker_index]["xvector"]
|
||||
).unsqueeze(0)
|
||||
|
||||
speech = app.state.speech_synthesiser(
|
||||
payload["input"],
|
||||
forward_params={"speaker_embeddings": speaker_embedding},
|
||||
)
|
||||
|
||||
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(json.loads(body.decode("utf-8")), f)
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def transcribe(file_path):
|
||||
|
||||
Reference in New Issue
Block a user