feat: cli

This commit is contained in:
Timothy J. Baek 2024-05-28 22:37:54 -07:00
parent dc744bf0e0
commit a0ba05ce03
3 changed files with 137 additions and 2 deletions

70
cli.py Normal file
View File

@ -0,0 +1,70 @@
import logging
from pathlib import Path
from typing import Optional
import typer
import subprocess
import time
def start_process(app: str, host: str, port: int, reload: bool = False):
# Start the FastAPI application
command = [
"uvicorn",
app,
"--host",
host,
"--port",
str(port),
"--forwarded-allow-ips",
"*",
]
if reload:
command.append("--reload")
process = subprocess.Popen(command)
return process
main = typer.Typer()
@main.command()
def serve(
host: str = "0.0.0.0",
port: int = 9099,
):
while True:
process = start_process("main:app", host, port, reload=False)
process.wait()
if process.returncode == 42:
print("Restarting due to restart request")
time.sleep(2) # optional delay to prevent tight restart loops
else:
print("Normal exit, stopping the manager")
break
@main.command()
def dev(
host: str = "0.0.0.0",
port: int = 9099,
):
while True:
process = start_process("main:app", host, port, reload=True)
process.wait()
if process.returncode == 42:
print("Restarting due to restart request")
time.sleep(2) # optional delay to prevent tight restart loops
else:
print("Normal exit, stopping the manager")
break
if __name__ == "__main__":
main()

11
main.py
View File

@ -15,6 +15,8 @@ import uuid
from utils import get_last_user_message, stream_message_template from utils import get_last_user_message, stream_message_template
from schemas import FilterForm, OpenAIChatCompletionForm from schemas import FilterForm, OpenAIChatCompletionForm
import sys
import os import os
import importlib.util import importlib.util
@ -24,8 +26,6 @@ from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import os
#################################### ####################################
# Load .env file # Load .env file
#################################### ####################################
@ -452,6 +452,13 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
return await run_in_threadpool(job) return await run_in_threadpool(job)
@app.get("/v1")
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return {"status": True} return {"status": True}
@app.post("/v1/restart")
@app.post("/restart")
def restart_server():
sys.exit(42) # Use a distinctive code to indicate a restart request

58
utils/main.py Normal file
View File

@ -0,0 +1,58 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from pydantic import BaseModel
from typing import Union, Optional
from passlib.context import CryptContext
from datetime import datetime, timedelta
import jwt
import logging
import os
import requests
import uuid
SESSION_SECRET = os.getenv("SESSION_SECRET", " ")
ALGORITHM = "HS256"
##############
# Auth Utils
##############
bearer_security = HTTPBearer()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
return (
pwd_context.verify(plain_password, hashed_password) if hashed_password else None
)
def get_password_hash(password):
return pwd_context.hash(password)
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
payload = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
payload.update({"exp": expire})
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
return encoded_jwt
def decode_token(token: str) -> Optional[dict]:
try:
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
return decoded
except Exception as e:
return None
def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :]