From a0ba05ce038b71003063835fbf8d6952fdbe9e89 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 28 May 2024 22:37:54 -0700 Subject: [PATCH] feat: cli --- cli.py | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 11 ++++++-- utils/main.py | 58 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 cli.py create mode 100644 utils/main.py diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..9b834cb --- /dev/null +++ b/cli.py @@ -0,0 +1,70 @@ +import logging +from pathlib import Path +from typing import Optional + +import typer + + +import subprocess +import time + + +def start_process(app: str, host: str, port: int, reload: bool = False): + # Start the FastAPI application + command = [ + "uvicorn", + app, + "--host", + host, + "--port", + str(port), + "--forwarded-allow-ips", + "*", + ] + + if reload: + command.append("--reload") + + process = subprocess.Popen(command) + return process + + +main = typer.Typer() + + +@main.command() +def serve( + host: str = "0.0.0.0", + port: int = 9099, +): + while True: + process = start_process("main:app", host, port, reload=False) + process.wait() + + if process.returncode == 42: + print("Restarting due to restart request") + time.sleep(2) # optional delay to prevent tight restart loops + else: + print("Normal exit, stopping the manager") + break + + +@main.command() +def dev( + host: str = "0.0.0.0", + port: int = 9099, +): + while True: + process = start_process("main:app", host, port, reload=True) + process.wait() + + if process.returncode == 42: + print("Restarting due to restart request") + time.sleep(2) # optional delay to prevent tight restart loops + else: + print("Normal exit, stopping the manager") + break + + +if __name__ == "__main__": + main() diff --git a/main.py b/main.py index 8da299a..dafbbe7 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,8 @@ import uuid from utils import get_last_user_message, stream_message_template from schemas import FilterForm, OpenAIChatCompletionForm + +import sys import os import importlib.util @@ -24,8 +26,6 @@ from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor -import os - #################################### # Load .env file #################################### @@ -452,6 +452,13 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): return await run_in_threadpool(job) +@app.get("/v1") @app.get("/") async def get_status(): return {"status": True} + + +@app.post("/v1/restart") +@app.post("/restart") +def restart_server(): + sys.exit(42) # Use a distinctive code to indicate a restart request diff --git a/utils/main.py b/utils/main.py new file mode 100644 index 0000000..9edd330 --- /dev/null +++ b/utils/main.py @@ -0,0 +1,58 @@ +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi import HTTPException, status, Depends + +from pydantic import BaseModel +from typing import Union, Optional + + +from passlib.context import CryptContext +from datetime import datetime, timedelta +import jwt +import logging +import os + +import requests +import uuid + +SESSION_SECRET = os.getenv("SESSION_SECRET", " ") +ALGORITHM = "HS256" + +############## +# Auth Utils +############## + +bearer_security = HTTPBearer() +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def verify_password(plain_password, hashed_password): + return ( + pwd_context.verify(plain_password, hashed_password) if hashed_password else None + ) + + +def get_password_hash(password): + return pwd_context.hash(password) + + +def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str: + payload = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + payload.update({"exp": expire}) + + encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM) + return encoded_jwt + + +def decode_token(token: str) -> Optional[dict]: + try: + decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM]) + return decoded + except Exception as e: + return None + + +def extract_token_from_auth_header(auth_header: str): + return auth_header[len("Bearer ") :]