mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	enh: very long audio transcription
This commit is contained in:
		
							parent
							
								
									d54e588ec3
								
							
						
					
					
						commit
						b280f828b0
					
				@ -7,6 +7,7 @@ from functools import lru_cache
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from pydub import AudioSegment
 | 
			
		||||
from pydub.silence import split_on_silence
 | 
			
		||||
from concurrent.futures import ThreadPoolExecutor
 | 
			
		||||
 | 
			
		||||
import aiohttp
 | 
			
		||||
import aiofiles
 | 
			
		||||
@ -50,7 +51,7 @@ from open_webui.env import (
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
# Constants
 | 
			
		||||
MAX_FILE_SIZE_MB = 25
 | 
			
		||||
MAX_FILE_SIZE_MB = 20
 | 
			
		||||
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
 | 
			
		||||
AZURE_MAX_FILE_SIZE_MB = 200
 | 
			
		||||
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
 | 
			
		||||
@ -87,8 +88,6 @@ def get_audio_convert_format(file_path):
 | 
			
		||||
            and info.get("codec_tag_string") == "mp4a"
 | 
			
		||||
        ):
 | 
			
		||||
            return "mp4"
 | 
			
		||||
        elif info.get("format_name") == "ogg":
 | 
			
		||||
            return "ogg"
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.error(f"Error getting audio format: {e}")
 | 
			
		||||
        return False
 | 
			
		||||
