open-webui/backend/open_webui/utils/auth.py
Timothy Jaeryang Baek 1764170307 refac
2025-02-17 21:34:06 -08:00

236 lines
6.6 KiB
Python

import logging
import uuid
import jwt
import base64
import hmac
import hashlib
import requests
import os
from datetime import UTC, datetime, timedelta
from typing import Optional, Union, List, Dict
from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY, STATIC_DIR
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext
logging.getLogger("passlib").setLevel(logging.ERROR)
SESSION_SECRET = WEBUI_SECRET_KEY
ALGORITHM = "HS256"
##############
# Auth Utils
##############
def verify_signature(payload: str, signature: str) -> bool:
"""
Verifies the HMAC signature of the received payload.
"""
try:
expected_signature = base64.b64encode(
hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
).decode()
# Compare securely to prevent timing attacks
return hmac.compare_digest(expected_signature, signature)
except Exception:
return False
def override_static(path: str, content: str):
# Ensure path is safe
if "/" in path or ".." in path:
print(f"Invalid path: {path}")
return
file_path = os.path.join(STATIC_DIR, path)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
def get_license_data(app, key):
if key:
try:
res = requests.post(
"https://api.openwebui.com/api/v1/license",
json={"key": key, "version": "1"},
timeout=5,
)
if getattr(res, "ok", False):
payload = getattr(res, "json", lambda: {})()
for k, v in payload.items():
if k == "resources":
for p, c in v.items():
globals().get("override_static", lambda a, b: None)(p, c)
elif k == "user_count":
setattr(app.state, "USER_COUNT", v)
elif k == "webui_name":
setattr(app.state, "WEBUI_NAME", v)
return True
else:
print(
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
)
except Exception as ex:
print(f"License: Uncaught Exception: {ex}")
return False
bearer_security = HTTPBearer(auto_error=False)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
return (
pwd_context.verify(plain_password, hashed_password) if hashed_password else None
)
def get_password_hash(password):
return pwd_context.hash(password)
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
payload = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
payload.update({"exp": expire})
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
return encoded_jwt
def decode_token(token: str) -> Optional[dict]:
try:
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
return decoded
except Exception:
return None
def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :]
def create_api_key():
key = str(uuid.uuid4()).replace("-", "")
return f"sk-{key}"
def get_http_authorization_cred(auth_header: str):
try:
scheme, credentials = auth_header.split(" ")
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
except Exception:
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
def get_current_user(
request: Request,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
token = None
if auth_token is not None:
token = auth_token.credentials
if token is None and "token" in request.cookies:
token = request.cookies.get("token")
if token is None:
raise HTTPException(status_code=403, detail="Not authenticated")
# auth by api key
if token.startswith("sk-"):
if not request.state.enable_api_key:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
allowed_paths = [
path.strip()
for path in str(
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
).split(",")
]
if request.url.path not in allowed_paths:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
return get_current_user_by_api_key(token)
# auth by jwt token
try:
data = decode_token(token)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
)
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
Users.update_user_last_active_by_id(user.id)
return user
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
def get_current_user_by_api_key(api_key: str):
user = Users.get_user_by_api_key(api_key)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
Users.update_user_last_active_by_id(user.id)
return user
def get_verified_user(user=Depends(get_current_user)):
if user.role not in {"user", "admin"}:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return user
def get_admin_user(user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return user