From 73e64fe7fb29e020159ae034de0811fec539bfec Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 19 May 2025 02:52:48 +0400 Subject: [PATCH] refac: audio upload handling --- backend/open_webui/routers/audio.py | 74 +++++++++++++++-------------- backend/open_webui/routers/files.py | 13 ++--- 2 files changed, 41 insertions(+), 46 deletions(-) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 484bc138a..a0f5af4fc 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -73,33 +73,50 @@ from pydub import AudioSegment from pydub.utils import mediainfo -def get_audio_convert_format(file_path): - """Check if the given file needs to be converted to a different format.""" +def is_audio_conversion_required(file_path): + """ + Check if the given audio file needs conversion to mp3. + """ + SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"} + if not os.path.isfile(file_path): log.error(f"File not found: {file_path}") return False try: info = mediainfo(file_path) + codec_name = info.get("codec_name", "").lower() + codec_type = info.get("codec_type", "").lower() + codec_tag_string = info.get("codec_tag_string", "").lower() + if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a": + # File is AAC/mp4a audio, recommend mp3 conversion + return True + + # If the codec name or file extension is in the supported formats if ( - info.get("codec_name") == "aac" - and info.get("codec_type") == "audio" - and info.get("codec_tag_string") == "mp4a" + codec_name in SUPPORTED_FORMATS + or os.path.splitext(file_path)[1][1:].lower() in SUPPORTED_FORMATS ): - return "mp4" + return False # Already supported + + return True except Exception as e: log.error(f"Error getting audio format: {e}") return False - return None - -def convert_audio_to_wav(file_path, output_path, conversion_type): - """Convert MP4/OGG audio file to WAV format.""" - audio = AudioSegment.from_file(file_path, format=conversion_type) - audio.export(output_path, format="wav") - log.info(f"Converted {file_path} to {output_path}") +def convert_audio_to_mp3(file_path): + """Convert audio file to mp3 format.""" + try: + output_path = os.path.splitext(file_path)[0] + ".mp3" + audio = AudioSegment.from_file(file_path) + audio.export(output_path, format="mp3") + log.info(f"Converted {file_path} to {output_path}") + return output_path + except Exception as e: + log.error(f"Error converting audio file: {e}") + return None def set_faster_whisper_model(model: str, auto_update: bool = False): @@ -544,19 +561,6 @@ def transcription_handler(request, file_path): log.debug(data) return data elif request.app.state.config.STT_ENGINE == "openai": - convert_format = get_audio_convert_format(file_path) - - if convert_format: - ext = convert_format.split(".")[-1] - - os.rename(file_path, file_path.replace(".{ext}", f".{convert_format}")) - # Convert unsupported audio file to WAV format - convert_audio_to_wav( - file_path.replace(".{ext}", f".{convert_format}"), - file_path, - convert_format, - ) - r = None try: r = requests.post( @@ -776,6 +780,9 @@ def transcription_handler(request, file_path): def transcribe(request: Request, file_path): log.info(f"transcribe: {file_path}") + if is_audio_conversion_required(file_path): + file_path = convert_audio_to_mp3(file_path) + try: file_path = compress_audio(file_path) except Exception as e: @@ -894,16 +901,11 @@ def transcription( ): log.info(f"file.content_type: {file.content_type}") - supported_filetypes = ( - "audio/mpeg", - "audio/wav", - "audio/ogg", - "audio/x-m4a", - "audio/webm", - "video/webm", - ) - - if not file.content_type.startswith(supported_filetypes): + SUPPORTED_CONTENT_TYPES = {"video/webm"} # Extend if you add more video types! + if not ( + file.content_type.startswith("audio/") + or file.content_type in SUPPORTED_CONTENT_TYPES + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 4ba57f56b..ad556d327 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -140,16 +140,9 @@ def upload_file( if process: try: if file.content_type: - if file.content_type.startswith( - ( - "audio/mpeg", - "audio/wav", - "audio/ogg", - "audio/x-m4a", - "audio/webm", - "video/webm", - ) - ): + if file.content_type.startswith("audio/") or file.content_type in { + "video/webm" + }: file_path = Storage.get_file(file_path) result = transcribe(request, file_path)