mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/dev' into playwright
# Conflicts: # backend/open_webui/retrieval/web/utils.py # backend/open_webui/routers/retrieval.py
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
import logging
|
||||
import uuid
|
||||
import jwt
|
||||
import base64
|
||||
import hmac
|
||||
import hashlib
|
||||
import requests
|
||||
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional, Union, List, Dict
|
||||
@@ -8,7 +13,8 @@ from typing import Optional, Union, List, Dict
|
||||
from open_webui.models.users import Users
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import WEBUI_SECRET_KEY
|
||||
from open_webui.config import override_static
|
||||
from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
@@ -24,6 +30,53 @@ ALGORITHM = "HS256"
|
||||
# Auth Utils
|
||||
##############
|
||||
|
||||
|
||||
def verify_signature(payload: str, signature: str) -> bool:
|
||||
"""
|
||||
Verifies the HMAC signature of the received payload.
|
||||
"""
|
||||
try:
|
||||
expected_signature = base64.b64encode(
|
||||
hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
|
||||
).decode()
|
||||
|
||||
# Compare securely to prevent timing attacks
|
||||
return hmac.compare_digest(expected_signature, signature)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_license_data(app, key):
|
||||
if key:
|
||||
try:
|
||||
res = requests.post(
|
||||
"https://api.openwebui.com/api/v1/license",
|
||||
json={"key": key, "version": "1"},
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
if getattr(res, "ok", False):
|
||||
payload = getattr(res, "json", lambda: {})()
|
||||
for k, v in payload.items():
|
||||
if k == "resources":
|
||||
for p, c in v.items():
|
||||
globals().get("override_static", lambda a, b: None)(p, c)
|
||||
elif k == "user_count":
|
||||
setattr(app.state, "USER_COUNT", v)
|
||||
elif k == "webui_name":
|
||||
setattr(app.state, "WEBUI_NAME", v)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
|
||||
)
|
||||
except Exception as ex:
|
||||
print(f"License: Uncaught Exception: {ex}")
|
||||
return False
|
||||
|
||||
|
||||
bearer_security = HTTPBearer(auto_error=False)
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
@@ -186,12 +186,6 @@ async def generate_chat_completion(
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
|
||||
# Process the form_data through the pipeline
|
||||
try:
|
||||
form_data = process_pipeline_inlet_filter(request, form_data, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
model = models[model_id]
|
||||
|
||||
if getattr(request.state, "direct", False):
|
||||
@@ -206,7 +200,7 @@ async def generate_chat_completion(
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if model["owned_by"] == "arena":
|
||||
if model.get("owned_by") == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
if model_ids and filter_mode == "exclude":
|
||||
@@ -259,7 +253,7 @@ async def generate_chat_completion(
|
||||
return await generate_function_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model["owned_by"] == "ollama":
|
||||
if model.get("owned_by") == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
@@ -308,7 +302,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
model = models[model_id]
|
||||
|
||||
try:
|
||||
data = process_pipeline_outlet_filter(request, data, user, models)
|
||||
data = await process_pipeline_outlet_filter(request, data, user, models)
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
|
||||
@@ -39,7 +39,10 @@ from open_webui.routers.tasks import (
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
from open_webui.routers.images import image_generations, GenerateImageForm
|
||||
|
||||
from open_webui.routers.pipelines import (
|
||||
process_pipeline_inlet_filter,
|
||||
process_pipeline_outlet_filter,
|
||||
)
|
||||
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
@@ -359,14 +362,25 @@ async def chat_web_search_handler(
|
||||
)
|
||||
|
||||
files = form_data.get("files", [])
|
||||
files.append(
|
||||
{
|
||||
"collection_name": results["collection_name"],
|
||||
"name": searchQuery,
|
||||
"type": "web_search_results",
|
||||
"urls": results["filenames"],
|
||||
}
|
||||
)
|
||||
|
||||
if request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT:
|
||||
files.append(
|
||||
{
|
||||
"docs": results.get("docs", []),
|
||||
"name": searchQuery,
|
||||
"type": "web_search_docs",
|
||||
"urls": results["filenames"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
files.append(
|
||||
{
|
||||
"collection_name": results["collection_name"],
|
||||
"name": searchQuery,
|
||||
"type": "web_search_results",
|
||||
"urls": results["filenames"],
|
||||
}
|
||||
)
|
||||
form_data["files"] = files
|
||||
else:
|
||||
await event_emitter(
|
||||
@@ -676,6 +690,25 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
variables = form_data.pop("variables", None)
|
||||
|
||||
# Process the form_data through the pipeline
|
||||
try:
|
||||
form_data = await process_pipeline_inlet_filter(
|
||||
request, form_data, user, models
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
form_data, flags = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=get_sorted_filter_ids(model),
|
||||
filter_type="inlet",
|
||||
form_data=form_data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error: {e}")
|
||||
|
||||
features = form_data.pop("features", None)
|
||||
if features:
|
||||
if "web_search" in features and features["web_search"]:
|
||||
@@ -698,17 +731,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
form_data["messages"],
|
||||
)
|
||||
|
||||
try:
|
||||
form_data, flags = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=get_sorted_filter_ids(model),
|
||||
filter_type="inlet",
|
||||
form_data=form_data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error: {e}")
|
||||
|
||||
tool_ids = form_data.pop("tool_ids", None)
|
||||
files = form_data.pop("files", None)
|
||||
# Remove files duplicates
|
||||
@@ -789,7 +811,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
# Workaround for Ollama 2.0+ system prompt issue
|
||||
# TODO: replace with add_or_update_system_message
|
||||
if model["owned_by"] == "ollama":
|
||||
if model.get("owned_by") == "ollama":
|
||||
form_data["messages"] = prepend_to_first_user_message_content(
|
||||
rag_template(
|
||||
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
@@ -997,6 +1019,7 @@ async def process_chat_response(
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
webhook_url,
|
||||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||||
{
|
||||
@@ -1335,7 +1358,14 @@ async def process_chat_response(
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
content = message.get("content", "") if message else ""
|
||||
|
||||
last_assistant_message = get_last_assistant_message(form_data["messages"])
|
||||
content = (
|
||||
message.get("content", "")
|
||||
if message
|
||||
else last_assistant_message if last_assistant_message else ""
|
||||
)
|
||||
|
||||
content_blocks = [
|
||||
{
|
||||
"type": "text",
|
||||
@@ -1862,6 +1892,7 @@ async def process_chat_response(
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
webhook_url,
|
||||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||||
{
|
||||
|
||||
@@ -142,7 +142,7 @@ async def get_all_models(request):
|
||||
custom_model.base_model_id == model["id"]
|
||||
or custom_model.base_model_id == model["id"].split(":")[0]
|
||||
):
|
||||
owned_by = model["owned_by"]
|
||||
owned_by = model.get("owned_by", "unknown owner")
|
||||
if "pipe" in model:
|
||||
pipe = model["pipe"]
|
||||
break
|
||||
|
||||
@@ -36,7 +36,11 @@ from open_webui.config import (
|
||||
AppConfig,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE
|
||||
from open_webui.env import (
|
||||
WEBUI_NAME,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
from open_webui.utils.misc import parse_duration
|
||||
from open_webui.utils.auth import get_password_hash, create_token
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
@@ -66,8 +70,9 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||
|
||||
|
||||
class OAuthManager:
|
||||
def __init__(self):
|
||||
def __init__(self, app):
|
||||
self.oauth = OAuth()
|
||||
self.app = app
|
||||
for _, provider_config in OAUTH_PROVIDERS.items():
|
||||
provider_config["register"](self.oauth)
|
||||
|
||||
@@ -200,7 +205,7 @@ class OAuthManager:
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
)
|
||||
|
||||
async def handle_login(self, provider, request):
|
||||
async def handle_login(self, request, provider):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
|
||||
@@ -212,7 +217,7 @@ class OAuthManager:
|
||||
raise HTTPException(404)
|
||||
return await client.authorize_redirect(request, redirect_uri)
|
||||
|
||||
async def handle_callback(self, provider, request, response):
|
||||
async def handle_callback(self, request, provider, response):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
client = self.get_client(provider)
|
||||
@@ -266,6 +271,17 @@ class OAuthManager:
|
||||
Users.update_user_role_by_id(user.id, determined_role)
|
||||
|
||||
if not user:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
if (
|
||||
request.app.state.USER_COUNT
|
||||
and user_count >= request.app.state.USER_COUNT
|
||||
):
|
||||
raise HTTPException(
|
||||
403,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
# If the user does not exist, check if signups are enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||
# Check if an existing user with the same email already exists
|
||||
@@ -334,6 +350,7 @@ class OAuthManager:
|
||||
|
||||
if auth_manager_config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
WEBUI_NAME,
|
||||
auth_manager_config.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
@@ -380,6 +397,3 @@ class OAuthManager:
|
||||
# Redirect back to the frontend with the JWT token
|
||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||
|
||||
|
||||
oauth_manager = OAuthManager()
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_task_model_id(
|
||||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
# Check if the user has a custom task model and use that model
|
||||
if models[task_model_id]["owned_by"] == "ollama":
|
||||
if models[task_model_id].get("owned_by") == "ollama":
|
||||
if task_model and task_model in models:
|
||||
task_model_id = task_model
|
||||
else:
|
||||
|
||||
@@ -2,14 +2,14 @@ import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.config import WEBUI_FAVICON_URL, WEBUI_NAME
|
||||
from open_webui.config import WEBUI_FAVICON_URL
|
||||
from open_webui.env import SRC_LOG_LEVELS, VERSION
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
|
||||
|
||||
|
||||
def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
||||
def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
|
||||
try:
|
||||
log.debug(f"post_webhook: {url}, {message}, {event_data}")
|
||||
payload = {}
|
||||
@@ -39,7 +39,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
||||
"sections": [
|
||||
{
|
||||
"activityTitle": message,
|
||||
"activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}",
|
||||
"activitySubtitle": f"{name} ({VERSION}) - {action}",
|
||||
"activityImage": WEBUI_FAVICON_URL,
|
||||
"facts": facts,
|
||||
"markdown": True,
|
||||
|
||||
Reference in New Issue
Block a user