diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 9f053351d..a810cc586 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -71,21 +71,26 @@ from pydub import AudioSegment from pydub.utils import mediainfo -def get_audio_format(file_path): +def get_audio_convert_format(file_path): """Check if the given file needs to be converted to a different format.""" if not os.path.isfile(file_path): log.error(f"File not found: {file_path}") return False - info = mediainfo(file_path) - if ( - info.get("codec_name") == "aac" - and info.get("codec_type") == "audio" - and info.get("codec_tag_string") == "mp4a" - ): - return "mp4" - elif info.get("format_name") == "ogg": - return "ogg" + try: + info = mediainfo(file_path) + + if ( + info.get("codec_name") == "aac" + and info.get("codec_type") == "audio" + and info.get("codec_tag_string") == "mp4a" + ): + return "mp4" + elif info.get("format_name") == "ogg": + return "ogg" + except Exception as e: + log.error(f"Error getting audio format: {e}") + return False return None @@ -537,14 +542,18 @@ def transcribe(request: Request, file_path): log.debug(data) return data elif request.app.state.config.STT_ENGINE == "openai": - audio_format = get_audio_format(file_path) - if audio_format: - os.rename(file_path, file_path.replace(".wav", f".{audio_format}")) + convert_format = get_audio_convert_format(file_path) + + print(f"convert_format: {convert_format}") + if convert_format: + ext = convert_format.split(".")[-1] + + os.rename(file_path, file_path.replace(".{ext}", f".{convert_format}")) # Convert unsupported audio file to WAV format convert_audio_to_wav( - file_path.replace(".wav", f".{audio_format}"), + file_path.replace(".{ext}", f".{convert_format}"), file_path, - audio_format, + convert_format, ) r = None diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index d963cd632..475905da1 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -133,6 +133,7 @@ def upload_file( "audio/ogg", "audio/x-m4a", "audio/webm", + "video/webm", ) ): file_path = Storage.get_file(file_path) @@ -150,7 +151,6 @@ def upload_file( "video/mp4", "video/ogg", "video/quicktime", - "video/webm", ]: process_file(request, ProcessFileForm(file_id=id), user=user)