From bc1c1f5ce87201048fc82bac76b547160eda6730 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 21 Oct 2022 14:42:37 +0200 Subject: [PATCH] Fix most nil pointers, actually make it check for unique across headscale Signed-off-by: Kristoffer Dalby --- grpcv1.go | 2 +- integration/general_test.go | 8 ++++---- machine.go | 31 +++++++++++++++++++++++-------- protocol_common.go | 7 +++++-- 4 files changed, 33 insertions(+), 15 deletions(-) diff --git a/grpcv1.go b/grpcv1.go index 6c9e50bf39..9fac9affec 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -479,7 +479,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( Hostname: "DebugTestMachine", } - givenName, err := api.h.GenerateGivenName(namespace.Name, request.GetKey(), request.GetName()) + givenName, err := api.h.GenerateGivenName(request.GetKey(), request.GetName()) if err != nil { return nil, err } diff --git a/integration/general_test.go b/integration/general_test.go index 2e7689ace6..2487e2b8a7 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -71,8 +71,8 @@ func TestPingAll(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) - err = scenario.Shutdown() - if err != nil { - t.Errorf("failed to tear down scenario: %s", err) - } + // err = scenario.Shutdown() + // if err != nil { + // t.Errorf("failed to tear down scenario: %s", err) + // } } diff --git a/machine.go b/machine.go index 2e4f9caaa1..39e5110ac3 100644 --- a/machine.go +++ b/machine.go @@ -332,6 +332,15 @@ func (h *Headscale) ListMachines() ([]Machine, error) { return machines, nil } +func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error) { + machines := []Machine{} + if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Where("given_name = ?", givenName).Error; err != nil { + return nil, err + } + + return machines, nil +} + // GetMachine finds a Machine by name and namespace and returns the Machine struct. func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) { machines, err := h.ListMachinesInNamespace(namespace) @@ -1061,21 +1070,27 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s return normalizedHostname, nil } -func (h *Headscale) GenerateGivenName(namespace string, machineKey string, suppliedName string) (string, error) { +func (h *Headscale) GenerateGivenName(machineKey string, suppliedName string) (string, error) { givenName, err := h.generateGivenName(suppliedName, false) if err != nil { return "", err } // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - machine, _ := h.GetMachineByGivenName(namespace, givenName) - if machine != nil && machine.MachineKey != machineKey && machine.GivenName == givenName { - postfixedName, err := h.generateGivenName(suppliedName, true) - if err != nil { - return "", err - } + machines, err := h.ListMachinesByGivenName(givenName) + if err != nil { + return "", err + } - givenName = postfixedName + for _, machine := range machines { + if machine.MachineKey != machineKey && machine.GivenName == givenName { + postfixedName, err := h.generateGivenName(suppliedName, true) + if err != nil { + return "", err + } + + givenName = postfixedName + } } return givenName, nil diff --git a/protocol_common.go b/protocol_common.go index 42d413efb8..c6bc2ee5a5 100644 --- a/protocol_common.go +++ b/protocol_common.go @@ -150,7 +150,10 @@ func (h *Headscale) handleRegisterCommon( Bool("noise", machineKey.IsZero()). Msg("New machine not yet in the database") - givenName, err := h.GenerateGivenName(machine.Namespace.Name, machine.MachineKey, registerRequest.Hostinfo.Hostname) + givenName, err := h.GenerateGivenName( + machineKey.String(), + registerRequest.Hostinfo.Hostname, + ) if err != nil { log.Error(). Caller(). @@ -374,7 +377,7 @@ func (h *Headscale) handleAuthKeyCommon( } else { now := time.Now().UTC() - givenName, err := h.GenerateGivenName(machine.Namespace.Name, MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname) + givenName, err := h.GenerateGivenName(MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname) if err != nil { log.Error(). Caller().