diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 96e288d77..ba546a2eb 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -419,3 +419,25 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" if OFFLINE_MODE: os.environ["HF_HUB_OFFLINE"] = "1" + +#################################### +# AUDIT LOGGING +#################################### +ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true" +# Where to store log file +AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log" +# Maximum size of a file before rotating into a new log file +AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB") +# METADATA | REQUEST | REQUEST_RESPONSE +AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper() +try: + MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048) +except ValueError: + MAX_BODY_LOG_SIZE = 2048 + +# Comma separated list for urls to exclude from audit +AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split( + "," +) +AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS] +AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS] diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 346d28d6c..e2f97ddda 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -45,6 +45,9 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse +from open_webui.utils import logger +from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware +from open_webui.utils.logger import start_logger from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, @@ -298,8 +301,11 @@ from open_webui.config import ( reset_config, ) from open_webui.env import ( + AUDIT_EXCLUDED_PATHS, + AUDIT_LOG_LEVEL, CHANGELOG, GLOBAL_LOG_LEVEL, + MAX_BODY_LOG_SIZE, SAFE_MODE, SRC_LOG_LEVELS, VERSION, @@ -384,6 +390,7 @@ https://github.com/open-webui/open-webui @asynccontextmanager async def lifespan(app: FastAPI): + start_logger() if RESET_CONFIG_ON_START: reset_config() @@ -879,6 +886,19 @@ app.include_router( app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) +try: + audit_level = AuditLevel(AUDIT_LOG_LEVEL) +except ValueError as e: + logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}") + audit_level = AuditLevel.NONE + +if audit_level != AuditLevel.NONE: + app.add_middleware( + AuditLoggingMiddleware, + audit_level=audit_level, + excluded_paths=AUDIT_EXCLUDED_PATHS, + max_body_size=MAX_BODY_LOG_SIZE, + ) ################################## # # Chat Endpoints diff --git a/backend/open_webui/utils/audit.py b/backend/open_webui/utils/audit.py new file mode 100644 index 000000000..95c0745a9 --- /dev/null +++ b/backend/open_webui/utils/audit.py @@ -0,0 +1,249 @@ +from contextlib import asynccontextmanager +from dataclasses import asdict, dataclass +from enum import Enum +import re +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Dict, + MutableMapping, + Optional, + cast, +) +import uuid + +from asgiref.typing import ( + ASGI3Application, + ASGIReceiveCallable, + ASGIReceiveEvent, + ASGISendCallable, + ASGISendEvent, + Scope as ASGIScope, +) +from loguru import logger +from starlette.requests import Request + +from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE +from open_webui.utils.auth import get_current_user, get_http_authorization_cred +from open_webui.models.users import UserModel + + +if TYPE_CHECKING: + from loguru import Logger + + +@dataclass(frozen=True) +class AuditLogEntry: + # `Metadata` audit level properties + id: str + user: dict[str, Any] + audit_level: str + verb: str + request_uri: str + user_agent: Optional[str] = None + source_ip: Optional[str] = None + # `Request` audit level properties + request_object: Any = None + # `Request Response` level + response_object: Any = None + response_status_code: Optional[int] = None + + +class AuditLevel(str, Enum): + NONE = "NONE" + METADATA = "METADATA" + REQUEST = "REQUEST" + REQUEST_RESPONSE = "REQUEST_RESPONSE" + + +class AuditLogger: + """ + A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly. + + Parameters: + logger (Logger): An instance of Loguru’s logger. + """ + + def __init__(self, logger: "Logger"): + self.logger = logger.bind(auditable=True) + + def write( + self, + audit_entry: AuditLogEntry, + *, + log_level: str = "INFO", + extra: Optional[dict] = None, + ): + + entry = asdict(audit_entry) + + if extra: + entry["extra"] = extra + + self.logger.log( + log_level, + "", + **entry, + ) + + +class AuditContext: + """ + Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage. + + Attributes: + request_body (bytearray): Accumulated request payload. + response_body (bytearray): Accumulated response payload. + max_body_size (int): Maximum number of bytes to capture. + metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.). + """ + + def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE): + self.request_body = bytearray() + self.response_body = bytearray() + self.max_body_size = max_body_size + self.metadata: Dict[str, Any] = {} + + def add_request_chunk(self, chunk: bytes): + if len(self.request_body) < self.max_body_size: + self.request_body.extend( + chunk[: self.max_body_size - len(self.request_body)] + ) + + def add_response_chunk(self, chunk: bytes): + if len(self.response_body) < self.max_body_size: + self.response_body.extend( + chunk[: self.max_body_size - len(self.response_body)] + ) + + +class AuditLoggingMiddleware: + """ + ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle. + """ + + AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"} + + def __init__( + self, + app: ASGI3Application, + *, + excluded_paths: Optional[list[str]] = None, + max_body_size: int = MAX_BODY_LOG_SIZE, + audit_level: AuditLevel = AuditLevel.NONE, + ) -> None: + self.app = app + self.audit_logger = AuditLogger(logger) + self.excluded_paths = excluded_paths or [] + self.max_body_size = max_body_size + self.audit_level = audit_level + + async def __call__( + self, + scope: ASGIScope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + ) -> None: + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope=cast(MutableMapping, scope)) + + if self._should_skip_auditing(request): + return await self.app(scope, receive, send) + + async with self._audit_context(request) as context: + + async def send_wrapper(message: ASGISendEvent) -> None: + if self.audit_level == AuditLevel.REQUEST_RESPONSE: + await self._capture_response(message, context) + + await send(message) + + original_receive = receive + + async def receive_wrapper() -> ASGIReceiveEvent: + nonlocal original_receive + message = await original_receive() + + if self.audit_level in ( + AuditLevel.REQUEST, + AuditLevel.REQUEST_RESPONSE, + ): + await self._capture_request(message, context) + + return message + + await self.app(scope, receive_wrapper, send_wrapper) + + @asynccontextmanager + async def _audit_context( + self, request: Request + ) -> AsyncGenerator[AuditContext, None]: + """ + async context manager that ensures that an audit log entry is recorded after the request is processed. + """ + context = AuditContext() + try: + yield context + finally: + await self._log_audit_entry(request, context) + + async def _get_authenticated_user(self, request: Request) -> UserModel: + + auth_header = request.headers.get("Authorization") + assert auth_header + user = get_current_user(request, get_http_authorization_cred(auth_header)) + + return user + + def _should_skip_auditing(self, request: Request) -> bool: + if ( + request.method not in {"POST", "PUT", "PATCH", "DELETE"} + or AUDIT_LOG_LEVEL == "NONE" + or not request.headers.get("authorization") + ): + return True + # match either /api//...(for the endpoint /api/chat case) or /api/v1//... + pattern = re.compile( + r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b" + ) + if pattern.match(request.url.path): + return True + + return False + + async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext): + if message["type"] == "http.request": + body = message.get("body", b"") + context.add_request_chunk(body) + + async def _capture_response(self, message: ASGISendEvent, context: AuditContext): + if message["type"] == "http.response.start": + context.metadata["response_status_code"] = message["status"] + + elif message["type"] == "http.response.body": + body = message.get("body", b"") + context.add_response_chunk(body) + + async def _log_audit_entry(self, request: Request, context: AuditContext): + try: + user = await self._get_authenticated_user(request) + + entry = AuditLogEntry( + id=str(uuid.uuid4()), + user=user.model_dump(include={"id", "name", "email", "role"}), + audit_level=self.audit_level.value, + verb=request.method, + request_uri=str(request.url), + response_status_code=context.metadata.get("response_status_code", None), + source_ip=request.client.host if request.client else None, + user_agent=request.headers.get("user-agent"), + request_object=context.request_body.decode("utf-8", errors="replace"), + response_object=context.response_body.decode("utf-8", errors="replace"), + ) + + self.audit_logger.write(entry) + except Exception as e: + logger.error(f"Failed to log audit entry: {str(e)}") diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py new file mode 100644 index 000000000..255761006 --- /dev/null +++ b/backend/open_webui/utils/logger.py @@ -0,0 +1,140 @@ +import json +import logging +import sys +from typing import TYPE_CHECKING + +from loguru import logger + +from open_webui.env import ( + AUDIT_LOG_FILE_ROTATION_SIZE, + AUDIT_LOG_LEVEL, + AUDIT_LOGS_FILE_PATH, + GLOBAL_LOG_LEVEL, +) + + +if TYPE_CHECKING: + from loguru import Record + + +def stdout_format(record: "Record") -> str: + """ + Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON). + + Parameters: + record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context. + Returns: + str: A formatted log string intended for stdout. + """ + record["extra"]["extra_json"] = json.dumps(record["extra"]) + return ( + "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{name}:{function}:{line} - " + "{message} - {extra[extra_json]}" + "\n{exception}" + ) + + +class InterceptHandler(logging.Handler): + """ + Intercepts log records from Python's standard logging module + and redirects them to Loguru's logger. + """ + + def emit(self, record): + """ + Called by the standard logging module for each log event. + It transforms the standard `LogRecord` into a format compatible with Loguru + and passes it to Loguru's logger. + """ + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + frame, depth = sys._getframe(6), 6 + while frame and frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + +def file_format(record: "Record"): + """ + Formats audit log records into a structured JSON string for file output. + + Parameters: + record (Record): A Loguru record containing extra audit data. + Returns: + str: A JSON-formatted string representing the audit data. + """ + + audit_data = { + "id": record["extra"].get("id", ""), + "timestamp": int(record["time"].timestamp()), + "user": record["extra"].get("user", dict()), + "audit_level": record["extra"].get("audit_level", ""), + "verb": record["extra"].get("verb", ""), + "request_uri": record["extra"].get("request_uri", ""), + "response_status_code": record["extra"].get("response_status_code", 0), + "source_ip": record["extra"].get("source_ip", ""), + "user_agent": record["extra"].get("user_agent", ""), + "request_object": record["extra"].get("request_object", b""), + "response_object": record["extra"].get("response_object", b""), + "extra": record["extra"].get("extra", {}), + } + + record["extra"]["file_extra"] = json.dumps(audit_data, default=str) + return "{extra[file_extra]}\n" + + +def start_logger(): + """ + Initializes and configures Loguru's logger with distinct handlers: + + A console (stdout) handler for general log messages (excluding those marked as auditable). + An optional file handler for audit logs if audit logging is enabled. + Additionally, this function reconfigures Python’s standard logging to route through Loguru and adjusts logging levels for Uvicorn. + + Parameters: + enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file. + """ + logger.remove() + + logger.add( + sys.stdout, + level=GLOBAL_LOG_LEVEL, + format=stdout_format, + filter=lambda record: "auditable" not in record["extra"], + ) + + if AUDIT_LOG_LEVEL != "NONE": + try: + logger.add( + AUDIT_LOGS_FILE_PATH, + level="INFO", + rotation=AUDIT_LOG_FILE_ROTATION_SIZE, + compression="zip", + format=file_format, + filter=lambda record: record["extra"].get("auditable") is True, + ) + except Exception as e: + logger.error(f"Failed to initialize audit log file handler: {str(e)}") + + logging.basicConfig( + handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True + ) + for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]: + uvicorn_logger = logging.getLogger(uvicorn_logger_name) + uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) + uvicorn_logger.handlers = [] + for uvicorn_logger_name in ["uvicorn.access"]: + uvicorn_logger = logging.getLogger(uvicorn_logger_name) + uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) + uvicorn_logger.handlers = [InterceptHandler()] + + logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") diff --git a/backend/requirements.txt b/backend/requirements.txt index 965741f78..a04f49105 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -31,6 +31,9 @@ APScheduler==3.10.4 RestrictedPython==8.0 +loguru==0.7.2 +asgiref==3.8.1 + # AI libraries openai anthropic