open-webui/backend/open_webui/utils/utils.py

211 lines
5.9 KiB
Python
Raw Normal View History

2024-08-27 22:10:27 +00:00
import logging
import uuid
2024-11-15 09:29:07 +00:00
import jwt
2024-08-27 22:10:27 +00:00
from datetime import UTC, datetime, timedelta
2024-11-15 09:29:07 +00:00
from typing import Optional, Union, List, Dict
2024-04-02 17:05:53 +00:00
from open_webui.apps.webui.models.users import Users
2024-11-15 09:29:07 +00:00
from open_webui.apps.webui.models.groups import Groups
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY
2024-11-15 09:29:07 +00:00
2024-10-15 09:48:41 +00:00
from fastapi import Depends, HTTPException, Request, Response, status
2024-08-27 22:10:27 +00:00
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext
2023-11-19 00:47:12 +00:00
2024-01-05 20:22:27 +00:00
logging.getLogger("passlib").setLevel(logging.ERROR)
2024-08-25 14:52:36 +00:00
SESSION_SECRET = WEBUI_SECRET_KEY
2023-11-19 00:47:12 +00:00
ALGORITHM = "HS256"
##############
# Auth Utils
##############
2024-06-19 21:38:09 +00:00
bearer_security = HTTPBearer(auto_error=False)
2023-11-19 00:47:12 +00:00
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
2024-01-05 20:22:27 +00:00
return (
pwd_context.verify(plain_password, hashed_password) if hashed_password else None
)
2023-11-19 00:47:12 +00:00
def get_password_hash(password):
return pwd_context.hash(password)
2024-01-05 20:22:27 +00:00
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
2023-11-19 00:47:12 +00:00
payload = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
2023-11-19 00:47:12 +00:00
payload.update({"exp": expire})
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
2023-11-19 00:47:12 +00:00
return encoded_jwt
def decode_token(token: str) -> Optional[dict]:
try:
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
2023-11-19 00:47:12 +00:00
return decoded
2024-08-03 13:24:26 +00:00
except Exception:
2023-11-19 00:47:12 +00:00
return None
def extract_token_from_auth_header(auth_header: str):
2024-01-05 20:22:27 +00:00
return auth_header[len("Bearer ") :]
2023-11-19 00:47:12 +00:00
2024-03-26 10:22:17 +00:00
def create_api_key():
key = str(uuid.uuid4()).replace("-", "")
return f"sk-{key}"
2024-02-24 06:44:56 +00:00
def get_http_authorization_cred(auth_header: str):
try:
scheme, credentials = auth_header.split(" ")
2024-02-25 06:10:43 +00:00
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
2024-08-03 13:24:26 +00:00
except Exception:
2024-02-24 06:44:56 +00:00
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
2024-02-11 01:54:33 +00:00
def get_current_user(
2024-06-19 21:38:09 +00:00
request: Request,
2024-02-11 01:54:33 +00:00
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
2024-06-19 21:49:35 +00:00
token = None
2024-06-19 21:38:09 +00:00
if auth_token is not None:
token = auth_token.credentials
2024-06-19 21:49:35 +00:00
if token is None and "token" in request.cookies:
token = request.cookies.get("token")
if token is None:
raise HTTPException(status_code=403, detail="Not authenticated")
2024-03-26 10:22:17 +00:00
# auth by api key
2024-06-19 21:38:09 +00:00
if token.startswith("sk-"):
return get_current_user_by_api_key(token)
2024-06-19 21:38:09 +00:00
2024-03-26 10:22:17 +00:00
# auth by jwt token
2024-11-06 05:14:02 +00:00
try:
data = decode_token(token)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
)
2024-08-03 13:24:26 +00:00
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
2023-11-19 00:47:12 +00:00
)
2024-04-27 23:38:51 +00:00
else:
Users.update_user_last_active_by_id(user.id)
return user
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
2024-04-02 16:42:45 +00:00
2024-06-24 11:45:33 +00:00
def get_current_user_by_api_key(api_key: str):
user = Users.get_user_by_api_key(api_key)
2024-04-27 23:38:51 +00:00
2024-03-26 10:22:17 +00:00
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
2024-04-27 23:38:51 +00:00
else:
2024-06-24 11:45:33 +00:00
Users.update_user_last_active_by_id(user.id)
2024-04-27 23:38:51 +00:00
2024-03-26 10:22:17 +00:00
return user
2024-04-02 16:42:45 +00:00
2024-02-11 01:54:33 +00:00
def get_verified_user(user=Depends(get_current_user)):
if user.role not in {"user", "admin"}:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
2024-02-11 01:54:33 +00:00
return user
2024-02-11 01:54:33 +00:00
def get_admin_user(user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
2024-02-11 01:54:33 +00:00
return user
2024-11-15 09:29:07 +00:00
def has_permission(
user_id: str,
permission_key: str,
default_permissions: Dict[str, bool] = {},
) -> bool:
"""
Check if a user has a specific permission by checking the group permissions
and falls back to default permissions if not found in any group.
Permission keys can be hierarchical and separated by dots ('.').
"""
def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
"""Traverse permissions dict using a list of keys (from dot-split permission_key)."""
for key in keys:
if key not in permissions:
return False # If any part of the hierarchy is missing, deny access
permissions = permissions[key] # Go one level deeper
return bool(permissions) # Return the boolean at the final level
permission_hierarchy = permission_key.split(".")
# Retrieve user group permissions
user_groups = Groups.get_groups_by_member_id(user_id)
for group in user_groups:
group_permissions = group.permissions
if get_permission(group_permissions, permission_hierarchy):
return True
# Check default permissions afterwards if the group permissions don't allow it
return get_permission(default_permissions, permission_hierarchy)
def has_access(
user_id: str,
2024-11-16 12:41:07 +00:00
type: str = "write",
2024-11-15 09:29:07 +00:00
access_control: Optional[dict] = None,
) -> bool:
2024-11-16 12:41:07 +00:00
print("user_id", user_id, "type", type, "access_control", access_control)
2024-11-15 09:29:07 +00:00
if access_control is None:
2024-11-16 12:41:07 +00:00
return type == "read"
2024-11-15 09:29:07 +00:00
user_groups = Groups.get_groups_by_member_id(user_id)
user_group_ids = [group.id for group in user_groups]
2024-11-16 12:41:07 +00:00
permission_access = access_control.get(type, {})
2024-11-15 09:29:07 +00:00
permitted_group_ids = permission_access.get("group_ids", [])
permitted_user_ids = permission_access.get("user_ids", [])
return user_id in permitted_user_ids or any(
group_id in permitted_group_ids for group_id in user_group_ids
)