feat: add audit logging to file

This commit is contained in:
Victor Ribeiro 2024-01-10 11:24:55 -03:00
parent 770be671d2
commit e0520382bd
4 changed files with 373 additions and 2 deletions

View File

@ -405,3 +405,12 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
####################################
# AUDIT LOGGING
####################################
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "METADATA").upper()
MAX_BODY_LOG_SIZE = 2048

View File

@ -8,13 +8,15 @@ import shutil
import sys
import time
import random
import re
from contextlib import asynccontextmanager
from urllib.parse import urlencode, parse_qs, urlparse
from loguru import logger
from pydantic import BaseModel
from sqlalchemy import text
from typing import Optional
from typing import Any, Optional, cast
from aiocache import cached
import aiohttp
import requests
@ -39,12 +41,16 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from starlette.background import BackgroundTask
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
from starlette.types import Message
from open_webui.models.audits import AuditLevel
from open_webui.utils.logger import AuditLogger, start_logger
from open_webui.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
@ -262,8 +268,11 @@ from open_webui.config import (
reset_config,
)
from open_webui.env import (
AUDIT_LOG_LEVEL,
CHANGELOG,
ENABLE_AUDIT_LOGS,
GLOBAL_LOG_LEVEL,
MAX_BODY_LOG_SIZE,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
@ -296,6 +305,8 @@ from open_webui.utils.access_control import has_access
from open_webui.utils.auth import (
decode_token,
get_admin_user,
get_current_user,
get_http_authorization_cred,
get_verified_user,
)
from open_webui.utils.oauth import oauth_manager
@ -342,6 +353,8 @@ https://github.com/open-webui/open-webui
@asynccontextmanager
async def lifespan(app: FastAPI):
start_logger(ENABLE_AUDIT_LOGS)
if RESET_CONFIG_ON_START:
reset_config()
@ -731,6 +744,156 @@ app.add_middleware(
allow_headers=["*"],
)
URL_RESOURCE_PATTERN = re.compile(r"/v\d+/([^/]+)")
URL_API_VERSION_PATTERN = re.compile(r"/api/(\w+)/")
class AuditLoggerMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if request.method == "GET":
return await call_next(request)
auth_header = request.headers.get("Authorization")
if not auth_header:
return await call_next(request)
audit_log_level = AuditLevel(AUDIT_LOG_LEVEL)
try:
user = get_current_user(
request,
get_http_authorization_cred(auth_header),
)
except Exception:
logger.error(
f"Audit Logger - logging failed while retrieving user for request URL: {str(request.url)}"
)
return await call_next(request)
audit_logger = AuditLogger(logger, user)
user_agent = request.headers.get("user-agent", "")
source_ip = request.client.host if request.client else ""
request_uri = str(request.url)
api_version, resource = self._get_api_version_and_resource(request_uri)
raw_request_body = await request.body()
truncated_request_body = raw_request_body[:MAX_BODY_LOG_SIZE]
async def _receive() -> Message:
return {
"type": "http.request",
"body": raw_request_body,
"more_body": False,
}
modified_request = Request(request.scope, receive=_receive)
request_log_data = None
if audit_log_level in (AuditLevel.REQUEST, AuditLevel.REQUEST_RESPONSE):
request_log_data = self._build_request_log_data(
request, truncated_request_body
)
response = await call_next(modified_request)
response, response_content = await self._read_response_content(response)
truncated_response_content = response_content[:MAX_BODY_LOG_SIZE]
response_log_data = None
if audit_log_level == AuditLevel.REQUEST_RESPONSE:
response_log_data = self._build_response_log_data(
response, truncated_response_content
)
task = BackgroundTask(
audit_logger.write,
api_version,
VERSION,
request.method,
audit_log_level,
resource,
source_ip,
user_agent,
request_uri,
request_object=request_log_data,
response_object=response_log_data,
)
return Response(
content=response_content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
background=task,
)
def _build_request_log_data(self, request: Request, body: bytes) -> dict:
"""
Build a dictionary representing the request data for logging.
Sensitive headers are removed later in the AuditLogger.
"""
request_data: dict[str, Any] = {
"headers": dict(request.headers),
"query_params": dict(request.query_params),
"path_params": dict(request.path_params),
"body": body.decode("utf-8", errors="replace"),
}
return request_data
def _build_response_log_data(self, response: Response, content: bytes) -> dict:
response_data = {
"headers": dict(response.headers),
"status_code": response.status_code,
"media_type": response.media_type,
"content": content.decode("utf-8", errors="replace"),
}
return response_data
def _get_api_version_and_resource(self, url: str) -> tuple[str, str]:
resource_match = re.search(URL_RESOURCE_PATTERN, url)
resource = resource_match.group(1) if resource_match else ""
api_version_match = re.search(URL_API_VERSION_PATTERN, url)
api_version = api_version_match.group(1) if api_version_match else ""
return api_version, resource
async def _read_response_content(
self, response: Response
) -> tuple[Response, bytes]:
"""
Read the entire response content from a StreamingResponse into memory
and return a new regular Response along with the content bytes.
"""
if not isinstance(response, StreamingResponse):
body = response.body
new_response = Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
return new_response, body
content = b"".join(
[cast(bytes, chunk) async for chunk in response.body_iterator]
)
new_response = Response(
content=content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
return new_response, content
if not AUDIT_LOG_LEVEL == "NONE":
app.add_middleware(AuditLoggerMiddleware)
app.mount("/ws", socket_app)

View File

@ -0,0 +1,25 @@
from typing import Optional
from pydantic import BaseModel, ConfigDict
from sqlalchemy import Enum
class UserAuditInfo(BaseModel):
id: str
name: str
email: str
role: str
oauth_sub: Optional[str]
model_config = ConfigDict(from_attributes=True)
class AuditLevel(str, Enum):
NONE = "NONE"
METADATA = "METADATA"
REQUEST = "REQUEST"
REQUEST_RESPONSE = "REQUEST_RESPONSE"
@classmethod
def _missing_(cls, value):
return cls.NONE

View File

@ -0,0 +1,174 @@
import json
import logging
import sys
from typing import TYPE_CHECKING, Optional
from loguru import logger
from open_webui.models.audits import UserAuditInfo
from open_webui.env import (
AUDIT_LOG_FILE_ROTATION_SIZE,
AUDIT_LOGS_FILE_PATH,
GLOBAL_LOG_LEVEL,
)
from open_webui.models.users import UserModel
if TYPE_CHECKING:
from loguru import Logger, Message, Record
def stdout_format(record: "Record") -> str:
record["extra"]["extra_json"] = json.dumps(record["extra"])
return (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
"<level>{message}</level> - {extra[extra_json]}"
"\n{exception}"
)
class InterceptHandler(logging.Handler):
"""
Intercepts log records from Python's standard logging module
and redirects them to Loguru's logger.
"""
def emit(self, record):
"""
Called by the standard logging module for each log event.
It transforms the standard `LogRecord` into a format compatible with Loguru
and passes it to Loguru's logger.
"""
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
frame, depth = sys._getframe(6), 6
while frame and frame.f_code.co_filename == logging.__file__:
frame = frame.f_back
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
class AuditLogger:
def __init__(self, logger: "Logger", user: UserModel):
self.logger = logger.bind(auditable=True)
self.user = user
def write(
self,
api_version: str,
open_webui_version,
http_method: str,
audit_level: str,
resource: str,
source_ip: str,
user_agent: str,
request_uri: str,
*,
user: Optional[UserModel] = None,
log_level: str = "INFO",
request_object: Optional[dict] = None,
response_object: Optional[dict] = None,
extra: Optional[dict] = None,
):
user = user or self.user
if request_object and "headers" in request_object:
request_object["headers"].pop("Authorization", None)
log_extra = {
"user": user.model_dump(),
"api_version": api_version,
"open_webui_version": open_webui_version,
"http_method": http_method,
"audit_level": audit_level,
"resource": resource,
"source_ip": source_ip,
"user_agent": user_agent,
"request_uri": request_uri,
"request_object": request_object,
"response_object": response_object,
"extra": extra,
}
if extra:
log_extra.update(extra)
event = self._format_event(resource, http_method)
self.logger.log(
log_level,
event,
**log_extra,
)
def _format_event(self, resource: str, method: str) -> str:
return f"RESOURCE_{resource}_{method}_EVENT"
def file_format(record: "Record"):
user = record["extra"].get("user", dict())
user_audit_info = UserAuditInfo.model_validate(user)
audit_data = {
"timestamp": int(record["time"].timestamp()),
"user": user_audit_info.model_dump(),
"api_version": record["extra"].get("api_version"),
"http_method": record["extra"].get("http_method"),
"audit_level": record["extra"].get("audit_level"),
"log_level": record["level"].name,
"resource": record["extra"].get("resource"),
"source_ip": record["extra"].get("source_ip"),
"user_agent": record["extra"].get("user_agent"),
"request_uri": record["extra"].get("request_uri"),
"request_object": record["extra"].get("request_object"),
"response_object": record["extra"].get("response_object"),
"extra": record["extra"].get("extra", {}),
}
record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
return "{extra[file_extra]}\n"
def start_logger(enable_audit_logging: bool):
logger.remove()
logger.add(
sys.stdout,
level=GLOBAL_LOG_LEVEL,
format=stdout_format,
filter=lambda record: "auditable" not in record["extra"],
)
if enable_audit_logging:
logger.add(
AUDIT_LOGS_FILE_PATH,
level=GLOBAL_LOG_LEVEL,
rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
compression="zip",
format=file_format,
filter=lambda record: record["extra"].get("auditable") is True,
)
logging.basicConfig(
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
)
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = []
for uvicorn_logger_name in ["uvicorn.access"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = [InterceptHandler()]
logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")