diff --git a/internal/app/api/v0/handlers/endpoint_authentication.go b/internal/app/api/v0/handlers/endpoint_authentication.go index 2b15b0d..054d4d1 100644 --- a/internal/app/api/v0/handlers/endpoint_authentication.go +++ b/internal/app/api/v0/handlers/endpoint_authentication.go @@ -44,7 +44,8 @@ func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti // @Router /auth/providers [get] func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc { return func(c *gin.Context) { - providers := e.app.Authenticator.GetExternalLoginProviders(c.Request.Context()) + ctx := domain.SetUserInfoFromGin(c) + providers := e.app.Authenticator.GetExternalLoginProviders(ctx) c.JSON(http.StatusOK, model.NewLoginProviderInfos(providers)) } @@ -69,7 +70,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc { var email *string if currentSession.LoggedIn { - uid := string(currentSession.UserIdentifier) + uid := currentSession.UserIdentifier f := currentSession.Firstname l := currentSession.Lastname e := currentSession.Email @@ -134,7 +135,8 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc { return } - authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(c.Request.Context(), provider) + ctx := domain.SetUserInfoFromGin(c) + authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(ctx, provider) if err != nil { if autoRedirect { redirectToReturn() @@ -292,7 +294,8 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc { return } - user, err := e.app.Authenticator.PlainLogin(c.Request.Context(), loginData.Username, loginData.Password) + ctx := domain.SetUserInfoFromGin(c) + user, err := e.app.Authenticator.PlainLogin(ctx, loginData.Username, loginData.Password) if err != nil { c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "login failed"}) return diff --git a/internal/app/api/v0/handlers/endpoint_interfaces.go b/internal/app/api/v0/handlers/endpoint_interfaces.go index 707f8de..d9bb218 100644 --- a/internal/app/api/v0/handlers/endpoint_interfaces.go +++ b/internal/app/api/v0/handlers/endpoint_interfaces.go @@ -19,7 +19,7 @@ func (e interfaceEndpoint) GetName() string { } func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) { - apiGroup := g.Group("/interface", e.authenticator.LoggedIn()) + apiGroup := g.Group("/interface", e.authenticator.LoggedIn(ScopeAdmin)) apiGroup.GET("/prepare", e.handlePrepareGet()) apiGroup.GET("/all", e.handleAllGet()) @@ -45,7 +45,8 @@ func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *aut // @Router /interface/prepare [get] func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc { return func(c *gin.Context) { - in, err := e.app.PrepareInterface(c.Request.Context()) + ctx := domain.SetUserInfoFromGin(c) + in, err := e.app.PrepareInterface(ctx) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), @@ -68,7 +69,8 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc { // @Router /interface/all [get] func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc { return func(c *gin.Context) { - interfaces, peers, err := e.app.GetAllInterfacesAndPeers(c.Request.Context()) + ctx := domain.SetUserInfoFromGin(c) + interfaces, peers, err := e.app.GetAllInterfacesAndPeers(ctx) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), @@ -92,6 +94,7 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc { // @Router /interface/get/{id} [get] func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc { return func(c *gin.Context) { + ctx := domain.SetUserInfoFromGin(c) id := Base64UrlDecode(c.Param("id")) if id == "" { c.JSON(http.StatusBadRequest, model.Error{ @@ -100,7 +103,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc { return } - iface, peers, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id)) + iface, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id)) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), @@ -124,6 +127,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc { // @Router /interface/config/{id} [get] func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc { return func(c *gin.Context) { + ctx := domain.SetUserInfoFromGin(c) id := Base64UrlDecode(c.Param("id")) if id == "" { c.JSON(http.StatusBadRequest, model.Error{ @@ -132,7 +136,7 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc { return } - config, err := e.app.GetInterfaceConfig(c.Request.Context(), domain.InterfaceIdentifier(id)) + config, err := e.app.GetInterfaceConfig(ctx, domain.InterfaceIdentifier(id)) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), diff --git a/internal/app/api/v0/handlers/endpoint_peers.go b/internal/app/api/v0/handlers/endpoint_peers.go index 50ca0cb..80593de 100644 --- a/internal/app/api/v0/handlers/endpoint_peers.go +++ b/internal/app/api/v0/handlers/endpoint_peers.go @@ -21,11 +21,11 @@ func (e peerEndpoint) GetName() string { func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) { apiGroup := g.Group("/peer", e.authenticator.LoggedIn()) - apiGroup.GET("/iface/:iface/all", e.handleAllGet()) - apiGroup.GET("/iface/:iface/stats", e.handleStatsGet()) - apiGroup.GET("/iface/:iface/prepare", e.handlePrepareGet()) - apiGroup.POST("/iface/:iface/new", e.handleCreatePost()) - apiGroup.POST("/iface/:iface/multiplenew", e.handleCreateMultiplePost()) + apiGroup.GET("/iface/:iface/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet()) + apiGroup.GET("/iface/:iface/stats", e.authenticator.LoggedIn(ScopeAdmin), e.handleStatsGet()) + apiGroup.GET("/iface/:iface/prepare", e.authenticator.LoggedIn(ScopeAdmin), e.handlePrepareGet()) + apiGroup.POST("/iface/:iface/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost()) + apiGroup.POST("/iface/:iface/multiplenew", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreateMultiplePost()) apiGroup.GET("/config-qr/:id", e.handleQrCodeGet()) apiGroup.POST("/config-mail", e.handleEmailPost()) apiGroup.GET("/config/:id", e.handleConfigGet()) @@ -298,6 +298,8 @@ func (e peerEndpoint) handleDelete() gin.HandlerFunc { // @Router /peer/config/{id} [get] func (e peerEndpoint) handleConfigGet() gin.HandlerFunc { return func(c *gin.Context) { + ctx := domain.SetUserInfoFromGin(c) + id := Base64UrlDecode(c.Param("id")) if id == "" { c.JSON(http.StatusBadRequest, model.Error{ @@ -306,7 +308,7 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc { return } - config, err := e.app.GetPeerConfig(c.Request.Context(), domain.PeerIdentifier(id)) + config, err := e.app.GetPeerConfig(ctx, domain.PeerIdentifier(id)) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), @@ -339,6 +341,7 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc { // @Router /peer/config-qr/{id} [get] func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc { return func(c *gin.Context) { + ctx := domain.SetUserInfoFromGin(c) id := Base64UrlDecode(c.Param("id")) if id == "" { c.JSON(http.StatusBadRequest, model.Error{ @@ -347,7 +350,7 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc { return } - config, err := e.app.GetPeerConfigQrCode(c.Request.Context(), domain.PeerIdentifier(id)) + config, err := e.app.GetPeerConfigQrCode(ctx, domain.PeerIdentifier(id)) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), @@ -392,11 +395,13 @@ func (e peerEndpoint) handleEmailPost() gin.HandlerFunc { return } + ctx := domain.SetUserInfoFromGin(c) + peerIds := make([]domain.PeerIdentifier, len(req.Identifiers)) for i := range req.Identifiers { peerIds[i] = domain.PeerIdentifier(req.Identifiers[i]) } - err = e.app.SendPeerEmail(c.Request.Context(), req.LinkOnly, peerIds...) + err = e.app.SendPeerEmail(ctx, req.LinkOnly, peerIds...) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return diff --git a/internal/app/api/v0/handlers/endpoint_users.go b/internal/app/api/v0/handlers/endpoint_users.go index 30cc084..aac3169 100644 --- a/internal/app/api/v0/handlers/endpoint_users.go +++ b/internal/app/api/v0/handlers/endpoint_users.go @@ -20,13 +20,13 @@ func (e userEndpoint) GetName() string { func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) { apiGroup := g.Group("/user", e.authenticator.LoggedIn()) - apiGroup.GET("/all", e.handleAllGet()) - apiGroup.GET("/:id", e.handleSingleGet()) - apiGroup.PUT("/:id", e.handleUpdatePut()) - apiGroup.DELETE("/:id", e.handleDelete()) - apiGroup.POST("/new", e.handleCreatePost()) - apiGroup.GET("/:id/peers", e.handlePeersGet()) - apiGroup.GET("/:id/stats", e.handleStatsGet()) + apiGroup.GET("/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet()) + apiGroup.GET("/:id", e.authenticator.UserIdMatch("id"), e.handleSingleGet()) + apiGroup.PUT("/:id", e.authenticator.UserIdMatch("id"), e.handleUpdatePut()) + apiGroup.DELETE("/:id", e.authenticator.UserIdMatch("id"), e.handleDelete()) + apiGroup.POST("/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost()) + apiGroup.GET("/:id/peers", e.authenticator.UserIdMatch("id"), e.handlePeersGet()) + apiGroup.GET("/:id/stats", e.authenticator.UserIdMatch("id"), e.handleStatsGet()) } // handleAllGet returns a gorm handler function. diff --git a/internal/app/api/v0/handlers/middleware_authentication.go b/internal/app/api/v0/handlers/middleware_authentication.go index bf034d2..f15446b 100644 --- a/internal/app/api/v0/handlers/middleware_authentication.go +++ b/internal/app/api/v0/handlers/middleware_authentication.go @@ -58,6 +58,31 @@ func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc { } } +// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted. +func (h authenticationHandler) UserIdMatch(idParameter string) gin.HandlerFunc { + return func(c *gin.Context) { + session := h.Session.GetData(c) + + if session.IsAdmin { + c.Next() // Admins can do everything + return + } + + sessionUserId := domain.UserIdentifier(session.UserIdentifier) + requestUserId := domain.UserIdentifier(Base64UrlDecode(c.Param(idParameter))) + + if sessionUserId != requestUserId { + // Abort the request with the appropriate error code + c.Abort() + c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"}) + return + } + + // Continue down the chain to handler etc + c.Next() + } +} + func UserHasScopes(session SessionData, scopes ...Scope) bool { // No scopes give, so the check should succeed if len(scopes) == 0 { diff --git a/internal/app/auth/auth.go b/internal/app/auth/auth.go index 5a1c76d..89637c6 100644 --- a/internal/app/auth/auth.go +++ b/internal/app/auth/auth.go @@ -150,6 +150,7 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo } func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool { + ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context user, err := a.users.GetUser(ctx, id) if err != nil { return false @@ -187,6 +188,8 @@ func (a *Authenticator) PlainLogin(ctx context.Context, username, password strin } func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier domain.UserIdentifier, password string) (*domain.User, error) { + ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists + var ldapUserInfo *domain.AuthenticatorUserInfo var ldapProvider domain.LdapAuthenticator @@ -315,6 +318,7 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, return nil, fmt.Errorf("unable to parse user information: %w", err) } + ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(), oauthProvider.RegistrationEnabled()) if err != nil { return nil, fmt.Errorf("unable to process user information: %w", err) diff --git a/internal/app/configfile/manager.go b/internal/app/configfile/manager.go index b2037a9..0f69531 100644 --- a/internal/app/configfile/manager.go +++ b/internal/app/configfile/manager.go @@ -109,6 +109,10 @@ func (m Manager) handlePeerInterfaceUpdatedEvent(id domain.InterfaceIdentifier) } func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, err + } + iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id) if err != nil { return nil, fmt.Errorf("failed to fetch interface %s: %w", id, err) @@ -123,6 +127,10 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err) } + if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil { + return nil, err + } + return m.tplHandler.GetPeerConfig(peer) } @@ -132,6 +140,10 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err) } + if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil { + return nil, err + } + cfgData, err := m.tplHandler.GetPeerConfig(peer) if err != nil { return nil, fmt.Errorf("failed to get peer config for %s: %w", id, err) @@ -172,6 +184,10 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi } func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return err + } + if m.fsRepo == nil { return fmt.Errorf("peristing configuration is not supported") } diff --git a/internal/app/mail/manager.go b/internal/app/mail/manager.go index c970791..3f54bd4 100644 --- a/internal/app/mail/manager.go +++ b/internal/app/mail/manager.go @@ -44,6 +44,10 @@ func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...doma return fmt.Errorf("failed to fetch peer %s: %w", peerId, err) } + if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil { + return err + } + if peer.UserIdentifier == "" { logrus.Debugf("skipping peer email for %s, no user linked", peerId) continue diff --git a/internal/app/users/user_manager.go b/internal/app/users/user_manager.go index 0a80860..4937aef 100644 --- a/internal/app/users/user_manager.go +++ b/internal/app/users/user_manager.go @@ -43,6 +43,10 @@ func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabase } func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return err + } + err := m.NewUser(ctx, user) if err != nil { return err @@ -58,6 +62,10 @@ func (m Manager) NewUser(ctx context.Context, user *domain.User) error { return errors.New("missing user identifier") } + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return err + } + err := m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) { u.Identifier = user.Identifier u.Email = user.Email @@ -83,6 +91,10 @@ func (m Manager) StartBackgroundJobs(ctx context.Context) { } func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { + if err := domain.ValidateUserAccessRights(ctx, id); err != nil { + return nil, err + } + user, err := m.users.GetUser(ctx, id) if err != nil { return nil, fmt.Errorf("unable to load peer %s: %w", id, err) @@ -95,6 +107,10 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain } func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, err + } + users, err := m.users.GetAllUsers(ctx) if err != nil { return nil, fmt.Errorf("unable to load users: %w", err) @@ -123,6 +139,10 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) { } func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) { + if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil { + return nil, err + } + existingUser, err := m.users.GetUser(ctx, user.Identifier) if err != nil { return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err) @@ -153,6 +173,10 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use } func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, err + } + existingUser, err := m.users.GetUser(ctx, user.Identifier) if err != nil && !errors.Is(err, domain.ErrNotFound) { return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err) @@ -182,6 +206,10 @@ func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.Use } func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return err + } + existingUser, err := m.users.GetUser(ctx, id) if err != nil && !errors.Is(err, domain.ErrNotFound) { return fmt.Errorf("unable to find user %s: %w", id, err) diff --git a/internal/app/wireguard/wireguard.go b/internal/app/wireguard/wireguard.go index 6c406f9..d720470 100644 --- a/internal/app/wireguard/wireguard.go +++ b/internal/app/wireguard/wireguard.go @@ -47,7 +47,8 @@ func (m Manager) handleUserCreationEvent(user *domain.User) { logrus.Errorf("handling new user event for %s", user.Identifier) if m.cfg.Core.CreateDefaultPeer { - err := m.CreateDefaultPeer(context.Background(), user) + ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) + err := m.CreateDefaultPeer(ctx, user) if err != nil { logrus.Errorf("failed to create default peer for %s: %v", user.Identifier, err) return diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 7c00a15..ecf368c 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -13,6 +13,10 @@ import ( ) func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, err + } + physicalInterfaces, err := m.wg.GetInterfaces(ctx) if err != nil { return nil, err @@ -22,14 +26,26 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical } func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, nil, err + } + return m.db.GetInterfaceAndPeers(ctx, id) } func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, err + } + return m.db.GetAllInterfaces(ctx) } func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, nil, err + } + interfaces, err := m.db.GetAllInterfaces(ctx) if err != nil { return nil, nil, fmt.Errorf("unable to load all interfaces: %w", err) @@ -48,6 +64,10 @@ func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interfa } func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return 0, err + } + physicalInterfaces, err := m.wg.GetInterfaces(ctx) if err != nil { return 0, err @@ -95,6 +115,10 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter } func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return err + } + existingInterface, err := m.db.GetInterface(ctx, in.Identifier) if err != nil { return fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err) @@ -122,6 +146,10 @@ func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) er } func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return err + } + interfaces, err := m.db.GetAllInterfaces(ctx) if err != nil { return err @@ -201,6 +229,10 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool } func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, err + } + currentUser := domain.GetUserInfo(ctx) kp, err := domain.NewFreshKeypair() @@ -277,6 +309,10 @@ func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error } func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, err + } + existingInterface, err := m.db.GetInterface(ctx, in.Identifier) if err != nil && !errors.Is(err, domain.ErrNotFound) { return nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err) @@ -298,6 +334,10 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do } func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, nil, err + } + existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, in.Identifier) if err != nil { return nil, nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err) @@ -316,6 +356,10 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do } func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return err + } + existingInterface, err := m.db.GetInterface(ctx, id) if err != nil { return fmt.Errorf("unable to find interface %s: %w", id, err) diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index d4b09ee..d8c55ca 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -12,6 +12,10 @@ import ( ) func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return err + } + existingInterfaces, err := m.db.GetAllInterfaces(ctx) if err != nil { return fmt.Errorf("failed to fetch all interfaces: %w", err) @@ -49,10 +53,18 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error } func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) { + if err := domain.ValidateUserAccessRights(ctx, id); err != nil { + return nil, err + } + return m.db.GetUserPeers(ctx, id) } func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, err // TODO: self provisioning? + } + currentUser := domain.GetUserInfo(ctx) iface, err := m.db.GetInterface(ctx, id) @@ -128,10 +140,18 @@ func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain return nil, fmt.Errorf("unable to find peer %s: %w", id, err) } + if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil { + return nil, err + } + return peer, nil } func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) { + if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil { + return nil, err + } + existingPeer, err := m.db.GetPeer(ctx, peer.Identifier) if err != nil && !errors.Is(err, domain.ErrNotFound) { return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err) @@ -153,6 +173,10 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee } func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) { + if err := domain.ValidateAdminAccessRights(ctx); err != nil { + return nil, err + } + var newPeers []*domain.Peer for _, id := range r.UserIdentifiers { @@ -192,6 +216,10 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err) } + if err := domain.ValidateUserAccessRights(ctx, existingPeer.UserIdentifier); err != nil { + return nil, err + } + if err := m.validatePeerModifications(ctx, existingPeer, peer); err != nil { return nil, fmt.Errorf("update not allowed: %w", err) } @@ -210,6 +238,10 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error return fmt.Errorf("unable to find peer %s: %w", id, err) } + if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil { + return err + } + err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id) if err != nil { return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err) @@ -231,6 +263,10 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier peerIds := make([]domain.PeerIdentifier, len(peers)) for i, peer := range peers { + if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil { + return nil, err + } + peerIds[i] = peer.Identifier } @@ -238,6 +274,10 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier } func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) { + if err := domain.ValidateUserAccessRights(ctx, id); err != nil { + return nil, err + } + peers, err := m.db.GetUserPeers(ctx, id) if err != nil { return nil, fmt.Errorf("failed to fetch peers for user %s: %w", id, err) diff --git a/internal/domain/context.go b/internal/domain/context.go index b391db4..26726db 100644 --- a/internal/domain/context.go +++ b/internal/domain/context.go @@ -3,6 +3,7 @@ package domain import ( "context" "fmt" + "github.com/sirupsen/logrus" "github.com/gin-gonic/gin" ) @@ -72,3 +73,29 @@ func GetUserInfo(ctx context.Context) *ContextUserInfo { return DefaultContextUserInfo() } + +func ValidateUserAccessRights(ctx context.Context, requiredUser UserIdentifier) error { + sessionUser := GetUserInfo(ctx) + + if sessionUser.IsAdmin { + return nil // Admins can do everything + } + + if sessionUser.Id == requiredUser { + return nil // User can access own data + } + + logrus.Warnf("insufficient permissions for %s (want %s), stack: %s", sessionUser.Id, requiredUser, GetStackTrace()) + return fmt.Errorf("insufficient permissions") +} + +func ValidateAdminAccessRights(ctx context.Context) error { + sessionUser := GetUserInfo(ctx) + + if sessionUser.IsAdmin { + return nil + } + + logrus.Warnf("insufficient admin permissions for %s, stack: %s", sessionUser.Id, GetStackTrace()) + return fmt.Errorf("insufficient permissions") +} diff --git a/internal/domain/errors.go b/internal/domain/errors.go index 12821e5..48978a5 100644 --- a/internal/domain/errors.go +++ b/internal/domain/errors.go @@ -1,6 +1,18 @@ package domain -import "errors" +import ( + "errors" + "runtime" +) var ErrNotFound = errors.New("record not found") var ErrNotUnique = errors.New("record not unique") + +// GetStackTrace returns a stack trace of the current goroutine. The stack trace has at most 1024 bytes. +func GetStackTrace() string { + b := make([]byte, 1024) + n := runtime.Stack(b, false) + s := string(b[:n]) + + return s +}