mirror of
https://github.com/h44z/wg-portal
synced 2025-02-26 05:49:14 +00:00
fix change of peer identifier (public key) (#265)
This commit is contained in:
parent
6d86f15ff8
commit
3020fbca4e
@ -78,7 +78,7 @@ func main() {
|
||||
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
|
||||
internal.AssertNoError(err)
|
||||
|
||||
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard, metricsServer)
|
||||
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer)
|
||||
internal.AssertNoError(err)
|
||||
|
||||
cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem)
|
||||
|
@ -4,15 +4,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/utils"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/utils"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
@ -204,7 +205,8 @@ func (r *SqlRepo) preCheck() error {
|
||||
return nil // we probably don't have a V1 database =)
|
||||
}
|
||||
|
||||
return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first", lastVersion.Version)
|
||||
return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first",
|
||||
lastVersion.Version)
|
||||
}
|
||||
|
||||
func (r *SqlRepo) migrate() error {
|
||||
@ -249,7 +251,11 @@ func (r *SqlRepo) GetInterface(ctx context.Context, id domain.InterfaceIdentifie
|
||||
return &in, nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
|
||||
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.Interface,
|
||||
[]domain.Peer,
|
||||
error,
|
||||
) {
|
||||
in, err := r.GetInterface(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to load interface: %w", err)
|
||||
@ -305,7 +311,11 @@ func (r *SqlRepo) FindInterfaces(ctx context.Context, search string) ([]domain.I
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error {
|
||||
func (r *SqlRepo) SaveInterface(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(in *domain.Interface) (*domain.Interface, error),
|
||||
) error {
|
||||
userInfo := domain.GetUserInfo(ctx)
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
in, err := r.getOrCreateInterface(userInfo, tx, id)
|
||||
@ -333,7 +343,11 @@ func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifi
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) getOrCreateInterface(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.Interface, error) {
|
||||
func (r *SqlRepo) getOrCreateInterface(
|
||||
ui *domain.ContextUserInfo,
|
||||
tx *gorm.DB,
|
||||
id domain.InterfaceIdentifier,
|
||||
) (*domain.Interface, error) {
|
||||
var in domain.Interface
|
||||
|
||||
// interfaceDefaults will be applied to newly created interface records
|
||||
@ -449,7 +463,10 @@ func (r *SqlRepo) GetInterfacePeers(ctx context.Context, id domain.InterfaceIden
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error) {
|
||||
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) (
|
||||
[]domain.Peer,
|
||||
error,
|
||||
) {
|
||||
var peers []domain.Peer
|
||||
|
||||
searchValue := "%" + strings.ToLower(search) + "%"
|
||||
@ -492,7 +509,11 @@ func (r *SqlRepo) FindUserPeers(ctx context.Context, id domain.UserIdentifier, s
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error {
|
||||
func (r *SqlRepo) SavePeer(
|
||||
ctx context.Context,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(in *domain.Peer) (*domain.Peer, error),
|
||||
) error {
|
||||
userInfo := domain.GetUserInfo(ctx)
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
peer, err := r.getOrCreatePeer(userInfo, tx, id)
|
||||
@ -520,7 +541,10 @@ func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, update
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (*domain.Peer, error) {
|
||||
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (
|
||||
*domain.Peer,
|
||||
error,
|
||||
) {
|
||||
var peer domain.Peer
|
||||
|
||||
// interfaceDefaults will be applied to newly created interface records
|
||||
@ -601,7 +625,10 @@ func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]d
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) {
|
||||
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (
|
||||
map[domain.Cidr][]domain.Cidr,
|
||||
error,
|
||||
) {
|
||||
var peerIps []struct {
|
||||
domain.Cidr
|
||||
PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"`
|
||||
@ -699,7 +726,11 @@ func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User,
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error {
|
||||
func (r *SqlRepo) SaveUser(
|
||||
ctx context.Context,
|
||||
id domain.UserIdentifier,
|
||||
updateFunc func(u *domain.User) (*domain.User, error),
|
||||
) error {
|
||||
userInfo := domain.GetUserInfo(ctx)
|
||||
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
@ -737,7 +768,10 @@ func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (*domain.User, error) {
|
||||
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (
|
||||
*domain.User,
|
||||
error,
|
||||
) {
|
||||
var user domain.User
|
||||
|
||||
// userDefaults will be applied to newly created user records
|
||||
@ -777,7 +811,11 @@ func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *doma
|
||||
|
||||
// region statistics
|
||||
|
||||
func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error {
|
||||
func (r *SqlRepo) UpdateInterfaceStatus(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
|
||||
) error {
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
in, err := r.getOrCreateInterfaceStatus(tx, id)
|
||||
if err != nil {
|
||||
@ -804,7 +842,10 @@ func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.Interface
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.InterfaceStatus, error) {
|
||||
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (
|
||||
*domain.InterfaceStatus,
|
||||
error,
|
||||
) {
|
||||
var in domain.InterfaceStatus
|
||||
|
||||
// defaults will be applied to newly created record
|
||||
@ -830,7 +871,11 @@ func (r *SqlRepo) upsertInterfaceStatus(tx *gorm.DB, in *domain.InterfaceStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error {
|
||||
func (r *SqlRepo) UpdatePeerStatus(
|
||||
ctx context.Context,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
|
||||
) error {
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
in, err := r.getOrCreatePeerStatus(tx, id)
|
||||
if err != nil {
|
||||
@ -883,6 +928,15 @@ func (r *SqlRepo) upsertPeerStatus(tx *gorm.DB, in *domain.PeerStatus) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error {
|
||||
err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{}, id).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion statistics
|
||||
|
||||
// region audit
|
||||
|
@ -10,3 +10,4 @@ const TopicRouteUpdate = "route:update"
|
||||
const TopicRouteRemove = "route:remove"
|
||||
const TopicInterfaceUpdated = "interface:updated"
|
||||
const TopicPeerInterfaceUpdated = "peer:interface:updated"
|
||||
const TopicPeerIdentifierUpdated = "peer:identifier:updated"
|
||||
|
@ -13,13 +13,21 @@ type InterfaceAndPeerDatabaseRepo interface {
|
||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||
FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error)
|
||||
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
|
||||
SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error
|
||||
SaveInterface(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(in *domain.Interface) (*domain.Interface, error),
|
||||
) error
|
||||
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error)
|
||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||
FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error)
|
||||
SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error
|
||||
SavePeer(
|
||||
ctx context.Context,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(in *domain.Peer) (*domain.Peer, error),
|
||||
) error
|
||||
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
||||
@ -30,18 +38,40 @@ type StatisticsDatabaseRepo interface {
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
|
||||
UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error
|
||||
UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error
|
||||
UpdatePeerStatus(
|
||||
ctx context.Context,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
|
||||
) error
|
||||
UpdateInterfaceStatus(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
|
||||
) error
|
||||
|
||||
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type InterfaceController interface {
|
||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
|
||||
SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
|
||||
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
|
||||
*domain.PhysicalPeer,
|
||||
error,
|
||||
)
|
||||
SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
||||
SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
|
@ -5,14 +5,17 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
probing "github.com/prometheus-community/pro-bing"
|
||||
"github.com/sirupsen/logrus"
|
||||
evbus "github.com/vardius/message-bus"
|
||||
)
|
||||
|
||||
type StatisticsCollector struct {
|
||||
cfg *config.Config
|
||||
bus evbus.MessageBus
|
||||
|
||||
pingWaitGroup sync.WaitGroup
|
||||
pingJobs chan domain.Peer
|
||||
@ -22,14 +25,25 @@ type StatisticsCollector struct {
|
||||
ms MetricsServer
|
||||
}
|
||||
|
||||
func NewStatisticsCollector(cfg *config.Config, db StatisticsDatabaseRepo, wg InterfaceController, ms MetricsServer) (*StatisticsCollector, error) {
|
||||
return &StatisticsCollector{
|
||||
func NewStatisticsCollector(
|
||||
cfg *config.Config,
|
||||
bus evbus.MessageBus,
|
||||
db StatisticsDatabaseRepo,
|
||||
wg InterfaceController,
|
||||
ms MetricsServer,
|
||||
) (*StatisticsCollector, error) {
|
||||
c := &StatisticsCollector{
|
||||
cfg: cfg,
|
||||
bus: bus,
|
||||
|
||||
db: db,
|
||||
wg: wg,
|
||||
ms: ms,
|
||||
}, nil
|
||||
}
|
||||
|
||||
c.connectToMessageBus()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) {
|
||||
@ -69,16 +83,17 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
|
||||
logrus.Warnf("failed to load physical interface %s for data collection: %v", in.Identifier, err)
|
||||
continue
|
||||
}
|
||||
err = c.db.UpdateInterfaceStatus(ctx, in.Identifier, func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) {
|
||||
i.UpdatedAt = time.Now()
|
||||
i.BytesReceived = physicalInterface.BytesDownload
|
||||
i.BytesTransmitted = physicalInterface.BytesUpload
|
||||
err = c.db.UpdateInterfaceStatus(ctx, in.Identifier,
|
||||
func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) {
|
||||
i.UpdatedAt = time.Now()
|
||||
i.BytesReceived = physicalInterface.BytesDownload
|
||||
i.BytesTransmitted = physicalInterface.BytesUpload
|
||||
|
||||
// Update prometheus metrics
|
||||
go c.updateInterfaceMetrics(*i)
|
||||
// Update prometheus metrics
|
||||
go c.updateInterfaceMetrics(*i)
|
||||
|
||||
return i, nil
|
||||
})
|
||||
return i, nil
|
||||
})
|
||||
if err != nil {
|
||||
logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err)
|
||||
}
|
||||
@ -120,36 +135,43 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
for _, peer := range peers {
|
||||
err = c.db.UpdatePeerStatus(ctx, peer.Identifier, func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
|
||||
var lastHandshake *time.Time
|
||||
if !peer.LastHandshake.IsZero() {
|
||||
lastHandshake = &peer.LastHandshake
|
||||
}
|
||||
err = c.db.UpdatePeerStatus(ctx, peer.Identifier,
|
||||
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
|
||||
var lastHandshake *time.Time
|
||||
if !peer.LastHandshake.IsZero() {
|
||||
lastHandshake = &peer.LastHandshake
|
||||
}
|
||||
|
||||
// calculate if session was restarted
|
||||
p.UpdatedAt = time.Now()
|
||||
p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload, lastHandshake)
|
||||
p.BytesReceived = peer.BytesUpload // store bytes that where uploaded from the peer and received by the server
|
||||
p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
|
||||
p.Endpoint = peer.Endpoint
|
||||
p.LastHandshake = lastHandshake
|
||||
// calculate if session was restarted
|
||||
p.UpdatedAt = time.Now()
|
||||
p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload,
|
||||
lastHandshake)
|
||||
p.BytesReceived = peer.BytesUpload // store bytes that where uploaded from the peer and received by the server
|
||||
p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
|
||||
p.Endpoint = peer.Endpoint
|
||||
p.LastHandshake = lastHandshake
|
||||
|
||||
// Update prometheus metrics
|
||||
go c.updatePeerMetrics(ctx, *p)
|
||||
// Update prometheus metrics
|
||||
go c.updatePeerMetrics(ctx, *p)
|
||||
|
||||
return p, nil
|
||||
})
|
||||
return p, nil
|
||||
})
|
||||
if err != nil {
|
||||
logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err)
|
||||
logrus.Warnf("failed to update peer status for %s: %v", peer.Identifier, err)
|
||||
} else {
|
||||
logrus.Tracef("updated peer status for %s", peer.Identifier)
|
||||
}
|
||||
logrus.Tracef("updated peer status for %s", peer.Identifier)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getSessionStartTime(oldStats domain.PeerStatus, newReceived, newTransmitted uint64, latestHandshake *time.Time) *time.Time {
|
||||
func getSessionStartTime(
|
||||
oldStats domain.PeerStatus,
|
||||
newReceived, newTransmitted uint64,
|
||||
latestHandshake *time.Time,
|
||||
) *time.Time {
|
||||
if latestHandshake == nil {
|
||||
return nil // currently not connected
|
||||
}
|
||||
@ -242,6 +264,28 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
|
||||
for peer := range c.pingJobs {
|
||||
peerPingable := c.isPeerPingable(ctx, peer)
|
||||
logrus.Tracef("peer %s pingable: %t", peer.Identifier, peerPingable)
|
||||
|
||||
now := time.Now()
|
||||
err := c.db.UpdatePeerStatus(ctx, peer.Identifier,
|
||||
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
|
||||
if peerPingable {
|
||||
p.IsPingable = true
|
||||
p.LastPing = &now
|
||||
} else {
|
||||
p.IsPingable = false
|
||||
p.LastPing = nil
|
||||
}
|
||||
|
||||
// Update prometheus metrics
|
||||
go c.updatePeerMetrics(ctx, *p)
|
||||
|
||||
return p, nil
|
||||
})
|
||||
if err != nil {
|
||||
logrus.Warnf("failed to update peer ping status for %s: %v", peer.Identifier, err)
|
||||
} else {
|
||||
logrus.Tracef("updated peer ping status for %s", peer.Identifier)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -257,7 +301,7 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
|
||||
|
||||
pinger, err := probing.NewPinger(checkAddr)
|
||||
if err != nil {
|
||||
logrus.Tracef("failed to instatiate pinger for %s: %v", checkAddr, err)
|
||||
logrus.Tracef("failed to instatiate pinger for %s (%s): %v", peer.Identifier, checkAddr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
@ -267,7 +311,7 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
|
||||
pinger.Timeout = 2 * time.Second
|
||||
err = pinger.RunWithContext(ctx) // Blocks until finished.
|
||||
if err != nil {
|
||||
logrus.Tracef("pinger for %s exited unexpectedly: %v", checkAddr, err)
|
||||
logrus.Tracef("pinger for peer %s (%s) exited unexpectedly: %v", peer.Identifier, checkAddr, err)
|
||||
return false
|
||||
}
|
||||
stats := pinger.Statistics()
|
||||
@ -287,3 +331,18 @@ func (c *StatisticsCollector) updatePeerMetrics(ctx context.Context, status doma
|
||||
}
|
||||
c.ms.UpdatePeerMetrics(peer, status)
|
||||
}
|
||||
|
||||
func (c *StatisticsCollector) connectToMessageBus() {
|
||||
_ = c.bus.Subscribe(app.TopicPeerIdentifierUpdated, c.handlePeerIdentifierChangeEvent)
|
||||
}
|
||||
|
||||
func (c *StatisticsCollector) handlePeerIdentifierChangeEvent(oldIdentifier, newIdentifier domain.PeerIdentifier) {
|
||||
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
|
||||
|
||||
// remove potential left-over status data
|
||||
err := c.db.DeletePeerStatus(ctx, oldIdentifier)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to delete old peer status for migrated peer, %s -> %s: %v",
|
||||
oldIdentifier, newIdentifier, err)
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,11 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
func Test_getSessionStartTime(t *testing.T) {
|
||||
@ -66,7 +67,9 @@ func Test_getSessionStartTime(t *testing.T) {
|
||||
{
|
||||
name: "still connected",
|
||||
args: args{
|
||||
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 10, BytesTransmitted: 10},
|
||||
oldStats: domain.PeerStatus{
|
||||
LastSessionStart: &nowMinus1, BytesReceived: 10, BytesTransmitted: 10,
|
||||
},
|
||||
newReceived: 100,
|
||||
newTransmitted: 100,
|
||||
lastHandshake: &now,
|
||||
@ -76,7 +79,9 @@ func Test_getSessionStartTime(t *testing.T) {
|
||||
{
|
||||
name: "no longer connected",
|
||||
args: args{
|
||||
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100},
|
||||
oldStats: domain.PeerStatus{
|
||||
LastSessionStart: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100,
|
||||
},
|
||||
newReceived: 100,
|
||||
newTransmitted: 100,
|
||||
lastHandshake: &nowMinus3,
|
||||
@ -116,7 +121,9 @@ func Test_getSessionStartTime(t *testing.T) {
|
||||
{
|
||||
name: "reconnect (sent)",
|
||||
args: args{
|
||||
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100},
|
||||
oldStats: domain.PeerStatus{
|
||||
LastSessionStart: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100,
|
||||
},
|
||||
newReceived: 100,
|
||||
newTransmitted: 10,
|
||||
lastHandshake: &now,
|
||||
@ -126,7 +133,8 @@ func Test_getSessionStartTime(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getSessionStartTime(tt.args.oldStats, tt.args.newReceived, tt.args.newTransmitted, tt.args.lastHandshake); !reflect.DeepEqual(got, tt.want) {
|
||||
if got := getSessionStartTime(tt.args.oldStats, tt.args.newReceived, tt.args.newTransmitted,
|
||||
tt.args.lastHandshake); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("getSessionStartTime() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
|
@ -230,9 +230,31 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
||||
return nil, fmt.Errorf("update not allowed: %w", err)
|
||||
}
|
||||
|
||||
err = m.savePeers(ctx, peer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failure: %w", err)
|
||||
// handle peer identifier change (new public key)
|
||||
if existingPeer.Identifier != domain.PeerIdentifier(peer.Interface.PublicKey) {
|
||||
peer.Identifier = domain.PeerIdentifier(peer.Interface.PublicKey) // set new identifier
|
||||
|
||||
// delete old peer
|
||||
err = m.DeletePeer(ctx, existingPeer.Identifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete old peer %s for %s: %w",
|
||||
existingPeer.Identifier, peer.Identifier, err)
|
||||
}
|
||||
|
||||
// save new peer
|
||||
err = m.savePeers(ctx, peer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failure for re-identified peer %s (was %s): %w",
|
||||
peer.Identifier, existingPeer.Identifier, err)
|
||||
}
|
||||
|
||||
// publish event
|
||||
m.bus.Publish(app.TopicPeerIdentifierUpdated, existingPeer.Identifier, peer.Identifier)
|
||||
} else { // normal update
|
||||
err = m.savePeers(ctx, peer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failure: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
|
Loading…
Reference in New Issue
Block a user