diff --git a/backend/apps/web/models/auths.py b/backend/apps/web/models/auths.py index b26236ef8..04f6d4bbd 100644 --- a/backend/apps/web/models/auths.py +++ b/backend/apps/web/models/auths.py @@ -24,6 +24,7 @@ class Auth(Model): email = CharField() password = CharField() active = BooleanField() + api_key = CharField(null=True, unique=True) class Meta: database = DB @@ -34,6 +35,7 @@ class AuthModel(BaseModel): email: str password: str active: bool = True + api_key: Optional[str] = None #################### @@ -45,6 +47,8 @@ class Token(BaseModel): token: str token_type: str +class ApiKey(BaseModel): + api_key: Optional[str] = None class UserResponse(BaseModel): id: str @@ -122,6 +126,21 @@ class AuthsTable: except: return None + def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + log.info(f"authenticate_user_by_api_key: {api_key}") + # if no api_key, return None + if not api_key: + return None + try: + auth = Auth.get(Auth.api_key == api_key, Auth.active == True) + if auth: + user = Users.get_user_by_id(auth.id) + return user + else: + return None + except: + return None + def update_user_password_by_id(self, id: str, new_password: str) -> bool: try: query = Auth.update(password=new_password).where(Auth.id == id) @@ -140,6 +159,22 @@ class AuthsTable: except: return False + def update_api_key_by_id(self, id: str, api_key: str) -> str: + try: + query = Auth.update(api_key=api_key).where(Auth.id == id) + result = query.execute() + + return True if result == 1 else False + except: + return False + + def get_api_key_by_id(self, id: str) -> Optional[str]: + try: + auth = Auth.get(Auth.id == id) + return auth.api_key + except: + return None + def delete_auth_by_id(self, id: str) -> bool: try: # Delete User diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index d881ec746..abbdc1c42 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -1,12 +1,8 @@ -from fastapi import Response, Request -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import List, Union +from fastapi import Request +from fastapi import Depends, HTTPException, status -from fastapi import APIRouter, status +from fastapi import APIRouter from pydantic import BaseModel -import time -import uuid import re from apps.web.models.auths import ( @@ -17,6 +13,7 @@ from apps.web.models.auths import ( UserResponse, SigninResponse, Auths, + ApiKey ) from apps.web.models.users import Users @@ -25,6 +22,7 @@ from utils.utils import ( get_current_user, get_admin_user, create_token, + create_api_key ) from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook @@ -249,3 +247,40 @@ async def update_token_expires_duration( return request.app.state.JWT_EXPIRES_IN else: return request.app.state.JWT_EXPIRES_IN + + +############################ +# API Key +############################ + + +# create api key +@router.post("/api_key", response_model=ApiKey) +async def create_api_key_(user=Depends(get_current_user)): + api_key = create_api_key() + success = Auths.update_api_key_by_id(user.id, api_key) + if success: + return { + "api_key": api_key, + } + else: + raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR) + + +# delete api key +@router.delete("/api_key", response_model=bool) +async def delete_api_key(user=Depends(get_current_user)): + success = Auths.update_api_key_by_id(user.id, None) + return success + + +# get api key +@router.get("/api_key", response_model=ApiKey) +async def get_api_key(user=Depends(get_current_user)): + api_key = Auths.get_api_key_by_id(user.id, None) + if api_key: + return { + "api_key": api_key, + } + else: + raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) diff --git a/backend/constants.py b/backend/constants.py index 42c5c85eb..1adfe353f 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -58,5 +58,6 @@ class ERROR_MESSAGES(str, Enum): RATE_LIMIT_EXCEEDED = "API rate limit exceeded" MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" - OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" + OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" + CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." diff --git a/backend/utils/utils.py b/backend/utils/utils.py index 32724af39..071bb0ef9 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -1,6 +1,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import HTTPException, status, Depends from apps.web.models.users import Users + from pydantic import BaseModel from typing import Union, Optional from constants import ERROR_MESSAGES @@ -8,6 +9,7 @@ from passlib.context import CryptContext from datetime import datetime, timedelta import requests import jwt +import uuid import logging import config @@ -58,6 +60,11 @@ def extract_token_from_auth_header(auth_header: str): return auth_header[len("Bearer ") :] +def create_api_key(): + key = str(uuid.uuid4()).replace("-", "") + return f"sk-{key}" + + def get_http_authorization_cred(auth_header: str): try: scheme, credentials = auth_header.split(" ") @@ -69,6 +76,10 @@ def get_http_authorization_cred(auth_header: str): def get_current_user( auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), ): + # auth by api key + if auth_token.credentials.startswith("sk-"): + return get_current_user_by_api_key(auth_token.credentials) + # auth by jwt token data = decode_token(auth_token.credentials) if data != None and "id" in data: user = Users.get_user_by_id(data["id"]) @@ -84,6 +95,16 @@ def get_current_user( detail=ERROR_MESSAGES.UNAUTHORIZED, ) +def get_current_user_by_api_key(api_key: str): + from apps.web.models.auths import Auths + + user = Auths.authenticate_user_by_api_key(api_key) + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) + return user def get_verified_user(user=Depends(get_current_user)): if user.role not in {"user", "admin"}: