Merge remote-tracking branch 'upstream/dev' into playwright

# Conflicts:
#	backend/open_webui/retrieval/web/utils.py
#	backend/open_webui/routers/retrieval.py
This commit is contained in:
Rory
2025-02-17 21:53:39 -06:00
226 changed files with 3402 additions and 1802 deletions

View File

@@ -1,6 +1,11 @@
import logging
import uuid
import jwt
import base64
import hmac
import hashlib
import requests
from datetime import UTC, datetime, timedelta
from typing import Optional, Union, List, Dict
@@ -8,7 +13,8 @@ from typing import Optional, Union, List, Dict
from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY
from open_webui.config import override_static
from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@@ -24,6 +30,53 @@ ALGORITHM = "HS256"
# Auth Utils
##############
def verify_signature(payload: str, signature: str) -> bool:
"""
Verifies the HMAC signature of the received payload.
"""
try:
expected_signature = base64.b64encode(
hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
).decode()
# Compare securely to prevent timing attacks
return hmac.compare_digest(expected_signature, signature)
except Exception:
return False
def get_license_data(app, key):
if key:
try:
res = requests.post(
"https://api.openwebui.com/api/v1/license",
json={"key": key, "version": "1"},
timeout=5,
)
if getattr(res, "ok", False):
payload = getattr(res, "json", lambda: {})()
for k, v in payload.items():
if k == "resources":
for p, c in v.items():
globals().get("override_static", lambda a, b: None)(p, c)
elif k == "user_count":
setattr(app.state, "USER_COUNT", v)
elif k == "webui_name":
setattr(app.state, "WEBUI_NAME", v)
return True
else:
print(
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
)
except Exception as ex:
print(f"License: Uncaught Exception: {ex}")
return False
bearer_security = HTTPBearer(auto_error=False)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

View File

@@ -186,12 +186,6 @@ async def generate_chat_completion(
if model_id not in models:
raise Exception("Model not found")
# Process the form_data through the pipeline
try:
form_data = process_pipeline_inlet_filter(request, form_data, user, models)
except Exception as e:
raise e
model = models[model_id]
if getattr(request.state, "direct", False):
@@ -206,7 +200,7 @@ async def generate_chat_completion(
except Exception as e:
raise e
if model["owned_by"] == "arena":
if model.get("owned_by") == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
@@ -259,7 +253,7 @@ async def generate_chat_completion(
return await generate_function_chat_completion(
request, form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
if model.get("owned_by") == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
@@ -308,7 +302,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
model = models[model_id]
try:
data = process_pipeline_outlet_filter(request, data, user, models)
data = await process_pipeline_outlet_filter(request, data, user, models)
except Exception as e:
return Exception(f"Error: {e}")

View File

@@ -39,7 +39,10 @@ from open_webui.routers.tasks import (
)
from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.routers.images import image_generations, GenerateImageForm
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
)
from open_webui.utils.webhook import post_webhook
@@ -359,14 +362,25 @@ async def chat_web_search_handler(
)
files = form_data.get("files", [])
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search_results",
"urls": results["filenames"],
}
)
if request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT:
files.append(
{
"docs": results.get("docs", []),
"name": searchQuery,
"type": "web_search_docs",
"urls": results["filenames"],
}
)
else:
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search_results",
"urls": results["filenames"],
}
)
form_data["files"] = files
else:
await event_emitter(
@@ -676,6 +690,25 @@ async def process_chat_payload(request, form_data, metadata, user, model):
variables = form_data.pop("variables", None)
# Process the form_data through the pipeline
try:
form_data = await process_pipeline_inlet_filter(
request, form_data, user, models
)
except Exception as e:
raise e
try:
form_data, flags = await process_filter_functions(
request=request,
filter_ids=get_sorted_filter_ids(model),
filter_type="inlet",
form_data=form_data,
extra_params=extra_params,
)
except Exception as e:
raise Exception(f"Error: {e}")
features = form_data.pop("features", None)
if features:
if "web_search" in features and features["web_search"]:
@@ -698,17 +731,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
form_data["messages"],
)
try:
form_data, flags = await process_filter_functions(
request=request,
filter_ids=get_sorted_filter_ids(model),
filter_type="inlet",
form_data=form_data,
extra_params=extra_params,
)
except Exception as e:
raise Exception(f"Error: {e}")
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# Remove files duplicates
@@ -789,7 +811,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
if model.get("owned_by") == "ollama":
form_data["messages"] = prepend_to_first_user_message_content(
rag_template(
request.app.state.config.RAG_TEMPLATE, context_string, prompt
@@ -997,6 +1019,7 @@ async def process_chat_response(
webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url:
post_webhook(
request.app.state.WEBUI_NAME,
webhook_url,
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
{
@@ -1335,7 +1358,14 @@ async def process_chat_response(
)
tool_calls = []
content = message.get("content", "") if message else ""
last_assistant_message = get_last_assistant_message(form_data["messages"])
content = (
message.get("content", "")
if message
else last_assistant_message if last_assistant_message else ""
)
content_blocks = [
{
"type": "text",
@@ -1862,6 +1892,7 @@ async def process_chat_response(
webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url:
post_webhook(
request.app.state.WEBUI_NAME,
webhook_url,
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
{

View File

@@ -142,7 +142,7 @@ async def get_all_models(request):
custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0]
):
owned_by = model["owned_by"]
owned_by = model.get("owned_by", "unknown owner")
if "pipe" in model:
pipe = model["pipe"]
break

View File

@@ -36,7 +36,11 @@ from open_webui.config import (
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE
from open_webui.env import (
WEBUI_NAME,
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
)
from open_webui.utils.misc import parse_duration
from open_webui.utils.auth import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
@@ -66,8 +70,9 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
class OAuthManager:
def __init__(self):
def __init__(self, app):
self.oauth = OAuth()
self.app = app
for _, provider_config in OAUTH_PROVIDERS.items():
provider_config["register"](self.oauth)
@@ -200,7 +205,7 @@ class OAuthManager:
id=group_model.id, form_data=update_form, overwrite=False
)
async def handle_login(self, provider, request):
async def handle_login(self, request, provider):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
@@ -212,7 +217,7 @@ class OAuthManager:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
async def handle_callback(self, provider, request, response):
async def handle_callback(self, request, provider, response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = self.get_client(provider)
@@ -266,6 +271,17 @@ class OAuthManager:
Users.update_user_role_by_id(user.id, determined_role)
if not user:
user_count = Users.get_num_users()
if (
request.app.state.USER_COUNT
and user_count >= request.app.state.USER_COUNT
):
raise HTTPException(
403,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists
@@ -334,6 +350,7 @@ class OAuthManager:
if auth_manager_config.WEBHOOK_URL:
post_webhook(
WEBUI_NAME,
auth_manager_config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
@@ -380,6 +397,3 @@ class OAuthManager:
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url, headers=response.headers)
oauth_manager = OAuthManager()

View File

@@ -22,7 +22,7 @@ def get_task_model_id(
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models[task_model_id]["owned_by"] == "ollama":
if models[task_model_id].get("owned_by") == "ollama":
if task_model and task_model in models:
task_model_id = task_model
else:

View File

@@ -2,14 +2,14 @@ import json
import logging
import requests
from open_webui.config import WEBUI_FAVICON_URL, WEBUI_NAME
from open_webui.config import WEBUI_FAVICON_URL
from open_webui.env import SRC_LOG_LEVELS, VERSION
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
def post_webhook(url: str, message: str, event_data: dict) -> bool:
def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
try:
log.debug(f"post_webhook: {url}, {message}, {event_data}")
payload = {}
@@ -39,7 +39,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
"sections": [
{
"activityTitle": message,
"activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}",
"activitySubtitle": f"{name} ({VERSION}) - {action}",
"activityImage": WEBUI_FAVICON_URL,
"facts": facts,
"markdown": True,