fix REST API permission checks (#209)

This commit is contained in:
Christoph Haas 2024-01-31 21:14:36 +01:00
parent 81e696fc7d
commit 1b4b5ff161
14 changed files with 239 additions and 26 deletions

View File

@ -44,7 +44,8 @@ func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti
// @Router /auth/providers [get]
func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc {
return func(c *gin.Context) {
providers := e.app.Authenticator.GetExternalLoginProviders(c.Request.Context())
ctx := domain.SetUserInfoFromGin(c)
providers := e.app.Authenticator.GetExternalLoginProviders(ctx)
c.JSON(http.StatusOK, model.NewLoginProviderInfos(providers))
}
@ -69,7 +70,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
var email *string
if currentSession.LoggedIn {
uid := string(currentSession.UserIdentifier)
uid := currentSession.UserIdentifier
f := currentSession.Firstname
l := currentSession.Lastname
e := currentSession.Email
@ -134,7 +135,8 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
return
}
authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(c.Request.Context(), provider)
ctx := domain.SetUserInfoFromGin(c)
authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(ctx, provider)
if err != nil {
if autoRedirect {
redirectToReturn()
@ -292,7 +294,8 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc {
return
}
user, err := e.app.Authenticator.PlainLogin(c.Request.Context(), loginData.Username, loginData.Password)
ctx := domain.SetUserInfoFromGin(c)
user, err := e.app.Authenticator.PlainLogin(ctx, loginData.Username, loginData.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "login failed"})
return

View File

@ -19,7 +19,7 @@ func (e interfaceEndpoint) GetName() string {
}
func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
apiGroup := g.Group("/interface", e.authenticator.LoggedIn())
apiGroup := g.Group("/interface", e.authenticator.LoggedIn(ScopeAdmin))
apiGroup.GET("/prepare", e.handlePrepareGet())
apiGroup.GET("/all", e.handleAllGet())
@ -45,7 +45,8 @@ func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *aut
// @Router /interface/prepare [get]
func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
return func(c *gin.Context) {
in, err := e.app.PrepareInterface(c.Request.Context())
ctx := domain.SetUserInfoFromGin(c)
in, err := e.app.PrepareInterface(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@ -68,7 +69,8 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
// @Router /interface/all [get]
func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
return func(c *gin.Context) {
interfaces, peers, err := e.app.GetAllInterfacesAndPeers(c.Request.Context())
ctx := domain.SetUserInfoFromGin(c)
interfaces, peers, err := e.app.GetAllInterfacesAndPeers(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@ -92,6 +94,7 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
// @Router /interface/get/{id} [get]
func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{
@ -100,7 +103,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
return
}
iface, peers, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id))
iface, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@ -124,6 +127,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
// @Router /interface/config/{id} [get]
func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{
@ -132,7 +136,7 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
return
}
config, err := e.app.GetInterfaceConfig(c.Request.Context(), domain.InterfaceIdentifier(id))
config, err := e.app.GetInterfaceConfig(ctx, domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),

View File

@ -21,11 +21,11 @@ func (e peerEndpoint) GetName() string {
func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
apiGroup := g.Group("/peer", e.authenticator.LoggedIn())
apiGroup.GET("/iface/:iface/all", e.handleAllGet())
apiGroup.GET("/iface/:iface/stats", e.handleStatsGet())
apiGroup.GET("/iface/:iface/prepare", e.handlePrepareGet())
apiGroup.POST("/iface/:iface/new", e.handleCreatePost())
apiGroup.POST("/iface/:iface/multiplenew", e.handleCreateMultiplePost())
apiGroup.GET("/iface/:iface/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
apiGroup.GET("/iface/:iface/stats", e.authenticator.LoggedIn(ScopeAdmin), e.handleStatsGet())
apiGroup.GET("/iface/:iface/prepare", e.authenticator.LoggedIn(ScopeAdmin), e.handlePrepareGet())
apiGroup.POST("/iface/:iface/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
apiGroup.POST("/iface/:iface/multiplenew", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreateMultiplePost())
apiGroup.GET("/config-qr/:id", e.handleQrCodeGet())
apiGroup.POST("/config-mail", e.handleEmailPost())
apiGroup.GET("/config/:id", e.handleConfigGet())
@ -298,6 +298,8 @@ func (e peerEndpoint) handleDelete() gin.HandlerFunc {
// @Router /peer/config/{id} [get]
func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{
@ -306,7 +308,7 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
return
}
config, err := e.app.GetPeerConfig(c.Request.Context(), domain.PeerIdentifier(id))
config, err := e.app.GetPeerConfig(ctx, domain.PeerIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@ -339,6 +341,7 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
// @Router /peer/config-qr/{id} [get]
func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := domain.SetUserInfoFromGin(c)
id := Base64UrlDecode(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, model.Error{
@ -347,7 +350,7 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
return
}
config, err := e.app.GetPeerConfigQrCode(c.Request.Context(), domain.PeerIdentifier(id))
config, err := e.app.GetPeerConfigQrCode(ctx, domain.PeerIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@ -392,11 +395,13 @@ func (e peerEndpoint) handleEmailPost() gin.HandlerFunc {
return
}
ctx := domain.SetUserInfoFromGin(c)
peerIds := make([]domain.PeerIdentifier, len(req.Identifiers))
for i := range req.Identifiers {
peerIds[i] = domain.PeerIdentifier(req.Identifiers[i])
}
err = e.app.SendPeerEmail(c.Request.Context(), req.LinkOnly, peerIds...)
err = e.app.SendPeerEmail(ctx, req.LinkOnly, peerIds...)
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return

View File

@ -20,13 +20,13 @@ func (e userEndpoint) GetName() string {
func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
apiGroup := g.Group("/user", e.authenticator.LoggedIn())
apiGroup.GET("/all", e.handleAllGet())
apiGroup.GET("/:id", e.handleSingleGet())
apiGroup.PUT("/:id", e.handleUpdatePut())
apiGroup.DELETE("/:id", e.handleDelete())
apiGroup.POST("/new", e.handleCreatePost())
apiGroup.GET("/:id/peers", e.handlePeersGet())
apiGroup.GET("/:id/stats", e.handleStatsGet())
apiGroup.GET("/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
apiGroup.GET("/:id", e.authenticator.UserIdMatch("id"), e.handleSingleGet())
apiGroup.PUT("/:id", e.authenticator.UserIdMatch("id"), e.handleUpdatePut())
apiGroup.DELETE("/:id", e.authenticator.UserIdMatch("id"), e.handleDelete())
apiGroup.POST("/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
apiGroup.GET("/:id/peers", e.authenticator.UserIdMatch("id"), e.handlePeersGet())
apiGroup.GET("/:id/stats", e.authenticator.UserIdMatch("id"), e.handleStatsGet())
}
// handleAllGet returns a gorm handler function.

View File

@ -58,6 +58,31 @@ func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc {
}
}
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
func (h authenticationHandler) UserIdMatch(idParameter string) gin.HandlerFunc {
return func(c *gin.Context) {
session := h.Session.GetData(c)
if session.IsAdmin {
c.Next() // Admins can do everything
return
}
sessionUserId := domain.UserIdentifier(session.UserIdentifier)
requestUserId := domain.UserIdentifier(Base64UrlDecode(c.Param(idParameter)))
if sessionUserId != requestUserId {
// Abort the request with the appropriate error code
c.Abort()
c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
return
}
// Continue down the chain to handler etc
c.Next()
}
}
func UserHasScopes(session SessionData, scopes ...Scope) bool {
// No scopes give, so the check should succeed
if len(scopes) == 0 {

View File

@ -150,6 +150,7 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo
}
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
user, err := a.users.GetUser(ctx, id)
if err != nil {
return false
@ -187,6 +188,8 @@ func (a *Authenticator) PlainLogin(ctx context.Context, username, password strin
}
func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier domain.UserIdentifier, password string) (*domain.User, error) {
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
var ldapUserInfo *domain.AuthenticatorUserInfo
var ldapProvider domain.LdapAuthenticator
@ -315,6 +318,7 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce,
return nil, fmt.Errorf("unable to parse user information: %w", err)
}
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(), oauthProvider.RegistrationEnabled())
if err != nil {
return nil, fmt.Errorf("unable to process user information: %w", err)

View File

@ -109,6 +109,10 @@ func (m Manager) handlePeerInterfaceUpdatedEvent(id domain.InterfaceIdentifier)
}
func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to fetch interface %s: %w", id, err)
@ -123,6 +127,10 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i
return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err)
}
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
return nil, err
}
return m.tplHandler.GetPeerConfig(peer)
}
@ -132,6 +140,10 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err)
}
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
return nil, err
}
cfgData, err := m.tplHandler.GetPeerConfig(peer)
if err != nil {
return nil, fmt.Errorf("failed to get peer config for %s: %w", id, err)
@ -172,6 +184,10 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
}
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
}
if m.fsRepo == nil {
return fmt.Errorf("peristing configuration is not supported")
}

View File

@ -44,6 +44,10 @@ func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...doma
return fmt.Errorf("failed to fetch peer %s: %w", peerId, err)
}
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
return err
}
if peer.UserIdentifier == "" {
logrus.Debugf("skipping peer email for %s, no user linked", peerId)
continue

View File

@ -43,6 +43,10 @@ func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabase
}
func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
}
err := m.NewUser(ctx, user)
if err != nil {
return err
@ -58,6 +62,10 @@ func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
return errors.New("missing user identifier")
}
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
}
err := m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
u.Identifier = user.Identifier
u.Email = user.Email
@ -83,6 +91,10 @@ func (m Manager) StartBackgroundJobs(ctx context.Context) {
}
func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
return nil, err
}
user, err := m.users.GetUser(ctx, id)
if err != nil {
return nil, fmt.Errorf("unable to load peer %s: %w", id, err)
@ -95,6 +107,10 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain
}
func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
users, err := m.users.GetAllUsers(ctx)
if err != nil {
return nil, fmt.Errorf("unable to load users: %w", err)
@ -123,6 +139,10 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
}
func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil {
return nil, err
}
existingUser, err := m.users.GetUser(ctx, user.Identifier)
if err != nil {
return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
@ -153,6 +173,10 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use
}
func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
existingUser, err := m.users.GetUser(ctx, user.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
@ -182,6 +206,10 @@ func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.Use
}
func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
}
existingUser, err := m.users.GetUser(ctx, id)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return fmt.Errorf("unable to find user %s: %w", id, err)

