diff --git a/backend/apps/web/internal/migrations/003_add_auth_api_key.py b/backend/apps/web/internal/migrations/003_add_auth_api_key.py new file mode 100644 index 000000000..07144f3ac --- /dev/null +++ b/backend/apps/web/internal/migrations/003_add_auth_api_key.py @@ -0,0 +1,48 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", api_key=pw.CharField(max_length=255, null=True, unique=True) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "api_key") diff --git a/backend/apps/web/models/auths.py b/backend/apps/web/models/auths.py index 365958555..069865036 100644 --- a/backend/apps/web/models/auths.py +++ b/backend/apps/web/models/auths.py @@ -47,6 +47,10 @@ class Token(BaseModel): token_type: str +class ApiKey(BaseModel): + api_key: Optional[str] = None + + class UserResponse(BaseModel): id: str email: str @@ -123,6 +127,18 @@ class AuthsTable: except: return None + def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + log.info(f"authenticate_user_by_api_key: {api_key}") + # if no api_key, return None + if not api_key: + return None + + try: + user = Users.get_user_by_api_key(api_key) + return user if user else None + except: + return False + def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") try: diff --git a/backend/apps/web/models/users.py b/backend/apps/web/models/users.py index 255c701df..a01e595e5 100644 --- a/backend/apps/web/models/users.py +++ b/backend/apps/web/models/users.py @@ -20,6 +20,7 @@ class User(Model): role = CharField() profile_image_url = CharField() timestamp = DateField() + api_key = CharField(null=True, unique=True) class Meta: database = DB @@ -32,6 +33,7 @@ class UserModel(BaseModel): role: str = "pending" profile_image_url: str = "/user.png" timestamp: int # timestamp in epoch + api_key: Optional[str] = None #################### @@ -82,6 +84,13 @@ class UsersTable: except: return None + def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + try: + user = User.get(User.api_key == api_key) + return UserModel(**model_to_dict(user)) + except: + return None + def get_user_by_email(self, email: str) -> Optional[UserModel]: try: user = User.get(User.email == email) @@ -149,5 +158,21 @@ class UsersTable: except: return False + def update_user_api_key_by_id(self, id: str, api_key: str) -> str: + try: + query = User.update(api_key=api_key).where(User.id == id) + result = query.execute() + + return True if result == 1 else False + except: + return False + + def get_user_api_key_by_id(self, id: str) -> Optional[str]: + try: + user = User.get(User.id == id) + return user.api_key + except: + return None + Users = UsersTable(DB) diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index e938a58f5..293cb55b8 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -1,13 +1,10 @@ -from fastapi import Response, Request -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import List, Union +from fastapi import Request +from fastapi import Depends, HTTPException, status -from fastapi import APIRouter, status +from fastapi import APIRouter from pydantic import BaseModel -import time -import uuid import re +import uuid from apps.web.models.auths import ( SigninForm, @@ -17,6 +14,7 @@ from apps.web.models.auths import ( UserResponse, SigninResponse, Auths, + ApiKey, ) from apps.web.models.users import Users @@ -25,6 +23,7 @@ from utils.utils import ( get_current_user, get_admin_user, create_token, + create_api_key, ) from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook @@ -267,3 +266,40 @@ async def update_token_expires_duration( return request.app.state.JWT_EXPIRES_IN else: return request.app.state.JWT_EXPIRES_IN + + +############################ +# API Key +############################ + + +# create api key +@router.post("/api_key", response_model=ApiKey) +async def create_api_key_(user=Depends(get_current_user)): + api_key = create_api_key() + success = Users.update_user_api_key_by_id(user.id, api_key) + if success: + return { + "api_key": api_key, + } + else: + raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR) + + +# delete api key +@router.delete("/api_key", response_model=bool) +async def delete_api_key(user=Depends(get_current_user)): + success = Users.update_user_api_key_by_id(user.id, None) + return success + + +# get api key +@router.get("/api_key", response_model=ApiKey) +async def get_api_key(user=Depends(get_current_user)): + api_key = Users.get_user_api_key_by_id(user.id) + if api_key: + return { + "api_key": api_key, + } + else: + raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) diff --git a/backend/constants.py b/backend/constants.py index f8daf338d..da1ee0b3f 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -60,7 +60,8 @@ class ERROR_MESSAGES(str, Enum): RATE_LIMIT_EXCEEDED = "API rate limit exceeded" MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" - OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" + OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" + CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." diff --git a/backend/utils/utils.py b/backend/utils/utils.py index 32724af39..49e15789f 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -1,6 +1,8 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import HTTPException, status, Depends + from apps.web.models.users import Users + from pydantic import BaseModel from typing import Union, Optional from constants import ERROR_MESSAGES @@ -8,6 +10,7 @@ from passlib.context import CryptContext from datetime import datetime, timedelta import requests import jwt +import uuid import logging import config @@ -58,6 +61,11 @@ def extract_token_from_auth_header(auth_header: str): return auth_header[len("Bearer ") :] +def create_api_key(): + key = str(uuid.uuid4()).replace("-", "") + return f"sk-{key}" + + def get_http_authorization_cred(auth_header: str): try: scheme, credentials = auth_header.split(" ") @@ -69,6 +77,10 @@ def get_http_authorization_cred(auth_header: str): def get_current_user( auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), ): + # auth by api key + if auth_token.credentials.startswith("sk-"): + return get_current_user_by_api_key(auth_token.credentials) + # auth by jwt token data = decode_token(auth_token.credentials) if data != None and "id" in data: user = Users.get_user_by_id(data["id"]) @@ -85,6 +97,16 @@ def get_current_user( ) +def get_current_user_by_api_key(api_key: str): + user = Users.get_user_by_api_key(api_key) + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) + return user + + def get_verified_user(user=Depends(get_current_user)): if user.role not in {"user", "admin"}: raise HTTPException( diff --git a/src/lib/apis/auths/index.ts b/src/lib/apis/auths/index.ts index 169998726..548a9418d 100644 --- a/src/lib/apis/auths/index.ts +++ b/src/lib/apis/auths/index.ts @@ -318,3 +318,78 @@ export const updateJWTExpiresDuration = async (token: string, duration: string) return res; }; + +export const createAPIKey = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + if (error) { + throw error; + } + return res.api_key; +}; + +export const getAPIKey = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + if (error) { + throw error; + } + return res.api_key; +}; + +export const deleteAPIKey = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, { + method: 'DELETE', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + if (error) { + throw error; + } + return res; +}; diff --git a/src/lib/components/chat/Settings/Account.svelte b/src/lib/components/chat/Settings/Account.svelte index 3a2259a79..e0239fb0b 100644 --- a/src/lib/components/chat/Settings/Account.svelte +++ b/src/lib/components/chat/Settings/Account.svelte @@ -3,11 +3,13 @@ import { onMount, getContext } from 'svelte'; import { user } from '$lib/stores'; - import { updateUserProfile } from '$lib/apis/auths'; + import { updateUserProfile, createAPIKey, getAPIKey } from '$lib/apis/auths'; import UpdatePassword from './Account/UpdatePassword.svelte'; import { getGravatarUrl } from '$lib/apis/utils'; import { copyToClipboard } from '$lib/utils'; + import Plus from '$lib/components/icons/Plus.svelte'; + import Tooltip from '$lib/components/common/Tooltip.svelte'; const i18n = getContext('i18n'); @@ -15,8 +17,14 @@ let profileImageUrl = ''; let name = ''; + let showJWTToken = false; let JWTTokenCopied = false; + + let APIKey = ''; + let showAPIKey = false; + let APIKeyCopied = false; + let profileImageInputElement: HTMLInputElement; const submitHandler = async () => { @@ -33,9 +41,23 @@ return false; }; - onMount(() => { + const createAPIKeyHandler = async () => { + APIKey = await createAPIKey(localStorage.token); + if (APIKey) { + toast.success($i18n.t('API Key created.')); + } else { + toast.error($i18n.t('Failed to create API Key.')); + } + }; + + onMount(async () => { name = $user.name; profileImageUrl = $user.profile_image_url; + + APIKey = await getAPIKey(localStorage.token).catch((error) => { + console.log(error); + return ''; + }); }); @@ -170,41 +192,83 @@