diff --git a/internal/adapters/database.go b/internal/adapters/database.go index 8ddc31f..447b452 100644 --- a/internal/adapters/database.go +++ b/internal/adapters/database.go @@ -601,7 +601,7 @@ func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]d return result, nil } -func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context) (map[domain.Cidr][]domain.Cidr, error) { +func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) { var peerIps []struct { domain.Cidr PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"` @@ -628,14 +628,26 @@ func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context) (map[domain.Cidr][]do return nil, fmt.Errorf("failed to fetch interface IP's: %w", err) } - result := make(map[domain.Cidr][]domain.Cidr) + result := make(map[domain.Cidr][]domain.Cidr, len(subnets)) for _, ip := range interfaceIps { - networkAddr := ip.Cidr.NetworkAddr() - result[networkAddr] = append(result[networkAddr], ip.Cidr) + var subnet domain.Cidr // default empty subnet (if no subnet matches, we will add the IP to the empty subnet group) + for _, s := range subnets { + if s.Contains(ip.Cidr) { + subnet = s + break + } + } + result[subnet] = append(result[subnet], ip.Cidr) } for _, ip := range peerIps { - networkAddr := ip.Cidr.NetworkAddr() - result[networkAddr] = append(result[networkAddr], ip.Cidr) + var subnet domain.Cidr // default empty subnet (if no subnet matches, we will add the IP to the empty subnet group) + for _, s := range subnets { + if s.Contains(ip.Cidr) { + subnet = s + break + } + } + result[subnet] = append(result[subnet], ip.Cidr) } return result, nil } diff --git a/internal/app/wireguard/repos.go b/internal/app/wireguard/repos.go index 00b0b4b..3ae2ab1 100644 --- a/internal/app/wireguard/repos.go +++ b/internal/app/wireguard/repos.go @@ -21,7 +21,7 @@ type InterfaceAndPeerDatabaseRepo interface { SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error DeletePeer(ctx context.Context, id domain.PeerIdentifier) error GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) - GetUsedIpsPerSubnet(ctx context.Context) (map[domain.Cidr][]domain.Cidr, error) + GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) } type StatisticsDatabaseRepo interface { diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index 2aa4958..d4b09ee 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -292,13 +292,17 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { } func (m Manager) getFreshPeerIpConfig(ctx context.Context, iface *domain.Interface) (ips []domain.Cidr, err error) { + if iface.PeerDefNetworkStr == "" { + return []domain.Cidr{}, nil // cannot suggest new ip addresses if there is no subnet + } + networks, err := domain.CidrsFromString(iface.PeerDefNetworkStr) if err != nil { err = fmt.Errorf("failed to parse default network address: %w", err) return } - existingIps, err := m.db.GetUsedIpsPerSubnet(ctx) + existingIps, err := m.db.GetUsedIpsPerSubnet(ctx, networks) if err != nil { err = fmt.Errorf("failed to get existing IP addresses: %w", err) return diff --git a/internal/domain/ip.go b/internal/domain/ip.go index 35b6e75..1d48a68 100644 --- a/internal/domain/ip.go +++ b/internal/domain/ip.go @@ -199,3 +199,10 @@ func CidrsToStringSlice(slice []Cidr) []string { return cidrs } + +func (c Cidr) Contains(other Cidr) bool { + _, subnet, _ := net.ParseCIDR(c.String()) + otherIP, _, _ := net.ParseCIDR(other.String()) + + return subnet.Contains(otherIP) +}