wg-portal/internal/app/auth/auth.go
2025-01-11 18:55:23 +01:00

482 lines
13 KiB
Go

package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"net/url"
"path"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/sirupsen/logrus"
evbus "github.com/vardius/message-bus"
)
type UserManager interface {
GetUser(context.Context, domain.UserIdentifier) (*domain.User, error)
RegisterUser(ctx context.Context, user *domain.User) error
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
}
type Authenticator struct {
cfg *config.Auth
bus evbus.MessageBus
oauthAuthenticators map[string]domain.OauthAuthenticator
ldapAuthenticators map[string]domain.LdapAuthenticator
// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
callbackUrlPrefix string
users UserManager
}
func NewAuthenticator(cfg *config.Auth, extUrl string, bus evbus.MessageBus, users UserManager) (
*Authenticator,
error,
) {
a := &Authenticator{
cfg: cfg,
bus: bus,
users: users,
callbackUrlPrefix: fmt.Sprintf("%s/api/v0", extUrl),
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := a.setupExternalAuthProviders(ctx)
if err != nil {
return nil, err
}
return a, nil
}
func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
extUrl, err := url.Parse(a.callbackUrlPrefix)
if err != nil {
return fmt.Errorf("failed to parse external url: %w", err)
}
a.oauthAuthenticators = make(map[string]domain.OauthAuthenticator, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
a.ldapAuthenticators = make(map[string]domain.LdapAuthenticator, len(a.cfg.Ldap))
for i := range a.cfg.OpenIDConnect { // OIDC
providerCfg := &a.cfg.OpenIDConnect[i]
providerId := strings.ToLower(providerCfg.ProviderName)
if _, exists := a.oauthAuthenticators[providerId]; exists {
return fmt.Errorf("auth provider with name %s is already registerd", providerId)
}
redirectUrl := *extUrl
redirectUrl.Path = path.Join(redirectUrl.Path, "/auth/login/", providerId, "/callback")
provider, err := newOidcAuthenticator(ctx, redirectUrl.String(), providerCfg)
if err != nil {
return fmt.Errorf("failed to setup oidc authentication provider %s: %w", providerCfg.ProviderName, err)
}
a.oauthAuthenticators[providerId] = provider
}
for i := range a.cfg.OAuth { // PLAIN OAUTH
providerCfg := &a.cfg.OAuth[i]
providerId := strings.ToLower(providerCfg.ProviderName)
if _, exists := a.oauthAuthenticators[providerId]; exists {
return fmt.Errorf("auth provider with name %s is already registerd", providerId)
}
redirectUrl := *extUrl
redirectUrl.Path = path.Join(redirectUrl.Path, "/auth/login/", providerId, "/callback")
provider, err := newPlainOauthAuthenticator(ctx, redirectUrl.String(), providerCfg)
if err != nil {
return fmt.Errorf("failed to setup oauth authentication provider %s: %w", providerId, err)
}
a.oauthAuthenticators[providerId] = provider
}
for i := range a.cfg.Ldap { // LDAP
providerCfg := &a.cfg.Ldap[i]
providerId := strings.ToLower(providerCfg.URL)
if _, exists := a.ldapAuthenticators[providerId]; exists {
return fmt.Errorf("auth provider with name %s is already registerd", providerId)
}
provider, err := newLdapAuthenticator(ctx, providerCfg)
if err != nil {
return fmt.Errorf("failed to setup ldap authentication provider %s: %w", providerId, err)
}
a.ldapAuthenticators[providerId] = provider
}
return nil
}
func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo {
authProviders := make([]domain.LoginProviderInfo, 0, len(a.cfg.OAuth)+len(a.cfg.OpenIDConnect))
for _, provider := range a.cfg.OpenIDConnect {
providerId := strings.ToLower(provider.ProviderName)
providerName := provider.DisplayName
if providerName == "" {
providerName = provider.ProviderName
}
authProviders = append(authProviders, domain.LoginProviderInfo{
Identifier: providerId,
Name: providerName,
ProviderUrl: fmt.Sprintf("/auth/login/%s/init", providerId),
CallbackUrl: fmt.Sprintf("/auth/login/%s/callback", providerId),
})
}
for _, provider := range a.cfg.OAuth {
providerId := strings.ToLower(provider.ProviderName)
providerName := provider.DisplayName
if providerName == "" {
providerName = provider.ProviderName
}
authProviders = append(authProviders, domain.LoginProviderInfo{
Identifier: providerId,
Name: providerName,
ProviderUrl: fmt.Sprintf("/auth/login/%s/init", providerId),
CallbackUrl: fmt.Sprintf("/auth/login/%s/callback", providerId),
})
}
return authProviders
}
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
user, err := a.users.GetUser(ctx, id)
if err != nil {
return false
}
if user.IsDisabled() {
return false
}
if user.IsLocked() {
return false
}
return true
}
// region password authentication
func (a *Authenticator) PlainLogin(ctx context.Context, username, password string) (*domain.User, error) {
// Validate form input
username = strings.TrimSpace(username)
password = strings.TrimSpace(password)
if username == "" || password == "" {
return nil, fmt.Errorf("missing username or password")
}
user, err := a.passwordAuthentication(ctx, domain.UserIdentifier(username), password)
if err != nil {
return nil, fmt.Errorf("login failed: %w", err)
}
a.bus.Publish(app.TopicAuthLogin, user.Identifier)
return user, nil
}
func (a *Authenticator) passwordAuthentication(
ctx context.Context,
identifier domain.UserIdentifier,
password string,
) (*domain.User, error) {
ctx = domain.SetUserInfo(ctx,
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
var ldapUserInfo *domain.AuthenticatorUserInfo
var ldapProvider domain.LdapAuthenticator
var userInDatabase = false
var userSource domain.UserSource
existingUser, err := a.users.GetUser(ctx, identifier)
if err == nil {
userInDatabase = true
userSource = existingUser.Source
}
if userInDatabase && (existingUser.IsLocked() || existingUser.IsDisabled()) {
return nil, errors.New("user is locked")
}
if !userInDatabase || userSource == domain.UserSourceLdap {
// search user in ldap if registration is enabled
for _, ldapAuth := range a.ldapAuthenticators {
if !userInDatabase && !ldapAuth.RegistrationEnabled() {
continue
}
rawUserInfo, err := ldapAuth.GetUserInfo(context.Background(), identifier)
if err != nil {
if !errors.Is(err, domain.ErrNotFound) {
logrus.Warnf("failed to fetch ldap user info for %s: %v", identifier, err)
}
continue // user not found / other ldap error
}
ldapUserInfo, err = ldapAuth.ParseUserInfo(rawUserInfo)
if err != nil {
continue
}
// ldap user found
userSource = domain.UserSourceLdap
ldapProvider = ldapAuth
break
}
}
if userSource == "" {
return nil, errors.New("user not found")
}
if userSource == domain.UserSourceLdap && ldapProvider == nil {
return nil, errors.New("ldap provider not found")
}
switch userSource {
case domain.UserSourceDatabase:
err = existingUser.CheckPassword(password)
case domain.UserSourceLdap:
err = ldapProvider.PlaintextAuthentication(identifier, password)
default:
err = errors.New("no authentication backend available")
}
if err != nil {
return nil, fmt.Errorf("failed to authenticate: %w", err)
}
if !userInDatabase {
user, err := a.processUserInfo(ctx, ldapUserInfo, domain.UserSourceLdap, ldapProvider.GetName(),
ldapProvider.RegistrationEnabled())
if err != nil {
return nil, fmt.Errorf("unable to process user information: %w", err)
}
return user, nil
} else {
return existingUser, nil
}
}
// endregion password authentication
// region oauth authentication
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
authCodeUrl, state, nonce string,
err error,
) {
oauthProvider, ok := a.oauthAuthenticators[providerId]
if !ok {
return "", "", "", fmt.Errorf("missing oauth provider %s", providerId)
}
// Prepare authentication flow, set state cookies
state, err = a.randString(16)
if err != nil {
return "", "", "", fmt.Errorf("failed to generate state: %w", err)
}
switch oauthProvider.GetType() {
case domain.AuthenticatorTypeOAuth:
authCodeUrl = oauthProvider.AuthCodeURL(state)
case domain.AuthenticatorTypeOidc:
nonce, err = a.randString(16)
if err != nil {
return "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
}
authCodeUrl = oauthProvider.AuthCodeURL(state, oidc.Nonce(nonce))
}
return
}
func (a *Authenticator) randString(nByte int) (string, error) {
b := make([]byte, nByte)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) {
oauthProvider, ok := a.oauthAuthenticators[providerId]
if !ok {
return nil, fmt.Errorf("missing oauth provider %s", providerId)
}
oauth2Token, err := oauthProvider.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("unable to exchange code: %w", err)
}
rawUserInfo, err := oauthProvider.GetUserInfo(ctx, oauth2Token, nonce)
if err != nil {
return nil, fmt.Errorf("unable to fetch user information: %w", err)
}
userInfo, err := oauthProvider.ParseUserInfo(rawUserInfo)
if err != nil {
return nil, fmt.Errorf("unable to parse user information: %w", err)
}
ctx = domain.SetUserInfo(ctx,
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(),
oauthProvider.RegistrationEnabled())
if err != nil {
return nil, fmt.Errorf("unable to process user information: %w", err)
}
if user.IsLocked() || user.IsDisabled() {
return nil, errors.New("user is locked")
}
a.bus.Publish(app.TopicAuthLogin, user.Identifier)
return user, nil
}
func (a *Authenticator) processUserInfo(
ctx context.Context,
userInfo *domain.AuthenticatorUserInfo,
source domain.UserSource,
provider string,
withReg bool,
) (*domain.User, error) {
// Search user in backend
user, err := a.users.GetUser(ctx, userInfo.Identifier)
switch {
case err != nil && withReg:
user, err = a.registerNewUser(ctx, userInfo, source, provider)
if err != nil {
return nil, fmt.Errorf("failed to register user: %w", err)
}
case err != nil:
return nil, fmt.Errorf("registration disabled, cannot create missing user: %w", err)
default:
err = a.updateExternalUser(ctx, user, userInfo)
if err != nil {
return nil, fmt.Errorf("failed to update user: %w", err)
}
}
return user, nil
}
func (a *Authenticator) registerNewUser(
ctx context.Context,
userInfo *domain.AuthenticatorUserInfo,
source domain.UserSource,
provider string,
) (*domain.User, error) {
// convert user info to domain.User
user := &domain.User{
Identifier: userInfo.Identifier,
Email: userInfo.Email,
Source: source,
ProviderName: provider,
IsAdmin: userInfo.IsAdmin,
Firstname: userInfo.Firstname,
Lastname: userInfo.Lastname,
Phone: userInfo.Phone,
Department: userInfo.Department,
}
err := a.users.RegisterUser(ctx, user)
if err != nil {
return nil, fmt.Errorf("failed to register new user: %w", err)
}
logrus.Tracef("registered user %s from external authentication provider, admin user: %t",
user.Identifier, user.IsAdmin)
return user, nil
}
func (a *Authenticator) getAuthenticatorConfig(id string) (interface{}, error) {
for i := range a.cfg.OpenIDConnect {
if a.cfg.OpenIDConnect[i].ProviderName == id {
return a.cfg.OpenIDConnect[i], nil
}
}
for i := range a.cfg.OAuth {
if a.cfg.OAuth[i].ProviderName == id {
return a.cfg.OAuth[i], nil
}
}
return nil, fmt.Errorf("no configuration for Authenticator id %s", id)
}
func (a *Authenticator) updateExternalUser(
ctx context.Context,
existingUser *domain.User,
userInfo *domain.AuthenticatorUserInfo,
) error {
if existingUser.IsLocked() || existingUser.IsDisabled() {
return nil // user is locked or disabled, do not update
}
isChanged := false
if existingUser.Email != userInfo.Email {
existingUser.Email = userInfo.Email
isChanged = true
}
if existingUser.Firstname != userInfo.Firstname {
existingUser.Firstname = userInfo.Firstname
isChanged = true
}
if existingUser.Lastname != userInfo.Lastname {
existingUser.Lastname = userInfo.Lastname
isChanged = true
}
if existingUser.Phone != userInfo.Phone {
existingUser.Phone = userInfo.Phone
isChanged = true
}
if existingUser.Department != userInfo.Department {
existingUser.Department = userInfo.Department
isChanged = true
}
if existingUser.IsAdmin != userInfo.IsAdmin {
existingUser.IsAdmin = userInfo.IsAdmin
isChanged = true
}
if !isChanged {
return nil // nothing to update
}
_, err := a.users.UpdateUser(ctx, existingUser)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
logrus.Tracef("updated user %s with data from external authentication provider, admin user: %t",
existingUser.Identifier, existingUser.IsAdmin)
return nil
}
// endregion oauth authentication