This commit is contained in:
Timothy J. Baek 2024-05-28 09:50:17 -07:00
parent b870fd1118
commit e231333bcd

View File

@ -12,6 +12,7 @@ import mimetypes
from fastapi import FastAPI, Request, Depends, status from fastapi import FastAPI, Request, Depends, status
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -123,15 +124,6 @@ app.state.MODELS = {}
origins = ["*"] origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Custom middleware to add security headers # Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware): # class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next): # async def dispatch(self, request: Request, call_next):
@ -276,10 +268,8 @@ class PipelineMiddleware(BaseHTTPMiddleware):
except: except:
pass pass
print(sorted_filters)
for filter in sorted_filters: for filter in sorted_filters:
r = None
try: try:
urlIdx = filter["urlIdx"] urlIdx = filter["urlIdx"]
@ -303,7 +293,20 @@ class PipelineMiddleware(BaseHTTPMiddleware):
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") print(f"Connection error: {e}")
pass
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
@ -328,6 +331,15 @@ class PipelineMiddleware(BaseHTTPMiddleware):
app.add_middleware(PipelineMiddleware) app.add_middleware(PipelineMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0: if len(app.state.MODELS) == 0: