mirror of
https://github.com/open-webui/open-webui
synced 2025-01-19 01:06:45 +00:00
feat: add audit logging to file
This commit is contained in:
parent
770be671d2
commit
e0520382bd
@ -405,3 +405,12 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
||||
|
||||
if OFFLINE_MODE:
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
####################################
|
||||
# AUDIT LOGGING
|
||||
####################################
|
||||
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
|
||||
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "METADATA").upper()
|
||||
MAX_BODY_LOG_SIZE = 2048
|
||||
|
@ -8,13 +8,15 @@ import shutil
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, cast
|
||||
from aiocache import cached
|
||||
import aiohttp
|
||||
import requests
|
||||
@ -39,12 +41,16 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from starlette.types import Message
|
||||
|
||||
|
||||
from open_webui.models.audits import AuditLevel
|
||||
from open_webui.utils.logger import AuditLogger, start_logger
|
||||
from open_webui.socket.main import (
|
||||
app as socket_app,
|
||||
periodic_usage_pool_cleanup,
|
||||
@ -262,8 +268,11 @@ from open_webui.config import (
|
||||
reset_config,
|
||||
)
|
||||
from open_webui.env import (
|
||||
AUDIT_LOG_LEVEL,
|
||||
CHANGELOG,
|
||||
ENABLE_AUDIT_LOGS,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
MAX_BODY_LOG_SIZE,
|
||||
SAFE_MODE,
|
||||
SRC_LOG_LEVELS,
|
||||
VERSION,
|
||||
@ -296,6 +305,8 @@ from open_webui.utils.access_control import has_access
|
||||
from open_webui.utils.auth import (
|
||||
decode_token,
|
||||
get_admin_user,
|
||||
get_current_user,
|
||||
get_http_authorization_cred,
|
||||
get_verified_user,
|
||||
)
|
||||
from open_webui.utils.oauth import oauth_manager
|
||||
@ -342,6 +353,8 @@ https://github.com/open-webui/open-webui
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
start_logger(ENABLE_AUDIT_LOGS)
|
||||
|
||||
if RESET_CONFIG_ON_START:
|
||||
reset_config()
|
||||
|
||||
@ -731,6 +744,156 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
URL_RESOURCE_PATTERN = re.compile(r"/v\d+/([^/]+)")
|
||||
URL_API_VERSION_PATTERN = re.compile(r"/api/(\w+)/")
|
||||
|
||||
|
||||
class AuditLoggerMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
if request.method == "GET":
|
||||
return await call_next(request)
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return await call_next(request)
|
||||
|
||||
audit_log_level = AuditLevel(AUDIT_LOG_LEVEL)
|
||||
|
||||
try:
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(auth_header),
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Audit Logger - logging failed while retrieving user for request URL: {str(request.url)}"
|
||||
)
|
||||
return await call_next(request)
|
||||
|
||||
audit_logger = AuditLogger(logger, user)
|
||||
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
source_ip = request.client.host if request.client else ""
|
||||
request_uri = str(request.url)
|
||||
api_version, resource = self._get_api_version_and_resource(request_uri)
|
||||
|
||||
raw_request_body = await request.body()
|
||||
truncated_request_body = raw_request_body[:MAX_BODY_LOG_SIZE]
|
||||
|
||||
async def _receive() -> Message:
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": raw_request_body,
|
||||
"more_body": False,
|
||||
}
|
||||
|
||||
modified_request = Request(request.scope, receive=_receive)
|
||||
|
||||
request_log_data = None
|
||||
if audit_log_level in (AuditLevel.REQUEST, AuditLevel.REQUEST_RESPONSE):
|
||||
request_log_data = self._build_request_log_data(
|
||||
request, truncated_request_body
|
||||
)
|
||||
|
||||
response = await call_next(modified_request)
|
||||
|
||||
response, response_content = await self._read_response_content(response)
|
||||
truncated_response_content = response_content[:MAX_BODY_LOG_SIZE]
|
||||
|
||||
response_log_data = None
|
||||
if audit_log_level == AuditLevel.REQUEST_RESPONSE:
|
||||
response_log_data = self._build_response_log_data(
|
||||
response, truncated_response_content
|
||||
)
|
||||
|
||||
task = BackgroundTask(
|
||||
audit_logger.write,
|
||||
api_version,
|
||||
VERSION,
|
||||
request.method,
|
||||
audit_log_level,
|
||||
resource,
|
||||
source_ip,
|
||||
user_agent,
|
||||
request_uri,
|
||||
request_object=request_log_data,
|
||||
response_object=response_log_data,
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=response_content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
background=task,
|
||||
)
|
||||
|
||||
def _build_request_log_data(self, request: Request, body: bytes) -> dict:
|
||||
"""
|
||||
Build a dictionary representing the request data for logging.
|
||||
Sensitive headers are removed later in the AuditLogger.
|
||||
"""
|
||||
request_data: dict[str, Any] = {
|
||||
"headers": dict(request.headers),
|
||||
"query_params": dict(request.query_params),
|
||||
"path_params": dict(request.path_params),
|
||||
"body": body.decode("utf-8", errors="replace"),
|
||||
}
|
||||
return request_data
|
||||
|
||||
def _build_response_log_data(self, response: Response, content: bytes) -> dict:
|
||||
response_data = {
|
||||
"headers": dict(response.headers),
|
||||
"status_code": response.status_code,
|
||||
"media_type": response.media_type,
|
||||
"content": content.decode("utf-8", errors="replace"),
|
||||
}
|
||||
return response_data
|
||||
|
||||
def _get_api_version_and_resource(self, url: str) -> tuple[str, str]:
|
||||
resource_match = re.search(URL_RESOURCE_PATTERN, url)
|
||||
resource = resource_match.group(1) if resource_match else ""
|
||||
|
||||
api_version_match = re.search(URL_API_VERSION_PATTERN, url)
|
||||
api_version = api_version_match.group(1) if api_version_match else ""
|
||||
|
||||
return api_version, resource
|
||||
|
||||
async def _read_response_content(
|
||||
self, response: Response
|
||||
) -> tuple[Response, bytes]:
|
||||
"""
|
||||
Read the entire response content from a StreamingResponse into memory
|
||||
and return a new regular Response along with the content bytes.
|
||||
"""
|
||||
if not isinstance(response, StreamingResponse):
|
||||
body = response.body
|
||||
new_response = Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
return new_response, body
|
||||
|
||||
content = b"".join(
|
||||
[cast(bytes, chunk) async for chunk in response.body_iterator]
|
||||
)
|
||||
|
||||
new_response = Response(
|
||||
content=content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
return new_response, content
|
||||
|
||||
|
||||
if not AUDIT_LOG_LEVEL == "NONE":
|
||||
app.add_middleware(AuditLoggerMiddleware)
|
||||
|
||||
app.mount("/ws", socket_app)
|
||||
|
||||
|
25
backend/open_webui/models/audits.py
Normal file
25
backend/open_webui/models/audits.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import Enum
|
||||
|
||||
|
||||
class UserAuditInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str
|
||||
oauth_sub: Optional[str]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AuditLevel(str, Enum):
|
||||
NONE = "NONE"
|
||||
METADATA = "METADATA"
|
||||
REQUEST = "REQUEST"
|
||||
REQUEST_RESPONSE = "REQUEST_RESPONSE"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
return cls.NONE
|
174
backend/open_webui/utils/logger.py
Normal file
174
backend/open_webui/utils/logger.py
Normal file
@ -0,0 +1,174 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from open_webui.models.audits import UserAuditInfo
|
||||
from open_webui.env import (
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||
AUDIT_LOGS_FILE_PATH,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
)
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Logger, Message, Record
|
||||
|
||||
|
||||
def stdout_format(record: "Record") -> str:
|
||||
record["extra"]["extra_json"] = json.dumps(record["extra"])
|
||||
return (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
"<level>{message}</level> - {extra[extra_json]}"
|
||||
"\n{exception}"
|
||||
)
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
"""
|
||||
Intercepts log records from Python's standard logging module
|
||||
and redirects them to Loguru's logger.
|
||||
"""
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Called by the standard logging module for each log event.
|
||||
It transforms the standard `LogRecord` into a format compatible with Loguru
|
||||
and passes it to Loguru's logger.
|
||||
"""
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
frame, depth = sys._getframe(6), 6
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
def __init__(self, logger: "Logger", user: UserModel):
|
||||
self.logger = logger.bind(auditable=True)
|
||||
self.user = user
|
||||
|
||||
def write(
|
||||
self,
|
||||
api_version: str,
|
||||
open_webui_version,
|
||||
http_method: str,
|
||||
audit_level: str,
|
||||
resource: str,
|
||||
source_ip: str,
|
||||
user_agent: str,
|
||||
request_uri: str,
|
||||
*,
|
||||
user: Optional[UserModel] = None,
|
||||
log_level: str = "INFO",
|
||||
request_object: Optional[dict] = None,
|
||||
response_object: Optional[dict] = None,
|
||||
extra: Optional[dict] = None,
|
||||
):
|
||||
|
||||
user = user or self.user
|
||||
|
||||
if request_object and "headers" in request_object:
|
||||
request_object["headers"].pop("Authorization", None)
|
||||
|
||||
log_extra = {
|
||||
"user": user.model_dump(),
|
||||
"api_version": api_version,
|
||||
"open_webui_version": open_webui_version,
|
||||
"http_method": http_method,
|
||||
"audit_level": audit_level,
|
||||
"resource": resource,
|
||||
"source_ip": source_ip,
|
||||
"user_agent": user_agent,
|
||||
"request_uri": request_uri,
|
||||
"request_object": request_object,
|
||||
"response_object": response_object,
|
||||
"extra": extra,
|
||||
}
|
||||
|
||||
if extra:
|
||||
log_extra.update(extra)
|
||||
|
||||
event = self._format_event(resource, http_method)
|
||||
|
||||
self.logger.log(
|
||||
log_level,
|
||||
event,
|
||||
**log_extra,
|
||||
)
|
||||
|
||||
def _format_event(self, resource: str, method: str) -> str:
|
||||
return f"RESOURCE_{resource}_{method}_EVENT"
|
||||
|
||||
|
||||
def file_format(record: "Record"):
|
||||
|
||||
user = record["extra"].get("user", dict())
|
||||
user_audit_info = UserAuditInfo.model_validate(user)
|
||||
|
||||
audit_data = {
|
||||
"timestamp": int(record["time"].timestamp()),
|
||||
"user": user_audit_info.model_dump(),
|
||||
"api_version": record["extra"].get("api_version"),
|
||||
"http_method": record["extra"].get("http_method"),
|
||||
"audit_level": record["extra"].get("audit_level"),
|
||||
"log_level": record["level"].name,
|
||||
"resource": record["extra"].get("resource"),
|
||||
"source_ip": record["extra"].get("source_ip"),
|
||||
"user_agent": record["extra"].get("user_agent"),
|
||||
"request_uri": record["extra"].get("request_uri"),
|
||||
"request_object": record["extra"].get("request_object"),
|
||||
"response_object": record["extra"].get("response_object"),
|
||||
"extra": record["extra"].get("extra", {}),
|
||||
}
|
||||
|
||||
record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
|
||||
return "{extra[file_extra]}\n"
|
||||
|
||||
|
||||
def start_logger(enable_audit_logging: bool):
|
||||
logger.remove()
|
||||
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=GLOBAL_LOG_LEVEL,
|
||||
format=stdout_format,
|
||||
filter=lambda record: "auditable" not in record["extra"],
|
||||
)
|
||||
|
||||
if enable_audit_logging:
|
||||
logger.add(
|
||||
AUDIT_LOGS_FILE_PATH,
|
||||
level=GLOBAL_LOG_LEVEL,
|
||||
rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||
compression="zip",
|
||||
format=file_format,
|
||||
filter=lambda record: record["extra"].get("auditable") is True,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
|
||||
)
|
||||
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
|
||||
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||
uvicorn_logger.handlers = []
|
||||
for uvicorn_logger_name in ["uvicorn.access"]:
|
||||
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||
uvicorn_logger.handlers = [InterceptHandler()]
|
||||
|
||||
logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
Loading…
Reference in New Issue
Block a user