enh: very long audio transcription

This commit is contained in:
Timothy Jaeryang Baek 2025-05-17 02:51:28 +04:00
parent d54e588ec3
commit b280f828b0

View File

@ -7,6 +7,7 @@ from functools import lru_cache
from pathlib import Path
from pydub import AudioSegment
from pydub.silence import split_on_silence
from concurrent.futures import ThreadPoolExecutor
import aiohttp
import aiofiles
@ -50,7 +51,7 @@ from open_webui.env import (
router = APIRouter()
# Constants
MAX_FILE_SIZE_MB = 25
MAX_FILE_SIZE_MB = 20
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
AZURE_MAX_FILE_SIZE_MB = 200
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
@ -87,8 +88,6 @@ def get_audio_convert_format(file_path):
and info.get("codec_tag_string") == "mp4a"
):
return "mp4"
elif info.get("format_name") == "ogg":
return "ogg"
except Exception as e:
log.error(f"Error getting audio format: {e}")
return False
@ -511,8 +510,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
def transcribe(request: Request, file_path):
log.info(f"transcribe: {file_path}")
def transcription_handler(request, file_path):
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
@ -775,24 +773,119 @@ def transcribe(request: Request, file_path):
)
def transcribe(request: Request, file_path):
log.info(f"transcribe: {file_path}")
try:
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
# Always produce a list of chunk paths (could be one entry if small)
try:
chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
print(f"Chunk paths: {chunk_paths}")
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
results = []
try:
with ThreadPoolExecutor() as executor:
# Submit tasks for each chunk_path
futures = [
executor.submit(transcription_handler, request, chunk_path)
for chunk_path in chunk_paths
]
# Gather results as they complete
for future in futures:
try:
results.append(future.result())
except Exception as transcribe_exc:
log.exception(f"Error transcribing chunk: {transcribe_exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error during transcription.",
)
finally:
# Clean up only the temporary chunks, never the original file
for chunk_path in chunk_paths:
if chunk_path != file_path and os.path.isfile(chunk_path):
try:
os.remove(chunk_path)
except Exception:
pass
return {
"text": " ".join([result["text"] for result in results]),
}
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
id = os.path.splitext(os.path.basename(file_path))[
0
] # Handles names with multiple dots
file_dir = os.path.dirname(file_path)
audio = AudioSegment.from_file(file_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
if (
os.path.getsize(compressed_path) > MAX_FILE_SIZE
): # Still larger than MAX_FILE_SIZE after compression
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
audio.export(compressed_path, format="mp3", bitrate="32k")
# log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
return compressed_path
else:
return file_path
def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
"""
Splits audio into chunks not exceeding max_bytes.
Returns a list of chunk file paths. If audio fits, returns list with original path.
"""
file_size = os.path.getsize(file_path)
if file_size <= max_bytes:
return [file_path] # Nothing to split
audio = AudioSegment.from_file(file_path)
duration_ms = len(audio)
orig_size = file_size
approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
chunks = []
start = 0
i = 0
base, _ = os.path.splitext(file_path)
while start < duration_ms:
end = min(start + approx_chunk_ms, duration_ms)
chunk = audio[start:end]
chunk_path = f"{base}_chunk_{i}.{format}"
chunk.export(chunk_path, format=format, bitrate=bitrate)
# Reduce chunk duration if still too large
while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
end = start + ((end - start) // 2)
chunk = audio[start:end]
chunk.export(chunk_path, format=format, bitrate=bitrate)
if os.path.getsize(chunk_path) > max_bytes:
os.remove(chunk_path)
raise Exception("Audio chunk cannot be reduced below max file size.")
chunks.append(chunk_path)
start = end
i += 1
return chunks
@router.post("/transcriptions")
def transcription(
request: Request,
@ -807,6 +900,7 @@ def transcription(
"audio/ogg",
"audio/x-m4a",
"audio/webm",
"video/webm",
)
if not file.content_type.startswith(supported_filetypes):
@ -830,19 +924,13 @@ def transcription(
f.write(contents)
try:
try:
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
result = transcribe(request, file_path)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
return {
**result,
"filename": os.path.basename(file_path),
}
data = transcribe(request, file_path)
file_path = file_path.split("/")[-1]
return {**data, "filename": file_path}
except Exception as e:
log.exception(e)