diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index f30ff81cf..85781ac9f 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -74,6 +74,10 @@ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON + +app.state.speech_synthesiser = None +app.state.speech_speaker_embeddings_dataset = None + app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT @@ -231,6 +235,21 @@ async def update_audio_config( } +def load_speech_pipeline(): + from transformers import pipeline + from datasets import load_dataset + + if app.state.speech_synthesiser is None: + app.state.speech_synthesiser = pipeline( + "text-to-speech", "microsoft/speecht5_tts" + ) + + if app.state.speech_speaker_embeddings_dataset is None: + app.state.speech_speaker_embeddings_dataset = load_dataset( + "Matthijs/cmu-arctic-xvectors", split="validation" + ) + + @app.post("/speech") async def speech(request: Request, user=Depends(get_verified_user)): body = await request.body() @@ -397,6 +416,43 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException( status_code=500, detail=f"Error synthesizing speech - {response.reason}" ) + elif app.state.config.TTS_ENGINE == "transformers": + payload = None + try: + payload = json.loads(body.decode("utf-8")) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + import torch + import soundfile as sf + + load_speech_pipeline() + + embeddings_dataset = app.state.speech_speaker_embeddings_dataset + + speaker_index = 6799 + try: + speaker_index = embeddings_dataset["filename"].index( + app.state.config.TTS_MODEL + ) + except Exception: + pass + + speaker_embedding = torch.tensor( + embeddings_dataset[speaker_index]["xvector"] + ).unsqueeze(0) + + speech = app.state.speech_synthesiser( + payload["input"], + forward_params={"speaker_embeddings": speaker_embedding}, + ) + + sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"]) + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) + + return FileResponse(file_path) def transcribe(file_path): diff --git a/package-lock.json b/package-lock.json index 6cf21ae1a..3a046263b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13,6 +13,7 @@ "@codemirror/language-data": "^6.5.1", "@codemirror/theme-one-dark": "^6.1.2", "@huggingface/transformers": "^3.0.0", + "@mediapipe/tasks-vision": "^0.10.17", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", "@xyflow/svelte": "^0.1.19", @@ -1749,6 +1750,11 @@ "@lezer/lr": "^1.4.0" } }, + "node_modules/@mediapipe/tasks-vision": { + "version": "0.10.17", + "resolved": "https://registry.npmjs.org/@mediapipe/tasks-vision/-/tasks-vision-0.10.17.tgz", + "integrity": "sha512-CZWV/q6TTe8ta61cZXjfnnHsfWIdFhms03M9T7Cnd5y2mdpylJM0rF1qRq+wsQVRMLz1OYPVEBU9ph2Bx8cxrg==" + }, "node_modules/@melt-ui/svelte": { "version": "0.76.0", "resolved": "https://registry.npmjs.org/@melt-ui/svelte/-/svelte-0.76.0.tgz", diff --git a/package.json b/package.json index 319402307..ff6a9a40b 100644 --- a/package.json +++ b/package.json @@ -53,6 +53,7 @@ "@codemirror/language-data": "^6.5.1", "@codemirror/theme-one-dark": "^6.1.2", "@huggingface/transformers": "^3.0.0", + "@mediapipe/tasks-vision": "^0.10.17", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^2.0.0", "@xyflow/svelte": "^0.1.19", diff --git a/src/lib/components/admin/Settings/Audio.svelte b/src/lib/components/admin/Settings/Audio.svelte index ae827e6ab..00c8417fa 100644 --- a/src/lib/components/admin/Settings/Audio.svelte +++ b/src/lib/components/admin/Settings/Audio.svelte @@ -322,6 +322,7 @@ }} > + @@ -396,6 +397,47 @@ + {:else if TTS_ENGINE === 'transformers'} +