From 70de5cf7b8997e63e64a0372f4c7e36888912af5 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 19 Dec 2024 16:18:54 -0800 Subject: [PATCH] fix: audio --- backend/open_webui/routers/audio.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index a26355945..46ee8beab 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -218,7 +218,7 @@ async def update_audio_config( } -def load_speech_pipeline(): +def load_speech_pipeline(request): from transformers import pipeline from datasets import load_dataset @@ -236,7 +236,11 @@ def load_speech_pipeline(): @router.post("/speech") async def speech(request: Request, user=Depends(get_verified_user)): body = await request.body() - name = hashlib.sha256(body).hexdigest() + name = hashlib.sha256( + body + + str(request.app.state.config.TTS_ENGINE).encode("utf-8") + + str(request.app.state.config.TTS_MODEL).encode("utf-8") + ).hexdigest() file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") @@ -256,10 +260,11 @@ async def speech(request: Request, user=Depends(get_verified_user)): payload["model"] = request.app.state.config.TTS_MODEL try: + # print(payload) async with aiohttp.ClientSession() as session: async with session.post( url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", - data=payload, + json=payload, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}", @@ -281,7 +286,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): await f.write(await r.read()) async with aiofiles.open(file_body_path, "w") as f: - await f.write(json.dumps(json.loads(body.decode("utf-8")))) + await f.write(json.dumps(payload)) return FileResponse(file_path) @@ -292,6 +297,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): try: if r.status != 200: res = await r.json() + if "error" in res: detail = f"External: {res['error'].get('message', '')}" except Exception: @@ -332,7 +338,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): await f.write(await r.read()) async with aiofiles.open(file_body_path, "w") as f: - await f.write(json.dumps(json.loads(body.decode("utf-8")))) + await f.write(json.dumps(payload)) return FileResponse(file_path) @@ -384,6 +390,9 @@ async def speech(request: Request, user=Depends(get_verified_user)): async with aiofiles.open(file_path, "wb") as f: await f.write(await r.read()) + async with aiofiles.open(file_body_path, "w") as f: + await f.write(json.dumps(payload)) + return FileResponse(file_path) except Exception as e: @@ -414,7 +423,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): import torch import soundfile as sf - load_speech_pipeline() + load_speech_pipeline(request) embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset @@ -436,8 +445,9 @@ async def speech(request: Request, user=Depends(get_verified_user)): ) sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"]) - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) + + async with aiofiles.open(file_body_path, "w") as f: + await f.write(json.dumps(payload)) return FileResponse(file_path)