View File

@ -47,7 +47,8 @@ func (m Manager) handleUserCreationEvent(user *domain.User) {
logrus.Errorf("handling new user event for %s", user.Identifier)
if m.cfg.Core.CreateDefaultPeer {
err := m.CreateDefaultPeer(context.Background(), user)
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
err := m.CreateDefaultPeer(ctx, user)
if err != nil {
logrus.Errorf("failed to create default peer for %s: %v", user.Identifier, err)
return

View File

@ -13,6 +13,10 @@ import (
)
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return nil, err
@ -22,14 +26,26 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
}
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, nil, err
}
return m.db.GetInterfaceAndPeers(ctx, id)
}
func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
return m.db.GetAllInterfaces(ctx)
}
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, nil, err
}
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to load all interfaces: %w", err)
@ -48,6 +64,10 @@ func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interfa
}
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return 0, err
}
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return 0, err
@ -95,6 +115,10 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
}
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
}
existingInterface, err := m.db.GetInterface(ctx, in.Identifier)
if err != nil {
return fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
@ -122,6 +146,10 @@ func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) er
}
func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
}
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return err
@ -201,6 +229,10 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
}
func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
currentUser := domain.GetUserInfo(ctx)
kp, err := domain.NewFreshKeypair()
@ -277,6 +309,10 @@ func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error
}
func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
existingInterface, err := m.db.GetInterface(ctx, in.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
@ -298,6 +334,10 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
}
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, nil, err
}
existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, in.Identifier)
if err != nil {
return nil, nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
@ -316,6 +356,10 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
}
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
}
existingInterface, err := m.db.GetInterface(ctx, id)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", id, err)

