diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index f16f2ea6e..e862c6b35 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -405,3 +405,12 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" if OFFLINE_MODE: os.environ["HF_HUB_OFFLINE"] = "1" + +#################################### +# AUDIT LOGGING +#################################### +ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true" +AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log" +AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB") +AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "METADATA").upper() +MAX_BODY_LOG_SIZE = 2048 diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 52337a640..a22182b06 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -8,13 +8,15 @@ import shutil import sys import time import random +import re from contextlib import asynccontextmanager from urllib.parse import urlencode, parse_qs, urlparse +from loguru import logger from pydantic import BaseModel from sqlalchemy import text -from typing import Optional +from typing import Any, Optional, cast from aiocache import cached import aiohttp import requests @@ -39,12 +41,16 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles +from starlette.background import BackgroundTask from starlette.exceptions import HTTPException as StarletteHTTPException -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse +from starlette.types import Message +from open_webui.models.audits import AuditLevel +from open_webui.utils.logger import AuditLogger, start_logger from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, @@ -262,8 +268,11 @@ from open_webui.config import ( reset_config, ) from open_webui.env import ( + AUDIT_LOG_LEVEL, CHANGELOG, + ENABLE_AUDIT_LOGS, GLOBAL_LOG_LEVEL, + MAX_BODY_LOG_SIZE, SAFE_MODE, SRC_LOG_LEVELS, VERSION, @@ -296,6 +305,8 @@ from open_webui.utils.access_control import has_access from open_webui.utils.auth import ( decode_token, get_admin_user, + get_current_user, + get_http_authorization_cred, get_verified_user, ) from open_webui.utils.oauth import oauth_manager @@ -342,6 +353,8 @@ https://github.com/open-webui/open-webui @asynccontextmanager async def lifespan(app: FastAPI): + start_logger(ENABLE_AUDIT_LOGS) + if RESET_CONFIG_ON_START: reset_config() @@ -731,6 +744,156 @@ app.add_middleware( allow_headers=["*"], ) +URL_RESOURCE_PATTERN = re.compile(r"/v\d+/([^/]+)") +URL_API_VERSION_PATTERN = re.compile(r"/api/(\w+)/") + + +class AuditLoggerMiddleware(BaseHTTPMiddleware): + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + if request.method == "GET": + return await call_next(request) + + auth_header = request.headers.get("Authorization") + if not auth_header: + return await call_next(request) + + audit_log_level = AuditLevel(AUDIT_LOG_LEVEL) + + try: + user = get_current_user( + request, + get_http_authorization_cred(auth_header), + ) + except Exception: + logger.error( + f"Audit Logger - logging failed while retrieving user for request URL: {str(request.url)}" + ) + return await call_next(request) + + audit_logger = AuditLogger(logger, user) + + user_agent = request.headers.get("user-agent", "") + source_ip = request.client.host if request.client else "" + request_uri = str(request.url) + api_version, resource = self._get_api_version_and_resource(request_uri) + + raw_request_body = await request.body() + truncated_request_body = raw_request_body[:MAX_BODY_LOG_SIZE] + + async def _receive() -> Message: + return { + "type": "http.request", + "body": raw_request_body, + "more_body": False, + } + + modified_request = Request(request.scope, receive=_receive) + + request_log_data = None + if audit_log_level in (AuditLevel.REQUEST, AuditLevel.REQUEST_RESPONSE): + request_log_data = self._build_request_log_data( + request, truncated_request_body + ) + + response = await call_next(modified_request) + + response, response_content = await self._read_response_content(response) + truncated_response_content = response_content[:MAX_BODY_LOG_SIZE] + + response_log_data = None + if audit_log_level == AuditLevel.REQUEST_RESPONSE: + response_log_data = self._build_response_log_data( + response, truncated_response_content + ) + + task = BackgroundTask( + audit_logger.write, + api_version, + VERSION, + request.method, + audit_log_level, + resource, + source_ip, + user_agent, + request_uri, + request_object=request_log_data, + response_object=response_log_data, + ) + + return Response( + content=response_content, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + background=task, + ) + + def _build_request_log_data(self, request: Request, body: bytes) -> dict: + """ + Build a dictionary representing the request data for logging. + Sensitive headers are removed later in the AuditLogger. + """ + request_data: dict[str, Any] = { + "headers": dict(request.headers), + "query_params": dict(request.query_params), + "path_params": dict(request.path_params), + "body": body.decode("utf-8", errors="replace"), + } + return request_data + + def _build_response_log_data(self, response: Response, content: bytes) -> dict: + response_data = { + "headers": dict(response.headers), + "status_code": response.status_code, + "media_type": response.media_type, + "content": content.decode("utf-8", errors="replace"), + } + return response_data + + def _get_api_version_and_resource(self, url: str) -> tuple[str, str]: + resource_match = re.search(URL_RESOURCE_PATTERN, url) + resource = resource_match.group(1) if resource_match else "" + + api_version_match = re.search(URL_API_VERSION_PATTERN, url) + api_version = api_version_match.group(1) if api_version_match else "" + + return api_version, resource + + async def _read_response_content( + self, response: Response + ) -> tuple[Response, bytes]: + """ + Read the entire response content from a StreamingResponse into memory + and return a new regular Response along with the content bytes. + """ + if not isinstance(response, StreamingResponse): + body = response.body + new_response = Response( + content=body, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + return new_response, body + + content = b"".join( + [cast(bytes, chunk) async for chunk in response.body_iterator] + ) + + new_response = Response( + content=content, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + return new_response, content + + +if not AUDIT_LOG_LEVEL == "NONE": + app.add_middleware(AuditLoggerMiddleware) app.mount("/ws", socket_app) diff --git a/backend/open_webui/models/audits.py b/backend/open_webui/models/audits.py new file mode 100644 index 000000000..42ea07ac3 --- /dev/null +++ b/backend/open_webui/models/audits.py @@ -0,0 +1,25 @@ +from typing import Optional + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import Enum + + +class UserAuditInfo(BaseModel): + id: str + name: str + email: str + role: str + oauth_sub: Optional[str] + + model_config = ConfigDict(from_attributes=True) + + +class AuditLevel(str, Enum): + NONE = "NONE" + METADATA = "METADATA" + REQUEST = "REQUEST" + REQUEST_RESPONSE = "REQUEST_RESPONSE" + + @classmethod + def _missing_(cls, value): + return cls.NONE diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py new file mode 100644 index 000000000..071d497ec --- /dev/null +++ b/backend/open_webui/utils/logger.py @@ -0,0 +1,174 @@ +import json +import logging +import sys +from typing import TYPE_CHECKING, Optional + +from loguru import logger + +from open_webui.models.audits import UserAuditInfo +from open_webui.env import ( + AUDIT_LOG_FILE_ROTATION_SIZE, + AUDIT_LOGS_FILE_PATH, + GLOBAL_LOG_LEVEL, +) +from open_webui.models.users import UserModel + + +if TYPE_CHECKING: + from loguru import Logger, Message, Record + + +def stdout_format(record: "Record") -> str: + record["extra"]["extra_json"] = json.dumps(record["extra"]) + return ( + "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{name}:{function}:{line} - " + "{message} - {extra[extra_json]}" + "\n{exception}" + ) + + +class InterceptHandler(logging.Handler): + """ + Intercepts log records from Python's standard logging module + and redirects them to Loguru's logger. + """ + + def emit(self, record): + """ + Called by the standard logging module for each log event. + It transforms the standard `LogRecord` into a format compatible with Loguru + and passes it to Loguru's logger. + """ + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + frame, depth = sys._getframe(6), 6 + while frame and frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + +class AuditLogger: + def __init__(self, logger: "Logger", user: UserModel): + self.logger = logger.bind(auditable=True) + self.user = user + + def write( + self, + api_version: str, + open_webui_version, + http_method: str, + audit_level: str, + resource: str, + source_ip: str, + user_agent: str, + request_uri: str, + *, + user: Optional[UserModel] = None, + log_level: str = "INFO", + request_object: Optional[dict] = None, + response_object: Optional[dict] = None, + extra: Optional[dict] = None, + ): + + user = user or self.user + + if request_object and "headers" in request_object: + request_object["headers"].pop("Authorization", None) + + log_extra = { + "user": user.model_dump(), + "api_version": api_version, + "open_webui_version": open_webui_version, + "http_method": http_method, + "audit_level": audit_level, + "resource": resource, + "source_ip": source_ip, + "user_agent": user_agent, + "request_uri": request_uri, + "request_object": request_object, + "response_object": response_object, + "extra": extra, + } + + if extra: + log_extra.update(extra) + + event = self._format_event(resource, http_method) + + self.logger.log( + log_level, + event, + **log_extra, + ) + + def _format_event(self, resource: str, method: str) -> str: + return f"RESOURCE_{resource}_{method}_EVENT" + + +def file_format(record: "Record"): + + user = record["extra"].get("user", dict()) + user_audit_info = UserAuditInfo.model_validate(user) + + audit_data = { + "timestamp": int(record["time"].timestamp()), + "user": user_audit_info.model_dump(), + "api_version": record["extra"].get("api_version"), + "http_method": record["extra"].get("http_method"), + "audit_level": record["extra"].get("audit_level"), + "log_level": record["level"].name, + "resource": record["extra"].get("resource"), + "source_ip": record["extra"].get("source_ip"), + "user_agent": record["extra"].get("user_agent"), + "request_uri": record["extra"].get("request_uri"), + "request_object": record["extra"].get("request_object"), + "response_object": record["extra"].get("response_object"), + "extra": record["extra"].get("extra", {}), + } + + record["extra"]["file_extra"] = json.dumps(audit_data, default=str) + return "{extra[file_extra]}\n" + + +def start_logger(enable_audit_logging: bool): + logger.remove() + + logger.add( + sys.stdout, + level=GLOBAL_LOG_LEVEL, + format=stdout_format, + filter=lambda record: "auditable" not in record["extra"], + ) + + if enable_audit_logging: + logger.add( + AUDIT_LOGS_FILE_PATH, + level=GLOBAL_LOG_LEVEL, + rotation=AUDIT_LOG_FILE_ROTATION_SIZE, + compression="zip", + format=file_format, + filter=lambda record: record["extra"].get("auditable") is True, + ) + + logging.basicConfig( + handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True + ) + for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]: + uvicorn_logger = logging.getLogger(uvicorn_logger_name) + uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) + uvicorn_logger.handlers = [] + for uvicorn_logger_name in ["uvicorn.access"]: + uvicorn_logger = logging.getLogger(uvicorn_logger_name) + uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) + uvicorn_logger.handlers = [InterceptHandler()] + + logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")