mirror of
https://github.com/h44z/wg-portal
synced 2025-02-26 05:49:14 +00:00
fix REST API permission checks (#209)
This commit is contained in:
parent
81e696fc7d
commit
1b4b5ff161
@ -44,7 +44,8 @@ func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti
|
||||
// @Router /auth/providers [get]
|
||||
func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
providers := e.app.Authenticator.GetExternalLoginProviders(c.Request.Context())
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
providers := e.app.Authenticator.GetExternalLoginProviders(ctx)
|
||||
|
||||
c.JSON(http.StatusOK, model.NewLoginProviderInfos(providers))
|
||||
}
|
||||
@ -69,7 +70,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
|
||||
var email *string
|
||||
|
||||
if currentSession.LoggedIn {
|
||||
uid := string(currentSession.UserIdentifier)
|
||||
uid := currentSession.UserIdentifier
|
||||
f := currentSession.Firstname
|
||||
l := currentSession.Lastname
|
||||
e := currentSession.Email
|
||||
@ -134,7 +135,8 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(c.Request.Context(), provider)
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(ctx, provider)
|
||||
if err != nil {
|
||||
if autoRedirect {
|
||||
redirectToReturn()
|
||||
@ -292,7 +294,8 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := e.app.Authenticator.PlainLogin(c.Request.Context(), loginData.Username, loginData.Password)
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
user, err := e.app.Authenticator.PlainLogin(ctx, loginData.Username, loginData.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "login failed"})
|
||||
return
|
||||
|
@ -19,7 +19,7 @@ func (e interfaceEndpoint) GetName() string {
|
||||
}
|
||||
|
||||
func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
|
||||
apiGroup := g.Group("/interface", e.authenticator.LoggedIn())
|
||||
apiGroup := g.Group("/interface", e.authenticator.LoggedIn(ScopeAdmin))
|
||||
|
||||
apiGroup.GET("/prepare", e.handlePrepareGet())
|
||||
apiGroup.GET("/all", e.handleAllGet())
|
||||
@ -45,7 +45,8 @@ func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *aut
|
||||
// @Router /interface/prepare [get]
|
||||
func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
in, err := e.app.PrepareInterface(c.Request.Context())
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
in, err := e.app.PrepareInterface(ctx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
@ -68,7 +69,8 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
|
||||
// @Router /interface/all [get]
|
||||
func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
interfaces, peers, err := e.app.GetAllInterfacesAndPeers(c.Request.Context())
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
interfaces, peers, err := e.app.GetAllInterfacesAndPeers(ctx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
@ -92,6 +94,7 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
|
||||
// @Router /interface/get/{id} [get]
|
||||
func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
@ -100,7 +103,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
iface, peers, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id))
|
||||
iface, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
@ -124,6 +127,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
|
||||
// @Router /interface/config/{id} [get]
|
||||
func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
@ -132,7 +136,7 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
config, err := e.app.GetInterfaceConfig(c.Request.Context(), domain.InterfaceIdentifier(id))
|
||||
config, err := e.app.GetInterfaceConfig(ctx, domain.InterfaceIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
|
@ -21,11 +21,11 @@ func (e peerEndpoint) GetName() string {
|
||||
func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
|
||||
apiGroup := g.Group("/peer", e.authenticator.LoggedIn())
|
||||
|
||||
apiGroup.GET("/iface/:iface/all", e.handleAllGet())
|
||||
apiGroup.GET("/iface/:iface/stats", e.handleStatsGet())
|
||||
apiGroup.GET("/iface/:iface/prepare", e.handlePrepareGet())
|
||||
apiGroup.POST("/iface/:iface/new", e.handleCreatePost())
|
||||
apiGroup.POST("/iface/:iface/multiplenew", e.handleCreateMultiplePost())
|
||||
apiGroup.GET("/iface/:iface/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
|
||||
apiGroup.GET("/iface/:iface/stats", e.authenticator.LoggedIn(ScopeAdmin), e.handleStatsGet())
|
||||
apiGroup.GET("/iface/:iface/prepare", e.authenticator.LoggedIn(ScopeAdmin), e.handlePrepareGet())
|
||||
apiGroup.POST("/iface/:iface/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
|
||||
apiGroup.POST("/iface/:iface/multiplenew", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreateMultiplePost())
|
||||
apiGroup.GET("/config-qr/:id", e.handleQrCodeGet())
|
||||
apiGroup.POST("/config-mail", e.handleEmailPost())
|
||||
apiGroup.GET("/config/:id", e.handleConfigGet())
|
||||
@ -298,6 +298,8 @@ func (e peerEndpoint) handleDelete() gin.HandlerFunc {
|
||||
// @Router /peer/config/{id} [get]
|
||||
func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
@ -306,7 +308,7 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
config, err := e.app.GetPeerConfig(c.Request.Context(), domain.PeerIdentifier(id))
|
||||
config, err := e.app.GetPeerConfig(ctx, domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
@ -339,6 +341,7 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
|
||||
// @Router /peer/config-qr/{id} [get]
|
||||
func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
id := Base64UrlDecode(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, model.Error{
|
||||
@ -347,7 +350,7 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
config, err := e.app.GetPeerConfigQrCode(c.Request.Context(), domain.PeerIdentifier(id))
|
||||
config, err := e.app.GetPeerConfigQrCode(ctx, domain.PeerIdentifier(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{
|
||||
Code: http.StatusInternalServerError, Message: err.Error(),
|
||||
@ -392,11 +395,13 @@ func (e peerEndpoint) handleEmailPost() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := domain.SetUserInfoFromGin(c)
|
||||
|
||||
peerIds := make([]domain.PeerIdentifier, len(req.Identifiers))
|
||||
for i := range req.Identifiers {
|
||||
peerIds[i] = domain.PeerIdentifier(req.Identifiers[i])
|
||||
}
|
||||
err = e.app.SendPeerEmail(c.Request.Context(), req.LinkOnly, peerIds...)
|
||||
err = e.app.SendPeerEmail(ctx, req.LinkOnly, peerIds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
|
||||
return
|
||||
|
@ -20,13 +20,13 @@ func (e userEndpoint) GetName() string {
|
||||
func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
|
||||
apiGroup := g.Group("/user", e.authenticator.LoggedIn())
|
||||
|
||||
apiGroup.GET("/all", e.handleAllGet())
|
||||
apiGroup.GET("/:id", e.handleSingleGet())
|
||||
apiGroup.PUT("/:id", e.handleUpdatePut())
|
||||
apiGroup.DELETE("/:id", e.handleDelete())
|
||||
apiGroup.POST("/new", e.handleCreatePost())
|
||||
apiGroup.GET("/:id/peers", e.handlePeersGet())
|
||||
apiGroup.GET("/:id/stats", e.handleStatsGet())
|
||||
apiGroup.GET("/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
|
||||
apiGroup.GET("/:id", e.authenticator.UserIdMatch("id"), e.handleSingleGet())
|
||||
apiGroup.PUT("/:id", e.authenticator.UserIdMatch("id"), e.handleUpdatePut())
|
||||
apiGroup.DELETE("/:id", e.authenticator.UserIdMatch("id"), e.handleDelete())
|
||||
apiGroup.POST("/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
|
||||
apiGroup.GET("/:id/peers", e.authenticator.UserIdMatch("id"), e.handlePeersGet())
|
||||
apiGroup.GET("/:id/stats", e.authenticator.UserIdMatch("id"), e.handleStatsGet())
|
||||
}
|
||||
|
||||
// handleAllGet returns a gorm handler function.
|
||||
|
@ -58,6 +58,31 @@ func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
|
||||
func (h authenticationHandler) UserIdMatch(idParameter string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := h.Session.GetData(c)
|
||||
|
||||
if session.IsAdmin {
|
||||
c.Next() // Admins can do everything
|
||||
return
|
||||
}
|
||||
|
||||
sessionUserId := domain.UserIdentifier(session.UserIdentifier)
|
||||
requestUserId := domain.UserIdentifier(Base64UrlDecode(c.Param(idParameter)))
|
||||
|
||||
if sessionUserId != requestUserId {
|
||||
// Abort the request with the appropriate error code
|
||||
c.Abort()
|
||||
c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
// Continue down the chain to handler etc
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func UserHasScopes(session SessionData, scopes ...Scope) bool {
|
||||
// No scopes give, so the check should succeed
|
||||
if len(scopes) == 0 {
|
||||
|
@ -150,6 +150,7 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo
|
||||
}
|
||||
|
||||
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
|
||||
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
|
||||
user, err := a.users.GetUser(ctx, id)
|
||||
if err != nil {
|
||||
return false
|
||||
@ -187,6 +188,8 @@ func (a *Authenticator) PlainLogin(ctx context.Context, username, password strin
|
||||
}
|
||||
|
||||
func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier domain.UserIdentifier, password string) (*domain.User, error) {
|
||||
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
||||
|
||||
var ldapUserInfo *domain.AuthenticatorUserInfo
|
||||
var ldapProvider domain.LdapAuthenticator
|
||||
|
||||
@ -315,6 +318,7 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce,
|
||||
return nil, fmt.Errorf("unable to parse user information: %w", err)
|
||||
}
|
||||
|
||||
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
||||
user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(), oauthProvider.RegistrationEnabled())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to process user information: %w", err)
|
||||
|
@ -109,6 +109,10 @@ func (m Manager) handlePeerInterfaceUpdatedEvent(id domain.InterfaceIdentifier)
|
||||
}
|
||||
|
||||
func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch interface %s: %w", id, err)
|
||||
@ -123,6 +127,10 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i
|
||||
return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err)
|
||||
}
|
||||
|
||||
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.tplHandler.GetPeerConfig(peer)
|
||||
}
|
||||
|
||||
@ -132,6 +140,10 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
|
||||
return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err)
|
||||
}
|
||||
|
||||
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfgData, err := m.tplHandler.GetPeerConfig(peer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get peer config for %s: %w", id, err)
|
||||
@ -172,6 +184,10 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
|
||||
}
|
||||
|
||||
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.fsRepo == nil {
|
||||
return fmt.Errorf("peristing configuration is not supported")
|
||||
}
|
||||
|
@ -44,6 +44,10 @@ func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...doma
|
||||
return fmt.Errorf("failed to fetch peer %s: %w", peerId, err)
|
||||
}
|
||||
|
||||
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.UserIdentifier == "" {
|
||||
logrus.Debugf("skipping peer email for %s, no user linked", peerId)
|
||||
continue
|
||||
|
@ -43,6 +43,10 @@ func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabase
|
||||
}
|
||||
|
||||
func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := m.NewUser(ctx, user)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -58,6 +62,10 @@ func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
|
||||
return errors.New("missing user identifier")
|
||||
}
|
||||
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
|
||||
u.Identifier = user.Identifier
|
||||
u.Email = user.Email
|
||||
@ -83,6 +91,10 @@ func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := m.users.GetUser(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load peer %s: %w", id, err)
|
||||
@ -95,6 +107,10 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain
|
||||
}
|
||||
|
||||
func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users, err := m.users.GetAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load users: %w", err)
|
||||
@ -123,6 +139,10 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
||||
}
|
||||
|
||||
func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||
if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingUser, err := m.users.GetUser(ctx, user.Identifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
|
||||
@ -153,6 +173,10 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use
|
||||
}
|
||||
|
||||
func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingUser, err := m.users.GetUser(ctx, user.Identifier)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
|
||||
@ -182,6 +206,10 @@ func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.Use
|
||||
}
|
||||
|
||||
func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingUser, err := m.users.GetUser(ctx, id)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
return fmt.Errorf("unable to find user %s: %w", id, err)
|
||||
|
@ -47,7 +47,8 @@ func (m Manager) handleUserCreationEvent(user *domain.User) {
|
||||
logrus.Errorf("handling new user event for %s", user.Identifier)
|
||||
|
||||
if m.cfg.Core.CreateDefaultPeer {
|
||||
err := m.CreateDefaultPeer(context.Background(), user)
|
||||
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
|
||||
err := m.CreateDefaultPeer(ctx, user)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to create default peer for %s: %v", user.Identifier, err)
|
||||
return
|
||||
|
@ -13,6 +13,10 @@ import (
|
||||
)
|
||||
|
||||
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -22,14 +26,26 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
|
||||
}
|
||||
|
||||
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return m.db.GetInterfaceAndPeers(ctx, id)
|
||||
}
|
||||
|
||||
func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.db.GetAllInterfaces(ctx)
|
||||
}
|
||||
|
||||
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
interfaces, err := m.db.GetAllInterfaces(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to load all interfaces: %w", err)
|
||||
@ -48,6 +64,10 @@ func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interfa
|
||||
}
|
||||
|
||||
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@ -95,6 +115,10 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
|
||||
}
|
||||
|
||||
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingInterface, err := m.db.GetInterface(ctx, in.Identifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
|
||||
@ -122,6 +146,10 @@ func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) er
|
||||
}
|
||||
|
||||
func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
interfaces, err := m.db.GetAllInterfaces(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -201,6 +229,10 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
|
||||
}
|
||||
|
||||
func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentUser := domain.GetUserInfo(ctx)
|
||||
|
||||
kp, err := domain.NewFreshKeypair()
|
||||
@ -277,6 +309,10 @@ func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error
|
||||
}
|
||||
|
||||
func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingInterface, err := m.db.GetInterface(ctx, in.Identifier)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
return nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
|
||||
@ -298,6 +334,10 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
|
||||
}
|
||||
|
||||
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, in.Identifier)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
|
||||
@ -316,6 +356,10 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
|
||||
}
|
||||
|
||||
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingInterface, err := m.db.GetInterface(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", id, err)
|
||||
|
@ -12,6 +12,10 @@ import (
|
||||
)
|
||||
|
||||
func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch all interfaces: %w", err)
|
||||
@ -49,10 +53,18 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error
|
||||
}
|
||||
|
||||
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
|
||||
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.db.GetUserPeers(ctx, id)
|
||||
}
|
||||
|
||||
func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err // TODO: self provisioning?
|
||||
}
|
||||
|
||||
currentUser := domain.GetUserInfo(ctx)
|
||||
|
||||
iface, err := m.db.GetInterface(ctx, id)
|
||||
@ -128,10 +140,18 @@ func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain
|
||||
return nil, fmt.Errorf("unable to find peer %s: %w", id, err)
|
||||
}
|
||||
|
||||
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
|
||||
@ -153,6 +173,10 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
||||
}
|
||||
|
||||
func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var newPeers []*domain.Peer
|
||||
|
||||
for _, id := range r.UserIdentifiers {
|
||||
@ -192,6 +216,10 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
||||
return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
if err := domain.ValidateUserAccessRights(ctx, existingPeer.UserIdentifier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.validatePeerModifications(ctx, existingPeer, peer); err != nil {
|
||||
return nil, fmt.Errorf("update not allowed: %w", err)
|
||||
}
|
||||
@ -210,6 +238,10 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
return fmt.Errorf("unable to find peer %s: %w", id, err)
|
||||
}
|
||||
|
||||
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
|
||||
@ -231,6 +263,10 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
|
||||
|
||||
peerIds := make([]domain.PeerIdentifier, len(peers))
|
||||
for i, peer := range peers {
|
||||
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerIds[i] = peer.Identifier
|
||||
}
|
||||
|
||||
@ -238,6 +274,10 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
|
||||
}
|
||||
|
||||
func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
|
||||
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers, err := m.db.GetUserPeers(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch peers for user %s: %w", id, err)
|
||||
|
@ -3,6 +3,7 @@ package domain
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@ -72,3 +73,29 @@ func GetUserInfo(ctx context.Context) *ContextUserInfo {
|
||||
|
||||
return DefaultContextUserInfo()
|
||||
}
|
||||
|
||||
func ValidateUserAccessRights(ctx context.Context, requiredUser UserIdentifier) error {
|
||||
sessionUser := GetUserInfo(ctx)
|
||||
|
||||
if sessionUser.IsAdmin {
|
||||
return nil // Admins can do everything
|
||||
}
|
||||
|
||||
if sessionUser.Id == requiredUser {
|
||||
return nil // User can access own data
|
||||
}
|
||||
|
||||
logrus.Warnf("insufficient permissions for %s (want %s), stack: %s", sessionUser.Id, requiredUser, GetStackTrace())
|
||||
return fmt.Errorf("insufficient permissions")
|
||||
}
|
||||
|
||||
func ValidateAdminAccessRights(ctx context.Context) error {
|
||||
sessionUser := GetUserInfo(ctx)
|
||||
|
||||
if sessionUser.IsAdmin {
|
||||
return nil
|
||||
}
|
||||
|
||||
logrus.Warnf("insufficient admin permissions for %s, stack: %s", sessionUser.Id, GetStackTrace())
|
||||
return fmt.Errorf("insufficient permissions")
|
||||
}
|
||||
|
@ -1,6 +1,18 @@
|
||||
package domain
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var ErrNotFound = errors.New("record not found")
|
||||
var ErrNotUnique = errors.New("record not unique")
|
||||
|
||||
// GetStackTrace returns a stack trace of the current goroutine. The stack trace has at most 1024 bytes.
|
||||
func GetStackTrace() string {
|
||||
b := make([]byte, 1024)
|
||||
n := runtime.Stack(b, false)
|
||||
s := string(b[:n])
|
||||
|
||||
return s
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user