mirror of
https://github.com/open-webui/open-webui
synced 2025-03-03 19:07:21 +00:00
Merge pull request #10436 from victorstevansuse/feat/audits
feat: add audit logging feature
This commit is contained in:
commit
76e90d9f3f
@ -419,3 +419,25 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
||||
|
||||
if OFFLINE_MODE:
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
####################################
|
||||
# AUDIT LOGGING
|
||||
####################################
|
||||
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
|
||||
# Where to store log file
|
||||
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
||||
# Maximum size of a file before rotating into a new log file
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||
# METADATA | REQUEST | REQUEST_RESPONSE
|
||||
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
|
||||
try:
|
||||
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
|
||||
except ValueError:
|
||||
MAX_BODY_LOG_SIZE = 2048
|
||||
|
||||
# Comma separated list for urls to exclude from audit
|
||||
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
|
||||
","
|
||||
)
|
||||
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
|
||||
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||
|
@ -45,6 +45,9 @@ from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.utils import logger
|
||||
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
|
||||
from open_webui.utils.logger import start_logger
|
||||
from open_webui.socket.main import (
|
||||
app as socket_app,
|
||||
periodic_usage_pool_cleanup,
|
||||
@ -304,8 +307,11 @@ from open_webui.config import (
|
||||
reset_config,
|
||||
)
|
||||
from open_webui.env import (
|
||||
AUDIT_EXCLUDED_PATHS,
|
||||
AUDIT_LOG_LEVEL,
|
||||
CHANGELOG,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
MAX_BODY_LOG_SIZE,
|
||||
SAFE_MODE,
|
||||
SRC_LOG_LEVELS,
|
||||
VERSION,
|
||||
@ -390,6 +396,7 @@ https://github.com/open-webui/open-webui
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
start_logger()
|
||||
if RESET_CONFIG_ON_START:
|
||||
reset_config()
|
||||
|
||||
@ -891,6 +898,19 @@ app.include_router(
|
||||
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
||||
|
||||
|
||||
try:
|
||||
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}")
|
||||
audit_level = AuditLevel.NONE
|
||||
|
||||
if audit_level != AuditLevel.NONE:
|
||||
app.add_middleware(
|
||||
AuditLoggingMiddleware,
|
||||
audit_level=audit_level,
|
||||
excluded_paths=AUDIT_EXCLUDED_PATHS,
|
||||
max_body_size=MAX_BODY_LOG_SIZE,
|
||||
)
|
||||
##################################
|
||||
#
|
||||
# Chat Endpoints
|
||||
|
249
backend/open_webui/utils/audit.py
Normal file
249
backend/open_webui/utils/audit.py
Normal file
@ -0,0 +1,249 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
import re
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
import uuid
|
||||
|
||||
from asgiref.typing import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveCallable,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendCallable,
|
||||
ASGISendEvent,
|
||||
Scope as ASGIScope,
|
||||
)
|
||||
from loguru import logger
|
||||
from starlette.requests import Request
|
||||
|
||||
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
|
||||
from open_webui.utils.auth import get_current_user, get_http_authorization_cred
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Logger
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuditLogEntry:
|
||||
# `Metadata` audit level properties
|
||||
id: str
|
||||
user: dict[str, Any]
|
||||
audit_level: str
|
||||
verb: str
|
||||
request_uri: str
|
||||
user_agent: Optional[str] = None
|
||||
source_ip: Optional[str] = None
|
||||
# `Request` audit level properties
|
||||
request_object: Any = None
|
||||
# `Request Response` level
|
||||
response_object: Any = None
|
||||
response_status_code: Optional[int] = None
|
||||
|
||||
|
||||
class AuditLevel(str, Enum):
|
||||
NONE = "NONE"
|
||||
METADATA = "METADATA"
|
||||
REQUEST = "REQUEST"
|
||||
REQUEST_RESPONSE = "REQUEST_RESPONSE"
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""
|
||||
A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
|
||||
|
||||
Parameters:
|
||||
logger (Logger): An instance of Loguru’s logger.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: "Logger"):
|
||||
self.logger = logger.bind(auditable=True)
|
||||
|
||||
def write(
|
||||
self,
|
||||
audit_entry: AuditLogEntry,
|
||||
*,
|
||||
log_level: str = "INFO",
|
||||
extra: Optional[dict] = None,
|
||||
):
|
||||
|
||||
entry = asdict(audit_entry)
|
||||
|
||||
if extra:
|
||||
entry["extra"] = extra
|
||||
|
||||
self.logger.log(
|
||||
log_level,
|
||||
"",
|
||||
**entry,
|
||||
)
|
||||
|
||||
|
||||
class AuditContext:
|
||||
"""
|
||||
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
|
||||
|
||||
Attributes:
|
||||
request_body (bytearray): Accumulated request payload.
|
||||
response_body (bytearray): Accumulated response payload.
|
||||
max_body_size (int): Maximum number of bytes to capture.
|
||||
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
|
||||
self.request_body = bytearray()
|
||||
self.response_body = bytearray()
|
||||
self.max_body_size = max_body_size
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
|
||||
def add_request_chunk(self, chunk: bytes):
|
||||
if len(self.request_body) < self.max_body_size:
|
||||
self.request_body.extend(
|
||||
chunk[: self.max_body_size - len(self.request_body)]
|
||||
)
|
||||
|
||||
def add_response_chunk(self, chunk: bytes):
|
||||
if len(self.response_body) < self.max_body_size:
|
||||
self.response_body.extend(
|
||||
chunk[: self.max_body_size - len(self.response_body)]
|
||||
)
|
||||
|
||||
|
||||
class AuditLoggingMiddleware:
|
||||
"""
|
||||
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
|
||||
"""
|
||||
|
||||
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGI3Application,
|
||||
*,
|
||||
excluded_paths: Optional[list[str]] = None,
|
||||
max_body_size: int = MAX_BODY_LOG_SIZE,
|
||||
audit_level: AuditLevel = AuditLevel.NONE,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.audit_logger = AuditLogger(logger)
|
||||
self.excluded_paths = excluded_paths or []
|
||||
self.max_body_size = max_body_size
|
||||
self.audit_level = audit_level
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
scope: ASGIScope,
|
||||
receive: ASGIReceiveCallable,
|
||||
send: ASGISendCallable,
|
||||
) -> None:
|
||||
if scope["type"] != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
request = Request(scope=cast(MutableMapping, scope))
|
||||
|
||||
if self._should_skip_auditing(request):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async with self._audit_context(request) as context:
|
||||
|
||||
async def send_wrapper(message: ASGISendEvent) -> None:
|
||||
if self.audit_level == AuditLevel.REQUEST_RESPONSE:
|
||||
await self._capture_response(message, context)
|
||||
|
||||
await send(message)
|
||||
|
||||
original_receive = receive
|
||||
|
||||
async def receive_wrapper() -> ASGIReceiveEvent:
|
||||
nonlocal original_receive
|
||||
message = await original_receive()
|
||||
|
||||
if self.audit_level in (
|
||||
AuditLevel.REQUEST,
|
||||
AuditLevel.REQUEST_RESPONSE,
|
||||
):
|
||||
await self._capture_request(message, context)
|
||||
|
||||
return message
|
||||
|
||||
await self.app(scope, receive_wrapper, send_wrapper)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _audit_context(
|
||||
self, request: Request
|
||||
) -> AsyncGenerator[AuditContext, None]:
|
||||
"""
|
||||
async context manager that ensures that an audit log entry is recorded after the request is processed.
|
||||
"""
|
||||
context = AuditContext()
|
||||
try:
|
||||
yield context
|
||||
finally:
|
||||
await self._log_audit_entry(request, context)
|
||||
|
||||
async def _get_authenticated_user(self, request: Request) -> UserModel:
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
assert auth_header
|
||||
user = get_current_user(request, get_http_authorization_cred(auth_header))
|
||||
|
||||
return user
|
||||
|
||||
def _should_skip_auditing(self, request: Request) -> bool:
|
||||
if (
|
||||
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
|
||||
or AUDIT_LOG_LEVEL == "NONE"
|
||||
or not request.headers.get("authorization")
|
||||
):
|
||||
return True
|
||||
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
|
||||
pattern = re.compile(
|
||||
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
|
||||
)
|
||||
if pattern.match(request.url.path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
|
||||
if message["type"] == "http.request":
|
||||
body = message.get("body", b"")
|
||||
context.add_request_chunk(body)
|
||||
|
||||
async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
|
||||
if message["type"] == "http.response.start":
|
||||
context.metadata["response_status_code"] = message["status"]
|
||||
|
||||
elif message["type"] == "http.response.body":
|
||||
body = message.get("body", b"")
|
||||
context.add_response_chunk(body)
|
||||
|
||||
async def _log_audit_entry(self, request: Request, context: AuditContext):
|
||||
try:
|
||||
user = await self._get_authenticated_user(request)
|
||||
|
||||
entry = AuditLogEntry(
|
||||
id=str(uuid.uuid4()),
|
||||
user=user.model_dump(include={"id", "name", "email", "role"}),
|
||||
audit_level=self.audit_level.value,
|
||||
verb=request.method,
|
||||
request_uri=str(request.url),
|
||||
response_status_code=context.metadata.get("response_status_code", None),
|
||||
source_ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
request_object=context.request_body.decode("utf-8", errors="replace"),
|
||||
response_object=context.response_body.decode("utf-8", errors="replace"),
|
||||
)
|
||||
|
||||
self.audit_logger.write(entry)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit entry: {str(e)}")
|
140
backend/open_webui/utils/logger.py
Normal file
140
backend/open_webui/utils/logger.py
Normal file
@ -0,0 +1,140 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from open_webui.env import (
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||
AUDIT_LOG_LEVEL,
|
||||
AUDIT_LOGS_FILE_PATH,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Record
|
||||
|
||||
|
||||
def stdout_format(record: "Record") -> str:
|
||||
"""
|
||||
Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON).
|
||||
|
||||
Parameters:
|
||||
record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context.
|
||||
Returns:
|
||||
str: A formatted log string intended for stdout.
|
||||
"""
|
||||
record["extra"]["extra_json"] = json.dumps(record["extra"])
|
||||
return (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
"<level>{message}</level> - {extra[extra_json]}"
|
||||
"\n{exception}"
|
||||
)
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
"""
|
||||
Intercepts log records from Python's standard logging module
|
||||
and redirects them to Loguru's logger.
|
||||
"""
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Called by the standard logging module for each log event.
|
||||
It transforms the standard `LogRecord` into a format compatible with Loguru
|
||||
and passes it to Loguru's logger.
|
||||
"""
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
frame, depth = sys._getframe(6), 6
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
def file_format(record: "Record"):
|
||||
"""
|
||||
Formats audit log records into a structured JSON string for file output.
|
||||
|
||||
Parameters:
|
||||
record (Record): A Loguru record containing extra audit data.
|
||||
Returns:
|
||||
str: A JSON-formatted string representing the audit data.
|
||||
"""
|
||||
|
||||
audit_data = {
|
||||
"id": record["extra"].get("id", ""),
|
||||
"timestamp": int(record["time"].timestamp()),
|
||||
"user": record["extra"].get("user", dict()),
|
||||
"audit_level": record["extra"].get("audit_level", ""),
|
||||
"verb": record["extra"].get("verb", ""),
|
||||
"request_uri": record["extra"].get("request_uri", ""),
|
||||
"response_status_code": record["extra"].get("response_status_code", 0),
|
||||
"source_ip": record["extra"].get("source_ip", ""),
|
||||
"user_agent": record["extra"].get("user_agent", ""),
|
||||
"request_object": record["extra"].get("request_object", b""),
|
||||
"response_object": record["extra"].get("response_object", b""),
|
||||
"extra": record["extra"].get("extra", {}),
|
||||
}
|
||||
|
||||
record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
|
||||
return "{extra[file_extra]}\n"
|
||||
|
||||
|
||||
def start_logger():
|
||||
"""
|
||||
Initializes and configures Loguru's logger with distinct handlers:
|
||||
|
||||
A console (stdout) handler for general log messages (excluding those marked as auditable).
|
||||
An optional file handler for audit logs if audit logging is enabled.
|
||||
Additionally, this function reconfigures Python’s standard logging to route through Loguru and adjusts logging levels for Uvicorn.
|
||||
|
||||
Parameters:
|
||||
enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file.
|
||||
"""
|
||||
logger.remove()
|
||||
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=GLOBAL_LOG_LEVEL,
|
||||
format=stdout_format,
|
||||
filter=lambda record: "auditable" not in record["extra"],
|
||||
)
|
||||
|
||||
if AUDIT_LOG_LEVEL != "NONE":
|
||||
try:
|
||||
logger.add(
|
||||
AUDIT_LOGS_FILE_PATH,
|
||||
level="INFO",
|
||||
rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||
compression="zip",
|
||||
format=file_format,
|
||||
filter=lambda record: record["extra"].get("auditable") is True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize audit log file handler: {str(e)}")
|
||||
|
||||
logging.basicConfig(
|
||||
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
|
||||
)
|
||||
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
|
||||
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||
uvicorn_logger.handlers = []
|
||||
for uvicorn_logger_name in ["uvicorn.access"]:
|
||||
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||
uvicorn_logger.handlers = [InterceptHandler()]
|
||||
|
||||
logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
@ -31,6 +31,9 @@ APScheduler==3.10.4
|
||||
|
||||
RestrictedPython==8.0
|
||||
|
||||
loguru==0.7.2
|
||||
asgiref==3.8.1
|
||||
|
||||
# AI libraries
|
||||
openai
|
||||
anthropic
|
||||
|
Loading…
Reference in New Issue
Block a user