This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 04:37:47 -08:00
parent df48eac22b
commit df0cdd9f3c
2 changed files with 404 additions and 447 deletions

View File

@ -11,25 +11,27 @@ from pydub.silence import split_on_silence
import aiohttp import aiohttp
import aiofiles import aiofiles
import requests import requests
from fastapi import (
Depends,
FastAPI,
File,
HTTPException,
Request,
UploadFile,
status,
APIRouter,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import ( from open_webui.config import (
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_STT_OPENAI_API_KEY,
AUDIO_TTS_API_KEY,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
AUDIO_TTS_OPENAI_API_BASE_URL,
AUDIO_TTS_OPENAI_API_KEY,
AUDIO_TTS_SPLIT_ON,
AUDIO_TTS_VOICE,
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
WHISPER_MODEL,
WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR, WHISPER_MODEL_DIR,
CACHE_DIR, CACHE_DIR,
AppConfig,
) )
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
@ -40,78 +42,25 @@ from open_webui.env import (
ENABLE_FORWARD_USER_INFO_HEADERS, ENABLE_FORWARD_USER_INFO_HEADERS,
) )
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware router = APIRouter()
from fastapi.responses import FileResponse
from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_verified_user
# Constants # Constants
MAX_FILE_SIZE_MB = 25 MAX_FILE_SIZE_MB = 25
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"]) log.setLevel(SRC_LOG_LEVELS["AUDIO"])
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
log.info(f"whisper_device_type: {whisper_device_type}")
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
def set_faster_whisper_model(model: str, auto_update: bool = False): ##########################################
if model and app.state.config.STT_ENGINE == "": #
from faster_whisper import WhisperModel # Utility functions
#
faster_whisper_kwargs = { ##########################################
"model_size_or_path": model,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not auto_update,
}
try:
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
faster_whisper_kwargs["local_files_only"] = False
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
else:
app.state.faster_whisper_model = None
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
API_KEY: str
ENGINE: str
MODEL: str
VOICE: str
SPLIT_ON: str
AZURE_SPEECH_REGION: str
AZURE_SPEECH_OUTPUT_FORMAT: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
WHISPER_MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
from pydub import AudioSegment from pydub import AudioSegment
from pydub.utils import mediainfo from pydub.utils import mediainfo
@ -140,71 +89,124 @@ def convert_mp4_to_wav(file_path, output_path):
print(f"Converted {file_path} to {output_path}") print(f"Converted {file_path} to {output_path}")
@app.get("/config") def set_faster_whisper_model(model: str, auto_update: bool = False):
async def get_audio_config(user=Depends(get_admin_user)): whisper_model = None
if model:
from faster_whisper import WhisperModel
faster_whisper_kwargs = {
"model_size_or_path": model,
"device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not auto_update,
}
try:
whisper_model = WhisperModel(**faster_whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
faster_whisper_kwargs["local_files_only"] = False
whisper_model = WhisperModel(**faster_whisper_kwargs)
return whisper_model
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
API_KEY: str
ENGINE: str
MODEL: str
VOICE: str
SPLIT_ON: str
AZURE_SPEECH_REGION: str
AZURE_SPEECH_OUTPUT_FORMAT: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
WHISPER_MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
@router.get("/config")
async def get_audio_config(request: Request, user=Depends(get_admin_user)):
return { return {
"tts": { "tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": app.state.config.TTS_API_KEY, "API_KEY": request.app.state.config.TTS_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE, "ENGINE": request.app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL, "MODEL": request.app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE, "VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": app.state.config.TTS_SPLIT_ON, "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
}, },
"stt": { "stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE, "ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL, "MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL, "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
}, },
} }
@app.post("/config/update") @router.post("/config/update")
async def update_audio_config( async def update_audio_config(
form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
): ):
app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
app.state.config.TTS_API_KEY = form_data.tts.API_KEY request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
app.state.config.TTS_ENGINE = form_data.tts.ENGINE request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
app.state.config.TTS_MODEL = form_data.tts.MODEL request.app.state.config.TTS_MODEL = form_data.tts.MODEL
app.state.config.TTS_VOICE = form_data.tts.VOICE request.app.state.config.TTS_VOICE = form_data.tts.VOICE
app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = ( request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
) )
app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.STT_ENGINE = form_data.stt.ENGINE request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.STT_MODEL = form_data.stt.MODEL request.app.state.config.STT_MODEL = form_data.stt.MODEL
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
)
return { return {
"tts": { "tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": app.state.config.TTS_API_KEY, "API_KEY": request.app.state.config.TTS_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE, "ENGINE": request.app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL, "MODEL": request.app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE, "VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": app.state.config.TTS_SPLIT_ON, "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION, "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
}, },
"stt": { "stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE, "ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL, "MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL, "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
}, },
} }
@ -213,18 +215,18 @@ def load_speech_pipeline():
from transformers import pipeline from transformers import pipeline
from datasets import load_dataset from datasets import load_dataset
if app.state.speech_synthesiser is None: if request.app.state.speech_synthesiser is None:
app.state.speech_synthesiser = pipeline( request.app.state.speech_synthesiser = pipeline(
"text-to-speech", "microsoft/speecht5_tts" "text-to-speech", "microsoft/speecht5_tts"
) )
if app.state.speech_speaker_embeddings_dataset is None: if request.app.state.speech_speaker_embeddings_dataset is None:
app.state.speech_speaker_embeddings_dataset = load_dataset( request.app.state.speech_speaker_embeddings_dataset = load_dataset(
"Matthijs/cmu-arctic-xvectors", split="validation" "Matthijs/cmu-arctic-xvectors", split="validation"
) )
@app.post("/speech") @router.post("/speech")
async def speech(request: Request, user=Depends(get_verified_user)): async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body() body = await request.body()
name = hashlib.sha256(body).hexdigest() name = hashlib.sha256(body).hexdigest()
@ -236,9 +238,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
if file_path.is_file(): if file_path.is_file():
return FileResponse(file_path) return FileResponse(file_path)
if app.state.config.TTS_ENGINE == "openai": if request.app.state.config.TTS_ENGINE == "openai":
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" headers["Authorization"] = (
f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}"
)
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS: if ENABLE_FORWARD_USER_INFO_HEADERS:
@ -250,7 +254,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
try: try:
body = body.decode("utf-8") body = body.decode("utf-8")
body = json.loads(body) body = json.loads(body)
body["model"] = app.state.config.TTS_MODEL body["model"] = request.app.state.config.TTS_MODEL
body = json.dumps(body).encode("utf-8") body = json.dumps(body).encode("utf-8")
except Exception: except Exception:
pass pass
@ -258,7 +262,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
) as r: ) as r:
@ -287,7 +291,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail=error_detail, detail=error_detail,
) )
elif app.state.config.TTS_ENGINE == "elevenlabs": elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try: try:
payload = json.loads(body.decode("utf-8")) payload = json.loads(body.decode("utf-8"))
except Exception as e: except Exception as e:
@ -305,11 +309,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
headers = { headers = {
"Accept": "audio/mpeg", "Accept": "audio/mpeg",
"Content-Type": "application/json", "Content-Type": "application/json",
"xi-api-key": app.state.config.TTS_API_KEY, "xi-api-key": request.app.state.config.TTS_API_KEY,
} }
data = { data = {
"text": payload["input"], "text": payload["input"],
"model_id": app.state.config.TTS_MODEL, "model_id": request.app.state.config.TTS_MODEL,
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5}, "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
} }
@ -341,21 +345,21 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail=error_detail, detail=error_detail,
) )
elif app.state.config.TTS_ENGINE == "azure": elif request.app.state.config.TTS_ENGINE == "azure":
try: try:
payload = json.loads(body.decode("utf-8")) payload = json.loads(body.decode("utf-8"))
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload") raise HTTPException(status_code=400, detail="Invalid JSON payload")
region = app.state.config.TTS_AZURE_SPEECH_REGION region = request.app.state.config.TTS_AZURE_SPEECH_REGION
language = app.state.config.TTS_VOICE language = request.app.state.config.TTS_VOICE
locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1]) locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1" url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
headers = { headers = {
"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY, "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
"Content-Type": "application/ssml+xml", "Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": output_format, "X-Microsoft-OutputFormat": output_format,
} }
@ -378,7 +382,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
elif app.state.config.TTS_ENGINE == "transformers": elif request.app.state.config.TTS_ENGINE == "transformers":
payload = None payload = None
try: try:
payload = json.loads(body.decode("utf-8")) payload = json.loads(body.decode("utf-8"))
@ -391,12 +395,12 @@ async def speech(request: Request, user=Depends(get_verified_user)):
load_speech_pipeline() load_speech_pipeline()
embeddings_dataset = app.state.speech_speaker_embeddings_dataset embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
speaker_index = 6799 speaker_index = 6799
try: try:
speaker_index = embeddings_dataset["filename"].index( speaker_index = embeddings_dataset["filename"].index(
app.state.config.TTS_MODEL request.app.state.config.TTS_MODEL
) )
except Exception: except Exception:
pass pass
@ -405,7 +409,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
embeddings_dataset[speaker_index]["xvector"] embeddings_dataset[speaker_index]["xvector"]
).unsqueeze(0) ).unsqueeze(0)
speech = app.state.speech_synthesiser( speech = request.app.state.speech_synthesiser(
payload["input"], payload["input"],
forward_params={"speaker_embeddings": speaker_embedding}, forward_params={"speaker_embeddings": speaker_embedding},
) )
@ -417,17 +421,19 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
def transcribe(file_path): def transcribe(request: Request, file_path):
print("transcribe", file_path) print("transcribe", file_path)
filename = os.path.basename(file_path) filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path) file_dir = os.path.dirname(file_path)
id = filename.split(".")[0] id = filename.split(".")[0]
if app.state.config.STT_ENGINE == "": if request.app.state.config.STT_ENGINE == "":
if app.state.faster_whisper_model is None: if request.app.state.faster_whisper_model is None:
set_faster_whisper_model(app.state.config.WHISPER_MODEL) request.app.state.faster_whisper_model = set_faster_whisper_model(
request.app.state.config.WHISPER_MODEL
)
model = app.state.faster_whisper_model model = request.app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5) segments, info = model.transcribe(file_path, beam_size=5)
log.info( log.info(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
@ -444,31 +450,24 @@ def transcribe(file_path):
log.debug(data) log.debug(data)
return data return data
elif app.state.config.STT_ENGINE == "openai": elif request.app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path): if is_mp4_audio(file_path):
print("is_mp4_audio")
os.rename(file_path, file_path.replace(".wav", ".mp4")) os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format # Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": app.state.config.STT_MODEL}
log.debug(files, data)
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers=headers, headers={
files=files, "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
data=data, },
files={"file": (filename, open(file_path, "rb"))},
data={"model": request.app.state.config.STT_MODEL},
) )
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
# save the transcript to a json file # save the transcript to a json file
@ -476,24 +475,43 @@ def transcribe(file_path):
with open(transcript_file, "w") as f: with open(transcript_file, "w") as f:
json.dump(data, f) json.dump(data, f)
print(data)
return data return data
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message']}" detail = f"External: {res['error'].get('message', '')}"
except Exception: except Exception:
error_detail = f"External: {e}" detail = f"External: {e}"
raise Exception(error_detail) raise Exception(detail if detail else "Open WebUI: Server Connection Error")
@app.post("/transcriptions") def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
file_dir = os.path.dirname(file_path)
audio = AudioSegment.from_file(file_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
if (
os.path.getsize(compressed_path) > MAX_FILE_SIZE
): # Still larger than MAX_FILE_SIZE after compression
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
return compressed_path
else:
return file_path
@router.post("/transcriptions")
def transcription( def transcription(
request: Request,
file: UploadFile = File(...), file: UploadFile = File(...),
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
@ -520,36 +538,22 @@ def transcription(
f.write(contents) f.write(contents)
try: try:
if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB try:
log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB") file_path = compress_audio(file_path)
audio = AudioSegment.from_file(file_path) except Exception as e:
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio log.exception(e)
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
file_path = compressed_path
if ( raise HTTPException(
os.path.getsize(file_path) > MAX_FILE_SIZE status_code=status.HTTP_400_BAD_REQUEST,
): # Still larger than 25MB after compression detail=ERROR_MESSAGES.DEFAULT(e),
log.debug( )
f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_TOO_LARGE(
size=f"{MAX_FILE_SIZE_MB}MB"
),
)
data = transcribe(file_path)
else:
data = transcribe(file_path)
data = transcribe(request, file_path)
file_path = file_path.split("/")[-1] file_path = file_path.split("/")[-1]
return {**data, "filename": file_path} return {**data, "filename": file_path}
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
@ -564,39 +568,41 @@ def transcription(
) )
def get_available_models() -> list[dict]: def get_available_models(request: Request) -> list[dict]:
if app.state.config.TTS_ENGINE == "openai": available_models = []
return [{"id": "tts-1"}, {"id": "tts-1-hd"}] if request.app.state.config.TTS_ENGINE == "openai":
elif app.state.config.TTS_ENGINE == "elevenlabs": available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
headers = { elif request.app.state.config.TTS_ENGINE == "elevenlabs":
"xi-api-key": app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
}
try: try:
response = requests.get( response = requests.get(
"https://api.elevenlabs.io/v1/models", headers=headers, timeout=5 "https://api.elevenlabs.io/v1/models",
headers={
"xi-api-key": request.app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
},
timeout=5,
) )
response.raise_for_status() response.raise_for_status()
models = response.json() models = response.json()
return [
available_models = [
{"name": model["name"], "id": model["model_id"]} for model in models {"name": model["name"], "id": model["model_id"]} for model in models
] ]
except requests.RequestException as e: except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}") log.error(f"Error fetching voices: {str(e)}")
return [] return available_models
@app.get("/models") @router.get("/models")
async def get_models(user=Depends(get_verified_user)): async def get_models(request: Request, user=Depends(get_verified_user)):
return {"models": get_available_models()} return {"models": get_available_models(request)}
def get_available_voices() -> dict: def get_available_voices(request) -> dict:
"""Returns {voice_id: voice_name} dict""" """Returns {voice_id: voice_name} dict"""
ret = {} available_voices = {}
if app.state.config.TTS_ENGINE == "openai": if request.app.state.config.TTS_ENGINE == "openai":
ret = { available_voices = {
"alloy": "alloy", "alloy": "alloy",
"echo": "echo", "echo": "echo",
"fable": "fable", "fable": "fable",
@ -604,33 +610,38 @@ def get_available_voices() -> dict:
"nova": "nova", "nova": "nova",
"shimmer": "shimmer", "shimmer": "shimmer",
} }
elif app.state.config.TTS_ENGINE == "elevenlabs": elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try: try:
ret = get_elevenlabs_voices() available_voices = get_elevenlabs_voices(
api_key=request.app.state.config.TTS_API_KEY
)
except Exception: except Exception:
# Avoided @lru_cache with exception # Avoided @lru_cache with exception
pass pass
elif app.state.config.TTS_ENGINE == "azure": elif request.app.state.config.TTS_ENGINE == "azure":
try: try:
region = app.state.config.TTS_AZURE_SPEECH_REGION region = request.app.state.config.TTS_AZURE_SPEECH_REGION
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list" url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY} headers = {
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
}
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
response.raise_for_status() response.raise_for_status()
voices = response.json() voices = response.json()
for voice in voices: for voice in voices:
ret[voice["ShortName"]] = ( available_voices[voice["ShortName"]] = (
f"{voice['DisplayName']} ({voice['ShortName']})" f"{voice['DisplayName']} ({voice['ShortName']})"
) )
except requests.RequestException as e: except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}") log.error(f"Error fetching voices: {str(e)}")
return ret return available_voices
@lru_cache @lru_cache
def get_elevenlabs_voices() -> dict: def get_elevenlabs_voices(api_key: str) -> dict:
""" """
Note, set the following in your .env file to use Elevenlabs: Note, set the following in your .env file to use Elevenlabs:
AUDIO_TTS_ENGINE=elevenlabs AUDIO_TTS_ENGINE=elevenlabs
@ -638,13 +649,16 @@ def get_elevenlabs_voices() -> dict:
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
AUDIO_TTS_MODEL=eleven_multilingual_v2 AUDIO_TTS_MODEL=eleven_multilingual_v2
""" """
headers = {
"xi-api-key": app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
}
try: try:
# TODO: Add retries # TODO: Add retries
response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers) response = requests.get(
"https://api.elevenlabs.io/v1/voices",
headers={
"xi-api-key": api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status() response.raise_for_status()
voices_data = response.json() voices_data = response.json()
@ -659,6 +673,10 @@ def get_elevenlabs_voices() -> dict:
return voices return voices
@app.get("/voices") @router.get("/voices")
async def get_voices(user=Depends(get_verified_user)): async def get_voices(request: Request, user=Depends(get_verified_user)):
return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]} return {
"voices": [
{"id": k, "name": v} for k, v in get_available_voices(request).items()
]
}