@ -511,8 +510,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
        return FileResponse(file_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def transcribe(request: Request, file_path):
 | 
			
		||||
    log.info(f"transcribe: {file_path}")
 | 
			
		||||
def transcription_handler(request, file_path):
 | 
			
		||||
    filename = os.path.basename(file_path)
 | 
			
		||||
    file_dir = os.path.dirname(file_path)
 | 
			
		||||
    id = filename.split(".")[0]
 | 
			
		||||
@ -775,24 +773,119 @@ def transcribe(request: Request, file_path):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def transcribe(request: Request, file_path):
 | 
			
		||||
    log.info(f"transcribe: {file_path}")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        file_path = compress_audio(file_path)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(e)
 | 
			
		||||
 | 
			
		||||
    # Always produce a list of chunk paths (could be one entry if small)
 | 
			
		||||
    try:
 | 
			
		||||
        chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
 | 
			
		||||
        print(f"Chunk paths: {chunk_paths}")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(e)
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    results = []
 | 
			
		||||
    try:
 | 
			
		||||
        with ThreadPoolExecutor() as executor:
 | 
			
		||||
            # Submit tasks for each chunk_path
 | 
			
		||||
            futures = [
 | 
			
		||||
                executor.submit(transcription_handler, request, chunk_path)
 | 
			
		||||
                for chunk_path in chunk_paths
 | 
			
		||||
            ]
 | 
			
		||||
            # Gather results as they complete
 | 
			
		||||
            for future in futures:
 | 
			
		||||
                try:
 | 
			
		||||
                    results.append(future.result())
 | 
			
		||||
                except Exception as transcribe_exc:
 | 
			
		||||
                    log.exception(f"Error transcribing chunk: {transcribe_exc}")
 | 
			
		||||
                    raise HTTPException(
 | 
			
		||||
                        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 | 
			
		||||
                        detail="Error during transcription.",
 | 
			
		||||
                    )
 | 
			
		||||
    finally:
 | 
			
		||||
        # Clean up only the temporary chunks, never the original file
 | 
			
		||||
        for chunk_path in chunk_paths:
 | 
			
		||||
            if chunk_path != file_path and os.path.isfile(chunk_path):
 | 
			
		||||
                try:
 | 
			
		||||
                    os.remove(chunk_path)
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    pass
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        "text": " ".join([result["text"] for result in results]),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compress_audio(file_path):
 | 
			
		||||
    if os.path.getsize(file_path) > MAX_FILE_SIZE:
 | 
			
		||||
        id = os.path.splitext(os.path.basename(file_path))[
 | 
			
		||||
            0
 | 
			
		||||
        ]  # Handles names with multiple dots
 | 
			
		||||
        file_dir = os.path.dirname(file_path)
 | 
			
		||||
 | 
			
		||||
        audio = AudioSegment.from_file(file_path)
 | 
			
		||||
        audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
 | 
			
		||||
        compressed_path = f"{file_dir}/{id}_compressed.opus"
 | 
			
		||||
        audio.export(compressed_path, format="opus", bitrate="32k")
 | 
			
		||||
        log.debug(f"Compressed audio to {compressed_path}")
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            os.path.getsize(compressed_path) > MAX_FILE_SIZE
 | 
			
		||||
        ):  # Still larger than MAX_FILE_SIZE after compression
 | 
			
		||||
            raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
 | 
			
		||||
        compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
 | 
			
		||||
        audio.export(compressed_path, format="mp3", bitrate="32k")
 | 
			
		||||
        # log.debug(f"Compressed audio to {compressed_path}")  # Uncomment if log is defined
 | 
			
		||||
 | 
			
		||||
        return compressed_path
 | 
			
		||||
    else:
 | 
			
		||||
        return file_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
 | 
			
		||||
    """
 | 
			
		||||
    Splits audio into chunks not exceeding max_bytes.
 | 
			
		||||
    Returns a list of chunk file paths. If audio fits, returns list with original path.
 | 
			
		||||
    """
 | 
			
		||||
    file_size = os.path.getsize(file_path)
 | 
			
		||||
    if file_size <= max_bytes:
 | 
			
		||||
        return [file_path]  # Nothing to split
 | 
			
		||||
 | 
			
		||||
    audio = AudioSegment.from_file(file_path)
 | 
			
		||||
    duration_ms = len(audio)
 | 
			
		||||
    orig_size = file_size
 | 
			
		||||
 | 
			
		||||
    approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
 | 
			
		||||
    chunks = []
 | 
			
		||||
    start = 0
 | 
			
		||||
    i = 0
 | 
			
		||||
 | 
			
		||||
    base, _ = os.path.splitext(file_path)
 | 
			
		||||
 | 
			
		||||
    while start < duration_ms:
 | 
			
		||||
        end = min(start + approx_chunk_ms, duration_ms)
 | 
			
		||||
        chunk = audio[start:end]
 | 
			
		||||
        chunk_path = f"{base}_chunk_{i}.{format}"
 | 
			
		||||
        chunk.export(chunk_path, format=format, bitrate=bitrate)
 | 
			
		||||
 | 
			
		||||
        # Reduce chunk duration if still too large
 | 
			
		||||
        while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
 | 
			
		||||
            end = start + ((end - start) // 2)
 | 
			
		||||
            chunk = audio[start:end]
 | 
			
		||||
            chunk.export(chunk_path, format=format, bitrate=bitrate)
 | 
			
		||||
 | 
			
		||||
        if os.path.getsize(chunk_path) > max_bytes:
 | 
			
		||||
            os.remove(chunk_path)
 | 
			
		||||
            raise Exception("Audio chunk cannot be reduced below max file size.")
 | 
			
		||||
 | 
			
		||||
        chunks.append(chunk_path)
 | 
			
		||||
        start = end
 | 
			
		||||
        i += 1
 | 
			
		||||
 | 
			
		||||
    return chunks
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/transcriptions")
 | 
			
		||||
def transcription(
 | 
			
		||||
    request: Request,
 | 
			
		||||
@ -807,6 +900,7 @@ def transcription(
 | 
			
		||||
        "audio/ogg",
 | 
			
		||||
        "audio/x-m4a",
 | 
			
		||||
        "audio/webm",
 | 
			
		||||
        "video/webm",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if not file.content_type.startswith(supported_filetypes):
 | 
			
		||||
@ -830,19 +924,13 @@ def transcription(
 | 
			
		||||
            f.write(contents)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            try:
 | 
			
		||||
                file_path = compress_audio(file_path)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                log.exception(e)
 | 
			
		||||
            result = transcribe(request, file_path)
 | 
			
		||||
 | 
			
		||||
                raise HTTPException(
 | 
			
		||||
                    status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
                    detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
                )
 | 
			
		||||
            return {
 | 
			
		||||
                **result,
 | 
			
		||||
                "filename": os.path.basename(file_path),
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            data = transcribe(request, file_path)
 | 
			
		||||
            file_path = file_path.split("/")[-1]
 | 
			
		||||
            return {**data, "filename": file_path}
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            log.exception(e)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user