mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
feat: merge with main
This commit is contained in:
@@ -11,6 +11,7 @@ from pydub.silence import split_on_silence
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
import requests
|
||||
import mimetypes
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
@@ -36,6 +37,7 @@ from open_webui.config import (
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
DEVICE_TYPE,
|
||||
@@ -52,7 +54,7 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -69,7 +71,7 @@ from pydub.utils import mediainfo
|
||||
def is_mp4_audio(file_path):
|
||||
"""Check if the given file is an MP4 audio file."""
|
||||
if not os.path.isfile(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
log.error(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
info = mediainfo(file_path)
|
||||
@@ -86,7 +88,7 @@ def convert_mp4_to_wav(file_path, output_path):
|
||||
"""Convert MP4 audio file to WAV format."""
|
||||
audio = AudioSegment.from_file(file_path, format="mp4")
|
||||
audio.export(output_path, format="wav")
|
||||
print(f"Converted {file_path} to {output_path}")
|
||||
log.info(f"Converted {file_path} to {output_path}")
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
@@ -138,6 +140,7 @@ class STTConfigForm(BaseModel):
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
WHISPER_MODEL: str
|
||||
DEEPGRAM_API_KEY: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
@@ -165,6 +168,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -190,6 +194,7 @@ async def update_audio_config(
|
||||
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||
request.app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
@@ -214,6 +219,7 @@ async def update_audio_config(
|
||||
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -260,8 +266,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
payload["model"] = request.app.state.config.TTS_MODEL
|
||||
|
||||
try:
|
||||
# print(payload)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
json=payload,
|
||||
@@ -318,7 +326,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
||||
json={
|
||||
@@ -375,7 +386,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||
<voice name="{language}">{payload["input"]}</voice>
|
||||
</speak>"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
|
||||
headers={
|
||||
@@ -453,7 +467,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path):
|
||||
print("transcribe", file_path)
|
||||
log.info(f"transcribe: {file_path}")
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
@@ -521,6 +535,69 @@ def transcribe(request: Request, file_path):
|
||||
|
||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||
|
||||
elif request.app.state.config.STT_ENGINE == "deepgram":
|
||||
try:
|
||||
# Determine the MIME type of the file
|
||||
mime, _ = mimetypes.guess_type(file_path)
|
||||
if not mime:
|
||||
mime = "audio/wav" # fallback to wav if undetectable
|
||||
|
||||
# Read the audio file
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Build headers and parameters
|
||||
headers = {
|
||||
"Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
|
||||
"Content-Type": mime,
|
||||
}
|
||||
|
||||
# Add model if specified
|
||||
params = {}
|
||||
if request.app.state.config.STT_MODEL:
|
||||
params["model"] = request.app.state.config.STT_MODEL
|
||||
|
||||
# Make request to Deepgram API
|
||||
r = requests.post(
|
||||
"https://api.deepgram.com/v1/listen",
|
||||
headers=headers,
|
||||
params=params,
|
||||
data=file_data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
response_data = r.json()
|
||||
|
||||
# Extract transcript from Deepgram response
|
||||
try:
|
||||
transcript = response_data["results"]["channels"][0]["alternatives"][
|
||||
0
|
||||
].get("transcript", "")
|
||||
except (KeyError, IndexError) as e:
|
||||
log.error(f"Malformed response from Deepgram: {str(e)}")
|
||||
raise Exception(
|
||||
"Failed to parse Deepgram response - unexpected response format"
|
||||
)
|
||||
data = {"text": transcript.strip()}
|
||||
|
||||
# Save transcript
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||
|
||||
|
||||
def compress_audio(file_path):
|
||||
if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
||||
@@ -602,7 +679,22 @@ def transcription(
|
||||
def get_available_models(request: Request) -> list[dict]:
|
||||
available_models = []
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
available_models = data.get("models", [])
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching models from custom endpoint: {str(e)}")
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
else:
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
response = requests.get(
|
||||
@@ -633,14 +725,37 @@ def get_available_voices(request) -> dict:
|
||||
"""Returns {voice_id: voice_name} dict"""
|
||||
available_voices = {}
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
voices_list = data.get("voices", [])
|
||||
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
else:
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
available_voices = get_elevenlabs_voices(
|
||||
|
||||
@@ -25,16 +25,13 @@ from open_webui.env import (
|
||||
WEBUI_AUTH,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse, Response
|
||||
from open_webui.config import (
|
||||
OPENID_PROVIDER_URL,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
)
|
||||
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
||||
from pydantic import BaseModel
|
||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||
from open_webui.utils.auth import (
|
||||
@@ -51,8 +48,10 @@ from open_webui.utils.access_control import get_permissions
|
||||
from typing import Optional, List
|
||||
|
||||
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
|
||||
if ENABLE_LDAP.value:
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -95,8 +94,8 @@ async def get_session_user(
|
||||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
@@ -164,7 +163,7 @@ async def update_password(
|
||||
############################
|
||||
# LDAP Authentication
|
||||
############################
|
||||
@router.post("/ldap", response_model=SigninResponse)
|
||||
@router.post("/ldap", response_model=SessionUserResponse)
|
||||
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
|
||||
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
|
||||
@@ -231,9 +230,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
|
||||
entry = connection_app.entries[0]
|
||||
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
||||
mail = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
|
||||
if not mail or mail == "" or mail == "[]":
|
||||
raise HTTPException(400, f"User {form_data.user} does not have mail.")
|
||||
email = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
|
||||
if not email or email == "" or email == "[]":
|
||||
raise HTTPException(400, f"User {form_data.user} does not have email.")
|
||||
else:
|
||||
email = email.lower()
|
||||
|
||||
cn = str(entry["cn"])
|
||||
user_dn = entry.entry_dn
|
||||
|
||||
@@ -248,17 +250,22 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
if not connection_user.bind():
|
||||
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
||||
|
||||
user = Users.get_user_by_email(mail)
|
||||
user = Users.get_user_by_email(email)
|
||||
if not user:
|
||||
try:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
role = (
|
||||
"admin"
|
||||
if Users.get_num_users() == 0
|
||||
if user_count == 0
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
user = Auths.insert_new_auth(
|
||||
email=mail, password=str(uuid.uuid4()), name=cn, role=role
|
||||
email=email,
|
||||
password=str(uuid.uuid4()),
|
||||
name=cn,
|
||||
role=role,
|
||||
)
|
||||
|
||||
if not user:
|
||||
@@ -271,7 +278,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
|
||||
user = Auths.authenticate_user_by_trusted_header(mail)
|
||||
user = Auths.authenticate_user_by_trusted_header(email)
|
||||
|
||||
if user:
|
||||
token = create_token(
|
||||
@@ -288,6 +295,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
@@ -296,6 +307,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
"profile_image_url": user.profile_image_url,
|
||||
"permissions": user_permissions,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
@@ -378,8 +390,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
@@ -408,6 +420,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
|
||||
@router.post("/signup", response_model=SessionUserResponse)
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
|
||||
if WEBUI_AUTH:
|
||||
if (
|
||||
not request.app.state.config.ENABLE_SIGNUP
|
||||
@@ -422,6 +435,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
user_count = Users.get_num_users()
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
@@ -432,12 +446,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
|
||||
try:
|
||||
role = (
|
||||
"admin"
|
||||
if Users.get_num_users() == 0
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
if Users.get_num_users() == 0:
|
||||
if user_count == 0:
|
||||
# Disable signup after the first user is created
|
||||
request.app.state.config.ENABLE_SIGNUP = False
|
||||
|
||||
@@ -473,12 +485,13 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
if request.app.state.config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
@@ -525,7 +538,8 @@ async def signout(request: Request, response: Response):
|
||||
if logout_url:
|
||||
response.delete_cookie("oauth_id_token")
|
||||
return RedirectResponse(
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}"
|
||||
headers=response.headers,
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -591,7 +605,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
|
||||
print(admin_email, admin_name)
|
||||
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
||||
|
||||
if admin_email:
|
||||
admin = Users.get_user_by_email(admin_email)
|
||||
|
||||
@@ -192,7 +192,7 @@ async def get_channel_messages(
|
||||
############################
|
||||
|
||||
|
||||
async def send_notification(webui_url, channel, message, active_user_ids):
|
||||
async def send_notification(name, webui_url, channel, message, active_user_ids):
|
||||
users = get_users_with_access("read", channel.access_control)
|
||||
|
||||
for user in users:
|
||||
@@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
|
||||
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
name,
|
||||
webhook_url,
|
||||
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
|
||||
{
|
||||
@@ -302,6 +303,7 @@ async def post_new_message(
|
||||
|
||||
background_tasks.add_task(
|
||||
send_notification,
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBUI_URL,
|
||||
channel,
|
||||
message,
|
||||
|
||||
@@ -444,15 +444,21 @@ async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
############################
|
||||
|
||||
|
||||
class CloneForm(BaseModel):
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
|
||||
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def clone_chat_by_id(
|
||||
form_data: CloneForm, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
updated_chat = {
|
||||
**chat.chat,
|
||||
"originalChatId": chat.id,
|
||||
"branchPointMessageId": chat.chat["history"]["currentId"],
|
||||
"title": f"Clone of {chat.title}",
|
||||
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
||||
}
|
||||
|
||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
||||
|
||||
@@ -36,6 +36,140 @@ async def export_config(user=Depends(get_admin_user)):
|
||||
return get_config()
|
||||
|
||||
|
||||
############################
|
||||
# Direct Connections Config
|
||||
############################
|
||||
|
||||
|
||||
class DirectConnectionsConfigForm(BaseModel):
|
||||
ENABLE_DIRECT_CONNECTIONS: bool
|
||||
|
||||
|
||||
@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||
async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||
async def set_direct_connections_config(
|
||||
request: Request,
|
||||
form_data: DirectConnectionsConfigForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
|
||||
form_data.ENABLE_DIRECT_CONNECTIONS
|
||||
)
|
||||
return {
|
||||
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# CodeInterpreterConfig
|
||||
############################
|
||||
class CodeInterpreterConfigForm(BaseModel):
|
||||
CODE_EXECUTION_ENGINE: str
|
||||
CODE_EXECUTION_JUPYTER_URL: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
|
||||
ENABLE_CODE_INTERPRETER: bool
|
||||
CODE_INTERPRETER_ENGINE: str
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_URL: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
|
||||
|
||||
|
||||
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_execution_config(
|
||||
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_URL
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
|
||||
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
|
||||
form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_URL
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
return {
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# SetDefaultModels
|
||||
############################
|
||||
|
||||
@@ -3,30 +3,23 @@ import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
import mimetypes
|
||||
from urllib.parse import quote
|
||||
|
||||
from open_webui.storage.provider import Storage
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.models.files import (
|
||||
FileForm,
|
||||
FileModel,
|
||||
FileModelResponse,
|
||||
Files,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_file, ProcessFileForm
|
||||
|
||||
from open_webui.config import UPLOAD_DIR
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||
from open_webui.routers.audio import transcribe
|
||||
from open_webui.storage.provider import Storage
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from pydantic import BaseModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -41,7 +34,10 @@ router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=FileModelResponse)
|
||||
def upload_file(
|
||||
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
file_metadata: dict = {},
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
try:
|
||||
@@ -65,13 +61,29 @@ def upload_file(
|
||||
"name": name,
|
||||
"content_type": file.content_type,
|
||||
"size": len(contents),
|
||||
"data": file_metadata,
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
process_file(request, ProcessFileForm(file_id=id))
|
||||
if file.content_type in [
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/x-m4a",
|
||||
]:
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path)
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@@ -126,7 +138,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
||||
Storage.delete_all_files()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting files")
|
||||
log.error("Error deleting files")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||
@@ -193,7 +205,9 @@ async def update_file_data_content_by_id(
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
try:
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=id, content=form_data.content)
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=form_data.content),
|
||||
user=user,
|
||||
)
|
||||
file = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
@@ -227,17 +241,24 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename) # RFC5987 encoding
|
||||
|
||||
content_type = file.meta.get("content_type")
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename)
|
||||
headers = {}
|
||||
if file.meta.get("content_type") not in [
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
]:
|
||||
headers = {
|
||||
**headers,
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
}
|
||||
|
||||
return FileResponse(file_path, headers=headers)
|
||||
if content_type == "application/pdf" or filename.lower().endswith(
|
||||
".pdf"
|
||||
):
|
||||
headers["Content-Disposition"] = (
|
||||
f"inline; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
content_type = "application/pdf"
|
||||
elif content_type != "text/plain":
|
||||
headers["Content-Disposition"] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
|
||||
return FileResponse(file_path, headers=headers, media_type=content_type)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -246,7 +267,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error getting file content")
|
||||
log.error("Error getting file content")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||
@@ -268,7 +289,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
log.info(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -277,7 +298,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error getting file content")
|
||||
log.error("Error getting file content")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||
@@ -353,7 +374,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
Storage.delete_file(file.path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting files")
|
||||
log.error("Error deleting files")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -13,6 +14,11 @@ from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -68,7 +74,7 @@ async def create_new_function(
|
||||
|
||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
||||
|
||||
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
|
||||
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if function:
|
||||
@@ -79,7 +85,7 @@ async def create_new_function(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to create a new function: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -183,7 +189,7 @@ async def update_function_by_id(
|
||||
FUNCTIONS[id] = function_module
|
||||
|
||||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
|
||||
function = Functions.update_function_by_id(id, updated)
|
||||
|
||||
@@ -299,7 +305,7 @@ async def update_function_valves_by_id(
|
||||
Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function values by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -388,7 +394,7 @@ async def update_function_user_valves_by_id(
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
||||
16
backend/open_webui/routers/groups.py
Normal file → Executable file
16
backend/open_webui/routers/groups.py
Normal file → Executable file
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import (
|
||||
@@ -14,7 +14,13 @@ from open_webui.models.groups import (
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[GroupResponse])
|
||||
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
group = Groups.insert_new_group(user.id, form_data)
|
||||
if group:
|
||||
@@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error creating a new group: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -94,7 +100,7 @@ async def update_group_by_id(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error deleting group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
||||
@@ -1,37 +1,31 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
|
||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||
from open_webui.routers.files import upload_file
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.images.comfyui import (
|
||||
ComfyUIGenerateImageForm,
|
||||
ComfyUIWorkflow,
|
||||
comfyui_generate_image,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||
|
||||
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
||||
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
|
||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -61,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -84,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
|
||||
COMFYUI_WORKFLOW_NODES: list[dict]
|
||||
|
||||
|
||||
class GeminiConfigForm(BaseModel):
|
||||
GEMINI_API_BASE_URL: str
|
||||
GEMINI_API_KEY: str
|
||||
|
||||
|
||||
class ConfigForm(BaseModel):
|
||||
enabled: bool
|
||||
engine: str
|
||||
@@ -91,6 +94,7 @@ class ConfigForm(BaseModel):
|
||||
openai: OpenAIConfigForm
|
||||
automatic1111: Automatic1111ConfigForm
|
||||
comfyui: ComfyUIConfigForm
|
||||
gemini: GeminiConfigForm
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
@@ -109,6 +113,11 @@ async def update_config(
|
||||
)
|
||||
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
||||
|
||||
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
|
||||
form_data.gemini.GEMINI_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
|
||||
|
||||
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
||||
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
||||
)
|
||||
@@ -135,6 +144,8 @@ async def update_config(
|
||||
request.app.state.config.COMFYUI_BASE_URL = (
|
||||
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
||||
)
|
||||
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
|
||||
|
||||
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
||||
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
|
||||
form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
||||
@@ -161,6 +172,10 @@ async def update_config(
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -190,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
|
||||
headers = None
|
||||
if request.app.state.config.COMFYUI_API_KEY:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
||||
headers=headers,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
@@ -230,6 +253,12 @@ def get_image_model(request):
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "dall-e-2"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "imagen-3.0-generate-002"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
@@ -271,7 +300,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
|
||||
async def update_image_config(
|
||||
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
set_image_model(request, form_data.MODEL)
|
||||
|
||||
pattern = r"^\d+x\d+$"
|
||||
@@ -306,6 +334,10 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return [
|
||||
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
# TODO - get models from comfyui
|
||||
headers = {
|
||||
@@ -329,7 +361,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
if model_node_id:
|
||||
model_list_key = None
|
||||
|
||||
print(workflow[model_node_id]["class_type"])
|
||||
log.info(workflow[model_node_id]["class_type"])
|
||||
for key in info[workflow[model_node_id]["class_type"]]["input"][
|
||||
"required"
|
||||
]:
|
||||
@@ -383,40 +415,22 @@ class GenerateImageForm(BaseModel):
|
||||
negative_prompt: Optional[str] = None
|
||||
|
||||
|
||||
def save_b64_image(b64_str):
|
||||
def load_b64_image_data(b64_str):
|
||||
try:
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
if "," in b64_str:
|
||||
header, encoded = b64_str.split(",", 1)
|
||||
mime_type = header.split(";")[0]
|
||||
|
||||
img_data = base64.b64decode(encoded)
|
||||
image_format = mimetypes.guess_extension(mime_type)
|
||||
|
||||
image_filename = f"{image_id}{image_format}"
|
||||
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
return image_filename
|
||||
else:
|
||||
image_filename = f"{image_id}.png"
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
|
||||
|
||||
mime_type = "image/png"
|
||||
img_data = base64.b64decode(b64_str)
|
||||
|
||||
# Write the image data to a file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
return image_filename
|
||||
|
||||
return img_data, mime_type
|
||||
except Exception as e:
|
||||
log.exception(f"Error saving image: {e}")
|
||||
log.exception(f"Error loading image data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_url_image(url, headers=None):
|
||||
image_id = str(uuid.uuid4())
|
||||
def load_url_image_data(url, headers=None):
|
||||
try:
|
||||
if headers:
|
||||
r = requests.get(url, headers=headers)
|
||||
@@ -426,18 +440,7 @@ def save_url_image(url, headers=None):
|
||||
r.raise_for_status()
|
||||
if r.headers["content-type"].split("/")[0] == "image":
|
||||
mime_type = r.headers["content-type"]
|
||||
image_format = mimetypes.guess_extension(mime_type)
|
||||
|
||||
if not image_format:
|
||||
raise ValueError("Could not determine image type from MIME type")
|
||||
|
||||
image_filename = f"{image_id}{image_format}"
|
||||
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
|
||||
with open(file_path, "wb") as image_file:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
image_file.write(chunk)
|
||||
return image_filename
|
||||
return r.content, mime_type
|
||||
else:
|
||||
log.error("Url does not point to an image.")
|
||||
return None
|
||||
@@ -447,6 +450,20 @@ def save_url_image(url, headers=None):
|
||||
return None
|
||||
|
||||
|
||||
def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
image_format = mimetypes.guess_extension(content_type)
|
||||
file = UploadFile(
|
||||
file=io.BytesIO(image_data),
|
||||
filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
|
||||
headers={
|
||||
"content-type": content_type,
|
||||
},
|
||||
)
|
||||
file_item = upload_file(request, file, user, file_metadata=image_metadata)
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/generations")
|
||||
async def image_generations(
|
||||
request: Request,
|
||||
@@ -500,12 +517,49 @@ async def image_generations(
|
||||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
image_filename = save_b64_image(image["b64_json"])
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
if "url" in image:
|
||||
image_data, content_type = load_url_image_data(
|
||||
image["url"], headers
|
||||
)
|
||||
else:
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(data, f)
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
|
||||
|
||||
model = get_image_model(request)
|
||||
data = {
|
||||
"instances": {"prompt": form_data.prompt},
|
||||
"parameters": {
|
||||
"sampleCount": form_data.n,
|
||||
"outputOptions": {"mimeType": "image/png"},
|
||||
},
|
||||
}
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
res = r.json()
|
||||
|
||||
images = []
|
||||
for image in res["predictions"]:
|
||||
image_data, content_type = load_b64_image_data(
|
||||
image["bytesBase64Encoded"]
|
||||
)
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
images.append({"url": url})
|
||||
|
||||
return images
|
||||
|
||||
@@ -552,14 +606,15 @@ async def image_generations(
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
|
||||
image_filename = save_url_image(image["url"], headers)
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(form_data.model_dump(exclude_none=True), f)
|
||||
|
||||
log.debug(f"images: {images}")
|
||||
image_data, content_type = load_url_image_data(image["url"], headers)
|
||||
url = upload_image(
|
||||
request,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
image_data,
|
||||
content_type,
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
elif (
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
||||
@@ -604,13 +659,15 @@ async def image_generations(
|
||||
images = []
|
||||
|
||||
for image in res["images"]:
|
||||
image_filename = save_b64_image(image)
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump({**data, "info": res["info"]}, f)
|
||||
|
||||
image_data, content_type = load_b64_image_data(image)
|
||||
url = upload_image(
|
||||
request,
|
||||
{**data, "info": res["info"]},
|
||||
image_data,
|
||||
content_type,
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
except Exception as e:
|
||||
error = e
|
||||
|
||||
@@ -264,7 +264,11 @@ def add_file_to_knowledge_by_id(
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
@@ -285,7 +289,9 @@ def add_file_to_knowledge_by_id(
|
||||
# Add content to the vector database
|
||||
try:
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
||||
request,
|
||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
@@ -342,7 +348,12 @@ def update_file_from_knowledge_by_id(
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
@@ -363,7 +374,9 @@ def update_file_from_knowledge_by_id(
|
||||
# Add content to the vector database
|
||||
try:
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
||||
request,
|
||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
@@ -406,7 +419,11 @@ def remove_file_from_knowledge_by_id(
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
@@ -429,10 +446,6 @@ def remove_file_from_knowledge_by_id(
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
||||
|
||||
# Delete physical file
|
||||
if file.path:
|
||||
Storage.delete_file(file.path)
|
||||
|
||||
# Delete file from database
|
||||
Files.delete_file_by_id(form_data.file_id)
|
||||
|
||||
@@ -484,7 +497,11 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
@@ -543,7 +560,11 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
@@ -582,14 +603,18 @@ def add_files_to_knowledge_batch(
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
# Get files content
|
||||
print(f"files/batch/add - {len(form_data)} files")
|
||||
log.info(f"files/batch/add - {len(form_data)} files")
|
||||
files: List[FileModel] = []
|
||||
for form in form_data:
|
||||
file = Files.get_file_by_id(form.file_id)
|
||||
|
||||
@@ -57,7 +57,7 @@ async def add_memory(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"metadata": {"created_at": memory.created_at},
|
||||
}
|
||||
],
|
||||
@@ -82,7 +82,7 @@ async def query_memory(
|
||||
):
|
||||
results = VECTOR_DB_CLIENT.search(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
|
||||
limit=form_data.k,
|
||||
)
|
||||
|
||||
@@ -105,7 +105,7 @@ async def reset_memory_from_vector_db(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
@@ -160,7 +160,9 @@ async def update_memory_by_id(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user
|
||||
),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
|
||||
@@ -183,7 +183,11 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if model.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
user.role != "admin"
|
||||
and model.user_id != user.id
|
||||
and not has_access(user.id, "write", model.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
|
||||
@@ -11,11 +11,14 @@ import re
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from aiocache import cached
|
||||
|
||||
import requests
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.env import (
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
@@ -28,7 +31,7 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
|
||||
@@ -52,7 +55,7 @@ from open_webui.env import (
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
@@ -68,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -98,6 +115,7 @@ async def send_post_request(
|
||||
stream: bool = True,
|
||||
key: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
user: UserModel = None,
|
||||
):
|
||||
|
||||
r = None
|
||||
@@ -112,6 +130,16 @@ async def send_post_request(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
@@ -188,12 +216,24 @@ async def verify_connection(
|
||||
key = form_data.key
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/api/version",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
detail = f"HTTP Error: {r.status}"
|
||||
@@ -256,7 +296,7 @@ async def update_config(
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request):
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
log.info("get_all_models()")
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
@@ -264,7 +304,7 @@ async def get_all_models(request: Request):
|
||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags"))
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
@@ -277,7 +317,9 @@ async def get_all_models(request: Request):
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", key))
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/tags", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
@@ -362,7 +404,7 @@ async def get_ollama_tags(
|
||||
models = []
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
@@ -372,7 +414,19 @@ async def get_ollama_tags(
|
||||
r = requests.request(
|
||||
method="GET",
|
||||
url=f"{url}/api/tags",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -395,7 +449,7 @@ async def get_ollama_tags(
|
||||
)
|
||||
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
models["models"] = get_filtered_models(models, user)
|
||||
models["models"] = await get_filtered_models(models, user)
|
||||
|
||||
return models
|
||||
|
||||
@@ -479,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
|
||||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
user=user,
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
@@ -511,6 +566,7 @@ async def pull_model(
|
||||
url=f"{url}/api/pull",
|
||||
payload=json.dumps(payload),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -529,7 +585,7 @@ async def push_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
@@ -547,6 +603,7 @@ async def push_model(
|
||||
url=f"{url}/api/push",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -573,6 +630,7 @@ async def create_model(
|
||||
url=f"{url}/api/create",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -590,7 +648,7 @@ async def copy_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.source in models:
|
||||
@@ -611,6 +669,16 @@ async def copy_model(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -645,7 +713,7 @@ async def delete_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
@@ -667,6 +735,16 @@ async def delete_model(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
@@ -695,7 +773,7 @@ async def delete_model(
|
||||
async def show_model_info(
|
||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||
):
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name not in models:
|
||||
@@ -716,6 +794,16 @@ async def show_model_info(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -759,7 +847,7 @@ async def embed(
|
||||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
@@ -785,6 +873,16 @@ async def embed(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -828,7 +926,7 @@ async def embeddings(
|
||||
log.info(f"generate_ollama_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
@@ -854,6 +952,16 @@ async def embeddings(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -903,7 +1011,7 @@ async def generate_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
@@ -933,23 +1041,39 @@ async def generate_completion(
|
||||
url=f"{url}/api/generate",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[list[dict]] = None
|
||||
images: Optional[list[str]] = None
|
||||
|
||||
@validator("content", pre=True)
|
||||
@classmethod
|
||||
def check_at_least_one_field(cls, field_value, values, **kwargs):
|
||||
# Raise an error if both 'content' and 'tool_calls' are None
|
||||
if field_value is None and (
|
||||
"tool_calls" not in values or values["tool_calls"] is None
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of 'content' or 'tool_calls' must be provided"
|
||||
)
|
||||
|
||||
return field_value
|
||||
|
||||
|
||||
class GenerateChatCompletionForm(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage]
|
||||
format: Optional[dict] = None
|
||||
format: Optional[Union[dict, str]] = None
|
||||
options: Optional[dict] = None
|
||||
template: Optional[str] = None
|
||||
stream: Optional[bool] = True
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
|
||||
|
||||
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
|
||||
@@ -977,6 +1101,7 @@ async def generate_chat_completion(
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
metadata = form_data.pop("metadata", None)
|
||||
try:
|
||||
form_data = GenerateChatCompletionForm(**form_data)
|
||||
except Exception as e:
|
||||
@@ -1006,7 +1131,7 @@ async def generate_chat_completion(
|
||||
payload["options"] = apply_model_params_to_body_ollama(
|
||||
params, payload["options"]
|
||||
)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
@@ -1046,6 +1171,7 @@ async def generate_chat_completion(
|
||||
stream=form_data.stream,
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
content_type="application/x-ndjson",
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1148,6 +1274,7 @@ async def generate_openai_completion(
|
||||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1159,6 +1286,8 @@ async def generate_openai_chat_completion(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
metadata = form_data.pop("metadata", None)
|
||||
|
||||
try:
|
||||
completion_form = OpenAIChatCompletionForm(**form_data)
|
||||
except Exception as e:
|
||||
@@ -1185,7 +1314,7 @@ async def generate_openai_chat_completion(
|
||||
|
||||
if params:
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
@@ -1224,6 +1353,7 @@ async def generate_openai_chat_completion(
|
||||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1237,7 +1367,7 @@ async def get_openai_models(
|
||||
|
||||
models = []
|
||||
if url_idx is None:
|
||||
model_list = await get_all_models(request)
|
||||
model_list = await get_all_models(request, user=user)
|
||||
models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
@@ -1405,9 +1535,10 @@ async def download_model(
|
||||
return None
|
||||
|
||||
|
||||
# TODO: Progress bar does not reflect size & duration of upload.
|
||||
@router.post("/models/upload")
|
||||
@router.post("/models/upload/{url_idx}")
|
||||
def upload_model(
|
||||
async def upload_model(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
url_idx: Optional[int] = None,
|
||||
@@ -1416,59 +1547,85 @@ def upload_model(
|
||||
if url_idx is None:
|
||||
url_idx = 0
|
||||
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
||||
# --- P1: save file locally ---
|
||||
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
|
||||
with open(file_path, "wb") as out_f:
|
||||
while True:
|
||||
chunk = file.file.read(chunk_size)
|
||||
# log.info(f"Chunk: {str(chunk)}") # DEBUG
|
||||
if not chunk:
|
||||
break
|
||||
out_f.write(chunk)
|
||||
|
||||
# Save file in chunks
|
||||
with open(file_path, "wb+") as f:
|
||||
for chunk in file.file:
|
||||
f.write(chunk)
|
||||
|
||||
def file_process_stream():
|
||||
async def file_process_stream():
|
||||
nonlocal ollama_url
|
||||
total_size = os.path.getsize(file_path)
|
||||
chunk_size = 1024 * 1024
|
||||
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
|
||||
|
||||
# --- P2: SSE progress + calculate sha256 hash ---
|
||||
file_hash = calculate_sha256(file_path, chunk_size)
|
||||
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
total = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
done = True
|
||||
continue
|
||||
|
||||
total += len(chunk)
|
||||
progress = round((total / total_size) * 100, 2)
|
||||
|
||||
res = {
|
||||
bytes_read = 0
|
||||
while chunk := f.read(chunk_size):
|
||||
bytes_read += len(chunk)
|
||||
progress = round(bytes_read / total_size * 100, 2)
|
||||
data_msg = {
|
||||
"progress": progress,
|
||||
"total": total_size,
|
||||
"completed": total,
|
||||
"completed": bytes_read,
|
||||
}
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
yield f"data: {json.dumps(data_msg)}\n\n"
|
||||
|
||||
if done:
|
||||
f.seek(0)
|
||||
hashed = calculate_sha256(f)
|
||||
f.seek(0)
|
||||
# --- P3: Upload to ollama /api/blobs ---
|
||||
with open(file_path, "rb") as f:
|
||||
url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
|
||||
response = requests.post(url, data=f)
|
||||
|
||||
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
||||
response = requests.post(url, data=f)
|
||||
if response.ok:
|
||||
log.info(f"Uploaded to /api/blobs") # DEBUG
|
||||
# Remove local file
|
||||
os.remove(file_path)
|
||||
|
||||
if response.ok:
|
||||
res = {
|
||||
"done": done,
|
||||
"blob": f"sha256:{hashed}",
|
||||
"name": file.filename,
|
||||
}
|
||||
os.remove(file_path)
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
else:
|
||||
raise Exception(
|
||||
"Ollama: Could not create blob, Please try again."
|
||||
)
|
||||
# Create model in ollama
|
||||
model_name, ext = os.path.splitext(file.filename)
|
||||
log.info(f"Created Model: {model_name}") # DEBUG
|
||||
|
||||
create_payload = {
|
||||
"model": model_name,
|
||||
# Reference the file by its original name => the uploaded blob's digest
|
||||
"files": {file.filename: f"sha256:{file_hash}"},
|
||||
}
|
||||
log.info(f"Model Payload: {create_payload}") # DEBUG
|
||||
|
||||
# Call ollama /api/create
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
|
||||
create_resp = requests.post(
|
||||
url=f"{ollama_url}/api/create",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(create_payload),
|
||||
)
|
||||
|
||||
if create_resp.ok:
|
||||
log.info(f"API SUCCESS!") # DEBUG
|
||||
done_msg = {
|
||||
"done": True,
|
||||
"blob": f"sha256:{file_hash}",
|
||||
"name": file.filename,
|
||||
"model_created": model_name,
|
||||
}
|
||||
yield f"data: {json.dumps(done_msg)}\n\n"
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to create model in Ollama. {create_resp.text}"
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Ollama: Could not create blob, Please try again.")
|
||||
|
||||
except Exception as e:
|
||||
res = {"error": str(e)}
|
||||
|
||||
@@ -22,10 +22,11 @@ from open_webui.config import (
|
||||
)
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
||||
@@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
||||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -75,18 +89,24 @@ async def cleanup_response(
|
||||
await session.close()
|
||||
|
||||
|
||||
def openai_o1_handler(payload):
|
||||
def openai_o1_o3_handler(payload):
|
||||
"""
|
||||
Handle O1 specific parameters
|
||||
Handle o1, o3 specific parameters
|
||||
"""
|
||||
if "max_tokens" in payload:
|
||||
# Remove "max_tokens" from the payload
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
|
||||
# Fix: o1 and o3 do not support the "system" role directly.
|
||||
# For older models like "o1-mini" or "o1-preview", use role "user".
|
||||
# For newer o1/o3 models, replace "system" with "developer".
|
||||
if payload["messages"][0]["role"] == "system":
|
||||
payload["messages"][0]["role"] = "user"
|
||||
model_lower = payload["model"].lower()
|
||||
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
|
||||
payload["messages"][0]["role"] = "user"
|
||||
else:
|
||||
payload["messages"][0]["role"] = "developer"
|
||||
|
||||
return payload
|
||||
|
||||
@@ -172,7 +192,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
body = await request.body()
|
||||
name = hashlib.sha256(body).hexdigest()
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
||||
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
||||
@@ -247,7 +267,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
||||
|
||||
|
||||
async def get_all_models_responses(request: Request) -> list:
|
||||
async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return []
|
||||
|
||||
@@ -271,7 +291,9 @@ async def get_all_models_responses(request: Request) -> list:
|
||||
):
|
||||
request_tasks.append(
|
||||
send_get_request(
|
||||
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -291,6 +313,7 @@ async def get_all_models_responses(request: Request) -> list:
|
||||
send_get_request(
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -352,13 +375,13 @@ async def get_filtered_models(models, user):
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request) -> dict[str, list]:
|
||||
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||
log.info("get_all_models()")
|
||||
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return {"data": []}
|
||||
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user=user)
|
||||
|
||||
def extract_data(response):
|
||||
if response and "data" in response:
|
||||
@@ -418,16 +441,14 @@ async def get_models(
|
||||
}
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||
|
||||
r = None
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
||||
)
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
@@ -489,7 +510,7 @@ async def get_models(
|
||||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
models["data"] = get_filtered_models(models, user)
|
||||
models["data"] = await get_filtered_models(models, user)
|
||||
|
||||
return models
|
||||
|
||||
@@ -507,7 +528,7 @@ async def verify_connection(
|
||||
key = form_data.key
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
@@ -515,6 +536,16 @@ async def verify_connection(
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
@@ -551,9 +582,9 @@ async def generate_chat_completion(
|
||||
bypass_filter = True
|
||||
|
||||
idx = 0
|
||||
|
||||
payload = {**form_data}
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
metadata = payload.pop("metadata", None)
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
@@ -566,7 +597,7 @@ async def generate_chat_completion(
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
@@ -587,7 +618,7 @@ async def generate_chat_completion(
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||
if model:
|
||||
idx = model["urlIdx"]
|
||||
@@ -621,10 +652,10 @@ async def generate_chat_completion(
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1 = payload["model"].lower().startswith("o1-")
|
||||
if is_o1:
|
||||
payload = openai_o1_handler(payload)
|
||||
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
|
||||
if is_o1_o3:
|
||||
payload = openai_o1_o3_handler(payload)
|
||||
elif "api.openai.com" not in url:
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
if "max_completion_tokens" in payload:
|
||||
@@ -777,7 +808,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
print(res)
|
||||
log.error(res)
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except Exception:
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi import (
|
||||
status,
|
||||
APIRouter,
|
||||
)
|
||||
import aiohttp
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
@@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models):
|
||||
return sorted_filters
|
||||
|
||||
|
||||
def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
if urlIdx is None:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key == "":
|
||||
if not key:
|
||||
continue
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
request_data = {
|
||||
"user": user,
|
||||
"body": payload,
|
||||
}
|
||||
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
res = r.json()
|
||||
try:
|
||||
async with session.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
res = (
|
||||
await response.json()
|
||||
if response.content_type == "application/json"
|
||||
else {}
|
||||
)
|
||||
if "detail" in res:
|
||||
raise Exception(r.status_code, res["detail"])
|
||||
raise Exception(response.status, res["detail"])
|
||||
except Exception as e:
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
if urlIdx is None:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
r = requests.post(
|
||||
if not key:
|
||||
continue
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
request_data = {
|
||||
"user": user,
|
||||
"body": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers={"Authorization": f"Bearer {key}"},
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
payload = data
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
try:
|
||||
res = r.json()
|
||||
res = (
|
||||
await response.json()
|
||||
if "application/json" in response.content_type
|
||||
else {}
|
||||
)
|
||||
if "detail" in res:
|
||||
return Exception(r.status_code, res)
|
||||
raise Exception(response.status, res)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
@@ -161,7 +169,7 @@ router = APIRouter()
|
||||
|
||||
@router.get("/list")
|
||||
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user)
|
||||
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
||||
|
||||
urlIdxs = [
|
||||
@@ -188,7 +196,7 @@ async def upload_pipeline(
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
print("upload_pipeline", urlIdx, file.filename)
|
||||
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
|
||||
# Check if the uploaded file is a python file
|
||||
if not (file.filename and file.filename.endswith(".py")):
|
||||
raise HTTPException(
|
||||
@@ -223,7 +231,7 @@ async def upload_pipeline(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
@@ -274,7 +282,7 @@ async def add_pipeline(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -319,7 +327,7 @@ async def delete_pipeline(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -353,7 +361,7 @@ async def get_pipelines(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -392,7 +400,7 @@ async def get_pipeline_valves(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -432,7 +440,7 @@ async def get_pipeline_valves_spec(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -474,7 +482,7 @@ async def update_pipeline_valves(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
|
||||
|
||||
@@ -147,7 +147,11 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if prompt.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
prompt.user_id != user.id
|
||||
and not has_access(user.id, "write", prompt.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
|
||||
@@ -21,6 +21,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
import tiktoken
|
||||
|
||||
@@ -45,17 +46,20 @@ from open_webui.retrieval.web.utils import get_web_loader
|
||||
from open_webui.retrieval.web.brave import search_brave
|
||||
from open_webui.retrieval.web.kagi import search_kagi
|
||||
from open_webui.retrieval.web.mojeek import search_mojeek
|
||||
from open_webui.retrieval.web.bocha import search_bocha
|
||||
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
|
||||
from open_webui.retrieval.web.google_pse import search_google_pse
|
||||
from open_webui.retrieval.web.jina_search import search_jina
|
||||
from open_webui.retrieval.web.searchapi import search_searchapi
|
||||
from open_webui.retrieval.web.serpapi import search_serpapi
|
||||
from open_webui.retrieval.web.searxng import search_searxng
|
||||
from open_webui.retrieval.web.serper import search_serper
|
||||
from open_webui.retrieval.web.serply import search_serply
|
||||
from open_webui.retrieval.web.serpstack import search_serpstack
|
||||
from open_webui.retrieval.web.tavily import search_tavily
|
||||
from open_webui.retrieval.web.bing import search_bing
|
||||
|
||||
from open_webui.retrieval.web.exa import search_exa
|
||||
from open_webui.retrieval.web.perplexity import search_perplexity
|
||||
|
||||
from open_webui.retrieval.utils import (
|
||||
get_embedding_function,
|
||||
@@ -347,11 +351,18 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"status": True,
|
||||
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"enable_onedrive_integration": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
"content_extraction": {
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
},
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||
@@ -368,10 +379,12 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
},
|
||||
"web": {
|
||||
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"search": {
|
||||
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"onedrive": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||
"searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||
@@ -379,6 +392,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
||||
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
||||
@@ -386,11 +400,17 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
||||
"searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
||||
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
||||
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
|
||||
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
|
||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -401,10 +421,16 @@ class FileConfig(BaseModel):
|
||||
max_count: Optional[int] = None
|
||||
|
||||
|
||||
class DocumentIntelligenceConfigForm(BaseModel):
|
||||
endpoint: str
|
||||
key: str
|
||||
|
||||
|
||||
class ContentExtractionConfig(BaseModel):
|
||||
engine: str = ""
|
||||
tika_server_url: Optional[str] = None
|
||||
docling_server_url: Optional[str] = None
|
||||
document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
|
||||
|
||||
|
||||
class ChunkParamUpdateForm(BaseModel):
|
||||
@@ -428,6 +454,7 @@ class WebSearchConfig(BaseModel):
|
||||
brave_search_api_key: Optional[str] = None
|
||||
kagi_search_api_key: Optional[str] = None
|
||||
mojeek_search_api_key: Optional[str] = None
|
||||
bocha_search_api_key: Optional[str] = None
|
||||
serpstack_api_key: Optional[str] = None
|
||||
serpstack_https: Optional[bool] = None
|
||||
serper_api_key: Optional[str] = None
|
||||
@@ -435,21 +462,31 @@ class WebSearchConfig(BaseModel):
|
||||
tavily_api_key: Optional[str] = None
|
||||
searchapi_api_key: Optional[str] = None
|
||||
searchapi_engine: Optional[str] = None
|
||||
serpapi_api_key: Optional[str] = None
|
||||
serpapi_engine: Optional[str] = None
|
||||
jina_api_key: Optional[str] = None
|
||||
bing_search_v7_endpoint: Optional[str] = None
|
||||
bing_search_v7_subscription_key: Optional[str] = None
|
||||
exa_api_key: Optional[str] = None
|
||||
perplexity_api_key: Optional[str] = None
|
||||
result_count: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
trust_env: Optional[bool] = None
|
||||
domain_filter_list: Optional[List[str]] = []
|
||||
|
||||
|
||||
class WebConfig(BaseModel):
|
||||
search: WebSearchConfig
|
||||
web_loader_ssl_verification: Optional[bool] = None
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
|
||||
|
||||
class ConfigUpdateForm(BaseModel):
|
||||
RAG_FULL_CONTEXT: Optional[bool] = None
|
||||
BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
pdf_extract_images: Optional[bool] = None
|
||||
enable_google_drive_integration: Optional[bool] = None
|
||||
enable_onedrive_integration: Optional[bool] = None
|
||||
file: Optional[FileConfig] = None
|
||||
content_extraction: Optional[ContentExtractionConfig] = None
|
||||
chunk: Optional[ChunkParamUpdateForm] = None
|
||||
@@ -467,18 +504,38 @@ async def update_rag_config(
|
||||
else request.app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
|
||||
request.app.state.config.RAG_FULL_CONTEXT = (
|
||||
form_data.RAG_FULL_CONTEXT
|
||||
if form_data.RAG_FULL_CONTEXT is not None
|
||||
else request.app.state.config.RAG_FULL_CONTEXT
|
||||
)
|
||||
|
||||
request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = (
|
||||
form_data.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None
|
||||
else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
|
||||
form_data.enable_google_drive_integration
|
||||
if form_data.enable_google_drive_integration is not None
|
||||
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = (
|
||||
form_data.enable_onedrive_integration
|
||||
if form_data.enable_onedrive_integration is not None
|
||||
else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION
|
||||
)
|
||||
|
||||
if form_data.file is not None:
|
||||
request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
|
||||
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
|
||||
|
||||
if form_data.content_extraction is not None:
|
||||
log.info(f"Updating text settings: {form_data.content_extraction}")
|
||||
log.info(
|
||||
f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}"
|
||||
)
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
|
||||
form_data.content_extraction.engine
|
||||
)
|
||||
@@ -488,6 +545,13 @@ async def update_rag_config(
|
||||
request.app.state.config.DOCLING_SERVER_URL = (
|
||||
form_data.content_extraction.docling_server_url
|
||||
)
|
||||
if form_data.content_extraction.document_intelligence_config is not None:
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||
form_data.content_extraction.document_intelligence_config.endpoint
|
||||
)
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
|
||||
form_data.content_extraction.document_intelligence_config.key
|
||||
)
|
||||
|
||||
if form_data.chunk is not None:
|
||||
request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
|
||||
@@ -502,11 +566,16 @@ async def update_rag_config(
|
||||
if form_data.web is not None:
|
||||
request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
|
||||
form_data.web.web_loader_ssl_verification
|
||||
form_data.web.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
|
||||
request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
|
||||
|
||||
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
|
||||
request.app.state.config.SEARXNG_QUERY_URL = (
|
||||
form_data.web.search.searxng_query_url
|
||||
)
|
||||
@@ -525,6 +594,9 @@ async def update_rag_config(
|
||||
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
|
||||
form_data.web.search.mojeek_search_api_key
|
||||
)
|
||||
request.app.state.config.BOCHA_SEARCH_API_KEY = (
|
||||
form_data.web.search.bocha_search_api_key
|
||||
)
|
||||
request.app.state.config.SERPSTACK_API_KEY = (
|
||||
form_data.web.search.serpstack_api_key
|
||||
)
|
||||
@@ -539,6 +611,9 @@ async def update_rag_config(
|
||||
form_data.web.search.searchapi_engine
|
||||
)
|
||||
|
||||
request.app.state.config.SERPAPI_API_KEY = form_data.web.search.serpapi_api_key
|
||||
request.app.state.config.SERPAPI_ENGINE = form_data.web.search.serpapi_engine
|
||||
|
||||
request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
|
||||
request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
|
||||
form_data.web.search.bing_search_v7_endpoint
|
||||
@@ -547,16 +622,30 @@ async def update_rag_config(
|
||||
form_data.web.search.bing_search_v7_subscription_key
|
||||
)
|
||||
|
||||
request.app.state.config.EXA_API_KEY = form_data.web.search.exa_api_key
|
||||
|
||||
request.app.state.config.PERPLEXITY_API_KEY = (
|
||||
form_data.web.search.perplexity_api_key
|
||||
)
|
||||
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
|
||||
form_data.web.search.result_count
|
||||
)
|
||||
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||||
form_data.web.search.concurrent_requests
|
||||
)
|
||||
request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
|
||||
form_data.web.search.trust_env
|
||||
)
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
|
||||
form_data.web.search.domain_filter_list
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
"file": {
|
||||
"max_size": request.app.state.config.FILE_MAX_SIZE,
|
||||
"max_count": request.app.state.config.FILE_MAX_COUNT,
|
||||
@@ -565,6 +654,10 @@ async def update_rag_config(
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
},
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||
@@ -577,7 +670,8 @@ async def update_rag_config(
|
||||
"translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
},
|
||||
"web": {
|
||||
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"search": {
|
||||
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||
@@ -587,18 +681,25 @@ async def update_rag_config(
|
||||
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
||||
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
||||
"serply_api_key": request.app.state.config.SERPLY_API_KEY,
|
||||
"serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
||||
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
||||
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
|
||||
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
|
||||
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -666,6 +767,7 @@ def save_docs_to_vector_db(
|
||||
overwrite: bool = False,
|
||||
split: bool = True,
|
||||
add: bool = False,
|
||||
user=None,
|
||||
) -> bool:
|
||||
def _get_docs_info(docs: list[Document]) -> str:
|
||||
docs_info = set()
|
||||
@@ -746,7 +848,11 @@ def save_docs_to_vector_db(
|
||||
# for meta-data so convert them to string.
|
||||
for metadata in metadatas:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, datetime):
|
||||
if (
|
||||
isinstance(value, datetime)
|
||||
or isinstance(value, list)
|
||||
or isinstance(value, dict)
|
||||
):
|
||||
metadata[key] = str(value)
|
||||
|
||||
try:
|
||||
@@ -781,7 +887,7 @@ def save_docs_to_vector_db(
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
list(map(lambda x: x.replace("\n", " "), texts))
|
||||
list(map(lambda x: x.replace("\n", " "), texts)), user=user
|
||||
)
|
||||
|
||||
items = [
|
||||
@@ -829,7 +935,12 @@ def process_file(
|
||||
# Update the content in the file
|
||||
# Usage: /files/{file_id}/data/content/update
|
||||
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
||||
try:
|
||||
# /files/{file_id}/data/content/update
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
||||
except:
|
||||
# Audio file upload pipeline
|
||||
pass
|
||||
|
||||
docs = [
|
||||
Document(
|
||||
@@ -887,6 +998,8 @@ def process_file(
|
||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
)
|
||||
docs = loader.load(
|
||||
file.filename, file.meta.get("content_type"), file_path
|
||||
@@ -929,35 +1042,45 @@ def process_file(
|
||||
hash = calculate_sha256_string(text_content)
|
||||
Files.update_file_hash_by_id(file.id, hash)
|
||||
|
||||
try:
|
||||
result = save_docs_to_vector_db(
|
||||
request,
|
||||
docs=docs,
|
||||
collection_name=collection_name,
|
||||
metadata={
|
||||
"file_id": file.id,
|
||||
"name": file.filename,
|
||||
"hash": hash,
|
||||
},
|
||||
add=(True if form_data.collection_name else False),
|
||||
)
|
||||
|
||||
if result:
|
||||
Files.update_file_metadata_by_id(
|
||||
file.id,
|
||||
{
|
||||
"collection_name": collection_name,
|
||||
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||
try:
|
||||
result = save_docs_to_vector_db(
|
||||
request,
|
||||
docs=docs,
|
||||
collection_name=collection_name,
|
||||
metadata={
|
||||
"file_id": file.id,
|
||||
"name": file.filename,
|
||||
"hash": hash,
|
||||
},
|
||||
add=(True if form_data.collection_name else False),
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filename": file.filename,
|
||||
"content": text_content,
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
if result:
|
||||
Files.update_file_metadata_by_id(
|
||||
file.id,
|
||||
{
|
||||
"collection_name": collection_name,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filename": file.filename,
|
||||
"content": text_content,
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": None,
|
||||
"filename": file.filename,
|
||||
"content": text_content,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
if "No pandoc was found" in str(e):
|
||||
@@ -997,7 +1120,7 @@ def process_text(
|
||||
text_content = form_data.content
|
||||
log.debug(f"text_content: {text_content}")
|
||||
|
||||
result = save_docs_to_vector_db(request, docs, collection_name)
|
||||
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
|
||||
if result:
|
||||
return {
|
||||
"status": True,
|
||||
@@ -1030,7 +1153,9 @@ def process_youtube_video(
|
||||
content = " ".join([doc.page_content for doc in docs])
|
||||
log.debug(f"text_content: {content}")
|
||||
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
save_docs_to_vector_db(
|
||||
request, docs, collection_name, overwrite=True, user=user
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@@ -1071,7 +1196,13 @@ def process_web(
|
||||
content = " ".join([doc.page_content for doc in docs])
|
||||
|
||||
log.debug(f"text_content: {content}")
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
|
||||
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||
save_docs_to_vector_db(
|
||||
request, docs, collection_name, overwrite=True, user=user
|
||||
)
|
||||
else:
|
||||
collection_name = None
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@@ -1083,6 +1214,7 @@ def process_web(
|
||||
},
|
||||
"meta": {
|
||||
"name": form_data.url,
|
||||
"source": form_data.url,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1102,11 +1234,15 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
- BRAVE_SEARCH_API_KEY
|
||||
- KAGI_SEARCH_API_KEY
|
||||
- MOJEEK_SEARCH_API_KEY
|
||||
- BOCHA_SEARCH_API_KEY
|
||||
- SERPSTACK_API_KEY
|
||||
- SERPER_API_KEY
|
||||
- SERPLY_API_KEY
|
||||
- TAVILY_API_KEY
|
||||
- EXA_API_KEY
|
||||
- PERPLEXITY_API_KEY
|
||||
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
||||
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
"""
|
||||
@@ -1168,6 +1304,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
)
|
||||
else:
|
||||
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
|
||||
elif engine == "bocha":
|
||||
if request.app.state.config.BOCHA_SEARCH_API_KEY:
|
||||
return search_bocha(
|
||||
request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
|
||||
elif engine == "serpstack":
|
||||
if request.app.state.config.SERPSTACK_API_KEY:
|
||||
return search_serpstack(
|
||||
@@ -1211,6 +1357,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
request.app.state.config.TAVILY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No TAVILY_API_KEY found in environment variables")
|
||||
@@ -1225,6 +1372,17 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
)
|
||||
else:
|
||||
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
||||
elif engine == "serpapi":
|
||||
if request.app.state.config.SERPAPI_API_KEY:
|
||||
return search_serpapi(
|
||||
request.app.state.config.SERPAPI_API_KEY,
|
||||
request.app.state.config.SERPAPI_ENGINE,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SERPAPI_API_KEY found in environment variables")
|
||||
elif engine == "jina":
|
||||
return search_jina(
|
||||
request.app.state.config.JINA_API_KEY,
|
||||
@@ -1240,12 +1398,26 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "exa":
|
||||
return search_exa(
|
||||
request.app.state.config.EXA_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "perplexity":
|
||||
return search_perplexity(
|
||||
request.app.state.config.PERPLEXITY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
|
||||
|
||||
@router.post("/process/web/search")
|
||||
def process_web_search(
|
||||
async def process_web_search(
|
||||
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
@@ -1277,15 +1449,40 @@ def process_web_search(
|
||||
urls,
|
||||
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
)
|
||||
docs = loader.load()
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
docs = await loader.aload()
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filenames": urls,
|
||||
}
|
||||
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": None,
|
||||
"filenames": urls,
|
||||
"docs": [
|
||||
{
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
}
|
||||
for doc in docs
|
||||
],
|
||||
"loaded_count": len(docs),
|
||||
}
|
||||
else:
|
||||
await run_in_threadpool(
|
||||
save_docs_to_vector_db,
|
||||
request,
|
||||
docs,
|
||||
collection_name,
|
||||
overwrite=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filenames": urls,
|
||||
"loaded_count": len(docs),
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
@@ -1313,7 +1510,9 @@ def query_doc_handler(
|
||||
return query_doc_with_hybrid_search(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=(
|
||||
@@ -1321,12 +1520,16 @@ def query_doc_handler(
|
||||
if form_data.r
|
||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(
|
||||
form_data.query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@@ -1355,7 +1558,9 @@ def query_collection_handler(
|
||||
return query_collection_with_hybrid_search(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=(
|
||||
@@ -1368,7 +1573,9 @@ def query_collection_handler(
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
)
|
||||
|
||||
@@ -1432,11 +1639,11 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path) # Remove the directory
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
log.exception(f"Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
print(f"The directory {folder} does not exist")
|
||||
log.warning(f"The directory {folder} does not exist")
|
||||
except Exception as e:
|
||||
print(f"Failed to process the directory {folder}. Reason: {e}")
|
||||
log.exception(f"Failed to process the directory {folder}. Reason: {e}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -1516,6 +1723,7 @@ def process_files_batch(
|
||||
docs=all_docs,
|
||||
collection_name=collection_name,
|
||||
add=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Update all files with collection name
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import logging
|
||||
import re
|
||||
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.utils.task import (
|
||||
@@ -19,6 +20,10 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.constants import TASKS
|
||||
|
||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
from open_webui.utils.task import get_task_model_id
|
||||
|
||||
from open_webui.config import (
|
||||
@@ -57,6 +62,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
@@ -67,6 +73,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||
class TaskConfigForm(BaseModel):
|
||||
TASK_MODEL: Optional[str]
|
||||
TASK_MODEL_EXTERNAL: Optional[str]
|
||||
ENABLE_TITLE_GENERATION: bool
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
||||
@@ -85,10 +92,15 @@ async def update_task_config(
|
||||
):
|
||||
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
|
||||
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
|
||||
form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
||||
)
|
||||
@@ -117,6 +129,7 @@ async def update_task_config(
|
||||
return {
|
||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
@@ -134,7 +147,19 @@ async def update_task_config(
|
||||
async def generate_title(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
models = request.app.state.MODELS
|
||||
|
||||
if not request.app.state.config.ENABLE_TITLE_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Title generation is disabled"},
|
||||
)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -161,9 +186,20 @@ async def generate_title(
|
||||
else:
|
||||
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
messages = form_data["messages"]
|
||||
|
||||
# Remove reasoning details from the messages
|
||||
for message in messages:
|
||||
message["content"] = re.sub(
|
||||
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
|
||||
"",
|
||||
message["content"],
|
||||
flags=re.S,
|
||||
).strip()
|
||||
|
||||
content = title_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
messages,
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
@@ -175,19 +211,26 @@ async def generate_title(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 50}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
{"max_tokens": 1000}
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 50,
|
||||
"max_completion_tokens": 1000,
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.TITLE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -209,7 +252,12 @@ async def generate_chat_tags(
|
||||
content={"detail": "Tags generation is disabled"},
|
||||
)
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -245,12 +293,19 @@ async def generate_chat_tags(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.TAGS_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -265,7 +320,12 @@ async def generate_chat_tags(
|
||||
async def generate_image_prompt(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -305,12 +365,19 @@ async def generate_image_prompt(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -340,7 +407,12 @@ async def generate_queries(
|
||||
detail=f"Query generation is disabled",
|
||||
)
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -376,12 +448,19 @@ async def generate_queries(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.QUERY_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -415,7 +494,12 @@ async def generate_autocompletion(
|
||||
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
||||
)
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -451,12 +535,19 @@ async def generate_autocompletion(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -472,7 +563,12 @@ async def generate_emoji(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -509,15 +605,25 @@ async def generate_emoji(
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 4}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 4,
|
||||
}
|
||||
),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.EMOJI_GENERATION),
|
||||
"task_body": form_data,
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -532,7 +638,13 @@ async def generate_moa_response(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
|
||||
if model_id not in models:
|
||||
@@ -565,12 +677,19 @@ async def generate_moa_response(
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
||||
"task_body": form_data,
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.tools import get_tools_specs
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@@ -100,7 +105,7 @@ async def create_new_tools(
|
||||
specs = get_tools_specs(TOOLS[form_data.id])
|
||||
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
||||
|
||||
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
|
||||
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
|
||||
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if tools:
|
||||
@@ -111,7 +116,7 @@ async def create_new_tools(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
@@ -193,7 +198,7 @@ async def update_tools_by_id(
|
||||
"specs": specs,
|
||||
}
|
||||
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
tools = Tools.update_tool_by_id(id, updated)
|
||||
|
||||
if tools:
|
||||
@@ -227,7 +232,11 @@ async def delete_tools_by_id(
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if tools.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
tools.user_id != user.id
|
||||
and not has_access(user.id, "write", tools.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
@@ -339,7 +348,7 @@ async def update_tools_valves_by_id(
|
||||
Tools.update_tool_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
@@ -417,7 +426,7 @@ async def update_tools_user_valves_by_id(
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
|
||||
@@ -79,6 +79,7 @@ class ChatPermissions(BaseModel):
|
||||
class FeaturesPermissions(BaseModel):
|
||||
web_search: bool = True
|
||||
image_generation: bool = True
|
||||
code_interpreter: bool = True
|
||||
|
||||
|
||||
class UserPermissions(BaseModel):
|
||||
@@ -152,7 +153,7 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||
async def update_user_settings_by_session_user(
|
||||
form_data: UserSettings, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
|
||||
user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
|
||||
@@ -1,48 +1,84 @@
|
||||
import black
|
||||
import logging
|
||||
import markdown
|
||||
|
||||
from open_webui.models.chats import ChatTitleMessagesForm
|
||||
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
|
||||
from open_webui.utils.misc import get_gravatar_url
|
||||
from open_webui.utils.pdf_generator import PDFGenerator
|
||||
from open_webui.utils.auth import get_admin_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/gravatar")
|
||||
async def get_gravatar(
|
||||
email: str,
|
||||
):
|
||||
async def get_gravatar(email: str, user=Depends(get_verified_user)):
|
||||
return get_gravatar_url(email)
|
||||
|
||||
|
||||
class CodeFormatRequest(BaseModel):
|
||||
class CodeForm(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
@router.post("/code/format")
|
||||
async def format_code(request: CodeFormatRequest):
|
||||
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
|
||||
try:
|
||||
formatted_code = black.format_str(request.code, mode=black.Mode())
|
||||
formatted_code = black.format_str(form_data.code, mode=black.Mode())
|
||||
return {"code": formatted_code}
|
||||
except black.NothingChanged:
|
||||
return {"code": request.code}
|
||||
return {"code": form_data.code}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/code/execute")
|
||||
async def execute_code(
|
||||
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
|
||||
):
|
||||
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
|
||||
output = await execute_code_jupyter(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
form_data.code,
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
|
||||
else None
|
||||
),
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
|
||||
else None
|
||||
),
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
)
|
||||
|
||||
return output
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Code execution engine not supported",
|
||||
)
|
||||
|
||||
|
||||
class MarkdownForm(BaseModel):
|
||||
md: str
|
||||
|
||||
|
||||
@router.post("/markdown")
|
||||
async def get_html_from_markdown(
|
||||
form_data: MarkdownForm,
|
||||
form_data: MarkdownForm, user=Depends(get_verified_user)
|
||||
):
|
||||
return {"html": markdown.markdown(form_data.md)}
|
||||
|
||||
@@ -54,7 +90,7 @@ class ChatForm(BaseModel):
|
||||
|
||||
@router.post("/pdf")
|
||||
async def download_chat_as_pdf(
|
||||
form_data: ChatTitleMessagesForm,
|
||||
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
|
||||
@@ -65,7 +101,7 @@ async def download_chat_as_pdf(
|
||||
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating PDF: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user