View File

@ -385,7 +385,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
if request.app.state.config.ENABLE_OLLAMA_API: if request.app.state.config.ENABLE_OLLAMA_API:
if url_idx is None: if url_idx is None:
# returns lowest version # returns lowest version
tasks = [ request_tasks = [
send_get_request( send_get_request(
f"{url}/api/version", f"{url}/api/version",
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
@ -394,7 +394,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
) )
for url in request.app.state.config.OLLAMA_BASE_URLS for url in request.app.state.config.OLLAMA_BASE_URLS
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*request_tasks)
responses = list(filter(lambda x: x is not None, responses)) responses = list(filter(lambda x: x is not None, responses))
if len(responses) > 0: if len(responses) > 0:
@ -446,7 +446,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
List models that are currently loaded into Ollama memory, and which node they are loaded on. List models that are currently loaded into Ollama memory, and which node they are loaded on.
""" """
if request.app.state.config.ENABLE_OLLAMA_API: if request.app.state.config.ENABLE_OLLAMA_API:
tasks = [ request_tasks = [
send_get_request( send_get_request(
f"{url}/api/ps", f"{url}/api/ps",
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
@ -455,7 +455,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
) )
for url in request.app.state.config.OLLAMA_BASE_URLS for url in request.app.state.config.OLLAMA_BASE_URLS
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*request_tasks)
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
else: else:
@ -502,8 +502,8 @@ async def push_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
model_list = await get_all_models() await get_all_models(request)
models = {model["model"]: model for model in model_list["models"]} models = request.app.state.OLLAMA_MODELS
if form_data.name in models: if form_data.name in models:
url_idx = models[form_data.name]["urls"][0] url_idx = models[form_data.name]["urls"][0]
@ -540,7 +540,6 @@ async def create_model(
): ):
log.debug(f"form_data: {form_data}") log.debug(f"form_data: {form_data}")
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
return await send_post_request( return await send_post_request(
url=f"{url}/api/create", url=f"{url}/api/create",
@ -563,8 +562,8 @@ async def copy_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
model_list = await get_all_models() await get_all_models()
models = {model["model"]: model for model in model_list["models"]} models = request.app.state.OLLAMA_MODELS
if form_data.source in models: if form_data.source in models:
url_idx = models[form_data.source]["urls"][0] url_idx = models[form_data.source]["urls"][0]
@ -575,45 +574,37 @@ async def copy_model(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/copy",
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try: try:
r = requests.request(
method="POST",
url=f"{url}/api/copy",
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status() r.raise_for_status()
log.debug(f"r.text: {r.text}") log.debug(f"r.text: {r.text}")
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" detail = f"Ollama: {res['error']}"
except Exception: except Exception:
error_detail = f"Ollama: {e}" detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
status_code=r.status_code if r else 500, status_code=r.status_code if r else 500,
detail=error_detail, detail=detail if detail else "Open WebUI: Server Connection Error",
) )
@ -626,8 +617,8 @@ async def delete_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
model_list = await get_all_models() await get_all_models()
models = {model["model"]: model for model in model_list["models"]} models = request.app.state.OLLAMA_MODELS
if form_data.name in models: if form_data.name in models:
url_idx = models[form_data.name]["urls"][0] url_idx = models[form_data.name]["urls"][0]
@ -638,44 +629,37 @@ async def delete_model(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="DELETE",
url=f"{url}/api/delete",
data=form_data.model_dump_json(exclude_none=True).encode(),
headers=headers,
)
try: try:
r = requests.request(
method="DELETE",
url=f"{url}/api/delete",
data=form_data.model_dump_json(exclude_none=True).encode(),
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
},
)
r.raise_for_status() r.raise_for_status()
log.debug(f"r.text: {r.text}") log.debug(f"r.text: {r.text}")
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" detail = f"Ollama: {res['error']}"
except Exception: except Exception:
error_detail = f"Ollama: {e}" detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
status_code=r.status_code if r else 500, status_code=r.status_code if r else 500,
detail=error_detail, detail=detail if detail else "Open WebUI: Server Connection Error",
) )
@ -683,8 +667,8 @@ async def delete_model(
async def show_model_info( async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
): ):
model_list = await get_all_models() await get_all_models()
models = {model["model"]: model for model in model_list["models"]} models = request.app.state.OLLAMA_MODELS
if form_data.name not in models: if form_data.name not in models:
raise HTTPException( raise HTTPException(
@ -693,53 +677,41 @@ async def show_model_info(
) )
url_idx = random.choice(models[form_data.name]["urls"]) url_idx = random.choice(models[form_data.name]["urls"])
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/show",
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try: try:
r = requests.request(
method="POST",
url=f"{url}/api/show",
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status() r.raise_for_status()
return r.json() return r.json()
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" detail = f"Ollama: {res['error']}"
except Exception: except Exception:
error_detail = f"Ollama: {e}" detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
status_code=r.status_code if r else 500, status_code=r.status_code if r else 500,
detail=error_detail, detail=detail if detail else "Open WebUI: Server Connection Error",
) )
class GenerateEmbeddingsForm(BaseModel):
model: str
prompt: str
options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None
class GenerateEmbedForm(BaseModel): class GenerateEmbedForm(BaseModel):
model: str model: str
input: list[str] | str input: list[str] | str
@ -750,103 +722,17 @@ class GenerateEmbedForm(BaseModel):
@router.post("/api/embed") @router.post("/api/embed")
@router.post("/api/embed/{url_idx}") @router.post("/api/embed/{url_idx}")
async def generate_embeddings( async def embed(
request: Request,
form_data: GenerateEmbedForm, form_data: GenerateEmbedForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
):
return await generate_ollama_batch_embeddings(form_data, url_idx)
@router.post("/api/embeddings")
@router.post("/api/embeddings/{url_idx}")
async def generate_embeddings(
form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
async def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None,
):
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in models:
url_idx = random.choice(models[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
data = r.json()
log.info(f"generate_ollama_embeddings {data}")
if "embedding" in data:
return data
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
async def generate_ollama_batch_embeddings(
form_data: GenerateEmbedForm,
url_idx: Optional[int] = None,
): ):
log.info(f"generate_ollama_batch_embeddings {form_data}") log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None: if url_idx is None:
model_list = await get_all_models() await get_all_models()
models = {model["model"]: model for model in model_list["models"]} models = request.app.state.OLLAMA_MODELS
model = form_data.model model = form_data.model
@ -862,47 +748,107 @@ async def generate_ollama_batch_embeddings(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try: try:
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
return data
log.info(f"generate_ollama_batch_embeddings {data}")
if "embeddings" in data:
return data
else:
raise Exception("Something went wrong :/")
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" detail = f"Ollama: {res['error']}"
except Exception: except Exception:
error_detail = f"Ollama: {e}" detail = f"Ollama: {e}"
raise Exception(error_detail) raise HTTPException(
status_code=r.status_code if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
class GenerateEmbeddingsForm(BaseModel):
model: str
prompt: str
options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None
@router.post("/api/embeddings")
@router.post("/api/embeddings/{url_idx}")
async def embeddings(
request: Request,
form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None:
await get_all_models()
models = request.app.state.OLLAMA_MODELS
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in models:
url_idx = random.choice(models[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
try:
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status()
data = r.json()
return data
except Exception as e:
log.exception(e)
detail = None
if r is not None:
try:
res = r.json()
if "error" in res:
detail = f"Ollama: {res['error']}"
except Exception:
detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
class GenerateCompletionForm(BaseModel): class GenerateCompletionForm(BaseModel):
@ -947,10 +893,10 @@ async def generate_completion(
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
form_data.model = form_data.model.replace(f"{prefix_id}.", "") form_data.model = form_data.model.replace(f"{prefix_id}.", "")
log.info(f"url: {url}")
return await send_post_request( return await send_post_request(
url=f"{url}/api/generate", url=f"{url}/api/generate",
@ -975,7 +921,7 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
async def get_ollama_url(url_idx: Optional[int], model: str): async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
if url_idx is None: if url_idx is None:
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if model not in models: if model not in models:
@ -1001,7 +947,6 @@ async def generate_chat_completion(
bypass_filter = True bypass_filter = True
payload = {**form_data.model_dump(exclude_none=True)} payload = {**form_data.model_dump(exclude_none=True)}
log.debug(f"generate_chat_completion() - 1.payload = {payload}")
if "metadata" in payload: if "metadata" in payload:
del payload["metadata"] del payload["metadata"]
@ -1045,13 +990,9 @@ async def generate_chat_completion(
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = await get_ollama_url(url_idx, payload["model"]) url = await get_ollama_url(request, payload["model"], url_idx)
log.debug(f"generate_chat_completion() - 2.payload = {payload}") api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
@ -1148,10 +1089,9 @@ async def generate_openai_completion(
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = await get_ollama_url(url_idx, payload["model"]) url = await get_ollama_url(request, payload["model"], url_idx)
log.info(f"url: {url}")
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
@ -1223,10 +1163,9 @@ async def generate_openai_chat_completion(
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = await get_ollama_url(url_idx, payload["model"]) url = await get_ollama_url(request, payload["model"], url_idx)
log.info(f"url: {url}")
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") payload["model"] = payload["model"].replace(f"{prefix_id}.", "")