From 2ba1e0f7ddfff777a99bf36c175e70d800d1e782 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 15 Apr 2024 15:11:09 +0200 Subject: [PATCH] add backfill ip function Updates #614 Signed-off-by: Kristoffer Dalby --- hscontrol/db/ip.go | 68 +++++++++++++++ hscontrol/db/ip_test.go | 183 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+) diff --git a/hscontrol/db/ip.go b/hscontrol/db/ip.go index 965baf3d85f..4e979dccb53 100644 --- a/hscontrol/db/ip.go +++ b/hscontrol/db/ip.go @@ -239,3 +239,71 @@ func randomNext(pfx netip.Prefix) (netip.Addr, error) { return ip, nil } + +// BackfillNodeIPs will take a database transaction, and +// iterate through all of the current nodes in headscale +// and ensure it has IP addresses according to the current +// configuration. +// This means that if both IPv4 and IPv6 is set in the +// config, and some nodes are missing that type of IP, +// it will be added. +// If a prefix type has been removed (IPv4 or IPv6), it +// will remove the IPs in that family from the node. +func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) error { + return db.Write(func(tx *gorm.DB) error { + if i == nil { + return errors.New("backfilling IPs: ip allocator was nil") + } + + nodes, err := ListNodes(tx) + if err != nil { + return fmt.Errorf("listing nodes to backfill IPs: %w", err) + } + + for _, node := range nodes { + changed := false + // IPv4 prefix is set, but node ip is missing, alloc + if i.prefix4 != nil && node.IPv4 == nil { + ret4, err := i.next(i.prev4, i.prefix4) + if err != nil { + return fmt.Errorf("failed to allocate ipv4 for node(%d): %w", node.ID, err) + } + + node.IPv4 = ret4 + changed = true + } + + // IPv6 prefix is set, but node ip is missing, alloc + if i.prefix6 != nil && node.IPv6 == nil { + ret6, err := i.next(i.prev6, i.prefix6) + if err != nil { + return fmt.Errorf("failed to allocate ipv6 for node(%d): %w", node.ID, err) + } + + node.IPv6 = ret6 + changed = true + } + + // IPv4 prefix is not set, but node has IP, remove + if i.prefix4 == nil && node.IPv4 != nil { + node.IPv4 = nil + changed = true + } + + // IPv6 prefix is not set, but node has IP, remove + if i.prefix6 == nil && node.IPv6 != nil { + node.IPv6 = nil + changed = true + } + + if changed { + err := tx.Save(node).Error + if err != nil { + return fmt.Errorf("saving node(%d) after adding IPs: %w", node.ID, err) + } + } + } + + return nil + }) +} diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index 1c8fb5e4a73..a43f107f2a6 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -7,6 +7,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" ) @@ -18,6 +19,10 @@ var mpp = func(pref string) *netip.Prefix { var na = func(pref string) netip.Addr { return netip.MustParseAddr(pref) } +var nap = func(pref string) *netip.Addr { + n := na(pref) + return &n +} func TestIPAllocatorSequential(t *testing.T) { tests := []struct { @@ -277,3 +282,181 @@ func TestIPAllocatorRandom(t *testing.T) { }) } } + +func TestBackfillIPAddresses(t *testing.T) { + tests := []struct { + name string + dbFunc func() *HSDatabase + + prefix4 *netip.Prefix + prefix6 *netip.Prefix + want types.Nodes + }{ + { + name: "simple-backfill-ipv6", + dbFunc: func() *HSDatabase { + db := dbForTest(t, "simple-backfill-ipv6") + + db.DB.Save(&types.Node{ + IPv4DatabaseField: sql.NullString{ + Valid: true, + String: "100.64.0.1", + }, + }) + + return db + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + want: types.Nodes{ + &types.Node{ + IPv4DatabaseField: sql.NullString{ + Valid: true, + String: "100.64.0.1", + }, + IPv4: nap("100.64.0.1"), + IPv6DatabaseField: sql.NullString{ + Valid: true, + String: "fd7a:115c:a1e0::1", + }, + IPv6: nap("fd7a:115c:a1e0::1"), + }, + }, + }, + { + name: "simple-backfill-ipv4", + dbFunc: func() *HSDatabase { + db := dbForTest(t, "simple-backfill-ipv4") + + db.DB.Save(&types.Node{ + IPv6DatabaseField: sql.NullString{ + Valid: true, + String: "fd7a:115c:a1e0::1", + }, + }) + + return db + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + want: types.Nodes{ + &types.Node{ + IPv4DatabaseField: sql.NullString{ + Valid: true, + String: "100.64.0.1", + }, + IPv4: nap("100.64.0.1"), + IPv6DatabaseField: sql.NullString{ + Valid: true, + String: "fd7a:115c:a1e0::1", + }, + IPv6: nap("fd7a:115c:a1e0::1"), + }, + }, + }, + { + name: "simple-backfill-remove-ipv6", + dbFunc: func() *HSDatabase { + db := dbForTest(t, "simple-backfill-remove-ipv4") + + db.DB.Save(&types.Node{ + IPv4DatabaseField: sql.NullString{ + Valid: true, + String: "100.64.0.1", + }, + IPv6DatabaseField: sql.NullString{ + Valid: true, + String: "fd7a:115c:a1e0::1", + }, + }) + + return db + }, + + prefix4: mpp("100.64.0.0/10"), + + want: types.Nodes{ + &types.Node{ + IPv4DatabaseField: sql.NullString{ + Valid: true, + String: "100.64.0.1", + }, + IPv4: nap("100.64.0.1"), + }, + }, + }, + { + name: "simple-backfill-remove-ipv6", + dbFunc: func() *HSDatabase { + db := dbForTest(t, "simple-backfill-remove-ipv6") + + db.DB.Save(&types.Node{ + IPv4DatabaseField: sql.NullString{ + Valid: true, + String: "100.64.0.1", + }, + IPv6DatabaseField: sql.NullString{ + Valid: true, + String: "fd7a:115c:a1e0::1", + }, + }) + + return db + }, + + prefix6: mpp("fd7a:115c:a1e0::/48"), + + want: types.Nodes{ + &types.Node{ + IPv6DatabaseField: sql.NullString{ + Valid: true, + String: "fd7a:115c:a1e0::1", + }, + IPv6: nap("fd7a:115c:a1e0::1"), + }, + }, + }, + } + + comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{}, + "ID", + "MachineKeyDatabaseField", + "NodeKeyDatabaseField", + "DiscoKeyDatabaseField", + "Endpoints", + "HostinfoDatabaseField", + "Hostinfo", + "Routes", + "CreatedAt", + "UpdatedAt", + )) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := tt.dbFunc() + + alloc, err := NewIPAllocator(db, tt.prefix4, tt.prefix6, false) + if err != nil { + t.Fatalf("failed to set up ip alloc: %s", err) + } + + err = db.BackfillNodeIPs(alloc) + if err != nil { + t.Fatalf("failed to backfill: %s", err) + } + + got, err := db.ListNodes() + if err != nil { + t.Fatalf("failed to get nodes: %s", err) + } + + if diff := cmp.Diff(tt.want, got, comps...); diff != "" { + t.Errorf("Backfill unexpected result (-want +got):\n%s", diff) + } + }) + } +}