View File

@ -12,6 +12,10 @@ import (
)
func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
}
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return fmt.Errorf("failed to fetch all interfaces: %w", err)
@ -49,10 +53,18 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error
}
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
return nil, err
}
return m.db.GetUserPeers(ctx, id)
}
func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err // TODO: self provisioning?
}
currentUser := domain.GetUserInfo(ctx)
iface, err := m.db.GetInterface(ctx, id)
@ -128,10 +140,18 @@ func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain
return nil, fmt.Errorf("unable to find peer %s: %w", id, err)
}
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
return nil, err
}
return peer, nil
}
func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
return nil, err
}
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
@ -153,6 +173,10 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
}
func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
var newPeers []*domain.Peer
for _, id := range r.UserIdentifiers {
@ -192,6 +216,10 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
}
if err := domain.ValidateUserAccessRights(ctx, existingPeer.UserIdentifier); err != nil {
return nil, err
}
if err := m.validatePeerModifications(ctx, existingPeer, peer); err != nil {
return nil, fmt.Errorf("update not allowed: %w", err)
}
@ -210,6 +238,10 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("unable to find peer %s: %w", id, err)
}
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
return err
}
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
if err != nil {
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
@ -231,6 +263,10 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
peerIds := make([]domain.PeerIdentifier, len(peers))
for i, peer := range peers {
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
return nil, err
}
peerIds[i] = peer.Identifier
}
@ -238,6 +274,10 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
}
func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
return nil, err
}
peers, err := m.db.GetUserPeers(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to fetch peers for user %s: %w", id, err)

View File

@ -3,6 +3,7 @@ package domain
import (
"context"
"fmt"
"github.com/sirupsen/logrus"
"github.com/gin-gonic/gin"
)
@ -72,3 +73,29 @@ func GetUserInfo(ctx context.Context) *ContextUserInfo {
return DefaultContextUserInfo()
}
func ValidateUserAccessRights(ctx context.Context, requiredUser UserIdentifier) error {
sessionUser := GetUserInfo(ctx)
if sessionUser.IsAdmin {
return nil // Admins can do everything
}
if sessionUser.Id == requiredUser {
return nil // User can access own data
}
logrus.Warnf("insufficient permissions for %s (want %s), stack: %s", sessionUser.Id, requiredUser, GetStackTrace())
return fmt.Errorf("insufficient permissions")
}
func ValidateAdminAccessRights(ctx context.Context) error {
sessionUser := GetUserInfo(ctx)
if sessionUser.IsAdmin {
return nil
}
logrus.Warnf("insufficient admin permissions for %s, stack: %s", sessionUser.Id, GetStackTrace())
return fmt.Errorf("insufficient permissions")
}

View File

@ -1,6 +1,18 @@
package domain
import "errors"
import (
"errors"
"runtime"
)
var ErrNotFound = errors.New("record not found")
var ErrNotUnique = errors.New("record not unique")
// GetStackTrace returns a stack trace of the current goroutine. The stack trace has at most 1024 bytes.
func GetStackTrace() string {
b := make([]byte, 1024)
n := runtime.Stack(b, false)
s := string(b[:n])
return s
}