From df0cdd9f3ca064d5d8498538a32a675697d3cac9 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 04:37:47 -0800 Subject: [PATCH] wip --- backend/open_webui/routers/audio.py | 474 ++++++++++++++------------- backend/open_webui/routers/ollama.py | 377 +++++++++------------ 2 files changed, 404 insertions(+), 447 deletions(-) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 3203727a7..d410369af 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -11,25 +11,27 @@ from pydub.silence import split_on_silence import aiohttp import aiofiles import requests + +from fastapi import ( + Depends, + FastAPI, + File, + HTTPException, + Request, + UploadFile, + status, + APIRouter, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from pydantic import BaseModel + + +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import ( - AUDIO_STT_ENGINE, - AUDIO_STT_MODEL, - AUDIO_STT_OPENAI_API_BASE_URL, - AUDIO_STT_OPENAI_API_KEY, - AUDIO_TTS_API_KEY, - AUDIO_TTS_ENGINE, - AUDIO_TTS_MODEL, - AUDIO_TTS_OPENAI_API_BASE_URL, - AUDIO_TTS_OPENAI_API_KEY, - AUDIO_TTS_SPLIT_ON, - AUDIO_TTS_VOICE, - AUDIO_TTS_AZURE_SPEECH_REGION, - AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, - WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_DIR, CACHE_DIR, - AppConfig, ) from open_webui.constants import ERROR_MESSAGES @@ -40,78 +42,25 @@ from open_webui.env import ( ENABLE_FORWARD_USER_INFO_HEADERS, ) -from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse -from pydantic import BaseModel -from open_webui.utils.auth import get_admin_user, get_verified_user + +router = APIRouter() # Constants MAX_FILE_SIZE_MB = 25 MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes - log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["AUDIO"]) - -# setting device type for whisper model -whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" -log.info(f"whisper_device_type: {whisper_device_type}") - SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) -def set_faster_whisper_model(model: str, auto_update: bool = False): - if model and app.state.config.STT_ENGINE == "": - from faster_whisper import WhisperModel - - faster_whisper_kwargs = { - "model_size_or_path": model, - "device": whisper_device_type, - "compute_type": "int8", - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not auto_update, - } - - try: - app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) - except Exception: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" - ) - faster_whisper_kwargs["local_files_only"] = False - app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs) - - else: - app.state.faster_whisper_model = None - - -class TTSConfigForm(BaseModel): - OPENAI_API_BASE_URL: str - OPENAI_API_KEY: str - API_KEY: str - ENGINE: str - MODEL: str - VOICE: str - SPLIT_ON: str - AZURE_SPEECH_REGION: str - AZURE_SPEECH_OUTPUT_FORMAT: str - - -class STTConfigForm(BaseModel): - OPENAI_API_BASE_URL: str - OPENAI_API_KEY: str - ENGINE: str - MODEL: str - WHISPER_MODEL: str - - -class AudioConfigUpdateForm(BaseModel): - tts: TTSConfigForm - stt: STTConfigForm - +########################################## +# +# Utility functions +# +########################################## from pydub import AudioSegment from pydub.utils import mediainfo @@ -140,71 +89,124 @@ def convert_mp4_to_wav(file_path, output_path): print(f"Converted {file_path} to {output_path}") -@app.get("/config") -async def get_audio_config(user=Depends(get_admin_user)): +def set_faster_whisper_model(model: str, auto_update: bool = False): + whisper_model = None + if model: + from faster_whisper import WhisperModel + + faster_whisper_kwargs = { + "model_size_or_path": model, + "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu", + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not auto_update, + } + + try: + whisper_model = WhisperModel(**faster_whisper_kwargs) + except Exception: + log.warning( + "WhisperModel initialization failed, attempting download with local_files_only=False" + ) + faster_whisper_kwargs["local_files_only"] = False + whisper_model = WhisperModel(**faster_whisper_kwargs) + return whisper_model + + +class TTSConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + API_KEY: str + ENGINE: str + MODEL: str + VOICE: str + SPLIT_ON: str + AZURE_SPEECH_REGION: str + AZURE_SPEECH_OUTPUT_FORMAT: str + + +class STTConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + ENGINE: str + MODEL: str + WHISPER_MODEL: str + + +class AudioConfigUpdateForm(BaseModel): + tts: TTSConfigForm + stt: STTConfigForm + + +@router.get("/config") +async def get_audio_config(request: Request, user=Depends(get_admin_user)): return { "tts": { - "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, - "API_KEY": app.state.config.TTS_API_KEY, - "ENGINE": app.state.config.TTS_ENGINE, - "MODEL": app.state.config.TTS_MODEL, - "VOICE": app.state.config.TTS_VOICE, - "SPLIT_ON": app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": request.app.state.config.TTS_API_KEY, + "ENGINE": request.app.state.config.TTS_ENGINE, + "MODEL": request.app.state.config.TTS_MODEL, + "VOICE": request.app.state.config.TTS_VOICE, + "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, }, "stt": { - "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, - "ENGINE": app.state.config.STT_ENGINE, - "MODEL": app.state.config.STT_MODEL, - "WHISPER_MODEL": app.state.config.WHISPER_MODEL, + "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, + "ENGINE": request.app.state.config.STT_ENGINE, + "MODEL": request.app.state.config.STT_MODEL, + "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, }, } -@app.post("/config/update") +@router.post("/config/update") async def update_audio_config( - form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) ): - app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL - app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY - app.state.config.TTS_API_KEY = form_data.tts.API_KEY - app.state.config.TTS_ENGINE = form_data.tts.ENGINE - app.state.config.TTS_MODEL = form_data.tts.MODEL - app.state.config.TTS_VOICE = form_data.tts.VOICE - app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON - app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION - app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( + request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL + request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY + request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY + request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE + request.app.state.config.TTS_MODEL = form_data.tts.MODEL + request.app.state.config.TTS_VOICE = form_data.tts.VOICE + request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON + request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION + request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT ) - app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL - app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY - app.state.config.STT_ENGINE = form_data.stt.ENGINE - app.state.config.STT_MODEL = form_data.stt.MODEL - app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL - set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE) + request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL + request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY + request.app.state.config.STT_ENGINE = form_data.stt.ENGINE + request.app.state.config.STT_MODEL = form_data.stt.MODEL + request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL + + if request.app.state.config.STT_ENGINE == "": + request.app.state.faster_whisper_model = set_faster_whisper_model( + form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE + ) return { "tts": { - "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, - "API_KEY": app.state.config.TTS_API_KEY, - "ENGINE": app.state.config.TTS_ENGINE, - "MODEL": app.state.config.TTS_MODEL, - "VOICE": app.state.config.TTS_VOICE, - "SPLIT_ON": app.state.config.TTS_SPLIT_ON, - "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, - "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, + "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "API_KEY": request.app.state.config.TTS_API_KEY, + "ENGINE": request.app.state.config.TTS_ENGINE, + "MODEL": request.app.state.config.TTS_MODEL, + "VOICE": request.app.state.config.TTS_VOICE, + "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, + "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, + "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, }, "stt": { - "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, - "ENGINE": app.state.config.STT_ENGINE, - "MODEL": app.state.config.STT_MODEL, - "WHISPER_MODEL": app.state.config.WHISPER_MODEL, + "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, + "ENGINE": request.app.state.config.STT_ENGINE, + "MODEL": request.app.state.config.STT_MODEL, + "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, }, } @@ -213,18 +215,18 @@ def load_speech_pipeline(): from transformers import pipeline from datasets import load_dataset - if app.state.speech_synthesiser is None: - app.state.speech_synthesiser = pipeline( + if request.app.state.speech_synthesiser is None: + request.app.state.speech_synthesiser = pipeline( "text-to-speech", "microsoft/speecht5_tts" ) - if app.state.speech_speaker_embeddings_dataset is None: - app.state.speech_speaker_embeddings_dataset = load_dataset( + if request.app.state.speech_speaker_embeddings_dataset is None: + request.app.state.speech_speaker_embeddings_dataset = load_dataset( "Matthijs/cmu-arctic-xvectors", split="validation" ) -@app.post("/speech") +@router.post("/speech") async def speech(request: Request, user=Depends(get_verified_user)): body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -236,9 +238,11 @@ async def speech(request: Request, user=Depends(get_verified_user)): if file_path.is_file(): return FileResponse(file_path) - if app.state.config.TTS_ENGINE == "openai": + if request.app.state.config.TTS_ENGINE == "openai": headers = {} - headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" + headers["Authorization"] = ( + f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}" + ) headers["Content-Type"] = "application/json" if ENABLE_FORWARD_USER_INFO_HEADERS: @@ -250,7 +254,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): try: body = body.decode("utf-8") body = json.loads(body) - body["model"] = app.state.config.TTS_MODEL + body["model"] = request.app.state.config.TTS_MODEL body = json.dumps(body).encode("utf-8") except Exception: pass @@ -258,7 +262,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): try: async with aiohttp.ClientSession() as session: async with session.post( - url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", + url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", data=body, headers=headers, ) as r: @@ -287,7 +291,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): detail=error_detail, ) - elif app.state.config.TTS_ENGINE == "elevenlabs": + elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: payload = json.loads(body.decode("utf-8")) except Exception as e: @@ -305,11 +309,11 @@ async def speech(request: Request, user=Depends(get_verified_user)): headers = { "Accept": "audio/mpeg", "Content-Type": "application/json", - "xi-api-key": app.state.config.TTS_API_KEY, + "xi-api-key": request.app.state.config.TTS_API_KEY, } data = { "text": payload["input"], - "model_id": app.state.config.TTS_MODEL, + "model_id": request.app.state.config.TTS_MODEL, "voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, } @@ -341,21 +345,21 @@ async def speech(request: Request, user=Depends(get_verified_user)): detail=error_detail, ) - elif app.state.config.TTS_ENGINE == "azure": + elif request.app.state.config.TTS_ENGINE == "azure": try: payload = json.loads(body.decode("utf-8")) except Exception as e: log.exception(e) raise HTTPException(status_code=400, detail="Invalid JSON payload") - region = app.state.config.TTS_AZURE_SPEECH_REGION - language = app.state.config.TTS_VOICE - locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1]) - output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT + region = request.app.state.config.TTS_AZURE_SPEECH_REGION + language = request.app.state.config.TTS_VOICE + locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1]) + output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1" headers = { - "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY, + "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY, "Content-Type": "application/ssml+xml", "X-Microsoft-OutputFormat": output_format, } @@ -378,7 +382,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): except Exception as e: log.exception(e) raise HTTPException(status_code=500, detail=str(e)) - elif app.state.config.TTS_ENGINE == "transformers": + elif request.app.state.config.TTS_ENGINE == "transformers": payload = None try: payload = json.loads(body.decode("utf-8")) @@ -391,12 +395,12 @@ async def speech(request: Request, user=Depends(get_verified_user)): load_speech_pipeline() - embeddings_dataset = app.state.speech_speaker_embeddings_dataset + embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset speaker_index = 6799 try: speaker_index = embeddings_dataset["filename"].index( - app.state.config.TTS_MODEL + request.app.state.config.TTS_MODEL ) except Exception: pass @@ -405,7 +409,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): embeddings_dataset[speaker_index]["xvector"] ).unsqueeze(0) - speech = app.state.speech_synthesiser( + speech = request.app.state.speech_synthesiser( payload["input"], forward_params={"speaker_embeddings": speaker_embedding}, ) @@ -417,17 +421,19 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) -def transcribe(file_path): +def transcribe(request: Request, file_path): print("transcribe", file_path) filename = os.path.basename(file_path) file_dir = os.path.dirname(file_path) id = filename.split(".")[0] - if app.state.config.STT_ENGINE == "": - if app.state.faster_whisper_model is None: - set_faster_whisper_model(app.state.config.WHISPER_MODEL) + if request.app.state.config.STT_ENGINE == "": + if request.app.state.faster_whisper_model is None: + request.app.state.faster_whisper_model = set_faster_whisper_model( + request.app.state.config.WHISPER_MODEL + ) - model = app.state.faster_whisper_model + model = request.app.state.faster_whisper_model segments, info = model.transcribe(file_path, beam_size=5) log.info( "Detected language '%s' with probability %f" @@ -444,31 +450,24 @@ def transcribe(file_path): log.debug(data) return data - elif app.state.config.STT_ENGINE == "openai": + elif request.app.state.config.STT_ENGINE == "openai": if is_mp4_audio(file_path): - print("is_mp4_audio") os.rename(file_path, file_path.replace(".wav", ".mp4")) # Convert MP4 audio file to WAV format convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) - headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"} - - files = {"file": (filename, open(file_path, "rb"))} - data = {"model": app.state.config.STT_MODEL} - - log.debug(files, data) - r = None try: r = requests.post( - url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", - headers=headers, - files=files, - data=data, + url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", + headers={ + "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" + }, + files={"file": (filename, open(file_path, "rb"))}, + data={"model": request.app.state.config.STT_MODEL}, ) r.raise_for_status() - data = r.json() # save the transcript to a json file @@ -476,24 +475,43 @@ def transcribe(file_path): with open(transcript_file, "w") as f: json.dump(data, f) - print(data) return data except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"External: {res['error']['message']}" + detail = f"External: {res['error'].get('message', '')}" except Exception: - error_detail = f"External: {e}" + detail = f"External: {e}" - raise Exception(error_detail) + raise Exception(detail if detail else "Open WebUI: Server Connection Error") -@app.post("/transcriptions") +def compress_audio(file_path): + if os.path.getsize(file_path) > MAX_FILE_SIZE: + file_dir = os.path.dirname(file_path) + audio = AudioSegment.from_file(file_path) + audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio + compressed_path = f"{file_dir}/{id}_compressed.opus" + audio.export(compressed_path, format="opus", bitrate="32k") + log.debug(f"Compressed audio to {compressed_path}") + + if ( + os.path.getsize(compressed_path) > MAX_FILE_SIZE + ): # Still larger than MAX_FILE_SIZE after compression + raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB")) + return compressed_path + else: + return file_path + + +@router.post("/transcriptions") def transcription( + request: Request, file: UploadFile = File(...), user=Depends(get_verified_user), ): @@ -520,36 +538,22 @@ def transcription( f.write(contents) try: - if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB - log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB") - audio = AudioSegment.from_file(file_path) - audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio - compressed_path = f"{file_dir}/{id}_compressed.opus" - audio.export(compressed_path, format="opus", bitrate="32k") - log.debug(f"Compressed audio to {compressed_path}") - file_path = compressed_path + try: + file_path = compress_audio(file_path) + except Exception as e: + log.exception(e) - if ( - os.path.getsize(file_path) > MAX_FILE_SIZE - ): # Still larger than 25MB after compression - log.debug( - f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}" - ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.FILE_TOO_LARGE( - size=f"{MAX_FILE_SIZE_MB}MB" - ), - ) - - data = transcribe(file_path) - else: - data = transcribe(file_path) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + data = transcribe(request, file_path) file_path = file_path.split("/")[-1] return {**data, "filename": file_path} except Exception as e: log.exception(e) + raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -564,39 +568,41 @@ def transcription( ) -def get_available_models() -> list[dict]: - if app.state.config.TTS_ENGINE == "openai": - return [{"id": "tts-1"}, {"id": "tts-1-hd"}] - elif app.state.config.TTS_ENGINE == "elevenlabs": - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", - } - +def get_available_models(request: Request) -> list[dict]: + available_models = [] + if request.app.state.config.TTS_ENGINE == "openai": + available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] + elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: response = requests.get( - "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5 + "https://api.elevenlabs.io/v1/models", + headers={ + "xi-api-key": request.app.state.config.TTS_API_KEY, + "Content-Type": "application/json", + }, + timeout=5, ) response.raise_for_status() models = response.json() - return [ + + available_models = [ {"name": model["name"], "id": model["model_id"]} for model in models ] except requests.RequestException as e: log.error(f"Error fetching voices: {str(e)}") - return [] + return available_models -@app.get("/models") -async def get_models(user=Depends(get_verified_user)): - return {"models": get_available_models()} +@router.get("/models") +async def get_models(request: Request, user=Depends(get_verified_user)): + return {"models": get_available_models(request)} -def get_available_voices() -> dict: +def get_available_voices(request) -> dict: """Returns {voice_id: voice_name} dict""" - ret = {} - if app.state.config.TTS_ENGINE == "openai": - ret = { + available_voices = {} + if request.app.state.config.TTS_ENGINE == "openai": + available_voices = { "alloy": "alloy", "echo": "echo", "fable": "fable", @@ -604,33 +610,38 @@ def get_available_voices() -> dict: "nova": "nova", "shimmer": "shimmer", } - elif app.state.config.TTS_ENGINE == "elevenlabs": + elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: - ret = get_elevenlabs_voices() + available_voices = get_elevenlabs_voices( + api_key=request.app.state.config.TTS_API_KEY + ) except Exception: # Avoided @lru_cache with exception pass - elif app.state.config.TTS_ENGINE == "azure": + elif request.app.state.config.TTS_ENGINE == "azure": try: - region = app.state.config.TTS_AZURE_SPEECH_REGION + region = request.app.state.config.TTS_AZURE_SPEECH_REGION url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list" - headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY} + headers = { + "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY + } response = requests.get(url, headers=headers) response.raise_for_status() voices = response.json() + for voice in voices: - ret[voice["ShortName"]] = ( + available_voices[voice["ShortName"]] = ( f"{voice['DisplayName']} ({voice['ShortName']})" ) except requests.RequestException as e: log.error(f"Error fetching voices: {str(e)}") - return ret + return available_voices @lru_cache -def get_elevenlabs_voices() -> dict: +def get_elevenlabs_voices(api_key: str) -> dict: """ Note, set the following in your .env file to use Elevenlabs: AUDIO_TTS_ENGINE=elevenlabs @@ -638,13 +649,16 @@ def get_elevenlabs_voices() -> dict: AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices AUDIO_TTS_MODEL=eleven_multilingual_v2 """ - headers = { - "xi-api-key": app.state.config.TTS_API_KEY, - "Content-Type": "application/json", - } + try: # TODO: Add retries - response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers) + response = requests.get( + "https://api.elevenlabs.io/v1/voices", + headers={ + "xi-api-key": api_key, + "Content-Type": "application/json", + }, + ) response.raise_for_status() voices_data = response.json() @@ -659,6 +673,10 @@ def get_elevenlabs_voices() -> dict: return voices -@app.get("/voices") -async def get_voices(user=Depends(get_verified_user)): - return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]} +@router.get("/voices") +async def get_voices(request: Request, user=Depends(get_verified_user)): + return { + "voices": [ + {"id": k, "name": v} for k, v in get_available_voices(request).items() + ] + } diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 19bc12e21..082d14ec3 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -385,7 +385,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): if request.app.state.config.ENABLE_OLLAMA_API: if url_idx is None: # returns lowest version - tasks = [ + request_tasks = [ send_get_request( f"{url}/api/version", request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( @@ -394,7 +394,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): ) for url in request.app.state.config.OLLAMA_BASE_URLS ] - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) responses = list(filter(lambda x: x is not None, responses)) if len(responses) > 0: @@ -446,7 +446,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u List models that are currently loaded into Ollama memory, and which node they are loaded on. """ if request.app.state.config.ENABLE_OLLAMA_API: - tasks = [ + request_tasks = [ send_get_request( f"{url}/api/ps", request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( @@ -455,7 +455,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u ) for url in request.app.state.config.OLLAMA_BASE_URLS ] - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) else: @@ -502,8 +502,8 @@ async def push_model( user=Depends(get_admin_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models(request) + models = request.app.state.OLLAMA_MODELS if form_data.name in models: url_idx = models[form_data.name]["urls"][0] @@ -540,7 +540,6 @@ async def create_model( ): log.debug(f"form_data: {form_data}") url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") return await send_post_request( url=f"{url}/api/create", @@ -563,8 +562,8 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models() + models = request.app.state.OLLAMA_MODELS if form_data.source in models: url_idx = models[form_data.source]["urls"][0] @@ -575,45 +574,37 @@ async def copy_model( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/copy", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) try: + r = requests.request( + method="POST", + url=f"{url}/api/copy", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() log.debug(f"r.text: {r.text}") - return True except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) @@ -626,8 +617,8 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models() + models = request.app.state.OLLAMA_MODELS if form_data.name in models: url_idx = models[form_data.name]["urls"][0] @@ -638,44 +629,37 @@ async def delete_model( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="DELETE", - url=f"{url}/api/delete", - data=form_data.model_dump_json(exclude_none=True).encode(), - headers=headers, - ) try: + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True).encode(), + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + ) r.raise_for_status() log.debug(f"r.text: {r.text}") - return True except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) @@ -683,8 +667,8 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models() + models = request.app.state.OLLAMA_MODELS if form_data.name not in models: raise HTTPException( @@ -693,53 +677,41 @@ async def show_model_info( ) url_idx = random.choice(models[form_data.name]["urls"]) + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/show", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/show", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() return r.json() except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) -class GenerateEmbeddingsForm(BaseModel): - model: str - prompt: str - options: Optional[dict] = None - keep_alive: Optional[Union[int, str]] = None - - class GenerateEmbedForm(BaseModel): model: str input: list[str] | str @@ -750,103 +722,17 @@ class GenerateEmbedForm(BaseModel): @router.post("/api/embed") @router.post("/api/embed/{url_idx}") -async def generate_embeddings( +async def embed( + request: Request, form_data: GenerateEmbedForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), -): - return await generate_ollama_batch_embeddings(form_data, url_idx) - - -@router.post("/api/embeddings") -@router.post("/api/embeddings/{url_idx}") -async def generate_embeddings( - form_data: GenerateEmbeddingsForm, - url_idx: Optional[int] = None, - user=Depends(get_verified_user), -): - return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) - - -async def generate_ollama_embeddings( - form_data: GenerateEmbeddingsForm, - url_idx: Optional[int] = None, -): - log.info(f"generate_ollama_embeddings {form_data}") - - if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} - - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in models: - url_idx = random.choice(models[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - data = r.json() - - log.info(f"generate_ollama_embeddings {data}") - - if "embedding" in data: - return data - else: - raise Exception("Something went wrong :/") - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) - - -async def generate_ollama_batch_embeddings( - form_data: GenerateEmbedForm, - url_idx: Optional[int] = None, ): log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} + await get_all_models() + models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -862,47 +748,107 @@ async def generate_ollama_batch_embeddings( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = requests.request( - method="POST", - url=f"{url}/api/embed", - headers=headers, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) try: + r = requests.request( + method="POST", + url=f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) r.raise_for_status() data = r.json() - - log.info(f"generate_ollama_batch_embeddings {data}") - - if "embeddings" in data: - return data - else: - raise Exception("Something went wrong :/") + return data except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" - raise Exception(error_detail) + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +class GenerateEmbeddingsForm(BaseModel): + model: str + prompt: str + options: Optional[dict] = None + keep_alive: Optional[Union[int, str]] = None + + +@router.post("/api/embeddings") +@router.post("/api/embeddings/{url_idx}") +async def embeddings( + request: Request, + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + log.info(f"generate_ollama_embeddings {form_data}") + + if url_idx is None: + await get_all_models() + models = request.app.state.OLLAMA_MODELS + + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in models: + url_idx = random.choice(models[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + data = r.json() + return data + except Exception as e: + log.exception(e) + + detail = None + if r is not None: + try: + res = r.json() + if "error" in res: + detail = f"Ollama: {res['error']}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) class GenerateCompletionForm(BaseModel): @@ -947,10 +893,10 @@ async def generate_completion( url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) if prefix_id: form_data.model = form_data.model.replace(f"{prefix_id}.", "") - log.info(f"url: {url}") return await send_post_request( url=f"{url}/api/generate", @@ -975,7 +921,7 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -async def get_ollama_url(url_idx: Optional[int], model: str): +async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None): if url_idx is None: models = request.app.state.OLLAMA_MODELS if model not in models: @@ -1001,7 +947,6 @@ async def generate_chat_completion( bypass_filter = True payload = {**form_data.model_dump(exclude_none=True)} - log.debug(f"generate_chat_completion() - 1.payload = {payload}") if "metadata" in payload: del payload["metadata"] @@ -1045,13 +990,9 @@ async def generate_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(url_idx, payload["model"]) - log.debug(f"generate_chat_completion() - 2.payload = {payload}") + url = await get_ollama_url(request, payload["model"], url_idx) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") @@ -1148,10 +1089,9 @@ async def generate_openai_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") - + url = await get_ollama_url(request, payload["model"], url_idx) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) if prefix_id: @@ -1223,10 +1163,9 @@ async def generate_openai_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = await get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") - + url = await get_ollama_url(request, payload["model"], url_idx) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "")