From 21bab00496979e49b73174b798faaa79acda5ae9 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 28 May 2024 22:46:23 -0700 Subject: [PATCH] refac --- main.py | 23 +++++++++------- utils.py | 29 -------------------- utils/auth.py | 65 ++++++++++++++++++++++++++++++++++++++++++++ utils/main.py | 75 ++++++++++++++++----------------------------------- 4 files changed, 101 insertions(+), 91 deletions(-) delete mode 100644 utils.py create mode 100644 utils/auth.py diff --git a/main.py b/main.py index dafbbe7..bd79f31 100644 --- a/main.py +++ b/main.py @@ -8,22 +8,22 @@ from pydantic import BaseModel, ConfigDict from typing import List, Union, Generator, Iterator -import time -import json -import uuid +from utils.auth import bearer_security, get_current_user +from utils.main import get_last_user_message, stream_message_template -from utils import get_last_user_message, stream_message_template + +from contextlib import asynccontextmanager +from concurrent.futures import ThreadPoolExecutor from schemas import FilterForm, OpenAIChatCompletionForm import sys import os import importlib.util - import logging - -from contextlib import asynccontextmanager -from concurrent.futures import ThreadPoolExecutor +import time +import json +import uuid #################################### @@ -460,5 +460,8 @@ async def get_status(): @app.post("/v1/restart") @app.post("/restart") -def restart_server(): - sys.exit(42) # Use a distinctive code to indicate a restart request +def restart_server(user: str = Depends(get_current_user)): + + print(user) + + return True diff --git a/utils.py b/utils.py deleted file mode 100644 index a50cddf..0000000 --- a/utils.py +++ /dev/null @@ -1,29 +0,0 @@ -import uuid -import time - -from typing import List -from schemas import OpenAIChatMessage - - -def stream_message_template(model: str, message: str): - return { - "id": f"{model}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": {"content": message}, - "logprobs": None, - "finish_reason": None, - } - ], - } - - -def get_last_user_message(messages: List[dict]) -> str: - for message in reversed(messages): - if message.role == "user": - return message.content - return None diff --git a/utils/auth.py b/utils/auth.py new file mode 100644 index 0000000..df03ad8 --- /dev/null +++ b/utils/auth.py @@ -0,0 +1,65 @@ +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi import HTTPException, status, Depends + +from pydantic import BaseModel +from typing import Union, Optional + + +from passlib.context import CryptContext +from datetime import datetime, timedelta +import jwt +import logging +import os + +import requests +import uuid + +SESSION_SECRET = os.getenv("SESSION_SECRET", " ") +ALGORITHM = "HS256" + +############## +# Auth Utils +############## + +bearer_security = HTTPBearer() +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def verify_password(plain_password, hashed_password): + return ( + pwd_context.verify(plain_password, hashed_password) if hashed_password else None + ) + + +def get_password_hash(password): + return pwd_context.hash(password) + + +def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str: + payload = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + payload.update({"exp": expire}) + + encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM) + return encoded_jwt + + +def decode_token(token: str) -> Optional[dict]: + try: + decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM]) + return decoded + except Exception as e: + return None + + +def extract_token_from_auth_header(auth_header: str): + return auth_header[len("Bearer ") :] + + +def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(bearer_security), +) -> Optional[dict]: + token = credentials.credentials + return token diff --git a/utils/main.py b/utils/main.py index 9edd330..a50cddf 100644 --- a/utils/main.py +++ b/utils/main.py @@ -1,58 +1,29 @@ -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from fastapi import HTTPException, status, Depends - -from pydantic import BaseModel -from typing import Union, Optional - - -from passlib.context import CryptContext -from datetime import datetime, timedelta -import jwt -import logging -import os - -import requests import uuid +import time -SESSION_SECRET = os.getenv("SESSION_SECRET", " ") -ALGORITHM = "HS256" - -############## -# Auth Utils -############## - -bearer_security = HTTPBearer() -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +from typing import List +from schemas import OpenAIChatMessage -def verify_password(plain_password, hashed_password): - return ( - pwd_context.verify(plain_password, hashed_password) if hashed_password else None - ) +def stream_message_template(model: str, message: str): + return { + "id": f"{model}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": message}, + "logprobs": None, + "finish_reason": None, + } + ], + } -def get_password_hash(password): - return pwd_context.hash(password) - - -def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str: - payload = data.copy() - - if expires_delta: - expire = datetime.utcnow() + expires_delta - payload.update({"exp": expire}) - - encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM) - return encoded_jwt - - -def decode_token(token: str) -> Optional[dict]: - try: - decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM]) - return decoded - except Exception as e: - return None - - -def extract_token_from_auth_header(auth_header: str): - return auth_header[len("Bearer ") :] +def get_last_user_message(messages: List[dict]) -> str: + for message in reversed(messages): + if message.role == "user": + return message.content + return None