Merge remote-tracking branch 'upstream/dev' into playwright

# Conflicts:
#	backend/open_webui/retrieval/web/utils.py
#	backend/open_webui/routers/retrieval.py
This commit is contained in:
Rory
2025-02-17 21:53:39 -06:00
226 changed files with 3402 additions and 1802 deletions

View File

@@ -251,9 +251,19 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
user = Users.get_user_by_email(mail)
if not user:
try:
user_count = Users.get_num_users()
if (
request.app.state.USER_COUNT
and user_count >= request.app.state.USER_COUNT
):
raise HTTPException(
status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
role = (
"admin"
if Users.get_num_users() == 0
if user_count == 0
else request.app.state.config.DEFAULT_USER_ROLE
)
@@ -413,6 +423,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm):
user_count = Users.get_num_users()
if WEBUI_AUTH:
if (
not request.app.state.config.ENABLE_SIGNUP
@@ -422,11 +434,16 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
if Users.get_num_users() != 0:
if user_count != 0:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
if request.app.state.USER_COUNT and user_count >= request.app.state.USER_COUNT:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
if not validate_email_format(form_data.email.lower()):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@@ -437,12 +454,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
try:
role = (
"admin"
if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
)
if Users.get_num_users() == 0:
if user_count == 0:
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
@@ -484,6 +499,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.WEBUI_NAME,
request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{

View File

@@ -192,7 +192,7 @@ async def get_channel_messages(
############################
async def send_notification(webui_url, channel, message, active_user_ids):
async def send_notification(name, webui_url, channel, message, active_user_ids):
users = get_users_with_access("read", channel.access_control)
for user in users:
@@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
if webhook_url:
post_webhook(
name,
webhook_url,
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
{
@@ -302,6 +303,7 @@ async def post_new_message(
background_tasks.add_task(
send_notification,
request.app.state.WEBUI_NAME,
request.app.state.config.WEBUI_URL,
channel,
message,

View File

@@ -70,6 +70,11 @@ async def set_direct_connections_config(
# CodeInterpreterConfig
############################
class CodeInterpreterConfigForm(BaseModel):
CODE_EXECUTION_ENGINE: str
CODE_EXECUTION_JUPYTER_URL: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
ENABLE_CODE_INTERPRETER: bool
CODE_INTERPRETER_ENGINE: str
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
@@ -79,9 +84,14 @@ class CodeInterpreterConfigForm(BaseModel):
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
return {
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
@@ -92,10 +102,25 @@ async def get_code_interpreter_config(request: Request, user=Depends(get_admin_u
}
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def set_code_interpreter_config(
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
async def set_code_execution_config(
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
form_data.CODE_EXECUTION_JUPYTER_URL
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
form_data.CODE_EXECUTION_JUPYTER_AUTH
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
)
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
@@ -118,6 +143,11 @@ async def set_code_interpreter_config(
)
return {
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,

View File

@@ -9,6 +9,7 @@ from fastapi import (
status,
APIRouter,
)
import aiohttp
import os
import logging
import shutil
@@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models):
return sorted_filters
def process_pipeline_inlet_filter(request, payload, user, models):
async def process_pipeline_inlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
async with aiohttp.ClientSession() as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
if urlIdx is None:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "":
if not key:
continue
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
request_data = {
"user": user,
"body": payload,
}
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
res = r.json()
try:
async with session.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
except aiohttp.ClientResponseError as e:
res = (
await response.json()
if response.content_type == "application/json"
else {}
)
if "detail" in res:
raise Exception(r.status_code, res["detail"])
raise Exception(response.status, res["detail"])
except Exception as e:
print(f"Connection error: {e}")
return payload
def process_pipeline_outlet_filter(request, payload, user, models):
async def process_pipeline_outlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
async with aiohttp.ClientSession() as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
if urlIdx is None:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
r = requests.post(
if not key:
continue
headers = {"Authorization": f"Bearer {key}"}
request_data = {
"user": user,
"body": payload,
}
try:
async with session.post(
f"{url}/{filter['id']}/filter/outlet",
headers={"Authorization": f"Bearer {key}"},
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
data = r.json()
payload = data
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
except aiohttp.ClientResponseError as e:
try:
res = r.json()
res = (
await response.json()
if "application/json" in response.content_type
else {}
)
if "detail" in res:
return Exception(r.status_code, res)
raise Exception(response.status, res)
except Exception:
pass
else:
pass
except Exception as e:
print(f"Connection error: {e}")
return payload

View File

@@ -371,7 +371,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
},
"web": {
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"RAG_WEB_SEARCH_FULL_CONTEXT": request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT,
"search": {
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
@@ -457,7 +458,8 @@ class WebSearchConfig(BaseModel):
class WebConfig(BaseModel):
search: WebSearchConfig
web_loader_ssl_verification: Optional[bool] = None
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
RAG_WEB_SEARCH_FULL_CONTEXT: Optional[bool] = None
class ConfigUpdateForm(BaseModel):
@@ -512,11 +514,16 @@ async def update_rag_config(
if form_data.web is not None:
request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
form_data.web.web_loader_ssl_verification
form_data.web.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT = (
form_data.web.RAG_WEB_SEARCH_FULL_CONTEXT
)
request.app.state.config.SEARXNG_QUERY_URL = (
form_data.web.search.searxng_query_url
)
@@ -600,7 +607,8 @@ async def update_rag_config(
"translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"RAG_WEB_SEARCH_FULL_CONTEXT": request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT,
"search": {
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
@@ -1262,6 +1270,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.TAVILY_API_KEY,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
@@ -1349,21 +1358,36 @@ async def process_web_search(
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
)
docs = await loader.aload()
await run_in_threadpool(
save_docs_to_vector_db,
request,
docs,
collection_name,
overwrite=True,
user=user
)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
"loaded_count": len(docs),
}
if request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT:
return {
"status": True,
"docs": [
{
"content": doc.page_content,
"metadata": doc.metadata,
}
for doc in docs
],
"filenames": urls,
"loaded_count": len(docs),
}
else:
await run_in_threadpool(
save_docs_to_vector_db,
request,
docs,
collection_name,
overwrite=True,
user=user
)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
"loaded_count": len(docs),
}
except Exception as e:
log.exception(e)
raise HTTPException(

View File

@@ -208,7 +208,7 @@ async def generate_title(
"stream": False,
**(
{"max_tokens": 1000}
if models[task_model_id]["owned_by"] == "ollama"
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 1000,
}
@@ -571,7 +571,7 @@ async def generate_emoji(
"stream": False,
**(
{"max_tokens": 4}
if models[task_model_id]["owned_by"] == "ollama"
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 4,
}

View File

@@ -4,45 +4,75 @@ import markdown
from open_webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel
from starlette.responses import FileResponse
from open_webui.utils.misc import get_gravatar_url
from open_webui.utils.pdf_generator import PDFGenerator
from open_webui.utils.auth import get_admin_user
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.code_interpreter import execute_code_jupyter
router = APIRouter()
@router.get("/gravatar")
async def get_gravatar(
email: str,
):
async def get_gravatar(email: str, user=Depends(get_verified_user)):
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
class CodeForm(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
formatted_code = black.format_str(form_data.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
return {"code": form_data.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/code/execute")
async def execute_code(
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
):
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
output = await execute_code_jupyter(
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
form_data.code,
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
else None
),
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
else None
),
)
return output
else:
raise HTTPException(
status_code=400,
detail="Code execution engine not supported",
)
class MarkdownForm(BaseModel):
md: str
@router.post("/markdown")
async def get_html_from_markdown(
form_data: MarkdownForm,
form_data: MarkdownForm, user=Depends(get_verified_user)
):
return {"html": markdown.markdown(form_data.md)}
@@ -54,7 +84,7 @@ class ChatForm(BaseModel):
@router.post("/pdf")
async def download_chat_as_pdf(
form_data: ChatTitleMessagesForm,
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
):
try:
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()