diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index fa3ec9c..948e702 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -78,7 +78,7 @@ func main() { wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database) internal.AssertNoError(err) - statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard, metricsServer) + statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer) internal.AssertNoError(err) cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem) diff --git a/internal/adapters/database.go b/internal/adapters/database.go index 447b452..f4a778c 100644 --- a/internal/adapters/database.go +++ b/internal/adapters/database.go @@ -4,15 +4,16 @@ import ( "context" "errors" "fmt" - "github.com/sirupsen/logrus" - "gorm.io/gorm/clause" - "gorm.io/gorm/logger" - "gorm.io/gorm/utils" "os" "path/filepath" "strings" "time" + "github.com/sirupsen/logrus" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/utils" + "github.com/glebarez/sqlite" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" @@ -204,7 +205,8 @@ func (r *SqlRepo) preCheck() error { return nil // we probably don't have a V1 database =) } - return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first", lastVersion.Version) + return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first", + lastVersion.Version) } func (r *SqlRepo) migrate() error { @@ -249,7 +251,11 @@ func (r *SqlRepo) GetInterface(ctx context.Context, id domain.InterfaceIdentifie return &in, nil } -func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) { +func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) ( + *domain.Interface, + []domain.Peer, + error, +) { in, err := r.GetInterface(ctx, id) if err != nil { return nil, nil, fmt.Errorf("failed to load interface: %w", err) @@ -305,7 +311,11 @@ func (r *SqlRepo) FindInterfaces(ctx context.Context, search string) ([]domain.I return users, nil } -func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error { +func (r *SqlRepo) SaveInterface( + ctx context.Context, + id domain.InterfaceIdentifier, + updateFunc func(in *domain.Interface) (*domain.Interface, error), +) error { userInfo := domain.GetUserInfo(ctx) err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { in, err := r.getOrCreateInterface(userInfo, tx, id) @@ -333,7 +343,11 @@ func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifi return nil } -func (r *SqlRepo) getOrCreateInterface(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.Interface, error) { +func (r *SqlRepo) getOrCreateInterface( + ui *domain.ContextUserInfo, + tx *gorm.DB, + id domain.InterfaceIdentifier, +) (*domain.Interface, error) { var in domain.Interface // interfaceDefaults will be applied to newly created interface records @@ -449,7 +463,10 @@ func (r *SqlRepo) GetInterfacePeers(ctx context.Context, id domain.InterfaceIden return peers, nil } -func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error) { +func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ( + []domain.Peer, + error, +) { var peers []domain.Peer searchValue := "%" + strings.ToLower(search) + "%" @@ -492,7 +509,11 @@ func (r *SqlRepo) FindUserPeers(ctx context.Context, id domain.UserIdentifier, s return peers, nil } -func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error { +func (r *SqlRepo) SavePeer( + ctx context.Context, + id domain.PeerIdentifier, + updateFunc func(in *domain.Peer) (*domain.Peer, error), +) error { userInfo := domain.GetUserInfo(ctx) err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { peer, err := r.getOrCreatePeer(userInfo, tx, id) @@ -520,7 +541,10 @@ func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, update return nil } -func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (*domain.Peer, error) { +func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) ( + *domain.Peer, + error, +) { var peer domain.Peer // interfaceDefaults will be applied to newly created interface records @@ -601,7 +625,10 @@ func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]d return result, nil } -func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) { +func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) ( + map[domain.Cidr][]domain.Cidr, + error, +) { var peerIps []struct { domain.Cidr PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"` @@ -699,7 +726,11 @@ func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User, return users, nil } -func (r *SqlRepo) SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error { +func (r *SqlRepo) SaveUser( + ctx context.Context, + id domain.UserIdentifier, + updateFunc func(u *domain.User) (*domain.User, error), +) error { userInfo := domain.GetUserInfo(ctx) err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { @@ -737,7 +768,10 @@ func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) erro return nil } -func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (*domain.User, error) { +func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) ( + *domain.User, + error, +) { var user domain.User // userDefaults will be applied to newly created user records @@ -777,7 +811,11 @@ func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *doma // region statistics -func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error { +func (r *SqlRepo) UpdateInterfaceStatus( + ctx context.Context, + id domain.InterfaceIdentifier, + updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error), +) error { err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { in, err := r.getOrCreateInterfaceStatus(tx, id) if err != nil { @@ -804,7 +842,10 @@ func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.Interface return nil } -func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.InterfaceStatus, error) { +func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) ( + *domain.InterfaceStatus, + error, +) { var in domain.InterfaceStatus // defaults will be applied to newly created record @@ -830,7 +871,11 @@ func (r *SqlRepo) upsertInterfaceStatus(tx *gorm.DB, in *domain.InterfaceStatus) return nil } -func (r *SqlRepo) UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error { +func (r *SqlRepo) UpdatePeerStatus( + ctx context.Context, + id domain.PeerIdentifier, + updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error), +) error { err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { in, err := r.getOrCreatePeerStatus(tx, id) if err != nil { @@ -883,6 +928,15 @@ func (r *SqlRepo) upsertPeerStatus(tx *gorm.DB, in *domain.PeerStatus) error { return nil } +func (r *SqlRepo) DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error { + err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{}, id).Error + if err != nil { + return err + } + + return nil +} + // endregion statistics // region audit diff --git a/internal/app/eventbus.go b/internal/app/eventbus.go index 8dfe5fd..064c2cc 100644 --- a/internal/app/eventbus.go +++ b/internal/app/eventbus.go @@ -10,3 +10,4 @@ const TopicRouteUpdate = "route:update" const TopicRouteRemove = "route:remove" const TopicInterfaceUpdated = "interface:updated" const TopicPeerInterfaceUpdated = "peer:interface:updated" +const TopicPeerIdentifierUpdated = "peer:identifier:updated" diff --git a/internal/app/wireguard/repos.go b/internal/app/wireguard/repos.go index 83c4bce..67d88e1 100644 --- a/internal/app/wireguard/repos.go +++ b/internal/app/wireguard/repos.go @@ -13,13 +13,21 @@ type InterfaceAndPeerDatabaseRepo interface { GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error) GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) - SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error + SaveInterface( + ctx context.Context, + id domain.InterfaceIdentifier, + updateFunc func(in *domain.Interface) (*domain.Interface, error), + ) error DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error) - SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error + SavePeer( + ctx context.Context, + id domain.PeerIdentifier, + updateFunc func(in *domain.Peer) (*domain.Peer, error), + ) error DeletePeer(ctx context.Context, id domain.PeerIdentifier) error GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) @@ -30,18 +38,40 @@ type StatisticsDatabaseRepo interface { GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) - UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error - UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error + UpdatePeerStatus( + ctx context.Context, + id domain.PeerIdentifier, + updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error), + ) error + UpdateInterfaceStatus( + ctx context.Context, + id domain.InterfaceIdentifier, + updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error), + ) error + + DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error } type InterfaceController interface { GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) - GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) - SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error + GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) ( + *domain.PhysicalPeer, + error, + ) + SaveInterface( + _ context.Context, + id domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), + ) error DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error - SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error + SavePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), + ) error DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error } diff --git a/internal/app/wireguard/statistics.go b/internal/app/wireguard/statistics.go index 96ecc68..a9e2983 100644 --- a/internal/app/wireguard/statistics.go +++ b/internal/app/wireguard/statistics.go @@ -5,14 +5,17 @@ import ( "sync" "time" + "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" probing "github.com/prometheus-community/pro-bing" "github.com/sirupsen/logrus" + evbus "github.com/vardius/message-bus" ) type StatisticsCollector struct { cfg *config.Config + bus evbus.MessageBus pingWaitGroup sync.WaitGroup pingJobs chan domain.Peer @@ -22,14 +25,25 @@ type StatisticsCollector struct { ms MetricsServer } -func NewStatisticsCollector(cfg *config.Config, db StatisticsDatabaseRepo, wg InterfaceController, ms MetricsServer) (*StatisticsCollector, error) { - return &StatisticsCollector{ +func NewStatisticsCollector( + cfg *config.Config, + bus evbus.MessageBus, + db StatisticsDatabaseRepo, + wg InterfaceController, + ms MetricsServer, +) (*StatisticsCollector, error) { + c := &StatisticsCollector{ cfg: cfg, + bus: bus, db: db, wg: wg, ms: ms, - }, nil + } + + c.connectToMessageBus() + + return c, nil } func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) { @@ -69,16 +83,17 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) { logrus.Warnf("failed to load physical interface %s for data collection: %v", in.Identifier, err) continue } - err = c.db.UpdateInterfaceStatus(ctx, in.Identifier, func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) { - i.UpdatedAt = time.Now() - i.BytesReceived = physicalInterface.BytesDownload - i.BytesTransmitted = physicalInterface.BytesUpload + err = c.db.UpdateInterfaceStatus(ctx, in.Identifier, + func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) { + i.UpdatedAt = time.Now() + i.BytesReceived = physicalInterface.BytesDownload + i.BytesTransmitted = physicalInterface.BytesUpload - // Update prometheus metrics - go c.updateInterfaceMetrics(*i) + // Update prometheus metrics + go c.updateInterfaceMetrics(*i) - return i, nil - }) + return i, nil + }) if err != nil { logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err) } @@ -120,36 +135,43 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) { continue } for _, peer := range peers { - err = c.db.UpdatePeerStatus(ctx, peer.Identifier, func(p *domain.PeerStatus) (*domain.PeerStatus, error) { - var lastHandshake *time.Time - if !peer.LastHandshake.IsZero() { - lastHandshake = &peer.LastHandshake - } + err = c.db.UpdatePeerStatus(ctx, peer.Identifier, + func(p *domain.PeerStatus) (*domain.PeerStatus, error) { + var lastHandshake *time.Time + if !peer.LastHandshake.IsZero() { + lastHandshake = &peer.LastHandshake + } - // calculate if session was restarted - p.UpdatedAt = time.Now() - p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload, lastHandshake) - p.BytesReceived = peer.BytesUpload // store bytes that where uploaded from the peer and received by the server - p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server - p.Endpoint = peer.Endpoint - p.LastHandshake = lastHandshake + // calculate if session was restarted + p.UpdatedAt = time.Now() + p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload, + lastHandshake) + p.BytesReceived = peer.BytesUpload // store bytes that where uploaded from the peer and received by the server + p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server + p.Endpoint = peer.Endpoint + p.LastHandshake = lastHandshake - // Update prometheus metrics - go c.updatePeerMetrics(ctx, *p) + // Update prometheus metrics + go c.updatePeerMetrics(ctx, *p) - return p, nil - }) + return p, nil + }) if err != nil { - logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err) + logrus.Warnf("failed to update peer status for %s: %v", peer.Identifier, err) + } else { + logrus.Tracef("updated peer status for %s", peer.Identifier) } - logrus.Tracef("updated peer status for %s", peer.Identifier) } } } } } -func getSessionStartTime(oldStats domain.PeerStatus, newReceived, newTransmitted uint64, latestHandshake *time.Time) *time.Time { +func getSessionStartTime( + oldStats domain.PeerStatus, + newReceived, newTransmitted uint64, + latestHandshake *time.Time, +) *time.Time { if latestHandshake == nil { return nil // currently not connected } @@ -242,6 +264,28 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) { for peer := range c.pingJobs { peerPingable := c.isPeerPingable(ctx, peer) logrus.Tracef("peer %s pingable: %t", peer.Identifier, peerPingable) + + now := time.Now() + err := c.db.UpdatePeerStatus(ctx, peer.Identifier, + func(p *domain.PeerStatus) (*domain.PeerStatus, error) { + if peerPingable { + p.IsPingable = true + p.LastPing = &now + } else { + p.IsPingable = false + p.LastPing = nil + } + + // Update prometheus metrics + go c.updatePeerMetrics(ctx, *p) + + return p, nil + }) + if err != nil { + logrus.Warnf("failed to update peer ping status for %s: %v", peer.Identifier, err) + } else { + logrus.Tracef("updated peer ping status for %s", peer.Identifier) + } } } @@ -257,7 +301,7 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe pinger, err := probing.NewPinger(checkAddr) if err != nil { - logrus.Tracef("failed to instatiate pinger for %s: %v", checkAddr, err) + logrus.Tracef("failed to instatiate pinger for %s (%s): %v", peer.Identifier, checkAddr, err) return false } @@ -267,7 +311,7 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe pinger.Timeout = 2 * time.Second err = pinger.RunWithContext(ctx) // Blocks until finished. if err != nil { - logrus.Tracef("pinger for %s exited unexpectedly: %v", checkAddr, err) + logrus.Tracef("pinger for peer %s (%s) exited unexpectedly: %v", peer.Identifier, checkAddr, err) return false } stats := pinger.Statistics() @@ -287,3 +331,18 @@ func (c *StatisticsCollector) updatePeerMetrics(ctx context.Context, status doma } c.ms.UpdatePeerMetrics(peer, status) } + +func (c *StatisticsCollector) connectToMessageBus() { + _ = c.bus.Subscribe(app.TopicPeerIdentifierUpdated, c.handlePeerIdentifierChangeEvent) +} + +func (c *StatisticsCollector) handlePeerIdentifierChangeEvent(oldIdentifier, newIdentifier domain.PeerIdentifier) { + ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) + + // remove potential left-over status data + err := c.db.DeletePeerStatus(ctx, oldIdentifier) + if err != nil { + logrus.Errorf("failed to delete old peer status for migrated peer, %s -> %s: %v", + oldIdentifier, newIdentifier, err) + } +} diff --git a/internal/app/wireguard/statistics_test.go b/internal/app/wireguard/statistics_test.go index 23840e8..dc1ba43 100644 --- a/internal/app/wireguard/statistics_test.go +++ b/internal/app/wireguard/statistics_test.go @@ -1,10 +1,11 @@ package wireguard import ( - "github.com/h44z/wg-portal/internal/domain" "reflect" "testing" "time" + + "github.com/h44z/wg-portal/internal/domain" ) func Test_getSessionStartTime(t *testing.T) { @@ -66,7 +67,9 @@ func Test_getSessionStartTime(t *testing.T) { { name: "still connected", args: args{ - oldStats: domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 10, BytesTransmitted: 10}, + oldStats: domain.PeerStatus{ + LastSessionStart: &nowMinus1, BytesReceived: 10, BytesTransmitted: 10, + }, newReceived: 100, newTransmitted: 100, lastHandshake: &now, @@ -76,7 +79,9 @@ func Test_getSessionStartTime(t *testing.T) { { name: "no longer connected", args: args{ - oldStats: domain.PeerStatus{LastSessionStart: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100}, + oldStats: domain.PeerStatus{ + LastSessionStart: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100, + }, newReceived: 100, newTransmitted: 100, lastHandshake: &nowMinus3, @@ -116,7 +121,9 @@ func Test_getSessionStartTime(t *testing.T) { { name: "reconnect (sent)", args: args{ - oldStats: domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100}, + oldStats: domain.PeerStatus{ + LastSessionStart: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100, + }, newReceived: 100, newTransmitted: 10, lastHandshake: &now, @@ -126,7 +133,8 @@ func Test_getSessionStartTime(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := getSessionStartTime(tt.args.oldStats, tt.args.newReceived, tt.args.newTransmitted, tt.args.lastHandshake); !reflect.DeepEqual(got, tt.want) { + if got := getSessionStartTime(tt.args.oldStats, tt.args.newReceived, tt.args.newTransmitted, + tt.args.lastHandshake); !reflect.DeepEqual(got, tt.want) { t.Errorf("getSessionStartTime() = %v, want %v", got, tt.want) } }) diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index 38c7539..4eeb88b 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -230,9 +230,31 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee return nil, fmt.Errorf("update not allowed: %w", err) } - err = m.savePeers(ctx, peer) - if err != nil { - return nil, fmt.Errorf("update failure: %w", err) + // handle peer identifier change (new public key) + if existingPeer.Identifier != domain.PeerIdentifier(peer.Interface.PublicKey) { + peer.Identifier = domain.PeerIdentifier(peer.Interface.PublicKey) // set new identifier + + // delete old peer + err = m.DeletePeer(ctx, existingPeer.Identifier) + if err != nil { + return nil, fmt.Errorf("failed to delete old peer %s for %s: %w", + existingPeer.Identifier, peer.Identifier, err) + } + + // save new peer + err = m.savePeers(ctx, peer) + if err != nil { + return nil, fmt.Errorf("update failure for re-identified peer %s (was %s): %w", + peer.Identifier, existingPeer.Identifier, err) + } + + // publish event + m.bus.Publish(app.TopicPeerIdentifierUpdated, existingPeer.Identifier, peer.Identifier) + } else { // normal update + err = m.savePeers(ctx, peer) + if err != nil { + return nil, fmt.Errorf("update failure: %w", err) + } } return peer, nil