From 225f8d45b6492e2c670880bbda449fd9b276b95a Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 11 Sep 2024 14:44:00 +0200 Subject: [PATCH 1/2] move logic for validating node names this commits moves the generation of "given names" of nodes into the registration function, and adds validation of renames to RenameNode using the same logic. Fixes #2121 Signed-off-by: Kristoffer Dalby --- hscontrol/auth.go | 26 +------ hscontrol/db/node.go | 72 +++++++++---------- hscontrol/db/node_test.go | 143 ++++++++++++++++++++++++++------------ hscontrol/grpcv1.go | 8 +-- 4 files changed, 134 insertions(+), 115 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index aaab03cebb..8b8557babd 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -66,7 +66,7 @@ func (h *Headscale) handleRegister( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) { - logInfo, logTrace, logErr := logAuthFunc(regReq, machineKey) + logInfo, logTrace, _ := logAuthFunc(regReq, machineKey) now := time.Now().UTC() logTrace("handleRegister called, looking up machine in DB") node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) @@ -105,16 +105,6 @@ func (h *Headscale) handleRegister( logInfo("Node not found in database, creating new") - givenName, err := h.db.GenerateGivenName( - machineKey, - regReq.Hostinfo.Hostname, - ) - if err != nil { - logErr(err, "Failed to generate given name for node") - - return - } - // The node did not have a key to authenticate, which means // that we rely on a method that calls back some how (OpenID or CLI) // We create the node and then keep it around until a callback @@ -122,7 +112,6 @@ func (h *Headscale) handleRegister( newNode := types.Node{ MachineKey: machineKey, Hostname: regReq.Hostinfo.Hostname, - GivenName: givenName, NodeKey: regReq.NodeKey, LastSeen: &now, Expiry: &time.Time{}, @@ -354,21 +343,8 @@ func (h *Headscale) handleAuthKey( } else { now := time.Now().UTC() - givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname) - if err != nil { - log.Error(). - Caller(). - Str("func", "RegistrationHandler"). - Str("hostinfo.name", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Failed to generate given name for node") - - return - } - nodeToRegister := types.Node{ Hostname: registerRequest.Hostinfo.Hostname, - GivenName: givenName, UserID: pak.User.ID, User: pak.User, MachineKey: machineKey, diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index a9e78a4539..7a361fb723 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -90,20 +90,6 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { }) } -func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { - nodes := types.Nodes{} - if err := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Where("given_name = ?", givenName).Find(&nodes).Error; err != nil { - return nil, err - } - - return nodes, nil -} - func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { return getNode(rx, user, name) @@ -242,9 +228,9 @@ func SetTags( } // RenameNode takes a Node struct and a new GivenName for the nodes -// and renames it. +// and renames it. If the name is not unique, it will return an error. func RenameNode(tx *gorm.DB, - nodeID uint64, newName string, + nodeID types.NodeID, newName string, ) error { err := util.CheckForFQDNRules( newName, @@ -253,6 +239,15 @@ func RenameNode(tx *gorm.DB, return fmt.Errorf("renaming node: %w", err) } + uniq, err := isUnqiueName(tx, newName) + if err != nil { + return fmt.Errorf("checking if name is unique: %w", err) + } + + if !uniq { + return fmt.Errorf("name is not unique: %s", newName) + } + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { return fmt.Errorf("failed to rename node in the database: %w", err) } @@ -415,6 +410,15 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad node.IPv4 = ipv4 node.IPv6 = ipv6 + if node.GivenName == "" { + givenName, err := ensureUniqueGivenName(tx, node.Hostname) + if err != nil { + return nil, fmt.Errorf("failed to ensure unique given name: %w", err) + } + + node.GivenName = givenName + } + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register(save) node in the database: %w", err) } @@ -642,40 +646,32 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { return normalizedHostname, nil } -func (hsdb *HSDatabase) GenerateGivenName( - mkey key.MachinePublic, - suppliedName string, -) (string, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (string, error) { - return GenerateGivenName(rx, mkey, suppliedName) - }) +func isUnqiueName(tx *gorm.DB, name string) (bool, error) { + nodes := types.Nodes{} + if err := tx. + Where("given_name = ?", name, name).Find(&nodes).Error; err != nil { + return false, err + } + + return len(nodes) == 0, nil } -func GenerateGivenName( +func ensureUniqueGivenName( tx *gorm.DB, - mkey key.MachinePublic, - suppliedName string, + name string, ) (string, error) { - givenName, err := generateGivenName(suppliedName, false) + givenName, err := generateGivenName(name, false) if err != nil { return "", err } - // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - nodes, err := listNodesByGivenName(tx, givenName) + unique, err := isUnqiueName(tx, givenName) if err != nil { return "", err } - var nodeFound *types.Node - for idx, node := range nodes { - if node.GivenName == givenName { - nodeFound = nodes[idx] - } - } - - if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() { - postfixedName, err := generateGivenName(suppliedName, true) + if !unique { + postfixedName, err := generateGivenName(name, true) if err != nil { return "", err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 94cce13b55..bafb22ba30 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -19,6 +19,7 @@ import ( "github.com/puzpuzpuz/xsync/v3" "github.com/stretchr/testify/assert" "gopkg.in/check.v1" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/ptr" @@ -313,51 +314,6 @@ func (s *Suite) TestExpireNode(c *check.C) { c.Assert(nodeFromDB.IsExpired(), check.Equals, true) } -func (s *Suite) TestGenerateGivenName(c *check.C) { - user1, err := db.CreateUser("user-1") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.getNode("user-1", "testnode") - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - machineKey2 := key.NewMachine() - - node := &types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "hostname-1", - GivenName: "hostname-1", - UserID: user1.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - - trx := db.DB.Save(node) - c.Assert(trx.Error, check.IsNil) - - givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") - comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-2", comment) - - givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1") - comment = check.Commentf("Same user, same node, same hostname, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-1", comment) - - givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1") - comment = check.Commentf("Same user, unique nodes, same hostname, conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment) -} - func (s *Suite) TestSetTags(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) @@ -778,3 +734,100 @@ func TestListEphemeralNodes(t *testing.T) { assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID) assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname) } + +func TestRenameNode(t *testing.T) { + db, err := newTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser("test") + assert.NoError(t, err) + + user2, err := db.CreateUser("test2") + assert.NoError(t, err) + + node := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + } + + node2 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test", + UserID: user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + } + + err = db.DB.Save(&node).Error + assert.NoError(t, err) + + err = db.DB.Save(&node2).Error + assert.NoError(t, err) + + err = db.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNode(tx, node, nil, nil) + if err != nil { + return err + } + _, err = RegisterNode(tx, node2, nil, nil) + return err + }) + assert.NoError(t, err) + + nodes, err := db.ListNodes() + assert.NoError(t, err) + + assert.Len(t, nodes, 2) + + t.Logf("node1 %s %s", nodes[0].Hostname, nodes[0].GivenName) + t.Logf("node2 %s %s", nodes[1].Hostname, nodes[1].GivenName) + + assert.Equal(t, nodes[0].Hostname, nodes[0].GivenName) + assert.NotEqual(t, nodes[1].Hostname, nodes[1].GivenName) + assert.Equal(t, nodes[0].Hostname, nodes[1].Hostname) + assert.NotEqual(t, nodes[0].Hostname, nodes[1].GivenName) + assert.Contains(t, nodes[1].GivenName, nodes[0].Hostname) + assert.Equal(t, nodes[0].GivenName, nodes[1].Hostname) + assert.Len(t, nodes[0].Hostname, 4) + assert.Len(t, nodes[1].Hostname, 4) + assert.Len(t, nodes[0].GivenName, 4) + assert.Len(t, nodes[1].GivenName, 13) + + // Nodes can be renamed to a unique name + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[0].ID, "newname") + }) + assert.NoError(t, err) + + nodes, err = db.ListNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 2) + assert.Equal(t, nodes[0].Hostname, "test") + assert.Equal(t, nodes[0].GivenName, "newname") + + // Nodes can reuse name that is no longer used + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[1].ID, "test") + }) + assert.NoError(t, err) + + nodes, err = db.ListNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 2) + assert.Equal(t, nodes[0].Hostname, "test") + assert.Equal(t, nodes[0].GivenName, "newname") + assert.Equal(t, nodes[1].GivenName, "test") + + // Nodes cannot be renamed to used names + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[0].ID, "test") + }) + assert.ErrorContains(t, err, "name is not unique") +} diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 3f985d9857..596748f274 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -373,7 +373,7 @@ func (api headscaleV1APIServer) RenameNode( node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { err := db.RenameNode( tx, - request.GetNodeId(), + types.NodeID(request.GetNodeId()), request.GetNewName(), ) if err != nil { @@ -802,18 +802,12 @@ func (api headscaleV1APIServer) DebugCreateNode( return nil, err } - givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName()) - if err != nil { - return nil, err - } - nodeKey := key.NewNode() newNode := types.Node{ MachineKey: mkey, NodeKey: nodeKey.Public(), Hostname: request.GetName(), - GivenName: givenName, User: *user, Expiry: &time.Time{}, From d8c043225593abe06bef22838092be37acfad589 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 11 Sep 2024 17:15:50 +0200 Subject: [PATCH 2/2] fix double arg Signed-off-by: Kristoffer Dalby --- hscontrol/db/node.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 7a361fb723..c0f42de186 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -649,7 +649,7 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { func isUnqiueName(tx *gorm.DB, name string) (bool, error) { nodes := types.Nodes{} if err := tx. - Where("given_name = ?", name, name).Find(&nodes).Error; err != nil { + Where("given_name = ?", name).Find(&nodes).Error; err != nil { return false, err }