This commit is contained in:
Timothy J. Baek 2024-05-28 22:46:23 -07:00
parent a0ba05ce03
commit 21bab00496
4 changed files with 101 additions and 91 deletions

23
main.py
View File

@ -8,22 +8,22 @@ from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator
import time from utils.auth import bearer_security, get_current_user
import json from utils.main import get_last_user_message, stream_message_template
import uuid
from utils import get_last_user_message, stream_message_template
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from schemas import FilterForm, OpenAIChatCompletionForm from schemas import FilterForm, OpenAIChatCompletionForm
import sys import sys
import os import os
import importlib.util import importlib.util
import logging import logging
import time
from contextlib import asynccontextmanager import json
from concurrent.futures import ThreadPoolExecutor import uuid
#################################### ####################################
@ -460,5 +460,8 @@ async def get_status():
@app.post("/v1/restart") @app.post("/v1/restart")
@app.post("/restart") @app.post("/restart")
def restart_server(): def restart_server(user: str = Depends(get_current_user)):
sys.exit(42) # Use a distinctive code to indicate a restart request
print(user)
return True

View File

@ -1,29 +0,0 @@
import uuid
import time
from typing import List
from schemas import OpenAIChatMessage
def stream_message_template(model: str, message: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_last_user_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message.role == "user":
return message.content
return None

65
utils/auth.py Normal file
View File

@ -0,0 +1,65 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from pydantic import BaseModel
from typing import Union, Optional
from passlib.context import CryptContext
from datetime import datetime, timedelta
import jwt
import logging
import os
import requests
import uuid
SESSION_SECRET = os.getenv("SESSION_SECRET", " ")
ALGORITHM = "HS256"
##############
# Auth Utils
##############
bearer_security = HTTPBearer()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
return (
pwd_context.verify(plain_password, hashed_password) if hashed_password else None
)
def get_password_hash(password):
return pwd_context.hash(password)
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
payload = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
payload.update({"exp": expire})
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
return encoded_jwt
def decode_token(token: str) -> Optional[dict]:
try:
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
return decoded
except Exception as e:
return None
def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :]
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
) -> Optional[dict]:
token = credentials.credentials
return token

View File

@ -1,58 +1,29 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from pydantic import BaseModel
from typing import Union, Optional
from passlib.context import CryptContext
from datetime import datetime, timedelta
import jwt
import logging
import os
import requests
import uuid import uuid
import time
SESSION_SECRET = os.getenv("SESSION_SECRET", " ") from typing import List
ALGORITHM = "HS256" from schemas import OpenAIChatMessage
##############
# Auth Utils
##############
bearer_security = HTTPBearer()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password): def stream_message_template(model: str, message: str):
return ( return {
pwd_context.verify(plain_password, hashed_password) if hashed_password else None "id": f"{model}-{str(uuid.uuid4())}",
) "object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_password_hash(password): def get_last_user_message(messages: List[dict]) -> str:
return pwd_context.hash(password) for message in reversed(messages):
if message.role == "user":
return message.content
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
payload = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
payload.update({"exp": expire})
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
return encoded_jwt
def decode_token(token: str) -> Optional[dict]:
try:
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
return decoded
except Exception as e:
return None return None
def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :]