Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move logic for validating node names #2127

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
logInfo, logTrace, logErr := logAuthFunc(regReq, machineKey)
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
Expand Down Expand Up @@ -105,24 +105,13 @@ func (h *Headscale) handleRegister(

logInfo("Node not found in database, creating new")

givenName, err := h.db.GenerateGivenName(
machineKey,
regReq.Hostinfo.Hostname,
)
if err != nil {
logErr(err, "Failed to generate given name for node")

return
}

// The node did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the node and then keep it around until a callback
// happens
newNode := types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
Expand Down Expand Up @@ -354,21 +343,8 @@ func (h *Headscale) handleAuthKey(
} else {
now := time.Now().UTC()

givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed to generate given name for node")

return
}

nodeToRegister := types.Node{
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
User: pak.User,
MachineKey: machineKey,
Expand Down
72 changes: 34 additions & 38 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,6 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
})
}

func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Where("given_name = ?", givenName).Find(&nodes).Error; err != nil {
return nil, err
}

return nodes, nil
}

func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return getNode(rx, user, name)
Expand Down Expand Up @@ -242,9 +228,9 @@ func SetTags(
}

// RenameNode takes a Node struct and a new GivenName for the nodes
// and renames it.
// and renames it. If the name is not unique, it will return an error.
func RenameNode(tx *gorm.DB,
nodeID uint64, newName string,
nodeID types.NodeID, newName string,
) error {
err := util.CheckForFQDNRules(
newName,
Expand All @@ -253,6 +239,15 @@ func RenameNode(tx *gorm.DB,
return fmt.Errorf("renaming node: %w", err)
}

uniq, err := isUnqiueName(tx, newName)
if err != nil {
return fmt.Errorf("checking if name is unique: %w", err)
}

if !uniq {
return fmt.Errorf("name is not unique: %s", newName)
}

if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
return fmt.Errorf("failed to rename node in the database: %w", err)
}
Expand Down Expand Up @@ -415,6 +410,15 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
node.IPv4 = ipv4
node.IPv6 = ipv6

if node.GivenName == "" {
givenName, err := ensureUniqueGivenName(tx, node.Hostname)
if err != nil {
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
}

node.GivenName = givenName
}

if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
}
Expand Down Expand Up @@ -642,40 +646,32 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
return normalizedHostname, nil
}

func (hsdb *HSDatabase) GenerateGivenName(
mkey key.MachinePublic,
suppliedName string,
) (string, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (string, error) {
return GenerateGivenName(rx, mkey, suppliedName)
})
func isUnqiueName(tx *gorm.DB, name string) (bool, error) {
nodes := types.Nodes{}
if err := tx.
Where("given_name = ?", name).Find(&nodes).Error; err != nil {
return false, err
}

return len(nodes) == 0, nil
}

func GenerateGivenName(
func ensureUniqueGivenName(
tx *gorm.DB,
mkey key.MachinePublic,
suppliedName string,
name string,
) (string, error) {
givenName, err := generateGivenName(suppliedName, false)
givenName, err := generateGivenName(name, false)
if err != nil {
return "", err
}

// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
nodes, err := listNodesByGivenName(tx, givenName)
unique, err := isUnqiueName(tx, givenName)
if err != nil {
return "", err
}

var nodeFound *types.Node
for idx, node := range nodes {
if node.GivenName == givenName {
nodeFound = nodes[idx]
}
}

if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() {
postfixedName, err := generateGivenName(suppliedName, true)
if !unique {
postfixedName, err := generateGivenName(name, true)
if err != nil {
return "", err
}
Expand Down
143 changes: 98 additions & 45 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/puzpuzpuz/xsync/v3"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
Expand Down Expand Up @@ -313,51 +314,6 @@ func (s *Suite) TestExpireNode(c *check.C) {
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
}

func (s *Suite) TestGenerateGivenName(c *check.C) {
user1, err := db.CreateUser("user-1")
c.Assert(err, check.IsNil)

pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)

_, err = db.getNode("user-1", "testnode")
c.Assert(err, check.NotNil)

nodeKey := key.NewNode()
machineKey := key.NewMachine()

machineKey2 := key.NewMachine()

node := &types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "hostname-1",
GivenName: "hostname-1",
UserID: user1.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}

trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)

givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-2", comment)

givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1")
comment = check.Commentf("Same user, same node, same hostname, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-1", comment)

givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1")
comment = check.Commentf("Same user, unique nodes, same hostname, conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment)
}

func (s *Suite) TestSetTags(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
Expand Down Expand Up @@ -778,3 +734,100 @@ func TestListEphemeralNodes(t *testing.T) {
assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID)
assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname)
}

func TestRenameNode(t *testing.T) {
db, err := newTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}

user, err := db.CreateUser("test")
assert.NoError(t, err)

user2, err := db.CreateUser("test2")
assert.NoError(t, err)

node := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}

node2 := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}

err = db.DB.Save(&node).Error
assert.NoError(t, err)

err = db.DB.Save(&node2).Error
assert.NoError(t, err)

err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, node, nil, nil)
if err != nil {
return err
}
_, err = RegisterNode(tx, node2, nil, nil)
return err
})
assert.NoError(t, err)

nodes, err := db.ListNodes()
assert.NoError(t, err)

assert.Len(t, nodes, 2)

t.Logf("node1 %s %s", nodes[0].Hostname, nodes[0].GivenName)
t.Logf("node2 %s %s", nodes[1].Hostname, nodes[1].GivenName)

assert.Equal(t, nodes[0].Hostname, nodes[0].GivenName)
assert.NotEqual(t, nodes[1].Hostname, nodes[1].GivenName)
assert.Equal(t, nodes[0].Hostname, nodes[1].Hostname)
assert.NotEqual(t, nodes[0].Hostname, nodes[1].GivenName)
assert.Contains(t, nodes[1].GivenName, nodes[0].Hostname)
assert.Equal(t, nodes[0].GivenName, nodes[1].Hostname)
assert.Len(t, nodes[0].Hostname, 4)
assert.Len(t, nodes[1].Hostname, 4)
assert.Len(t, nodes[0].GivenName, 4)
assert.Len(t, nodes[1].GivenName, 13)

// Nodes can be renamed to a unique name
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "newname")
})
assert.NoError(t, err)

nodes, err = db.ListNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, nodes[0].Hostname, "test")
assert.Equal(t, nodes[0].GivenName, "newname")

// Nodes can reuse name that is no longer used
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[1].ID, "test")
})
assert.NoError(t, err)

nodes, err = db.ListNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, nodes[0].Hostname, "test")
assert.Equal(t, nodes[0].GivenName, "newname")
assert.Equal(t, nodes[1].GivenName, "test")

// Nodes cannot be renamed to used names
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "test")
})
assert.ErrorContains(t, err, "name is not unique")
}
8 changes: 1 addition & 7 deletions hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ func (api headscaleV1APIServer) RenameNode(
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.RenameNode(
tx,
request.GetNodeId(),
types.NodeID(request.GetNodeId()),
request.GetNewName(),
)
if err != nil {
Expand Down Expand Up @@ -802,18 +802,12 @@ func (api headscaleV1APIServer) DebugCreateNode(
return nil, err
}

givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName())
if err != nil {
return nil, err
}

nodeKey := key.NewNode()

newNode := types.Node{
MachineKey: mkey,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
GivenName: givenName,
User: *user,

Expiry: &time.Time{},
Expand Down
Loading