mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
feat: user stt language
This commit is contained in:
@@ -8,6 +8,8 @@ from pathlib import Path
|
||||
from pydub import AudioSegment
|
||||
from pydub.silence import split_on_silence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
@@ -18,6 +20,7 @@ from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
@@ -527,11 +530,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def transcription_handler(request, file_path):
|
||||
def transcription_handler(request, file_path, metadata):
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
|
||||
metadata = metadata or {}
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
if request.app.state.faster_whisper_model is None:
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
@@ -543,7 +548,7 @@ def transcription_handler(request, file_path):
|
||||
file_path,
|
||||
beam_size=5,
|
||||
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
|
||||
language=WHISPER_LANGUAGE,
|
||||
language=metadata.get("language") or WHISPER_LANGUAGE,
|
||||
)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
@@ -569,7 +574,14 @@ def transcription_handler(request, file_path):
|
||||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||
},
|
||||
files={"file": (filename, open(file_path, "rb"))},
|
||||
data={"model": request.app.state.config.STT_MODEL},
|
||||
data={
|
||||
"model": request.app.state.config.STT_MODEL,
|
||||
**(
|
||||
{"language": metadata.get("language")}
|
||||
if metadata.get("language")
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
@@ -777,8 +789,8 @@ def transcription_handler(request, file_path):
|
||||
)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path):
|
||||
log.info(f"transcribe: {file_path}")
|
||||
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
|
||||
log.info(f"transcribe: {file_path} {metadata}")
|
||||
|
||||
if is_audio_conversion_required(file_path):
|
||||
file_path = convert_audio_to_mp3(file_path)
|
||||
@@ -804,7 +816,7 @@ def transcribe(request: Request, file_path):
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Submit tasks for each chunk_path
|
||||
futures = [
|
||||
executor.submit(transcription_handler, request, chunk_path)
|
||||
executor.submit(transcription_handler, request, chunk_path, metadata)
|
||||
for chunk_path in chunk_paths
|
||||
]
|
||||
# Gather results as they complete
|
||||
@@ -812,10 +824,9 @@ def transcribe(request: Request, file_path):
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as transcribe_exc:
|
||||
log.exception(f"Error transcribing chunk: {transcribe_exc}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error during transcription.",
|
||||
detail=f"Error transcribing chunk: {transcribe_exc}",
|
||||
)
|
||||
finally:
|
||||
# Clean up only the temporary chunks, never the original file
|
||||
@@ -897,6 +908,7 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
|
||||
def transcription(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
language: Optional[str] = Form(None),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
@@ -926,7 +938,12 @@ def transcription(
|
||||
f.write(contents)
|
||||
|
||||
try:
|
||||
result = transcribe(request, file_path)
|
||||
metadata = None
|
||||
|
||||
if language:
|
||||
metadata = {"language": language}
|
||||
|
||||
result = transcribe(request, file_path, metadata)
|
||||
|
||||
return {
|
||||
**result,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -10,6 +11,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
@@ -84,13 +86,23 @@ def has_access_to_file(
|
||||
def upload_file(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
metadata: dict = None,
|
||||
metadata: Optional[dict | str] = Form(None),
|
||||
process: bool = Query(True),
|
||||
internal: bool = False,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
if isinstance(metadata, str):
|
||||
try:
|
||||
metadata = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
|
||||
)
|
||||
file_metadata = metadata if metadata else {}
|
||||
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
filename = os.path.basename(unsanitized_filename)
|
||||
@@ -99,7 +111,7 @@ def upload_file(
|
||||
# Remove the leading dot from the file extension
|
||||
file_extension = file_extension[1:] if file_extension else ""
|
||||
|
||||
if not file_metadata and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
|
||||
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
|
||||
]
|
||||
@@ -147,7 +159,7 @@ def upload_file(
|
||||
"video/webm"
|
||||
}:
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path)
|
||||
result = transcribe(request, file_path, file_metadata)
|
||||
|
||||
process_file(
|
||||
request,
|
||||
|
||||
@@ -460,7 +460,7 @@ def upload_image(request, image_data, content_type, metadata, user):
|
||||
"content-type": content_type,
|
||||
},
|
||||
)
|
||||
file_item = upload_file(request, file, user, metadata=metadata)
|
||||
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
return url
|
||||
|
||||
|
||||
Reference in New Issue
Block a user