mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'open-webui:main' into main
This commit is contained in:
@@ -7,6 +7,9 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from pydub import AudioSegment
|
||||
from pydub.silence import split_on_silence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
@@ -17,6 +20,7 @@ from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
@@ -33,10 +37,13 @@ from open_webui.config import (
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
WHISPER_MODEL_DIR,
|
||||
CACHE_DIR,
|
||||
WHISPER_LANGUAGE,
|
||||
)
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
DEVICE_TYPE,
|
||||
@@ -47,13 +54,15 @@ from open_webui.env import (
|
||||
router = APIRouter()
|
||||
|
||||
# Constants
|
||||
MAX_FILE_SIZE_MB = 25
|
||||
MAX_FILE_SIZE_MB = 20
|
||||
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||
AZURE_MAX_FILE_SIZE_MB = 200
|
||||
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -67,27 +76,47 @@ from pydub import AudioSegment
|
||||
from pydub.utils import mediainfo
|
||||
|
||||
|
||||
def is_mp4_audio(file_path):
|
||||
"""Check if the given file is an MP4 audio file."""
|
||||
def is_audio_conversion_required(file_path):
|
||||
"""
|
||||
Check if the given audio file needs conversion to mp3.
|
||||
"""
|
||||
SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
log.error(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
info = mediainfo(file_path)
|
||||
if (
|
||||
info.get("codec_name") == "aac"
|
||||
and info.get("codec_type") == "audio"
|
||||
and info.get("codec_tag_string") == "mp4a"
|
||||
):
|
||||
try:
|
||||
info = mediainfo(file_path)
|
||||
codec_name = info.get("codec_name", "").lower()
|
||||
codec_type = info.get("codec_type", "").lower()
|
||||
codec_tag_string = info.get("codec_tag_string", "").lower()
|
||||
|
||||
if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
|
||||
# File is AAC/mp4a audio, recommend mp3 conversion
|
||||
return True
|
||||
|
||||
# If the codec name is in the supported formats
|
||||
if codec_name in SUPPORTED_FORMATS:
|
||||
return False
|
||||
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
log.error(f"Error getting audio format: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def convert_mp4_to_wav(file_path, output_path):
|
||||
"""Convert MP4 audio file to WAV format."""
|
||||
audio = AudioSegment.from_file(file_path, format="mp4")
|
||||
audio.export(output_path, format="wav")
|
||||
print(f"Converted {file_path} to {output_path}")
|
||||
def convert_audio_to_mp3(file_path):
|
||||
"""Convert audio file to mp3 format."""
|
||||
try:
|
||||
output_path = os.path.splitext(file_path)[0] + ".mp3"
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
audio.export(output_path, format="mp3")
|
||||
log.info(f"Converted {file_path} to {output_path}")
|
||||
return output_path
|
||||
except Exception as e:
|
||||
log.error(f"Error converting audio file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
@@ -130,6 +159,7 @@ class TTSConfigForm(BaseModel):
|
||||
VOICE: str
|
||||
SPLIT_ON: str
|
||||
AZURE_SPEECH_REGION: str
|
||||
AZURE_SPEECH_BASE_URL: str
|
||||
AZURE_SPEECH_OUTPUT_FORMAT: str
|
||||
|
||||
|
||||
@@ -140,6 +170,11 @@ class STTConfigForm(BaseModel):
|
||||
MODEL: str
|
||||
WHISPER_MODEL: str
|
||||
DEEPGRAM_API_KEY: str
|
||||
AZURE_API_KEY: str
|
||||
AZURE_REGION: str
|
||||
AZURE_LOCALES: str
|
||||
AZURE_BASE_URL: str
|
||||
AZURE_MAX_SPEAKERS: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
@@ -159,6 +194,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||
"VOICE": request.app.state.config.TTS_VOICE,
|
||||
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
||||
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
||||
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
|
||||
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
},
|
||||
"stt": {
|
||||
@@ -168,6 +204,11 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
|
||||
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
|
||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -184,6 +225,9 @@ async def update_audio_config(
|
||||
request.app.state.config.TTS_VOICE = form_data.tts.VOICE
|
||||
request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
|
||||
request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
|
||||
request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
|
||||
form_data.tts.AZURE_SPEECH_BASE_URL
|
||||
)
|
||||
request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
|
||||
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
|
||||
)
|
||||
@@ -194,6 +238,13 @@ async def update_audio_config(
|
||||
request.app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
|
||||
request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
|
||||
request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
|
||||
request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
|
||||
request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
|
||||
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
|
||||
form_data.stt.AZURE_MAX_SPEAKERS
|
||||
)
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
@@ -210,6 +261,7 @@ async def update_audio_config(
|
||||
"VOICE": request.app.state.config.TTS_VOICE,
|
||||
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
||||
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
||||
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
|
||||
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
},
|
||||
"stt": {
|
||||
@@ -219,6 +271,11 @@ async def update_audio_config(
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
|
||||
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
|
||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -265,8 +322,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
payload["model"] = request.app.state.config.TTS_MODEL
|
||||
|
||||
try:
|
||||
# print(payload)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
json=payload,
|
||||
@@ -284,6 +343,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -309,7 +369,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
status_code=getattr(r, "status", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
@@ -323,7 +383,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
||||
json={
|
||||
@@ -336,6 +399,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": request.app.state.config.TTS_API_KEY,
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -360,7 +424,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
status_code=getattr(r, "status", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
@@ -371,7 +435,8 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
|
||||
region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
|
||||
base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
|
||||
language = request.app.state.config.TTS_VOICE
|
||||
locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
|
||||
output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
||||
@@ -380,15 +445,20 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||
<voice name="{language}">{payload["input"]}</voice>
|
||||
</speak>"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
|
||||
(base_url or f"https://{region}.tts.speech.microsoft.com")
|
||||
+ "/cognitiveservices/v1",
|
||||
headers={
|
||||
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": output_format,
|
||||
},
|
||||
data=data,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -413,7 +483,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
status_code=getattr(r, "status", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
@@ -457,12 +527,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path):
|
||||
print("transcribe", file_path)
|
||||
def transcription_handler(request, file_path, metadata):
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
|
||||
metadata = metadata or {}
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
if request.app.state.faster_whisper_model is None:
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
@@ -470,7 +541,12 @@ def transcribe(request: Request, file_path):
|
||||
)
|
||||
|
||||
model = request.app.state.faster_whisper_model
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
segments, info = model.transcribe(
|
||||
file_path,
|
||||
beam_size=5,
|
||||
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
|
||||
language=metadata.get("language") or WHISPER_LANGUAGE,
|
||||
)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
@@ -487,11 +563,6 @@ def transcribe(request: Request, file_path):
|
||||
log.debug(data)
|
||||
return data
|
||||
elif request.app.state.config.STT_ENGINE == "openai":
|
||||
if is_mp4_audio(file_path):
|
||||
os.rename(file_path, file_path.replace(".wav", ".mp4"))
|
||||
# Convert MP4 audio file to WAV format
|
||||
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
@@ -500,7 +571,14 @@ def transcribe(request: Request, file_path):
|
||||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||
},
|
||||
files={"file": (filename, open(file_path, "rb"))},
|
||||
data={"model": request.app.state.config.STT_MODEL},
|
||||
data={
|
||||
"model": request.app.state.config.STT_MODEL,
|
||||
**(
|
||||
{"language": metadata.get("language")}
|
||||
if metadata.get("language")
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
@@ -589,34 +667,254 @@ def transcribe(request: Request, file_path):
|
||||
detail = f"External: {e}"
|
||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||
|
||||
elif request.app.state.config.STT_ENGINE == "azure":
|
||||
# Check file exists and size
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=400, detail="Audio file not found")
|
||||
|
||||
# Check file size (Azure has a larger limit of 200MB)
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > AZURE_MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
|
||||
)
|
||||
|
||||
api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
|
||||
region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
|
||||
locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
|
||||
base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
|
||||
max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
|
||||
|
||||
# IF NO LOCALES, USE DEFAULTS
|
||||
if len(locales) < 2:
|
||||
locales = [
|
||||
"en-US",
|
||||
"es-ES",
|
||||
"es-MX",
|
||||
"fr-FR",
|
||||
"hi-IN",
|
||||
"it-IT",
|
||||
"de-DE",
|
||||
"en-GB",
|
||||
"en-IN",
|
||||
"ja-JP",
|
||||
"ko-KR",
|
||||
"pt-BR",
|
||||
"zh-CN",
|
||||
]
|
||||
locales = ",".join(locales)
|
||||
|
||||
if not api_key or not region:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Azure API key is required for Azure STT",
|
||||
)
|
||||
|
||||
r = None
|
||||
try:
|
||||
# Prepare the request
|
||||
data = {
|
||||
"definition": json.dumps(
|
||||
{
|
||||
"locales": locales.split(","),
|
||||
"diarization": {"maxSpeakers": max_speakers, "enabled": True},
|
||||
}
|
||||
if locales
|
||||
else {}
|
||||
)
|
||||
}
|
||||
|
||||
url = (
|
||||
base_url or f"https://{region}.api.cognitive.microsoft.com"
|
||||
) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
|
||||
|
||||
# Use context manager to ensure file is properly closed
|
||||
with open(file_path, "rb") as audio_file:
|
||||
r = requests.post(
|
||||
url=url,
|
||||
files={"audio": audio_file},
|
||||
data=data,
|
||||
headers={
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
response = r.json()
|
||||
|
||||
# Extract transcript from response
|
||||
if not response.get("combinedPhrases"):
|
||||
raise ValueError("No transcription found in response")
|
||||
|
||||
# Get the full transcript from combinedPhrases
|
||||
transcript = response["combinedPhrases"][0].get("text", "").strip()
|
||||
if not transcript:
|
||||
raise ValueError("Empty transcript in response")
|
||||
|
||||
data = {"text": transcript}
|
||||
|
||||
# Save transcript to json file (consistent with other providers)
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
log.debug(data)
|
||||
return data
|
||||
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
log.exception("Error parsing Azure response")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to parse Azure response: {str(e)}",
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
|
||||
try:
|
||||
if r is not None and r.status_code != 200:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status_code", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
|
||||
log.info(f"transcribe: {file_path} {metadata}")
|
||||
|
||||
if is_audio_conversion_required(file_path):
|
||||
file_path = convert_audio_to_mp3(file_path)
|
||||
|
||||
try:
|
||||
file_path = compress_audio(file_path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
# Always produce a list of chunk paths (could be one entry if small)
|
||||
try:
|
||||
chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
|
||||
print(f"Chunk paths: {chunk_paths}")
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
results = []
|
||||
try:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Submit tasks for each chunk_path
|
||||
futures = [
|
||||
executor.submit(transcription_handler, request, chunk_path, metadata)
|
||||
for chunk_path in chunk_paths
|
||||
]
|
||||
# Gather results as they complete
|
||||
for future in futures:
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as transcribe_exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error transcribing chunk: {transcribe_exc}",
|
||||
)
|
||||
finally:
|
||||
# Clean up only the temporary chunks, never the original file
|
||||
for chunk_path in chunk_paths:
|
||||
if chunk_path != file_path and os.path.isfile(chunk_path):
|
||||
try:
|
||||
os.remove(chunk_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"text": " ".join([result["text"] for result in results]),
|
||||
}
|
||||
|
||||
|
||||
def compress_audio(file_path):
|
||||
if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
||||
id = os.path.splitext(os.path.basename(file_path))[
|
||||
0
|
||||
] # Handles names with multiple dots
|
||||
file_dir = os.path.dirname(file_path)
|
||||
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
|
||||
compressed_path = f"{file_dir}/{id}_compressed.opus"
|
||||
audio.export(compressed_path, format="opus", bitrate="32k")
|
||||
log.debug(f"Compressed audio to {compressed_path}")
|
||||
|
||||
if (
|
||||
os.path.getsize(compressed_path) > MAX_FILE_SIZE
|
||||
): # Still larger than MAX_FILE_SIZE after compression
|
||||
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
|
||||
compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
|
||||
audio.export(compressed_path, format="mp3", bitrate="32k")
|
||||
# log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
|
||||
|
||||
return compressed_path
|
||||
else:
|
||||
return file_path
|
||||
|
||||
|
||||
def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
|
||||
"""
|
||||
Splits audio into chunks not exceeding max_bytes.
|
||||
Returns a list of chunk file paths. If audio fits, returns list with original path.
|
||||
"""
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size <= max_bytes:
|
||||
return [file_path] # Nothing to split
|
||||
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
duration_ms = len(audio)
|
||||
orig_size = file_size
|
||||
|
||||
approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
|
||||
chunks = []
|
||||
start = 0
|
||||
i = 0
|
||||
|
||||
base, _ = os.path.splitext(file_path)
|
||||
|
||||
while start < duration_ms:
|
||||
end = min(start + approx_chunk_ms, duration_ms)
|
||||
chunk = audio[start:end]
|
||||
chunk_path = f"{base}_chunk_{i}.{format}"
|
||||
chunk.export(chunk_path, format=format, bitrate=bitrate)
|
||||
|
||||
# Reduce chunk duration if still too large
|
||||
while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
|
||||
end = start + ((end - start) // 2)
|
||||
chunk = audio[start:end]
|
||||
chunk.export(chunk_path, format=format, bitrate=bitrate)
|
||||
|
||||
if os.path.getsize(chunk_path) > max_bytes:
|
||||
os.remove(chunk_path)
|
||||
raise Exception("Audio chunk cannot be reduced below max file size.")
|
||||
|
||||
chunks.append(chunk_path)
|
||||
start = end
|
||||
i += 1
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@router.post("/transcriptions")
|
||||
def transcription(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
language: Optional[str] = Form(None),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
|
||||
SUPPORTED_CONTENT_TYPES = {"video/webm"} # Extend if you add more video types!
|
||||
if not (
|
||||
file.content_type.startswith("audio/")
|
||||
or file.content_type in SUPPORTED_CONTENT_TYPES
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
||||
@@ -637,19 +935,18 @@ def transcription(
|
||||
f.write(contents)
|
||||
|
||||
try:
|
||||
try:
|
||||
file_path = compress_audio(file_path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
metadata = None
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
if language:
|
||||
metadata = {"language": language}
|
||||
|
||||
result = transcribe(request, file_path, metadata)
|
||||
|
||||
return {
|
||||
**result,
|
||||
"filename": os.path.basename(file_path),
|
||||
}
|
||||
|
||||
data = transcribe(request, file_path)
|
||||
file_path = file_path.split("/")[-1]
|
||||
return {**data, "filename": file_path}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
@@ -670,7 +967,22 @@ def transcription(
|
||||
def get_available_models(request: Request) -> list[dict]:
|
||||
available_models = []
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
available_models = data.get("models", [])
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching models from custom endpoint: {str(e)}")
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
else:
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
response = requests.get(
|
||||
@@ -701,14 +1013,37 @@ def get_available_voices(request) -> dict:
|
||||
"""Returns {voice_id: voice_name} dict"""
|
||||
available_voices = {}
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
voices_list = data.get("voices", [])
|
||||
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
else:
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
available_voices = get_elevenlabs_voices(
|
||||
@@ -720,7 +1055,10 @@ def get_available_voices(request) -> dict:
|
||||
elif request.app.state.config.TTS_ENGINE == "azure":
|
||||
try:
|
||||
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
|
||||
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
||||
base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
|
||||
url = (
|
||||
base_url or f"https://{region}.tts.speech.microsoft.com"
|
||||
) + "/cognitiveservices/voices/list"
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
|
||||
}
|
||||
|
||||
@@ -19,40 +19,45 @@ from open_webui.models.auths import (
|
||||
UserResponse,
|
||||
)
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
WEBUI_AUTH,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse, Response
|
||||
from open_webui.config import (
|
||||
OPENID_PROVIDER_URL,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
)
|
||||
from fastapi.responses import RedirectResponse, Response, JSONResponse
|
||||
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||
from open_webui.utils.auth import (
|
||||
decode_token,
|
||||
create_api_key,
|
||||
create_token,
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
get_current_user,
|
||||
get_password_hash,
|
||||
get_http_authorization_cred,
|
||||
)
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
from open_webui.utils.access_control import get_permissions
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS
|
||||
|
||||
if ENABLE_LDAP.value:
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -73,31 +78,36 @@ class SessionUserResponse(Token, UserResponse):
|
||||
async def get_session_user(
|
||||
request: Request, response: Response, user=Depends(get_current_user)
|
||||
):
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
auth_token = get_http_authorization_cred(auth_header)
|
||||
token = auth_token.credentials
|
||||
data = decode_token(token)
|
||||
|
||||
expires_at = None
|
||||
if expires_delta:
|
||||
expires_at = int(time.time()) + int(expires_delta.total_seconds())
|
||||
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=expires_delta,
|
||||
)
|
||||
if data:
|
||||
expires_at = data.get("exp")
|
||||
|
||||
datetime_expires_at = (
|
||||
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
|
||||
if expires_at
|
||||
else None
|
||||
)
|
||||
if (expires_at is not None) and int(time.time()) > expires_at:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=(
|
||||
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
|
||||
if expires_at
|
||||
else None
|
||||
),
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
@@ -178,6 +188,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
|
||||
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
|
||||
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
|
||||
LDAP_VALIDATE_CERT = (
|
||||
CERT_REQUIRED if request.app.state.config.LDAP_VALIDATE_CERT else CERT_NONE
|
||||
)
|
||||
LDAP_CIPHERS = (
|
||||
request.app.state.config.LDAP_CIPHERS
|
||||
if request.app.state.config.LDAP_CIPHERS
|
||||
@@ -189,14 +202,14 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
|
||||
try:
|
||||
tls = Tls(
|
||||
validate=CERT_REQUIRED,
|
||||
validate=LDAP_VALIDATE_CERT,
|
||||
version=PROTOCOL_TLS,
|
||||
ca_certs_file=LDAP_CA_CERT_FILE,
|
||||
ciphers=LDAP_CIPHERS,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"An error occurred on TLS: {str(e)}")
|
||||
raise HTTPException(400, detail=str(e))
|
||||
log.error(f"TLS configuration error: {str(e)}")
|
||||
raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.")
|
||||
|
||||
try:
|
||||
server = Server(
|
||||
@@ -211,7 +224,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
LDAP_APP_DN,
|
||||
LDAP_APP_PASSWORD,
|
||||
auto_bind="NONE",
|
||||
authentication="SIMPLE",
|
||||
authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS",
|
||||
)
|
||||
if not connection_app.bind():
|
||||
raise HTTPException(400, detail="Application account bind failed")
|
||||
@@ -226,14 +239,23 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
],
|
||||
)
|
||||
|
||||
if not search_success:
|
||||
if not search_success or not connection_app.entries:
|
||||
raise HTTPException(400, detail="User not found in the LDAP server")
|
||||
|
||||
entry = connection_app.entries[0]
|
||||
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
||||
mail = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
|
||||
if not mail or mail == "" or mail == "[]":
|
||||
raise HTTPException(400, f"User {form_data.user} does not have mail.")
|
||||
email = entry[
|
||||
f"{LDAP_ATTRIBUTE_FOR_MAIL}"
|
||||
].value # retrieve the Attribute value
|
||||
if not email:
|
||||
raise HTTPException(400, "User does not have a valid email address.")
|
||||
elif isinstance(email, str):
|
||||
email = email.lower()
|
||||
elif isinstance(email, list):
|
||||
email = email[0].lower()
|
||||
else:
|
||||
email = str(email).lower()
|
||||
|
||||
cn = str(entry["cn"])
|
||||
user_dn = entry.entry_dn
|
||||
|
||||
@@ -246,19 +268,24 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
authentication="SIMPLE",
|
||||
)
|
||||
if not connection_user.bind():
|
||||
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
||||
raise HTTPException(400, "Authentication failed.")
|
||||
|
||||
user = Users.get_user_by_email(mail)
|
||||
user = Users.get_user_by_email(email)
|
||||
if not user:
|
||||
try:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
role = (
|
||||
"admin"
|
||||
if Users.get_num_users() == 0
|
||||
if user_count == 0
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
user = Auths.insert_new_auth(
|
||||
email=mail, password=str(uuid.uuid4()), name=cn, role=role
|
||||
email=email,
|
||||
password=str(uuid.uuid4()),
|
||||
name=cn,
|
||||
role=role,
|
||||
)
|
||||
|
||||
if not user:
|
||||
@@ -269,23 +296,38 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
log.error(f"LDAP user creation error: {str(err)}")
|
||||
raise HTTPException(
|
||||
500, detail="Internal error occurred during LDAP user creation."
|
||||
)
|
||||
|
||||
user = Auths.authenticate_user_by_trusted_header(mail)
|
||||
user = Auths.authenticate_user_by_email(email)
|
||||
|
||||
if user:
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
expires_at = None
|
||||
if expires_delta:
|
||||
expires_at = int(time.time()) + int(expires_delta.total_seconds())
|
||||
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(
|
||||
request.app.state.config.JWT_EXPIRES_IN
|
||||
),
|
||||
expires_delta=expires_delta,
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=(
|
||||
datetime.datetime.fromtimestamp(
|
||||
expires_at, datetime.timezone.utc
|
||||
)
|
||||
if expires_at
|
||||
else None
|
||||
),
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
@@ -295,6 +337,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
"expires_at": expires_at,
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
@@ -305,12 +348,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
else:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
|
||||
)
|
||||
raise HTTPException(400, "User record mismatch.")
|
||||
except Exception as e:
|
||||
raise HTTPException(400, detail=str(e))
|
||||
log.error(f"LDAP authentication error: {str(e)}")
|
||||
raise HTTPException(400, detail="LDAP authentication failed.")
|
||||
|
||||
|
||||
############################
|
||||
@@ -324,21 +365,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
|
||||
|
||||
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
|
||||
trusted_name = trusted_email
|
||||
email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
|
||||
name = email
|
||||
|
||||
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
||||
trusted_name = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
|
||||
)
|
||||
if not Users.get_user_by_email(trusted_email.lower()):
|
||||
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
|
||||
|
||||
if not Users.get_user_by_email(email.lower()):
|
||||
await signup(
|
||||
request,
|
||||
response,
|
||||
SignupForm(
|
||||
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
|
||||
),
|
||||
SignupForm(email=email, password=str(uuid.uuid4()), name=name),
|
||||
)
|
||||
user = Auths.authenticate_user_by_trusted_header(trusted_email)
|
||||
|
||||
user = Auths.authenticate_user_by_email(email)
|
||||
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
|
||||
group_names = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
|
||||
).split(",")
|
||||
group_names = [name.strip() for name in group_names if name.strip()]
|
||||
|
||||
if group_names:
|
||||
Groups.sync_user_groups_by_group_names(user.id, group_names)
|
||||
|
||||
elif WEBUI_AUTH == False:
|
||||
admin_email = "admin@localhost"
|
||||
admin_password = "admin"
|
||||
@@ -413,6 +462,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
|
||||
@router.post("/signup", response_model=SessionUserResponse)
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
|
||||
if WEBUI_AUTH:
|
||||
if (
|
||||
not request.app.state.config.ENABLE_SIGNUP
|
||||
@@ -427,6 +477,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
user_count = Users.get_num_users()
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
@@ -437,14 +488,15 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
|
||||
try:
|
||||
role = (
|
||||
"admin"
|
||||
if Users.get_num_users() == 0
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
if Users.get_num_users() == 0:
|
||||
# Disable signup after the first user is created
|
||||
request.app.state.config.ENABLE_SIGNUP = False
|
||||
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||
if len(form_data.password.encode("utf-8")) > 72:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.PASSWORD_TOO_LONG,
|
||||
)
|
||||
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
@@ -484,6 +536,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
|
||||
if request.app.state.config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
@@ -497,6 +550,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
if user_count == 0:
|
||||
# Disable signup after the first user is created
|
||||
request.app.state.config.ENABLE_SIGNUP = False
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
@@ -511,7 +568,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
else:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
log.error(f"Signup error: {str(err)}")
|
||||
raise HTTPException(500, detail="An internal error occurred during signup.")
|
||||
|
||||
|
||||
@router.get("/signout")
|
||||
@@ -529,8 +587,14 @@ async def signout(request: Request, response: Response):
|
||||
logout_url = openid_data.get("end_session_endpoint")
|
||||
if logout_url:
|
||||
response.delete_cookie("oauth_id_token")
|
||||
return RedirectResponse(
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}"
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": True,
|
||||
"redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}",
|
||||
},
|
||||
headers=response.headers,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -538,9 +602,25 @@ async def signout(request: Request, response: Response):
|
||||
detail="Failed to fetch OpenID configuration",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
log.error(f"OpenID signout error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to sign out from the OpenID provider.",
|
||||
)
|
||||
|
||||
return {"status": True}
|
||||
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": True,
|
||||
"redirect_url": WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
},
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200, content={"status": True}, headers=response.headers
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
@@ -582,7 +662,10 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||
else:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
log.error(f"Add user error: {str(err)}")
|
||||
raise HTTPException(
|
||||
500, detail="An internal error occurred while adding the user."
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
@@ -596,7 +679,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
|
||||
print(admin_email, admin_name)
|
||||
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
||||
|
||||
if admin_email:
|
||||
admin = Users.get_user_by_email(admin_email)
|
||||
@@ -630,11 +713,16 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
|
||||
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
|
||||
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
|
||||
}
|
||||
|
||||
|
||||
@@ -645,11 +733,16 @@ class AdminConfig(BaseModel):
|
||||
ENABLE_API_KEY: bool
|
||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool
|
||||
API_KEY_ALLOWED_ENDPOINTS: str
|
||||
ENABLE_CHANNELS: bool
|
||||
DEFAULT_USER_ROLE: str
|
||||
JWT_EXPIRES_IN: str
|
||||
ENABLE_COMMUNITY_SHARING: bool
|
||||
ENABLE_MESSAGE_RATING: bool
|
||||
ENABLE_CHANNELS: bool
|
||||
ENABLE_NOTES: bool
|
||||
ENABLE_USER_WEBHOOKS: bool
|
||||
PENDING_USER_OVERLAY_TITLE: Optional[str] = None
|
||||
PENDING_USER_OVERLAY_CONTENT: Optional[str] = None
|
||||
RESPONSE_WATERMARK: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/admin/config")
|
||||
@@ -669,6 +762,7 @@ async def update_admin_config(
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
|
||||
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
|
||||
|
||||
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
||||
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
||||
@@ -684,6 +778,17 @@ async def update_admin_config(
|
||||
)
|
||||
request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING
|
||||
|
||||
request.app.state.config.ENABLE_USER_WEBHOOKS = form_data.ENABLE_USER_WEBHOOKS
|
||||
|
||||
request.app.state.config.PENDING_USER_OVERLAY_TITLE = (
|
||||
form_data.PENDING_USER_OVERLAY_TITLE
|
||||
)
|
||||
request.app.state.config.PENDING_USER_OVERLAY_CONTENT = (
|
||||
form_data.PENDING_USER_OVERLAY_CONTENT
|
||||
)
|
||||
|
||||
request.app.state.config.RESPONSE_WATERMARK = form_data.RESPONSE_WATERMARK
|
||||
|
||||
return {
|
||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||
@@ -691,11 +796,16 @@ async def update_admin_config(
|
||||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
|
||||
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
|
||||
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
|
||||
}
|
||||
|
||||
|
||||
@@ -711,6 +821,7 @@ class LdapServerConfig(BaseModel):
|
||||
search_filters: str = ""
|
||||
use_tls: bool = True
|
||||
certificate_path: Optional[str] = None
|
||||
validate_cert: bool = True
|
||||
ciphers: Optional[str] = "ALL"
|
||||
|
||||
|
||||
@@ -728,6 +839,7 @@ async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
|
||||
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
||||
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
||||
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
||||
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
|
||||
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
||||
}
|
||||
|
||||
@@ -750,11 +862,6 @@ async def update_ldap_server(
|
||||
if not value:
|
||||
raise HTTPException(400, detail=f"Required field {key} is empty")
|
||||
|
||||
if form_data.use_tls and not form_data.certificate_path:
|
||||
raise HTTPException(
|
||||
400, detail="TLS is enabled but certificate file path is missing"
|
||||
)
|
||||
|
||||
request.app.state.config.LDAP_SERVER_LABEL = form_data.label
|
||||
request.app.state.config.LDAP_SERVER_HOST = form_data.host
|
||||
request.app.state.config.LDAP_SERVER_PORT = form_data.port
|
||||
@@ -768,6 +875,7 @@ async def update_ldap_server(
|
||||
request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
|
||||
request.app.state.config.LDAP_USE_TLS = form_data.use_tls
|
||||
request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
|
||||
request.app.state.config.LDAP_VALIDATE_CERT = form_data.validate_cert
|
||||
request.app.state.config.LDAP_CIPHERS = form_data.ciphers
|
||||
|
||||
return {
|
||||
@@ -782,6 +890,7 @@ async def update_ldap_server(
|
||||
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
||||
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
||||
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
||||
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
|
||||
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
||||
}
|
||||
|
||||
|
||||
@@ -192,7 +192,7 @@ async def get_channel_messages(
|
||||
############################
|
||||
|
||||
|
||||
async def send_notification(webui_url, channel, message, active_user_ids):
|
||||
async def send_notification(name, webui_url, channel, message, active_user_ids):
|
||||
users = get_users_with_access("read", channel.access_control)
|
||||
|
||||
for user in users:
|
||||
@@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
|
||||
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
name,
|
||||
webhook_url,
|
||||
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
|
||||
{
|
||||
@@ -302,6 +303,7 @@ async def post_new_message(
|
||||
|
||||
background_tasks.add_task(
|
||||
send_notification,
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBUI_URL,
|
||||
channel,
|
||||
message,
|
||||
|
||||
@@ -2,6 +2,8 @@ import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from open_webui.socket.main import get_event_emitter
|
||||
from open_webui.models.chats import (
|
||||
ChatForm,
|
||||
ChatImportForm,
|
||||
@@ -74,17 +76,34 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
|
||||
@router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_user_id(
|
||||
user_id: str,
|
||||
page: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
user=Depends(get_admin_user),
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
):
|
||||
if not ENABLE_ADMIN_CHAT_ACCESS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if page is None:
|
||||
page = 1
|
||||
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
return Chats.get_chat_list_by_user_id(
|
||||
user_id, include_archived=True, skip=skip, limit=limit
|
||||
user_id, include_archived=True, filter=filter, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
|
||||
@@ -192,10 +211,10 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/pinned", response_model=list[ChatResponse])
|
||||
@router.get("/pinned", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_pinned_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
ChatResponse(**chat.model_dump())
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_pinned_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
@@ -265,9 +284,37 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
||||
|
||||
@router.get("/archived", response_model=list[ChatTitleIdResponse])
|
||||
async def get_archived_session_user_chat_list(
|
||||
user=Depends(get_verified_user), skip: int = 0, limit: int = 50
|
||||
page: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
|
||||
if page is None:
|
||||
page = 1
|
||||
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
chat_list = [
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_archived_chat_list_by_user_id(
|
||||
user.id,
|
||||
filter=filter,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
]
|
||||
|
||||
return chat_list
|
||||
|
||||
|
||||
############################
|
||||
@@ -372,6 +419,107 @@ async def update_chat_by_id(
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateChatMessageById
|
||||
############################
|
||||
class MessageForm(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
|
||||
async def update_chat_message_by_id(
|
||||
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if chat.user_id != user.id and user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
chat = Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
id,
|
||||
message_id,
|
||||
{
|
||||
"content": form_data.content,
|
||||
},
|
||||
)
|
||||
|
||||
event_emitter = get_event_emitter(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"chat_id": id,
|
||||
"message_id": message_id,
|
||||
},
|
||||
False,
|
||||
)
|
||||
|
||||
if event_emitter:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message",
|
||||
"data": {
|
||||
"chat_id": id,
|
||||
"message_id": message_id,
|
||||
"content": form_data.content,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
|
||||
|
||||
############################
|
||||
# SendChatMessageEventById
|
||||
############################
|
||||
class EventForm(BaseModel):
|
||||
type: str
|
||||
data: dict
|
||||
|
||||
|
||||
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
|
||||
async def send_chat_message_event_by_id(
|
||||
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if chat.user_id != user.id and user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
event_emitter = get_event_emitter(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"chat_id": id,
|
||||
"message_id": message_id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
if event_emitter:
|
||||
await event_emitter(form_data.model_dump())
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
############################
|
||||
# DeleteChatById
|
||||
############################
|
||||
@@ -476,7 +624,12 @@ async def clone_chat_by_id(
|
||||
|
||||
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
|
||||
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_share_id(id)
|
||||
|
||||
if user.role == "admin":
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
else:
|
||||
chat = Chats.get_chat_by_share_id(id)
|
||||
|
||||
if chat:
|
||||
updated_chat = {
|
||||
**chat.chat,
|
||||
@@ -530,8 +683,17 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.post("/{id}/share", response_model=Optional[ChatResponse])
|
||||
async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
if not has_permission(
|
||||
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
|
||||
if chat:
|
||||
if chat.share_id:
|
||||
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@@ -7,6 +7,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.config import get_config, save_config
|
||||
from open_webui.config import BannerModel
|
||||
|
||||
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -66,10 +68,86 @@ async def set_direct_connections_config(
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# ToolServers Config
|
||||
############################
|
||||
|
||||
|
||||
class ToolServerConnection(BaseModel):
|
||||
url: str
|
||||
path: str
|
||||
auth_type: Optional[str]
|
||||
key: Optional[str]
|
||||
config: Optional[dict]
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ToolServersConfigForm(BaseModel):
|
||||
TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
|
||||
|
||||
|
||||
@router.get("/tool_servers", response_model=ToolServersConfigForm)
|
||||
async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tool_servers", response_model=ToolServersConfigForm)
|
||||
async def set_tool_servers_config(
|
||||
request: Request,
|
||||
form_data: ToolServersConfigForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||
]
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
return {
|
||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tool_servers/verify")
|
||||
async def verify_tool_servers_config(
|
||||
request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
|
||||
):
|
||||
"""
|
||||
Verify the connection to the tool server.
|
||||
"""
|
||||
try:
|
||||
|
||||
token = None
|
||||
if form_data.auth_type == "bearer":
|
||||
token = form_data.key
|
||||
elif form_data.auth_type == "session":
|
||||
token = request.state.token.credentials
|
||||
|
||||
url = f"{form_data.url}/{form_data.path}"
|
||||
return await get_tool_server_data(token, url)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to connect to the tool server: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# CodeInterpreterConfig
|
||||
############################
|
||||
class CodeInterpreterConfigForm(BaseModel):
|
||||
ENABLE_CODE_EXECUTION: bool
|
||||
CODE_EXECUTION_ENGINE: str
|
||||
CODE_EXECUTION_JUPYTER_URL: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
|
||||
ENABLE_CODE_INTERPRETER: bool
|
||||
CODE_INTERPRETER_ENGINE: str
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
|
||||
@@ -77,11 +155,19 @@ class CodeInterpreterConfigForm(BaseModel):
|
||||
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
|
||||
|
||||
|
||||
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
|
||||
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
@@ -89,13 +175,34 @@ async def get_code_interpreter_config(request: Request, user=Depends(get_admin_u
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_interpreter_config(
|
||||
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_execution_config(
|
||||
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
request.app.state.config.ENABLE_CODE_EXECUTION = form_data.ENABLE_CODE_EXECUTION
|
||||
|
||||
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_URL
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
|
||||
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
|
||||
@@ -116,8 +223,18 @@ async def set_code_interpreter_config(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
return {
|
||||
"ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
@@ -125,6 +242,7 @@ async def set_code_interpreter_config(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -56,19 +56,35 @@ async def update_config(
|
||||
}
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str = "pending"
|
||||
|
||||
last_active_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class FeedbackUserResponse(FeedbackResponse):
|
||||
user: Optional[UserModel] = None
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
|
||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||
feedbacks = Feedbacks.get_all_feedbacks()
|
||||
return [
|
||||
FeedbackUserResponse(
|
||||
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
|
||||
|
||||
feedback_list = []
|
||||
for feedback in feedbacks:
|
||||
user = Users.get_user_by_id(feedback.user_id)
|
||||
feedback_list.append(
|
||||
FeedbackUserResponse(
|
||||
**feedback.model_dump(),
|
||||
user=UserResponse(**user.model_dump()) if user else None,
|
||||
)
|
||||
)
|
||||
for feedback in feedbacks
|
||||
]
|
||||
return feedback_list
|
||||
|
||||
|
||||
@router.delete("/feedbacks/all")
|
||||
@@ -80,12 +96,7 @@ async def delete_all_feedbacks(user=Depends(get_admin_user)):
|
||||
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
|
||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||
feedbacks = Feedbacks.get_all_feedbacks()
|
||||
return [
|
||||
FeedbackModel(
|
||||
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
|
||||
)
|
||||
for feedback in feedbacks
|
||||
]
|
||||
return feedbacks
|
||||
|
||||
|
||||
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
|
||||
|
||||
@@ -1,21 +1,39 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
status,
|
||||
Query,
|
||||
)
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.files import (
|
||||
FileForm,
|
||||
FileModel,
|
||||
FileModelResponse,
|
||||
Files,
|
||||
)
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
|
||||
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
|
||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||
from open_webui.routers.audio import transcribe
|
||||
from open_webui.storage.provider import Storage
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from pydantic import BaseModel
|
||||
@@ -26,6 +44,39 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
############################
|
||||
# Check if the current user has access to a file through any knowledge bases the user may be in.
|
||||
############################
|
||||
|
||||
|
||||
def has_access_to_file(
|
||||
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
|
||||
) -> bool:
|
||||
file = Files.get_file_by_id(file_id)
|
||||
log.debug(f"Checking if user has {access_type} access to file")
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
has_access = False
|
||||
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
||||
|
||||
if knowledge_base_id:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
||||
user.id, access_type
|
||||
)
|
||||
for knowledge_base in knowledge_bases:
|
||||
if knowledge_base.id == knowledge_base_id:
|
||||
has_access = True
|
||||
break
|
||||
|
||||
return has_access
|
||||
|
||||
|
||||
############################
|
||||
# Upload File
|
||||
############################
|
||||
@@ -35,19 +86,55 @@ router = APIRouter()
|
||||
def upload_file(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
metadata: Optional[dict | str] = Form(None),
|
||||
process: bool = Query(True),
|
||||
internal: bool = False,
|
||||
user=Depends(get_verified_user),
|
||||
file_metadata: dict = {},
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
if isinstance(metadata, str):
|
||||
try:
|
||||
metadata = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
|
||||
)
|
||||
file_metadata = metadata if metadata else {}
|
||||
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
filename = os.path.basename(unsanitized_filename)
|
||||
|
||||
file_extension = os.path.splitext(filename)[1]
|
||||
# Remove the leading dot from the file extension
|
||||
file_extension = file_extension[1:] if file_extension else ""
|
||||
|
||||
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
|
||||
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
|
||||
]
|
||||
|
||||
if file_extension not in request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(
|
||||
f"File type {file_extension} is not allowed"
|
||||
),
|
||||
)
|
||||
|
||||
# replace filename with uuid
|
||||
id = str(uuid.uuid4())
|
||||
name = filename
|
||||
filename = f"{id}_{filename}"
|
||||
contents, file_path = Storage.upload_file(file.file, filename)
|
||||
tags = {
|
||||
"OpenWebUI-User-Email": user.email,
|
||||
"OpenWebUI-User-Id": user.id,
|
||||
"OpenWebUI-User-Name": user.name,
|
||||
"OpenWebUI-File-Id": id,
|
||||
}
|
||||
contents, file_path = Storage.upload_file(file.file, filename, tags)
|
||||
|
||||
file_item = Files.insert_new_file(
|
||||
user.id,
|
||||
@@ -65,19 +152,40 @@ def upload_file(
|
||||
}
|
||||
),
|
||||
)
|
||||
if process:
|
||||
try:
|
||||
if file.content_type:
|
||||
if file.content_type.startswith("audio/") or file.content_type in {
|
||||
"video/webm"
|
||||
}:
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path, file_metadata)
|
||||
|
||||
try:
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error processing file: {file_item.id}")
|
||||
file_item = FileModelResponse(
|
||||
**{
|
||||
**file_item.model_dump(),
|
||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||
}
|
||||
)
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||
):
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
else:
|
||||
log.info(
|
||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||
)
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error processing file: {file_item.id}")
|
||||
file_item = FileModelResponse(
|
||||
**{
|
||||
**file_item.model_dump(),
|
||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||
}
|
||||
)
|
||||
|
||||
if file_item:
|
||||
return file_item
|
||||
@@ -91,7 +199,7 @@ def upload_file(
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
|
||||
)
|
||||
|
||||
|
||||
@@ -101,14 +209,62 @@ def upload_file(
|
||||
|
||||
|
||||
@router.get("/", response_model=list[FileModelResponse])
|
||||
async def list_files(user=Depends(get_verified_user)):
|
||||
async def list_files(user=Depends(get_verified_user), content: bool = Query(True)):
|
||||
if user.role == "admin":
|
||||
files = Files.get_files()
|
||||
else:
|
||||
files = Files.get_files_by_user_id(user.id)
|
||||
|
||||
if not content:
|
||||
for file in files:
|
||||
if "content" in file.data:
|
||||
del file.data["content"]
|
||||
|
||||
return files
|
||||
|
||||
|
||||
############################
|
||||
# Search Files
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/search", response_model=list[FileModelResponse])
|
||||
async def search_files(
|
||||
filename: str = Query(
|
||||
...,
|
||||
description="Filename pattern to search for. Supports wildcards such as '*.txt'",
|
||||
),
|
||||
content: bool = Query(True),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
"""
|
||||
Search for files by filename with support for wildcard patterns.
|
||||
"""
|
||||
# Get files according to user role
|
||||
if user.role == "admin":
|
||||
files = Files.get_files()
|
||||
else:
|
||||
files = Files.get_files_by_user_id(user.id)
|
||||
|
||||
# Get matching files
|
||||
matching_files = [
|
||||
file for file in files if fnmatch(file.filename.lower(), filename.lower())
|
||||
]
|
||||
|
||||
if not matching_files:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No files found matching the pattern.",
|
||||
)
|
||||
|
||||
if not content:
|
||||
for file in matching_files:
|
||||
if "content" in file.data:
|
||||
del file.data["content"]
|
||||
|
||||
return matching_files
|
||||
|
||||
|
||||
############################
|
||||
# Delete All Files
|
||||
############################
|
||||
@@ -144,7 +300,17 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
||||
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
return file
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -162,7 +328,17 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
return {"content": file.data.get("content", "")}
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -186,7 +362,17 @@ async def update_file_data_content_by_id(
|
||||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "write", user)
|
||||
):
|
||||
try:
|
||||
process_file(
|
||||
request,
|
||||
@@ -212,9 +398,22 @@ async def update_file_data_content_by_id(
|
||||
|
||||
|
||||
@router.get("/{id}/content")
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def get_file_content_by_id(
|
||||
id: str, user=Depends(get_verified_user), attachment: bool = Query(False)
|
||||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
try:
|
||||
file_path = Storage.get_file(file.path)
|
||||
file_path = Path(file_path)
|
||||
@@ -225,17 +424,29 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename) # RFC5987 encoding
|
||||
|
||||
content_type = file.meta.get("content_type")
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename)
|
||||
headers = {}
|
||||
if file.meta.get("content_type") not in [
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
]:
|
||||
headers = {
|
||||
**headers,
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
}
|
||||
|
||||
return FileResponse(file_path, headers=headers)
|
||||
if attachment:
|
||||
headers["Content-Disposition"] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
else:
|
||||
if content_type == "application/pdf" or filename.lower().endswith(
|
||||
".pdf"
|
||||
):
|
||||
headers["Content-Disposition"] = (
|
||||
f"inline; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
content_type = "application/pdf"
|
||||
elif content_type != "text/plain":
|
||||
headers["Content-Disposition"] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
|
||||
return FileResponse(file_path, headers=headers, media_type=content_type)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -259,14 +470,32 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
@router.get("/{id}/content/html")
|
||||
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
file_user = Users.get_user_by_id(file.user_id)
|
||||
if not file_user.role == "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
try:
|
||||
file_path = Storage.get_file(file.path)
|
||||
file_path = Path(file_path)
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
log.info(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -291,7 +520,17 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
):
|
||||
file_path = file.path
|
||||
|
||||
# Handle Unicode filenames
|
||||
@@ -342,7 +581,18 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
@router.delete("/{id}")
|
||||
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "write", user)
|
||||
):
|
||||
# We should add Chroma cleanup here
|
||||
|
||||
result = Files.delete_file_by_id(id)
|
||||
|
||||
@@ -20,11 +20,13 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_permission
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -228,7 +230,19 @@ async def update_folder_is_expanded_by_id(
|
||||
|
||||
|
||||
@router.delete("/{id}")
|
||||
async def delete_folder_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def delete_folder_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
chat_delete_permission = has_permission(
|
||||
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
if user.role != "admin" and not chat_delete_permission:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import logging
|
||||
import aiohttp
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -8,11 +12,22 @@ from open_webui.models.functions import (
|
||||
FunctionResponse,
|
||||
Functions,
|
||||
)
|
||||
from open_webui.utils.plugin import load_function_module_by_id, replace_imports
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
replace_imports,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -36,6 +51,97 @@ async def get_functions(user=Depends(get_admin_user)):
|
||||
return Functions.get_functions()
|
||||
|
||||
|
||||
############################
|
||||
# LoadFunctionFromLink
|
||||
############################
|
||||
|
||||
|
||||
class LoadUrlForm(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
def github_url_to_raw_url(url: str) -> str:
|
||||
# Handle 'tree' (folder) URLs (add main.py at the end)
|
||||
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
|
||||
if m1:
|
||||
org, repo, branch, path = m1.groups()
|
||||
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
|
||||
|
||||
# Handle 'blob' (file) URLs
|
||||
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
|
||||
if m2:
|
||||
org, repo, branch, path = m2.groups()
|
||||
return (
|
||||
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
|
||||
)
|
||||
|
||||
# No match; return as-is
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/load/url", response_model=Optional[dict])
|
||||
async def load_function_from_url(
|
||||
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
|
||||
):
|
||||
# NOTE: This is NOT a SSRF vulnerability:
|
||||
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
|
||||
# and does NOT accept untrusted user input. Access is enforced by authentication.
|
||||
|
||||
url = str(form_data.url)
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="Please enter a valid URL")
|
||||
|
||||
url = github_url_to_raw_url(url)
|
||||
url_parts = url.rstrip("/").split("/")
|
||||
|
||||
file_name = url_parts[-1]
|
||||
function_name = (
|
||||
file_name[:-3]
|
||||
if (
|
||||
file_name.endswith(".py")
|
||||
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
|
||||
)
|
||||
else url_parts[-2] if len(url_parts) > 1 else "function"
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, headers={"Content-Type": "application/json"}
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=resp.status, detail="Failed to fetch the function"
|
||||
)
|
||||
data = await resp.text()
|
||||
if not data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No data received from the URL"
|
||||
)
|
||||
return {
|
||||
"name": function_name,
|
||||
"content": data,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error importing function: {e}")
|
||||
|
||||
|
||||
############################
|
||||
# SyncFunctions
|
||||
############################
|
||||
|
||||
|
||||
class SyncFunctionsForm(FunctionForm):
|
||||
functions: list[FunctionModel] = []
|
||||
|
||||
|
||||
@router.post("/sync", response_model=Optional[FunctionModel])
|
||||
async def sync_functions(
|
||||
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
return Functions.sync_functions(user.id, form_data.functions)
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewFunction
|
||||
############################
|
||||
@@ -68,7 +174,7 @@ async def create_new_function(
|
||||
|
||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
||||
|
||||
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
|
||||
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if function:
|
||||
@@ -79,7 +185,7 @@ async def create_new_function(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to create a new function: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -183,7 +289,7 @@ async def update_function_by_id(
|
||||
FUNCTIONS[id] = function_module
|
||||
|
||||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
|
||||
function = Functions.update_function_by_id(id, updated)
|
||||
|
||||
@@ -256,11 +362,9 @@ async def get_function_valves_spec_by_id(
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
@@ -284,11 +388,9 @@ async def update_function_valves_by_id(
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
@@ -299,7 +401,7 @@ async def update_function_valves_by_id(
|
||||
Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function values by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -347,11 +449,9 @@ async def get_function_user_valves_spec_by_id(
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
@@ -371,11 +471,9 @@ async def update_function_user_valves_by_id(
|
||||
function = Functions.get_function_by_id(id)
|
||||
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
@@ -388,7 +486,7 @@ async def update_function_user_valves_by_id(
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
||||
16
backend/open_webui/routers/groups.py
Normal file → Executable file
16
backend/open_webui/routers/groups.py
Normal file → Executable file
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import (
|
||||
@@ -14,7 +14,13 @@ from open_webui.models.groups import (
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[GroupResponse])
|
||||
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
group = Groups.insert_new_group(user.id, form_data)
|
||||
if group:
|
||||
@@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error creating a new group: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -94,7 +100,7 @@ async def update_group_by_id(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error deleting group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
||||
@@ -25,7 +25,7 @@ from pydantic import BaseModel
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||
|
||||
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
||||
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
|
||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
|
||||
COMFYUI_WORKFLOW_NODES: list[dict]
|
||||
|
||||
|
||||
class GeminiConfigForm(BaseModel):
|
||||
GEMINI_API_BASE_URL: str
|
||||
GEMINI_API_KEY: str
|
||||
|
||||
|
||||
class ConfigForm(BaseModel):
|
||||
enabled: bool
|
||||
engine: str
|
||||
@@ -85,6 +94,7 @@ class ConfigForm(BaseModel):
|
||||
openai: OpenAIConfigForm
|
||||
automatic1111: Automatic1111ConfigForm
|
||||
comfyui: ComfyUIConfigForm
|
||||
gemini: GeminiConfigForm
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
@@ -103,6 +113,11 @@ async def update_config(
|
||||
)
|
||||
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
||||
|
||||
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
|
||||
form_data.gemini.GEMINI_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
|
||||
|
||||
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
||||
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
||||
)
|
||||
@@ -129,6 +144,8 @@ async def update_config(
|
||||
request.app.state.config.COMFYUI_BASE_URL = (
|
||||
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
||||
)
|
||||
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
|
||||
|
||||
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
||||
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
|
||||
form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
||||
@@ -155,6 +172,10 @@ async def update_config(
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -184,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
|
||||
headers = None
|
||||
if request.app.state.config.COMFYUI_API_KEY:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
||||
headers=headers,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
@@ -224,6 +253,12 @@ def get_image_model(request):
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "dall-e-2"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "imagen-3.0-generate-002"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
@@ -298,6 +333,11 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
return [
|
||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return [
|
||||
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
# TODO - get models from comfyui
|
||||
@@ -322,7 +362,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
if model_node_id:
|
||||
model_list_key = None
|
||||
|
||||
print(workflow[model_node_id]["class_type"])
|
||||
log.info(workflow[model_node_id]["class_type"])
|
||||
for key in info[workflow[model_node_id]["class_type"]]["input"][
|
||||
"required"
|
||||
]:
|
||||
@@ -411,7 +451,7 @@ def load_url_image_data(url, headers=None):
|
||||
return None
|
||||
|
||||
|
||||
def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
def upload_image(request, image_data, content_type, metadata, user):
|
||||
image_format = mimetypes.guess_extension(content_type)
|
||||
file = UploadFile(
|
||||
file=io.BytesIO(image_data),
|
||||
@@ -420,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
"content-type": content_type,
|
||||
},
|
||||
)
|
||||
file_item = upload_file(request, file, user, file_metadata=image_metadata)
|
||||
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
return url
|
||||
|
||||
@@ -461,7 +501,11 @@ async def image_generations(
|
||||
if form_data.size
|
||||
else request.app.state.config.IMAGE_SIZE
|
||||
),
|
||||
"response_format": "b64_json",
|
||||
**(
|
||||
{}
|
||||
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else {"response_format": "b64_json"}
|
||||
),
|
||||
}
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
@@ -478,11 +522,50 @@ async def image_generations(
|
||||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
if image_url := image.get("url", None):
|
||||
image_data, content_type = load_url_image_data(image_url, headers)
|
||||
else:
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
|
||||
|
||||
model = get_image_model(request)
|
||||
data = {
|
||||
"instances": {"prompt": form_data.prompt},
|
||||
"parameters": {
|
||||
"sampleCount": form_data.n,
|
||||
"outputOptions": {"mimeType": "image/png"},
|
||||
},
|
||||
}
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
res = r.json()
|
||||
|
||||
images = []
|
||||
for image in res["predictions"]:
|
||||
image_data, content_type = load_b64_image_data(
|
||||
image["bytesBase64Encoded"]
|
||||
)
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
@@ -529,9 +612,9 @@ async def image_generations(
|
||||
image_data, content_type = load_url_image_data(image["url"], headers)
|
||||
url = upload_image(
|
||||
request,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
image_data,
|
||||
content_type,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
@@ -541,7 +624,7 @@ async def image_generations(
|
||||
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
||||
):
|
||||
if form_data.model:
|
||||
set_image_model(form_data.model)
|
||||
set_image_model(request, form_data.model)
|
||||
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
@@ -582,9 +665,9 @@ async def image_generations(
|
||||
image_data, content_type = load_b64_image_data(image)
|
||||
url = upload_image(
|
||||
request,
|
||||
{**data, "info": res["info"]},
|
||||
image_data,
|
||||
content_type,
|
||||
{**data, "info": res["info"]},
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
|
||||
@@ -9,8 +9,8 @@ from open_webui.models.knowledge import (
|
||||
KnowledgeResponse,
|
||||
KnowledgeUserResponse,
|
||||
)
|
||||
from open_webui.models.files import Files, FileModel
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.models.files import Files, FileModel, FileMetadataResponse
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
from open_webui.routers.retrieval import (
|
||||
process_file,
|
||||
ProcessFileForm,
|
||||
@@ -161,13 +161,94 @@ async def create_new_knowledge(
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# ReindexKnowledgeFiles
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/reindex", response_model=bool)
|
||||
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||
|
||||
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
|
||||
|
||||
deleted_knowledge_bases = []
|
||||
|
||||
for knowledge_base in knowledge_bases:
|
||||
# -- Robust error handling for missing or invalid data
|
||||
if not knowledge_base.data or not isinstance(knowledge_base.data, dict):
|
||||
log.warning(
|
||||
f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting."
|
||||
)
|
||||
try:
|
||||
Knowledges.delete_knowledge_by_id(id=knowledge_base.id)
|
||||
deleted_knowledge_bases.append(knowledge_base.id)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
file_ids = knowledge_base.data.get("file_ids", [])
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
try:
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
|
||||
VECTOR_DB_CLIENT.delete_collection(
|
||||
collection_name=knowledge_base.id
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
|
||||
continue # Skip, don't raise
|
||||
|
||||
failed_files = []
|
||||
for file in files:
|
||||
try:
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(
|
||||
file_id=file.id, collection_name=knowledge_base.id
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error processing file {file.filename} (ID: {file.id}): {str(e)}"
|
||||
)
|
||||
failed_files.append({"file_id": file.id, "error": str(e)})
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
|
||||
# Don't raise, just continue
|
||||
continue
|
||||
|
||||
if failed_files:
|
||||
log.warning(
|
||||
f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}"
|
||||
)
|
||||
for failed in failed_files:
|
||||
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
|
||||
|
||||
log.info(
|
||||
f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
############################
|
||||
# GetKnowledgeById
|
||||
############################
|
||||
|
||||
|
||||
class KnowledgeFilesResponse(KnowledgeResponse):
|
||||
files: list[FileModel]
|
||||
files: list[FileMetadataResponse]
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
|
||||
@@ -183,7 +264,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
):
|
||||
|
||||
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
@@ -311,7 +392,7 @@ def add_file_to_knowledge_by_id(
|
||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
||||
|
||||
if knowledge:
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
@@ -388,7 +469,7 @@ def update_file_from_knowledge_by_id(
|
||||
data = knowledge.data or {}
|
||||
file_ids = data.get("file_ids", [])
|
||||
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
@@ -437,14 +518,24 @@ def remove_file_from_knowledge_by_id(
|
||||
)
|
||||
|
||||
# Remove content from the vector database
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||
)
|
||||
try:
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug("This was most likely caused by bypassing embedding processing")
|
||||
log.debug(e)
|
||||
pass
|
||||
|
||||
# Remove the file's collection from vector database
|
||||
file_collection = f"file-{form_data.file_id}"
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
||||
try:
|
||||
# Remove the file's collection from vector database
|
||||
file_collection = f"file-{form_data.file_id}"
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
||||
except Exception as e:
|
||||
log.debug("This was most likely caused by bypassing embedding processing")
|
||||
log.debug(e)
|
||||
pass
|
||||
|
||||
# Delete file from database
|
||||
Files.delete_file_by_id(form_data.file_id)
|
||||
@@ -460,7 +551,7 @@ def remove_file_from_knowledge_by_id(
|
||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
||||
|
||||
if knowledge:
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
@@ -614,7 +705,7 @@ def add_files_to_knowledge_batch(
|
||||
)
|
||||
|
||||
# Get files content
|
||||
print(f"files/batch/add - {len(form_data)} files")
|
||||
log.info(f"files/batch/add - {len(form_data)} files")
|
||||
files: List[FileModel] = []
|
||||
for form in form_data:
|
||||
file = Files.get_file_by_id(form.file_id)
|
||||
@@ -656,7 +747,7 @@ def add_files_to_knowledge_batch(
|
||||
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Files.get_files_by_ids(existing_file_ids),
|
||||
files=Files.get_file_metadatas_by_ids(existing_file_ids),
|
||||
warnings={
|
||||
"message": "Some files failed to process",
|
||||
"errors": error_details,
|
||||
@@ -664,5 +755,6 @@ def add_files_to_knowledge_batch(
|
||||
)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
|
||||
**knowledge.model_dump(),
|
||||
files=Files.get_file_metadatas_by_ids(existing_file_ids),
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.models.memories import Memories, MemoryModel
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.auth import get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
@@ -57,7 +57,9 @@ async def add_memory(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
),
|
||||
"metadata": {"created_at": memory.created_at},
|
||||
}
|
||||
],
|
||||
@@ -82,7 +84,7 @@ async def query_memory(
|
||||
):
|
||||
results = VECTOR_DB_CLIENT.search(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
|
||||
limit=form_data.k,
|
||||
)
|
||||
|
||||
@@ -105,7 +107,9 @@ async def reset_memory_from_vector_db(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
@@ -149,7 +153,9 @@ async def update_memory_by_id(
|
||||
form_data: MemoryUpdateModel,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
memory = Memories.update_memory_by_id(memory_id, form_data.content)
|
||||
memory = Memories.update_memory_by_id_and_user_id(
|
||||
memory_id, user.id, form_data.content
|
||||
)
|
||||
if memory is None:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
||||
@@ -161,7 +167,7 @@ async def update_memory_by_id(
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user
|
||||
memory.content, user=user
|
||||
),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
|
||||
218
backend/open_webui/routers/notes.py
Normal file
218
backend/open_webui/routers/notes.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from open_webui.models.notes import Notes, NoteModel, NoteForm, NoteUserResponse
|
||||
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetNotes
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[NoteUserResponse])
|
||||
async def get_notes(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
notes = [
|
||||
NoteUserResponse(
|
||||
**{
|
||||
**note.model_dump(),
|
||||
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
|
||||
}
|
||||
)
|
||||
for note in Notes.get_notes_by_user_id(user.id, "write")
|
||||
]
|
||||
|
||||
return notes
|
||||
|
||||
|
||||
@router.get("/list", response_model=list[NoteUserResponse])
|
||||
async def get_note_list(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
notes = [
|
||||
NoteUserResponse(
|
||||
**{
|
||||
**note.model_dump(),
|
||||
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
|
||||
}
|
||||
)
|
||||
for note in Notes.get_notes_by_user_id(user.id, "read")
|
||||
]
|
||||
|
||||
return notes
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewNote
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[NoteModel])
|
||||
async def create_new_note(
|
||||
request: Request, form_data: NoteForm, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
try:
|
||||
note = Notes.insert_new_note(form_data, user.id)
|
||||
return note
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetNoteById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[NoteModel])
|
||||
async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
note = Notes.get_note_by_id(id)
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
and not has_access(user.id, type="read", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
return note
|
||||
|
||||
|
||||
############################
|
||||
# UpdateNoteById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/update", response_model=Optional[NoteModel])
|
||||
async def update_note_by_id(
|
||||
request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user)
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
note = Notes.get_note_by_id(id)
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
note = Notes.update_note_by_id(id, form_data)
|
||||
return note
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteNoteById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/delete", response_model=bool)
|
||||
async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
note = Notes.get_note_by_id(id)
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
note = Notes.delete_note_by_id(id)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
@@ -9,11 +9,18 @@ import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
from aiocache import cached
|
||||
import requests
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.env import (
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
@@ -26,7 +33,7 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
|
||||
@@ -49,8 +56,9 @@ from open_webui.config import (
|
||||
from open_webui.env import (
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
@@ -66,12 +74,27 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -96,6 +119,7 @@ async def send_post_request(
|
||||
stream: bool = True,
|
||||
key: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
user: UserModel = None,
|
||||
):
|
||||
|
||||
r = None
|
||||
@@ -110,7 +134,18 @@ async def send_post_request(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -186,12 +221,26 @@ async def verify_connection(
|
||||
key = form_data.key
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/api/version",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
detail = f"HTTP Error: {r.status}"
|
||||
@@ -253,8 +302,24 @@ async def update_config(
|
||||
}
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request):
|
||||
def merge_ollama_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
|
||||
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
log.info("get_all_models()")
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
@@ -262,7 +327,7 @@ async def get_all_models(request: Request):
|
||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags"))
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
@@ -275,7 +340,9 @@ async def get_all_models(request: Request):
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", key))
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/tags", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
@@ -291,7 +358,10 @@ async def get_all_models(request: Request):
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
connection_type = api_config.get("connection_type", "local")
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
tags = api_config.get("tags", [])
|
||||
model_ids = api_config.get("model_ids", [])
|
||||
|
||||
if len(model_ids) != 0 and "models" in response:
|
||||
@@ -302,27 +372,18 @@ async def get_all_models(request: Request):
|
||||
)
|
||||
)
|
||||
|
||||
if prefix_id:
|
||||
for model in response.get("models", []):
|
||||
for model in response.get("models", []):
|
||||
if prefix_id:
|
||||
model["model"] = f"{prefix_id}.{model['model']}"
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
if tags:
|
||||
model["tags"] = tags
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
if connection_type:
|
||||
model["connection_type"] = connection_type
|
||||
|
||||
models = {
|
||||
"models": merge_models_lists(
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
@@ -330,6 +391,22 @@ async def get_all_models(request: Request):
|
||||
)
|
||||
}
|
||||
|
||||
try:
|
||||
loaded_models = await get_ollama_loaded_models(request, user=user)
|
||||
expires_map = {
|
||||
m["name"]: m["expires_at"]
|
||||
for m in loaded_models["models"]
|
||||
if "expires_at" in m
|
||||
}
|
||||
|
||||
for m in models["models"]:
|
||||
if m["name"] in expires_map:
|
||||
# Parse ISO8601 datetime with offset, get unix timestamp as int
|
||||
dt = datetime.fromisoformat(expires_map[m["name"]])
|
||||
m["expires_at"] = int(dt.timestamp())
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to get loaded models: {e}")
|
||||
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
@@ -360,7 +437,7 @@ async def get_ollama_tags(
|
||||
models = []
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
@@ -370,7 +447,19 @@ async def get_ollama_tags(
|
||||
r = requests.request(
|
||||
method="GET",
|
||||
url=f"{url}/api/tags",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -398,24 +487,95 @@ async def get_ollama_tags(
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
enable = api_config.get("enable", True)
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/ps", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
if response:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
|
||||
for model in response.get("models", []):
|
||||
if prefix_id:
|
||||
model["model"] = f"{prefix_id}.{model['model']}"
|
||||
|
||||
models = {
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
)
|
||||
)
|
||||
}
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/api/version")
|
||||
@router.get("/api/version/{url_idx}")
|
||||
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
if url_idx is None:
|
||||
# returns lowest version
|
||||
request_tasks = [
|
||||
send_get_request(
|
||||
f"{url}/api/version",
|
||||
request_tasks = []
|
||||
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
|
||||
enable = api_config.get("enable", True)
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(
|
||||
send_get_request(
|
||||
f"{url}/api/version",
|
||||
key,
|
||||
)
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
responses = list(filter(lambda x: x is not None, responses))
|
||||
|
||||
@@ -462,35 +622,74 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
return {"version": False}
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = [
|
||||
send_get_request(
|
||||
f"{url}/api/ps",
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class ModelNameForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.post("/api/unload")
|
||||
async def unload_model(
|
||||
request: Request,
|
||||
form_data: ModelNameForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
model_name = form_data.name
|
||||
if not model_name:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing 'name' of model to unload."
|
||||
)
|
||||
|
||||
# Refresh/load models if needed, get mapping from name to URLs
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
# Canonicalize model name (if not supplied with version)
|
||||
if ":" not in model_name:
|
||||
model_name = f"{model_name}:latest"
|
||||
|
||||
if model_name not in models:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
|
||||
)
|
||||
url_indices = models[model_name]["urls"]
|
||||
|
||||
# Send unload to ALL url_indices
|
||||
results = []
|
||||
errors = []
|
||||
for idx in url_indices:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
)
|
||||
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id and model_name.startswith(f"{prefix_id}."):
|
||||
model_name = model_name[len(f"{prefix_id}.") :]
|
||||
|
||||
payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
|
||||
|
||||
try:
|
||||
res = await send_post_request(
|
||||
url=f"{url}/api/generate",
|
||||
payload=json.dumps(payload),
|
||||
stream=False,
|
||||
key=key,
|
||||
user=user,
|
||||
)
|
||||
results.append({"url_idx": idx, "success": True, "response": res})
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to unload model on node {idx}: {e}")
|
||||
errors.append({"url_idx": idx, "success": False, "error": str(e)})
|
||||
|
||||
if len(errors) > 0:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
|
||||
)
|
||||
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@router.post("/api/pull")
|
||||
@router.post("/api/pull/{url_idx}")
|
||||
async def pull_model(
|
||||
@@ -509,6 +708,7 @@ async def pull_model(
|
||||
url=f"{url}/api/pull",
|
||||
payload=json.dumps(payload),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -527,7 +727,7 @@ async def push_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
@@ -545,6 +745,7 @@ async def push_model(
|
||||
url=f"{url}/api/push",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -571,6 +772,7 @@ async def create_model(
|
||||
url=f"{url}/api/create",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -588,7 +790,7 @@ async def copy_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.source in models:
|
||||
@@ -609,6 +811,16 @@ async def copy_model(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -643,7 +855,7 @@ async def delete_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
@@ -665,6 +877,16 @@ async def delete_model(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
@@ -693,7 +915,7 @@ async def delete_model(
|
||||
async def show_model_info(
|
||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||
):
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name not in models:
|
||||
@@ -714,6 +936,16 @@ async def show_model_info(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -757,7 +989,7 @@ async def embed(
|
||||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
@@ -774,8 +1006,16 @@ async def embed(
|
||||
)
|
||||
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(url_idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
||||
)
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
@@ -783,6 +1023,16 @@ async def embed(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -826,7 +1076,7 @@ async def embeddings(
|
||||
log.info(f"generate_ollama_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
@@ -843,8 +1093,16 @@ async def embeddings(
|
||||
)
|
||||
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(url_idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
||||
)
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
@@ -852,6 +1110,16 @@ async def embeddings(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -882,7 +1150,7 @@ class GenerateCompletionForm(BaseModel):
|
||||
prompt: str
|
||||
suffix: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
format: Optional[str] = None
|
||||
format: Optional[Union[dict, str]] = None
|
||||
options: Optional[dict] = None
|
||||
system: Optional[str] = None
|
||||
template: Optional[str] = None
|
||||
@@ -901,7 +1169,7 @@ async def generate_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
@@ -931,20 +1199,34 @@ async def generate_completion(
|
||||
url=f"{url}/api/generate",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[list[dict]] = None
|
||||
images: Optional[list[str]] = None
|
||||
|
||||
@validator("content", pre=True)
|
||||
@classmethod
|
||||
def check_at_least_one_field(cls, field_value, values, **kwargs):
|
||||
# Raise an error if both 'content' and 'tool_calls' are None
|
||||
if field_value is None and (
|
||||
"tool_calls" not in values or values["tool_calls"] is None
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of 'content' or 'tool_calls' must be provided"
|
||||
)
|
||||
|
||||
return field_value
|
||||
|
||||
|
||||
class GenerateChatCompletionForm(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage]
|
||||
format: Optional[dict] = None
|
||||
format: Optional[Union[dict, str]] = None
|
||||
options: Optional[dict] = None
|
||||
template: Optional[str] = None
|
||||
stream: Optional[bool] = True
|
||||
@@ -1001,13 +1283,14 @@ async def generate_chat_completion(
|
||||
params = model_info.params.model_dump()
|
||||
|
||||
if params:
|
||||
if payload.get("options") is None:
|
||||
payload["options"] = {}
|
||||
system = params.pop("system", None)
|
||||
|
||||
# Unlike OpenAI, Ollama does not support params directly in the body
|
||||
payload["options"] = apply_model_params_to_body_ollama(
|
||||
params, payload["options"]
|
||||
params, (payload.get("options", {}) or {})
|
||||
)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
@@ -1040,13 +1323,14 @@ async def generate_chat_completion(
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
# payload["keep_alive"] = -1 # keep alive forever
|
||||
return await send_post_request(
|
||||
url=f"{url}/api/chat",
|
||||
payload=json.dumps(payload),
|
||||
stream=form_data.stream,
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
content_type="application/x-ndjson",
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1058,7 +1342,7 @@ class OpenAIChatMessageContent(BaseModel):
|
||||
|
||||
class OpenAIChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Union[str, list[OpenAIChatMessageContent]]
|
||||
content: Union[Optional[str], list[OpenAIChatMessageContent]]
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -1149,6 +1433,7 @@ async def generate_openai_completion(
|
||||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1187,8 +1472,10 @@ async def generate_openai_chat_completion(
|
||||
params = model_info.params.model_dump()
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
@@ -1227,6 +1514,7 @@ async def generate_openai_chat_completion(
|
||||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1240,7 +1528,7 @@ async def get_openai_models(
|
||||
|
||||
models = []
|
||||
if url_idx is None:
|
||||
model_list = await get_all_models(request)
|
||||
model_list = await get_all_models(request, user=user)
|
||||
models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
@@ -1341,7 +1629,9 @@ async def download_file_stream(
|
||||
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(file_url, headers=headers) as response:
|
||||
async with session.get(
|
||||
file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||
) as response:
|
||||
total_size = int(response.headers.get("content-length", 0)) + current_size
|
||||
|
||||
with open(file_path, "ab+") as file:
|
||||
@@ -1356,7 +1646,8 @@ async def download_file_stream(
|
||||
|
||||
if done:
|
||||
file.seek(0)
|
||||
hashed = calculate_sha256(file)
|
||||
chunk_size = 1024 * 1024 * 2
|
||||
hashed = calculate_sha256(file, chunk_size)
|
||||
file.seek(0)
|
||||
|
||||
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
||||
@@ -1420,7 +1711,9 @@ async def upload_model(
|
||||
if url_idx is None:
|
||||
url_idx = 0
|
||||
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
||||
|
||||
filename = os.path.basename(file.filename)
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
# --- P1: save file locally ---
|
||||
@@ -1465,13 +1758,13 @@ async def upload_model(
|
||||
os.remove(file_path)
|
||||
|
||||
# Create model in ollama
|
||||
model_name, ext = os.path.splitext(file.filename)
|
||||
model_name, ext = os.path.splitext(filename)
|
||||
log.info(f"Created Model: {model_name}") # DEBUG
|
||||
|
||||
create_payload = {
|
||||
"model": model_name,
|
||||
# Reference the file by its original name => the uploaded blob's digest
|
||||
"files": {file.filename: f"sha256:{file_hash}"},
|
||||
"files": {filename: f"sha256:{file_hash}"},
|
||||
}
|
||||
log.info(f"Model Payload: {create_payload}") # DEBUG
|
||||
|
||||
@@ -1488,7 +1781,7 @@ async def upload_model(
|
||||
done_msg = {
|
||||
"done": True,
|
||||
"blob": f"sha256:{file_hash}",
|
||||
"name": file.filename,
|
||||
"name": filename,
|
||||
"model_created": model_name,
|
||||
}
|
||||
yield f"data: {json.dumps(done_msg)}\n\n"
|
||||
|
||||
@@ -21,11 +21,13 @@ from open_webui.config import (
|
||||
CACHE_DIR,
|
||||
)
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
||||
@@ -35,6 +37,9 @@ from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
convert_logit_bias_input_to_json,
|
||||
)
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access
|
||||
@@ -51,12 +56,26 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
||||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -75,18 +94,23 @@ async def cleanup_response(
|
||||
await session.close()
|
||||
|
||||
|
||||
def openai_o1_o3_handler(payload):
|
||||
def openai_o_series_handler(payload):
|
||||
"""
|
||||
Handle o1, o3 specific parameters
|
||||
Handle "o" series specific parameters
|
||||
"""
|
||||
if "max_tokens" in payload:
|
||||
# Remove "max_tokens" from the payload
|
||||
# Convert "max_tokens" to "max_completion_tokens" for all o-series models
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
|
||||
# Handle system role conversion based on model type
|
||||
if payload["messages"][0]["role"] == "system":
|
||||
payload["messages"][0]["role"] = "user"
|
||||
model_lower = payload["model"].lower()
|
||||
# Legacy models use "user" role instead of "system"
|
||||
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
|
||||
payload["messages"][0]["role"] = "user"
|
||||
else:
|
||||
payload["messages"][0]["role"] = "developer"
|
||||
|
||||
return payload
|
||||
|
||||
@@ -172,7 +196,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
body = await request.body()
|
||||
name = hashlib.sha256(body).hexdigest()
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
||||
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
||||
@@ -247,7 +271,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
||||
|
||||
|
||||
async def get_all_models_responses(request: Request) -> list:
|
||||
async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return []
|
||||
|
||||
@@ -271,7 +295,9 @@ async def get_all_models_responses(request: Request) -> list:
|
||||
):
|
||||
request_tasks.append(
|
||||
send_get_request(
|
||||
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -291,6 +317,7 @@ async def get_all_models_responses(request: Request) -> list:
|
||||
send_get_request(
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -326,14 +353,22 @@ async def get_all_models_responses(request: Request) -> list:
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
connection_type = api_config.get("connection_type", "external")
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
tags = api_config.get("tags", [])
|
||||
|
||||
if prefix_id:
|
||||
for model in (
|
||||
response if isinstance(response, list) else response.get("data", [])
|
||||
):
|
||||
for model in (
|
||||
response if isinstance(response, list) else response.get("data", [])
|
||||
):
|
||||
if prefix_id:
|
||||
model["id"] = f"{prefix_id}.{model['id']}"
|
||||
|
||||
if tags:
|
||||
model["tags"] = tags
|
||||
|
||||
if connection_type:
|
||||
model["connection_type"] = connection_type
|
||||
|
||||
log.debug(f"get_all_models:responses() {responses}")
|
||||
return responses
|
||||
|
||||
@@ -351,14 +386,14 @@ async def get_filtered_models(models, user):
|
||||
return filtered_models
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request) -> dict[str, list]:
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||
log.info("get_all_models()")
|
||||
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return {"data": []}
|
||||
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user=user)
|
||||
|
||||
def extract_data(response):
|
||||
if response and "data" in response:
|
||||
@@ -373,6 +408,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
|
||||
|
||||
for idx, models in enumerate(model_lists):
|
||||
if models is not None and "error" not in models:
|
||||
|
||||
merged_list.extend(
|
||||
[
|
||||
{
|
||||
@@ -380,21 +416,25 @@ async def get_all_models(request: Request) -> dict[str, list]:
|
||||
"name": model.get("name", model["id"]),
|
||||
"owned_by": "openai",
|
||||
"openai": model,
|
||||
"connection_type": model.get("connection_type", "external"),
|
||||
"urlIdx": idx,
|
||||
}
|
||||
for model in models
|
||||
if "api.openai.com"
|
||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
or not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
if (model.get("id") or model.get("name"))
|
||||
and (
|
||||
"api.openai.com"
|
||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
or not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -418,65 +458,79 @@ async def get_models(
|
||||
}
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(url_idx),
|
||||
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
||||
)
|
||||
|
||||
r = None
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
||||
)
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
response_data = await r.json()
|
||||
if api_config.get("azure", False):
|
||||
models = {
|
||||
"data": api_config.get("model_ids", []) or [],
|
||||
"object": "list",
|
||||
}
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
# Check if we're calling OpenAI API based on the URL
|
||||
if "api.openai.com" in url:
|
||||
# Filter models according to the specified conditions
|
||||
response_data["data"] = [
|
||||
model
|
||||
for model in response_data.get("data", [])
|
||||
if not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
]
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
|
||||
models = response_data
|
||||
response_data = await r.json()
|
||||
|
||||
# Check if we're calling OpenAI API based on the URL
|
||||
if "api.openai.com" in url:
|
||||
# Filter models according to the specified conditions
|
||||
response_data["data"] = [
|
||||
model
|
||||
for model in response_data.get("data", [])
|
||||
if not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
models = response_data
|
||||
except aiohttp.ClientError as e:
|
||||
# ClientError covers all aiohttp requests issues
|
||||
log.exception(f"Client error: {str(e)}")
|
||||
@@ -498,6 +552,8 @@ class ConnectionVerificationForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
|
||||
config: Optional[dict] = None
|
||||
|
||||
|
||||
@router.post("/verify")
|
||||
async def verify_connection(
|
||||
@@ -506,27 +562,64 @@ async def verify_connection(
|
||||
url = form_data.url
|
||||
key = form_data.key
|
||||
|
||||
api_config = form_data.config or {}
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
if api_config.get("azure", False):
|
||||
headers["api-key"] = key
|
||||
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
|
||||
async with session.get(
|
||||
url=f"{url}/openai/models?api-version={api_version}",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
# ClientError covers all aiohttp requests issues
|
||||
@@ -540,6 +633,63 @@ async def verify_connection(
|
||||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
|
||||
def convert_to_azure_payload(
|
||||
url,
|
||||
payload: dict,
|
||||
):
|
||||
model = payload.get("model", "")
|
||||
|
||||
# Filter allowed parameters based on Azure OpenAI API
|
||||
allowed_params = {
|
||||
"messages",
|
||||
"temperature",
|
||||
"role",
|
||||
"content",
|
||||
"contentPart",
|
||||
"contentPartImage",
|
||||
"enhancements",
|
||||
"dataSources",
|
||||
"n",
|
||||
"stream",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"top_p",
|
||||
"log_probs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"seed",
|
||||
"max_completion_tokens",
|
||||
}
|
||||
|
||||
# Special handling for o-series models
|
||||
if model.startswith("o") and model.endswith("-mini"):
|
||||
# Convert max_tokens to max_completion_tokens for o-series models
|
||||
if "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Remove temperature if not 1 for o-series models
|
||||
if "temperature" in payload and payload["temperature"] != 1:
|
||||
log.debug(
|
||||
f"Removing temperature parameter for o-series model {model} as only default value (1) is supported"
|
||||
)
|
||||
del payload["temperature"]
|
||||
|
||||
# Filter out unsupported parameters
|
||||
payload = {k: v for k, v in payload.items() if k in allowed_params}
|
||||
|
||||
url = f"{url}/openai/deployments/{model}"
|
||||
return url, payload
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def generate_chat_completion(
|
||||
request: Request,
|
||||
@@ -565,8 +715,12 @@ async def generate_chat_completion(
|
||||
model_id = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
@@ -587,7 +741,7 @@ async def generate_chat_completion(
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||
if model:
|
||||
idx = model["urlIdx"]
|
||||
@@ -621,10 +775,10 @@ async def generate_chat_completion(
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
|
||||
if is_o1_o3:
|
||||
payload = openai_o1_o3_handler(payload)
|
||||
# Check if model is from "o" series
|
||||
is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4"))
|
||||
if is_o_series:
|
||||
payload = openai_o_series_handler(payload)
|
||||
elif "api.openai.com" not in url:
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
if "max_completion_tokens" in payload:
|
||||
@@ -635,6 +789,43 @@ async def generate_chat_completion(
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
if "logit_bias" in payload:
|
||||
payload["logit_bias"] = json.loads(
|
||||
convert_logit_bias_input_to_json(payload["logit_bias"])
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"HTTP-Referer": "https://openwebui.com/",
|
||||
"X-Title": "Open WebUI",
|
||||
}
|
||||
if "openrouter.ai" in url
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
if api_config.get("azure", False):
|
||||
request_url, payload = convert_to_azure_payload(url, payload)
|
||||
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
headers["api-key"] = key
|
||||
headers["api-version"] = api_version
|
||||
request_url = f"{request_url}/chat/completions?api-version={api_version}"
|
||||
else:
|
||||
request_url = f"{url}/chat/completions"
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
payload = json.dumps(payload)
|
||||
|
||||
r = None
|
||||
@@ -649,30 +840,10 @@ async def generate_chat_completion(
|
||||
|
||||
r = await session.request(
|
||||
method="POST",
|
||||
url=f"{url}/chat/completions",
|
||||
url=request_url,
|
||||
data=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"HTTP-Referer": "https://openwebui.com/",
|
||||
"X-Title": "Open WebUI",
|
||||
}
|
||||
if "openrouter.ai" in url
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
|
||||
# Check if response is SSE
|
||||
@@ -801,31 +972,54 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
idx = 0
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
if api_config.get("azure", False):
|
||||
headers["api-key"] = key
|
||||
headers["api-version"] = (
|
||||
api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
)
|
||||
|
||||
payload = json.loads(body)
|
||||
url, payload = convert_to_azure_payload(url, payload)
|
||||
body = json.dumps(payload).encode()
|
||||
|
||||
request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
request_url = f"{url}/{path}"
|
||||
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method=request.method,
|
||||
url=f"{url}/{path}",
|
||||
url=request_url,
|
||||
data=body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -851,7 +1045,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
print(res)
|
||||
log.error(res)
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except Exception:
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi import (
|
||||
status,
|
||||
APIRouter,
|
||||
)
|
||||
import aiohttp
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
@@ -17,7 +18,7 @@ from pydantic import BaseModel
|
||||
from starlette.responses import FileResponse
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
@@ -56,96 +57,111 @@ def get_sorted_filters(model_id, models):
|
||||
return sorted_filters
|
||||
|
||||
|
||||
def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
|
||||
try:
|
||||
urlIdx = int(urlIdx)
|
||||
except:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key == "":
|
||||
if not key:
|
||||
continue
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
request_data = {
|
||||
"user": user,
|
||||
"body": payload,
|
||||
}
|
||||
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
res = r.json()
|
||||
try:
|
||||
async with session.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
payload = await response.json()
|
||||
response.raise_for_status()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
res = (
|
||||
await response.json()
|
||||
if response.content_type == "application/json"
|
||||
else {}
|
||||
)
|
||||
if "detail" in res:
|
||||
raise Exception(r.status_code, res["detail"])
|
||||
raise Exception(response.status, res["detail"])
|
||||
except Exception as e:
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
|
||||
try:
|
||||
urlIdx = int(urlIdx)
|
||||
except:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
r = requests.post(
|
||||
if not key:
|
||||
continue
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
request_data = {
|
||||
"user": user,
|
||||
"body": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers={"Authorization": f"Bearer {key}"},
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
payload = data
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
payload = await response.json()
|
||||
response.raise_for_status()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
try:
|
||||
res = r.json()
|
||||
res = (
|
||||
await response.json()
|
||||
if "application/json" in response.content_type
|
||||
else {}
|
||||
)
|
||||
if "detail" in res:
|
||||
return Exception(r.status_code, res)
|
||||
raise Exception(response.status, res)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
@@ -161,7 +177,7 @@ router = APIRouter()
|
||||
|
||||
@router.get("/list")
|
||||
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user)
|
||||
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
||||
|
||||
urlIdxs = [
|
||||
@@ -188,9 +204,11 @@ async def upload_pipeline(
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
print("upload_pipeline", urlIdx, file.filename)
|
||||
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
|
||||
filename = os.path.basename(file.filename)
|
||||
|
||||
# Check if the uploaded file is a python file
|
||||
if not (file.filename and file.filename.endswith(".py")):
|
||||
if not (filename and filename.endswith(".py")):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only Python (.py) files are allowed.",
|
||||
@@ -198,7 +216,7 @@ async def upload_pipeline(
|
||||
|
||||
upload_folder = f"{CACHE_DIR}/pipelines"
|
||||
os.makedirs(upload_folder, exist_ok=True)
|
||||
file_path = os.path.join(upload_folder, file.filename)
|
||||
file_path = os.path.join(upload_folder, filename)
|
||||
|
||||
r = None
|
||||
try:
|
||||
@@ -223,7 +241,7 @@ async def upload_pipeline(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
@@ -274,7 +292,7 @@ async def add_pipeline(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -319,7 +337,7 @@ async def delete_pipeline(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -353,7 +371,7 @@ async def get_pipelines(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -392,7 +410,7 @@ async def get_pipeline_valves(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -432,7 +450,7 @@ async def get_pipeline_valves_spec(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -474,7 +492,7 @@ async def update_pipeline_valves(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,6 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.constants import TASKS
|
||||
|
||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||
|
||||
from open_webui.utils.task import get_task_model_id
|
||||
|
||||
from open_webui.config import (
|
||||
@@ -182,35 +183,28 @@ async def generate_title(
|
||||
else:
|
||||
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
messages = form_data["messages"]
|
||||
|
||||
# Remove reasoning details from the messages
|
||||
for message in messages:
|
||||
message["content"] = re.sub(
|
||||
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
|
||||
"",
|
||||
message["content"],
|
||||
flags=re.S,
|
||||
).strip()
|
||||
|
||||
content = title_generation_template(
|
||||
template,
|
||||
messages,
|
||||
form_data["messages"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
max_tokens = (
|
||||
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 1000}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
{"max_tokens": max_tokens}
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 1000,
|
||||
"max_completion_tokens": max_tokens,
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
@@ -221,6 +215,12 @@ async def generate_title(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -290,6 +290,12 @@ async def generate_chat_tags(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -356,6 +362,12 @@ async def generate_image_prompt(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -433,6 +445,12 @@ async def generate_queries(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -514,6 +532,12 @@ async def generate_autocompletion(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -571,7 +595,7 @@ async def generate_emoji(
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 4}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 4,
|
||||
}
|
||||
@@ -584,6 +608,12 @@ async def generate_emoji(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -613,17 +643,6 @@ async def generate_moa_response(
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
|
||||
|
||||
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = moa_response_generation_template(
|
||||
@@ -633,7 +652,7 @@ async def generate_moa_response(
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"metadata": {
|
||||
@@ -644,6 +663,12 @@ async def generate_moa_response(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import time
|
||||
import re
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from open_webui.models.tools import (
|
||||
ToolForm,
|
||||
@@ -8,13 +13,20 @@ from open_webui.models.tools import (
|
||||
ToolUserResponse,
|
||||
Tools,
|
||||
)
|
||||
from open_webui.utils.plugin import load_tools_module_by_id, replace_imports
|
||||
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.tools import get_tools_specs
|
||||
from open_webui.utils.tools import get_tool_specs
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.utils.tools import get_tool_servers_data
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@@ -25,11 +37,51 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ToolUserResponse])
|
||||
async def get_tools(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
tools = Tools.get_tools()
|
||||
else:
|
||||
tools = Tools.get_tools_by_user_id(user.id, "read")
|
||||
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if not request.app.state.TOOL_SERVERS:
|
||||
# If the tool servers are not set, we need to set them
|
||||
# This is done only once when the server starts
|
||||
# This is done to avoid loading the tool servers every time
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
tools = Tools.get_tools()
|
||||
for server in request.app.state.TOOL_SERVERS:
|
||||
tools.append(
|
||||
ToolUserResponse(
|
||||
**{
|
||||
"id": f"server:{server['idx']}",
|
||||
"user_id": f"server:{server['idx']}",
|
||||
"name": server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("title", "Tool Server"),
|
||||
"meta": {
|
||||
"description": server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("description", ""),
|
||||
},
|
||||
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
|
||||
server["idx"]
|
||||
]
|
||||
.get("config", {})
|
||||
.get("access_control", None),
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if user.role != "admin":
|
||||
tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user.id
|
||||
or has_access(user.id, "read", tool.access_control)
|
||||
]
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
@@ -47,6 +99,81 @@ async def get_tool_list(user=Depends(get_verified_user)):
|
||||
return tools
|
||||
|
||||
|
||||
############################
|
||||
# LoadFunctionFromLink
|
||||
############################
|
||||
|
||||
|
||||
class LoadUrlForm(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
def github_url_to_raw_url(url: str) -> str:
|
||||
# Handle 'tree' (folder) URLs (add main.py at the end)
|
||||
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
|
||||
if m1:
|
||||
org, repo, branch, path = m1.groups()
|
||||
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
|
||||
|
||||
# Handle 'blob' (file) URLs
|
||||
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
|
||||
if m2:
|
||||
org, repo, branch, path = m2.groups()
|
||||
return (
|
||||
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
|
||||
)
|
||||
|
||||
# No match; return as-is
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/load/url", response_model=Optional[dict])
|
||||
async def load_tool_from_url(
|
||||
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
|
||||
):
|
||||
# NOTE: This is NOT a SSRF vulnerability:
|
||||
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
|
||||
# and does NOT accept untrusted user input. Access is enforced by authentication.
|
||||
|
||||
url = str(form_data.url)
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="Please enter a valid URL")
|
||||
|
||||
url = github_url_to_raw_url(url)
|
||||
url_parts = url.rstrip("/").split("/")
|
||||
|
||||
file_name = url_parts[-1]
|
||||
tool_name = (
|
||||
file_name[:-3]
|
||||
if (
|
||||
file_name.endswith(".py")
|
||||
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
|
||||
)
|
||||
else url_parts[-2] if len(url_parts) > 1 else "function"
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, headers={"Content-Type": "application/json"}
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=resp.status, detail="Failed to fetch the tool"
|
||||
)
|
||||
data = await resp.text()
|
||||
if not data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No data received from the URL"
|
||||
)
|
||||
return {
|
||||
"name": tool_name,
|
||||
"content": data,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
|
||||
|
||||
|
||||
############################
|
||||
# ExportTools
|
||||
############################
|
||||
@@ -89,18 +216,18 @@ async def create_new_tools(
|
||||
if tools is None:
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
tools_module, frontmatter = load_tools_module_by_id(
|
||||
tool_module, frontmatter = load_tool_module_by_id(
|
||||
form_data.id, content=form_data.content
|
||||
)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[form_data.id] = tools_module
|
||||
TOOLS[form_data.id] = tool_module
|
||||
|
||||
specs = get_tools_specs(TOOLS[form_data.id])
|
||||
specs = get_tool_specs(TOOLS[form_data.id])
|
||||
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
||||
|
||||
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
|
||||
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
|
||||
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if tools:
|
||||
@@ -111,7 +238,7 @@ async def create_new_tools(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
@@ -178,22 +305,20 @@ async def update_tools_by_id(
|
||||
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
tools_module, frontmatter = load_tools_module_by_id(
|
||||
id, content=form_data.content
|
||||
)
|
||||
tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[id] = tools_module
|
||||
TOOLS[id] = tool_module
|
||||
|
||||
specs = get_tools_specs(TOOLS[id])
|
||||
specs = get_tool_specs(TOOLS[id])
|
||||
|
||||
updated = {
|
||||
**form_data.model_dump(exclude={"id"}),
|
||||
"specs": specs,
|
||||
}
|
||||
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
tools = Tools.update_tool_by_id(id, updated)
|
||||
|
||||
if tools:
|
||||
@@ -284,7 +409,7 @@ async def get_tools_valves_spec_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tools_module_by_id(id)
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if hasattr(tools_module, "Valves"):
|
||||
@@ -327,7 +452,7 @@ async def update_tools_valves_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tools_module_by_id(id)
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if not hasattr(tools_module, "Valves"):
|
||||
@@ -343,7 +468,7 @@ async def update_tools_valves_by_id(
|
||||
Tools.update_tool_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
@@ -383,7 +508,7 @@ async def get_tools_user_valves_spec_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tools_module_by_id(id)
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if hasattr(tools_module, "UserValves"):
|
||||
@@ -407,7 +532,7 @@ async def update_tools_user_valves_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tools_module_by_id(id)
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if hasattr(tools_module, "UserValves"):
|
||||
@@ -421,7 +546,7 @@ async def update_tools_user_valves_by_id(
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
|
||||
@@ -2,9 +2,11 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.models.auths import Auths
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.users import (
|
||||
UserModel,
|
||||
UserListResponse,
|
||||
UserRoleUpdateForm,
|
||||
Users,
|
||||
UserSettings,
|
||||
@@ -17,7 +19,10 @@ from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
|
||||
from open_webui.utils.access_control import get_permissions, has_permission
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -29,13 +34,38 @@ router = APIRouter()
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[UserModel])
|
||||
PAGE_ITEM_COUNT = 30
|
||||
|
||||
|
||||
@router.get("/", response_model=UserListResponse)
|
||||
async def get_users(
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
return Users.get_users(skip, limit)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
||||
page = max(1, page)
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@router.get("/all", response_model=UserListResponse)
|
||||
async def get_all_users(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
return Users.get_users()
|
||||
|
||||
|
||||
############################
|
||||
@@ -45,7 +75,7 @@ async def get_users(
|
||||
|
||||
@router.get("/groups")
|
||||
async def get_user_groups(user=Depends(get_verified_user)):
|
||||
return Users.get_user_groups(user.id)
|
||||
return Groups.get_groups_by_member_id(user.id)
|
||||
|
||||
|
||||
############################
|
||||
@@ -54,8 +84,12 @@ async def get_user_groups(user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.get("/permissions")
|
||||
async def get_user_permissisions(user=Depends(get_verified_user)):
|
||||
return Users.get_user_groups(user.id)
|
||||
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
return user_permissions
|
||||
|
||||
|
||||
############################
|
||||
@@ -68,32 +102,52 @@ class WorkspacePermissions(BaseModel):
|
||||
tools: bool = False
|
||||
|
||||
|
||||
class SharingPermissions(BaseModel):
|
||||
public_models: bool = True
|
||||
public_knowledge: bool = True
|
||||
public_prompts: bool = True
|
||||
public_tools: bool = True
|
||||
|
||||
|
||||
class ChatPermissions(BaseModel):
|
||||
controls: bool = True
|
||||
file_upload: bool = True
|
||||
delete: bool = True
|
||||
edit: bool = True
|
||||
share: bool = True
|
||||
export: bool = True
|
||||
stt: bool = True
|
||||
tts: bool = True
|
||||
call: bool = True
|
||||
multiple_models: bool = True
|
||||
temporary: bool = True
|
||||
temporary_enforced: bool = False
|
||||
|
||||
|
||||
class FeaturesPermissions(BaseModel):
|
||||
direct_tool_servers: bool = False
|
||||
web_search: bool = True
|
||||
image_generation: bool = True
|
||||
code_interpreter: bool = True
|
||||
notes: bool = True
|
||||
|
||||
|
||||
class UserPermissions(BaseModel):
|
||||
workspace: WorkspacePermissions
|
||||
sharing: SharingPermissions
|
||||
chat: ChatPermissions
|
||||
features: FeaturesPermissions
|
||||
|
||||
|
||||
@router.get("/default/permissions", response_model=UserPermissions)
|
||||
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||
async def get_default_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"workspace": WorkspacePermissions(
|
||||
**request.app.state.config.USER_PERMISSIONS.get("workspace", {})
|
||||
),
|
||||
"sharing": SharingPermissions(
|
||||
**request.app.state.config.USER_PERMISSIONS.get("sharing", {})
|
||||
),
|
||||
"chat": ChatPermissions(
|
||||
**request.app.state.config.USER_PERMISSIONS.get("chat", {})
|
||||
),
|
||||
@@ -104,7 +158,7 @@ async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||
|
||||
|
||||
@router.post("/default/permissions")
|
||||
async def update_user_permissions(
|
||||
async def update_default_user_permissions(
|
||||
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
|
||||
@@ -151,9 +205,22 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||
|
||||
@router.post("/user/settings/update", response_model=UserSettings)
|
||||
async def update_user_settings_by_session_user(
|
||||
form_data: UserSettings, user=Depends(get_verified_user)
|
||||
request: Request, form_data: UserSettings, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
|
||||
updated_user_settings = form_data.model_dump()
|
||||
if (
|
||||
user.role != "admin"
|
||||
and "toolServers" in updated_user_settings.get("ui").keys()
|
||||
and not has_permission(
|
||||
user.id,
|
||||
"features.direct_tool_servers",
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
)
|
||||
):
|
||||
# If the user is not an admin and does not have permission to use tool servers, remove the key
|
||||
updated_user_settings["ui"].pop("toolServers", None)
|
||||
|
||||
user = Users.update_user_settings_by_id(user.id, updated_user_settings)
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
@@ -263,6 +330,21 @@ async def update_user_by_id(
|
||||
form_data: UserUpdateForm,
|
||||
session_user=Depends(get_admin_user),
|
||||
):
|
||||
# Prevent modification of the primary admin user by other admins
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
if first_user and user_id == first_user.id and session_user.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error checking primary admin status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not verify primary admin status.",
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(user_id)
|
||||
|
||||
if user:
|
||||
@@ -310,6 +392,21 @@ async def update_user_by_id(
|
||||
|
||||
@router.delete("/{user_id}", response_model=bool)
|
||||
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
# Prevent deletion of the primary admin user
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
if first_user and user_id == first_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error checking primary admin status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not verify primary admin status.",
|
||||
)
|
||||
|
||||
if user.id != user_id:
|
||||
result = Auths.delete_auth_by_id(user_id)
|
||||
|
||||
@@ -321,6 +418,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
|
||||
)
|
||||
|
||||
# Prevent self-deletion
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
|
||||
@@ -1,48 +1,84 @@
|
||||
import black
|
||||
import logging
|
||||
import markdown
|
||||
|
||||
from open_webui.models.chats import ChatTitleMessagesForm
|
||||
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
|
||||
from open_webui.utils.misc import get_gravatar_url
|
||||
from open_webui.utils.pdf_generator import PDFGenerator
|
||||
from open_webui.utils.auth import get_admin_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/gravatar")
|
||||
async def get_gravatar(
|
||||
email: str,
|
||||
):
|
||||
async def get_gravatar(email: str, user=Depends(get_verified_user)):
|
||||
return get_gravatar_url(email)
|
||||
|
||||
|
||||
class CodeFormatRequest(BaseModel):
|
||||
class CodeForm(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
@router.post("/code/format")
|
||||
async def format_code(request: CodeFormatRequest):
|
||||
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
|
||||
try:
|
||||
formatted_code = black.format_str(request.code, mode=black.Mode())
|
||||
formatted_code = black.format_str(form_data.code, mode=black.Mode())
|
||||
return {"code": formatted_code}
|
||||
except black.NothingChanged:
|
||||
return {"code": request.code}
|
||||
return {"code": form_data.code}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/code/execute")
|
||||
async def execute_code(
|
||||
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
|
||||
):
|
||||
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
|
||||
output = await execute_code_jupyter(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
form_data.code,
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
|
||||
else None
|
||||
),
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
|
||||
else None
|
||||
),
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
)
|
||||
|
||||
return output
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Code execution engine not supported",
|
||||
)
|
||||
|
||||
|
||||
class MarkdownForm(BaseModel):
|
||||
md: str
|
||||
|
||||
|
||||
@router.post("/markdown")
|
||||
async def get_html_from_markdown(
|
||||
form_data: MarkdownForm,
|
||||
form_data: MarkdownForm, user=Depends(get_verified_user)
|
||||
):
|
||||
return {"html": markdown.markdown(form_data.md)}
|
||||
|
||||
@@ -54,7 +90,7 @@ class ChatForm(BaseModel):
|
||||
|
||||
@router.post("/pdf")
|
||||
async def download_chat_as_pdf(
|
||||
form_data: ChatTitleMessagesForm,
|
||||
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
|
||||
@@ -65,7 +101,7 @@ async def download_chat_as_pdf(
|
||||
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating PDF: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user