import re import os from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from typing import Dict class SecurityHeadersMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): response = await call_next(request) response.headers.update(set_security_headers()) return response def set_security_headers() -> Dict[str, str]: """ Sets security headers based on environment variables. This function reads specific environment variables and uses their values to set corresponding security headers. The headers that can be set are: - cache-control - strict-transport-security - referrer-policy - x-content-type-options - x-download-options - x-frame-options - x-permitted-cross-domain-policies Each environment variable is associated with a specific setter function that constructs the header. If the environment variable is set, the corresponding header is added to the options dictionary. Returns: dict: A dictionary containing the security headers and their values. """ options = {} header_setters = { 'CACHE_CONTROL': set_cache_control, 'HSTS': set_hsts, 'REFERRER_POLICY': set_referrer, 'XCONTENT_TYPE': set_xcontent_type, 'XDOWNLOAD_OPTIONS': set_xdownload_options, 'XFRAME_OPTIONS': set_xframe, 'XPERMITTED_CROSS_DOMAIN_POLICIES': set_xpermitted_cross_domain_policies, } for env_var, setter in header_setters.items(): value = os.environ.get(env_var, None) if value: header = setter(value) if header: options.update(header) return options # Set HTTP Strict Transport Security(HSTS) response header def set_hsts(value: str): pattern = r'^max-age=(\d+)(;includeSubDomains)?(;preload)?$' match = re.match(pattern, value, re.IGNORECASE) if not match: return 'max-age=31536000;includeSubDomains' return { 'Strict-Transport-Security': value } # Set X-Frame-Options response header def set_xframe(value: str): pattern = r'^(DENY|SAMEORIGIN)$' match = re.match(pattern, value, re.IGNORECASE) if not match: value = 'DENY' return { "X-Frame-Options": value } # Set Referrer-Policy response header def set_referrer(value: str): pattern = r'^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$' match = re.match(pattern, value, re.IGNORECASE) if not match: value = 'no-referrer' return { 'Referrer-Policy': value } # Set Cache-Control response header def set_cache_control(value: str): pattern = r'^(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable)(,\s*(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable))*$' match = re.match(pattern, value, re.IGNORECASE) if not match: value = 'no-store, max-age=0' return { 'Cache-Control': value } # Set X-Download-Options response header def set_xdownload_options(value: str): if value != 'noopen': value = 'noopen' return { 'X-Download-Options': value } # Set X-Content-Type-Options response header def set_xcontent_type(value: str): if value != 'nosniff': value = 'nosniff' return { 'X-Content-Type-Options': value } # Set X-Permitted-Cross-Domain-Policies response header def set_xpermitted_cross_domain_policies(value: str): pattern = r'^(none|master-only|by-content-type|by-ftp-filename)$' match = re.match(pattern, value, re.IGNORECASE) if not match: value = 'none' return { 'X-Permitted-Cross-Domain-Policies': value }