feat: user stt language

This commit is contained in:
Timothy Jaeryang Baek
2025-05-24 00:36:30 +04:00
parent 9946dc7b5f
commit baaa285534
14 changed files with 149 additions and 103 deletions

View File

@@ -8,6 +8,8 @@ from pathlib import Path
from pydub import AudioSegment
from pydub.silence import split_on_silence
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import aiohttp
import aiofiles
@@ -18,6 +20,7 @@ from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
@@ -527,11 +530,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
def transcription_handler(request, file_path):
def transcription_handler(request, file_path, metadata):
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
metadata = metadata or {}
if request.app.state.config.STT_ENGINE == "":
if request.app.state.faster_whisper_model is None:
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -543,7 +548,7 @@ def transcription_handler(request, file_path):
file_path,
beam_size=5,
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
language=WHISPER_LANGUAGE,
language=metadata.get("language") or WHISPER_LANGUAGE,
)
log.info(
"Detected language '%s' with probability %f"
@@ -569,7 +574,14 @@ def transcription_handler(request, file_path):
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
},
files={"file": (filename, open(file_path, "rb"))},
data={"model": request.app.state.config.STT_MODEL},
data={
"model": request.app.state.config.STT_MODEL,
**(
{"language": metadata.get("language")}
if metadata.get("language")
else {}
),
},
)
r.raise_for_status()
@@ -777,8 +789,8 @@ def transcription_handler(request, file_path):
)
def transcribe(request: Request, file_path):
log.info(f"transcribe: {file_path}")
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
log.info(f"transcribe: {file_path} {metadata}")
if is_audio_conversion_required(file_path):
file_path = convert_audio_to_mp3(file_path)
@@ -804,7 +816,7 @@ def transcribe(request: Request, file_path):
with ThreadPoolExecutor() as executor:
# Submit tasks for each chunk_path
futures = [
executor.submit(transcription_handler, request, chunk_path)
executor.submit(transcription_handler, request, chunk_path, metadata)
for chunk_path in chunk_paths
]
# Gather results as they complete
@@ -812,10 +824,9 @@ def transcribe(request: Request, file_path):
try:
results.append(future.result())
except Exception as transcribe_exc:
log.exception(f"Error transcribing chunk: {transcribe_exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error during transcription.",
detail=f"Error transcribing chunk: {transcribe_exc}",
)
finally:
# Clean up only the temporary chunks, never the original file
@@ -897,6 +908,7 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
def transcription(
request: Request,
file: UploadFile = File(...),
language: Optional[str] = Form(None),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
@@ -926,7 +938,12 @@ def transcription(
f.write(contents)
try:
result = transcribe(request, file_path)
metadata = None
if language:
metadata = {"language": language}
result = transcribe(request, file_path, metadata)
return {
**result,

View File

@@ -1,6 +1,7 @@
import logging
import os
import uuid
import json
from fnmatch import fnmatch
from pathlib import Path
from typing import Optional
@@ -10,6 +11,7 @@ from fastapi import (
APIRouter,
Depends,
File,
Form,
HTTPException,
Request,
UploadFile,
@@ -84,13 +86,23 @@ def has_access_to_file(
def upload_file(
request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
metadata: dict = None,
metadata: Optional[dict | str] = Form(None),
process: bool = Query(True),
internal: bool = False,
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
if isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
)
file_metadata = metadata if metadata else {}
try:
unsanitized_filename = file.filename
filename = os.path.basename(unsanitized_filename)
@@ -99,7 +111,7 @@ def upload_file(
# Remove the leading dot from the file extension
file_extension = file_extension[1:] if file_extension else ""
if not file_metadata and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
]
@@ -147,7 +159,7 @@ def upload_file(
"video/webm"
}:
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path)
result = transcribe(request, file_path, file_metadata)
process_file(
request,

View File

@@ -460,7 +460,7 @@ def upload_image(request, image_data, content_type, metadata, user):
"content-type": content_type,
},
)
file_item = upload_file(request, file, user, metadata=metadata)
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url