refac: enforce api on all routes

This commit is contained in:
Timothy Jaeryang Baek 2024-11-21 20:38:50 -08:00
parent c98ca763bc
commit 1367d95750
2 changed files with 23 additions and 8 deletions

19
main.py
View File

@ -106,28 +106,31 @@ def get_all_pipelines():
return pipelines
def parse_frontmatter(content):
frontmatter = {}
for line in content.split('\n'):
if ':' in line:
key, value = line.split(':', 1)
for line in content.split("\n"):
if ":" in line:
key, value = line.split(":", 1)
frontmatter[key.strip().lower()] = value.strip()
return frontmatter
def install_frontmatter_requirements(requirements):
if requirements:
req_list = [req.strip() for req in requirements.split(',')]
req_list = [req.strip() for req in requirements.split(",")]
for req in req_list:
print(f"Installing requirement: {req}")
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
else:
print("No requirements found in frontmatter.")
async def load_module_from_path(module_name, module_path):
try:
# Read the module content
with open(module_path, 'r') as file:
with open(module_path, "r") as file:
content = file.read()
# Parse frontmatter
@ -139,8 +142,8 @@ async def load_module_from_path(module_name, module_path):
frontmatter = parse_frontmatter(frontmatter_content)
# Install requirements if specified
if 'requirements' in frontmatter:
install_frontmatter_requirements(frontmatter['requirements'])
if "requirements" in frontmatter:
install_frontmatter_requirements(frontmatter["requirements"])
# Load the module
spec = importlib.util.spec_from_file_location(module_name, module_path)
@ -277,7 +280,7 @@ async def check_url(request: Request, call_next):
@app.get("/v1/models")
@app.get("/models")
async def get_models():
async def get_models(user: str = Depends(get_current_user)):
"""
Returns the available pipelines
"""

View File

@ -1,6 +1,7 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from pydantic import BaseModel
from typing import Union, Optional
@ -14,6 +15,10 @@ import os
import requests
import uuid
from config import API_KEY, PIPELINES_DIR
SESSION_SECRET = os.getenv("SESSION_SECRET", " ")
ALGORITHM = "HS256"
@ -62,4 +67,11 @@ def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
) -> Optional[dict]:
token = credentials.credentials
if token != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return token