diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index a1e6e94fa..1cad5f4c7 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -5,6 +5,8 @@ import os import uuid from functools import lru_cache from pathlib import Path +from pydub import AudioSegment +from pydub.silence import split_on_silence import requests from open_webui.config import ( @@ -35,7 +37,12 @@ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_current_user, get_verified_user +from open_webui.utils.utils import get_admin_user, get_verified_user + +# Constants +MAX_FILE_SIZE_MB = 25 +MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["AUDIO"]) @@ -353,10 +360,103 @@ async def speech(request: Request, user=Depends(get_verified_user)): ) +def transcribe(file_path): + print("transcribe", file_path) + filename = os.path.basename(file_path) + file_dir = os.path.dirname(file_path) + id = filename.split(".")[0] + + if app.state.config.STT_ENGINE == "": + from faster_whisper import WhisperModel + + whisper_kwargs = { + "model_size_or_path": WHISPER_MODEL, + "device": whisper_device_type, + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not WHISPER_MODEL_AUTO_UPDATE, + } + + log.debug(f"whisper_kwargs: {whisper_kwargs}") + + try: + model = WhisperModel(**whisper_kwargs) + except Exception: + log.warning( + "WhisperModel initialization failed, attempting download with local_files_only=False" + ) + whisper_kwargs["local_files_only"] = False + model = WhisperModel(**whisper_kwargs) + + segments, info = model.transcribe(file_path, beam_size=5) + log.info( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + + transcript = "".join([segment.text for segment in list(segments)]) + + data = {"text": transcript.strip()} + + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + print(data) + return data + elif app.state.config.STT_ENGINE == "openai": + if is_mp4_audio(file_path): + print("is_mp4_audio") + os.rename(file_path, file_path.replace(".wav", ".mp4")) + # Convert MP4 audio file to WAV format + convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) + + headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"} + + files = {"file": (filename, open(file_path, "rb"))} + data = {"model": app.state.config.STT_MODEL} + + print(files, data) + + r = None + try: + r = requests.post( + url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", + headers=headers, + files=files, + data=data, + ) + + r.raise_for_status() + + data = r.json() + + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + print(data) + return data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']['message']}" + except Exception: + error_detail = f"External: {e}" + + raise error_detail + + @app.post("/transcriptions") -def transcribe( +def transcription( file: UploadFile = File(...), - user=Depends(get_current_user), + user=Depends(get_verified_user), ): log.info(f"file.content_type: {file.content_type}") @@ -368,111 +468,53 @@ def transcribe( try: ext = file.filename.split(".")[-1] - id = uuid.uuid4() + filename = f"{id}.{ext}" + contents = file.file.read() file_dir = f"{CACHE_DIR}/audio/transcriptions" os.makedirs(file_dir, exist_ok=True) file_path = f"{file_dir}/{filename}" - print(filename) - - contents = file.file.read() with open(file_path, "wb") as f: f.write(contents) - f.close() - if app.state.config.STT_ENGINE == "": - from faster_whisper import WhisperModel + try: + if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB + log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB") + audio = AudioSegment.from_file(file_path) + audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio + compressed_path = f"{file_dir}/{id}_compressed.opus" + audio.export(compressed_path, format="opus", bitrate="32k") + log.debug(f"Compressed audio to {compressed_path}") + file_path = compressed_path - whisper_kwargs = { - "model_size_or_path": WHISPER_MODEL, - "device": whisper_device_type, - "compute_type": "int8", - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not WHISPER_MODEL_AUTO_UPDATE, - } - - log.debug(f"whisper_kwargs: {whisper_kwargs}") - - try: - model = WhisperModel(**whisper_kwargs) - except Exception: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" - ) - whisper_kwargs["local_files_only"] = False - model = WhisperModel(**whisper_kwargs) - - segments, info = model.transcribe(file_path, beam_size=5) - log.info( - "Detected language '%s' with probability %f" - % (info.language, info.language_probability) - ) - - transcript = "".join([segment.text for segment in list(segments)]) - - data = {"text": transcript.strip()} - - # save the transcript to a json file - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: - json.dump(data, f) - - print(data) + if ( + os.path.getsize(file_path) > MAX_FILE_SIZE + ): # Still larger than 25MB after compression + chunks = split_on_silence( + audio, min_silence_len=500, silence_thresh=-40 + ) + texts = [] + for i, chunk in enumerate(chunks): + chunk_file_path = f"{file_dir}/{id}_chunk{i}.{ext}" + chunk.export(chunk_file_path, format=ext) + text = transcribe(chunk_file_path) + texts.append(text) + data = {"text": " ".join(texts)} + else: + data = transcribe(file_path) + else: + data = transcribe(file_path) return data - - elif app.state.config.STT_ENGINE == "openai": - if is_mp4_audio(file_path): - print("is_mp4_audio") - os.rename(file_path, file_path.replace(".wav", ".mp4")) - # Convert MP4 audio file to WAV format - convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) - - headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"} - - files = {"file": (filename, open(file_path, "rb"))} - data = {"model": app.state.config.STT_MODEL} - - print(files, data) - - r = None - try: - r = requests.post( - url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", - headers=headers, - files=files, - data=data, - ) - - r.raise_for_status() - - data = r.json() - - # save the transcript to a json file - transcript_file = f"{file_dir}/{id}.json" - with open(transcript_file, "w") as f: - json.dump(data, f) - - print(data) - return data - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']['message']}" - except Exception: - error_detail = f"External: {e}" - - raise HTTPException( - status_code=r.status_code if r != None else 500, - detail=error_detail, - ) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) except Exception as e: log.exception(e) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 066272e43..eeb2f330f 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -700,7 +700,7 @@ childrenIds: [], role: 'user', content: userPrompt, - files: chatFiles.length > 0 ? chatFiles : undefined, + files: _files.length > 0 ? _files : undefined, timestamp: Math.floor(Date.now() / 1000), // Unix epoch models: selectedModels }; diff --git a/src/lib/components/common/FileItemModal.svelte b/src/lib/components/common/FileItemModal.svelte index e98f554be..f97e4f33d 100644 --- a/src/lib/components/common/FileItemModal.svelte +++ b/src/lib/components/common/FileItemModal.svelte @@ -54,7 +54,7 @@
-
+
{#if file.size}
{formatFileSize(file.size)}