refac: enforce api on all routes

This commit is contained in:
Timothy Jaeryang Baek 2024-11-21 20:38:50 -08:00
parent c98ca763bc
commit 1367d95750
2 changed files with 23 additions and 8 deletions

19
main.py
View File

@ -106,28 +106,31 @@ def get_all_pipelines():
return pipelines return pipelines
def parse_frontmatter(content): def parse_frontmatter(content):
frontmatter = {} frontmatter = {}
for line in content.split('\n'): for line in content.split("\n"):
if ':' in line: if ":" in line:
key, value = line.split(':', 1) key, value = line.split(":", 1)
frontmatter[key.strip().lower()] = value.strip() frontmatter[key.strip().lower()] = value.strip()
return frontmatter return frontmatter
def install_frontmatter_requirements(requirements): def install_frontmatter_requirements(requirements):
if requirements: if requirements:
req_list = [req.strip() for req in requirements.split(',')] req_list = [req.strip() for req in requirements.split(",")]
for req in req_list: for req in req_list:
print(f"Installing requirement: {req}") print(f"Installing requirement: {req}")
subprocess.check_call([sys.executable, "-m", "pip", "install", req]) subprocess.check_call([sys.executable, "-m", "pip", "install", req])
else: else:
print("No requirements found in frontmatter.") print("No requirements found in frontmatter.")
async def load_module_from_path(module_name, module_path): async def load_module_from_path(module_name, module_path):
try: try:
# Read the module content # Read the module content
with open(module_path, 'r') as file: with open(module_path, "r") as file:
content = file.read() content = file.read()
# Parse frontmatter # Parse frontmatter
@ -139,8 +142,8 @@ async def load_module_from_path(module_name, module_path):
frontmatter = parse_frontmatter(frontmatter_content) frontmatter = parse_frontmatter(frontmatter_content)
# Install requirements if specified # Install requirements if specified
if 'requirements' in frontmatter: if "requirements" in frontmatter:
install_frontmatter_requirements(frontmatter['requirements']) install_frontmatter_requirements(frontmatter["requirements"])
# Load the module # Load the module
spec = importlib.util.spec_from_file_location(module_name, module_path) spec = importlib.util.spec_from_file_location(module_name, module_path)
@ -277,7 +280,7 @@ async def check_url(request: Request, call_next):
@app.get("/v1/models") @app.get("/v1/models")
@app.get("/models") @app.get("/models")
async def get_models(): async def get_models(user: str = Depends(get_current_user)):
""" """
Returns the available pipelines Returns the available pipelines
""" """

View File

@ -1,6 +1,7 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends
from pydantic import BaseModel from pydantic import BaseModel
from typing import Union, Optional from typing import Union, Optional
@ -14,6 +15,10 @@ import os
import requests import requests
import uuid import uuid
from config import API_KEY, PIPELINES_DIR
SESSION_SECRET = os.getenv("SESSION_SECRET", " ") SESSION_SECRET = os.getenv("SESSION_SECRET", " ")
ALGORITHM = "HS256" ALGORITHM = "HS256"
@ -62,4 +67,11 @@ def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer_security), credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
) -> Optional[dict]: ) -> Optional[dict]:
token = credentials.credentials token = credentials.credentials
if token != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return token return token