Skip to content

Commit

Permalink
Rework getAvailableIp
Browse files Browse the repository at this point in the history
This commit reworks getAvailableIp with a "simpler" version that will
look for the first available IP address in our IP Prefix.

There is a couple of ideas behind this:

* Make the host IPs reasonably predictable and in within similar
  subnets, which should simplify ACLs for subnets
* The code is not random, but deterministic so we can have tests
* The code is a bit more understandable (no bit shift magic)
  • Loading branch information
kradalby committed Aug 2, 2021
1 parent 309f868 commit b5841c8
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 41 deletions.
2 changes: 1 addition & 1 deletion app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (s *Suite) ResetDB(c *check.C) {
c.Fatal(err)
}
cfg := Config{
IPPrefix: netaddr.MustParseIPPrefix("127.0.0.1/32"),
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
}

h = Headscale{
Expand Down
1 change: 1 addition & 0 deletions cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
IPAddress: "10.0.0.1",
}
h.db.Save(&m)

Expand Down
99 changes: 59 additions & 40 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@ package headscale

import (
"crypto/rand"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"time"

mathrand "math/rand"

"golang.org/x/crypto/nacl/box"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/types/wgkey"
)
Expand Down Expand Up @@ -78,47 +71,73 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err
return msg, nil
}

func (h *Headscale) getAvailableIP() (*net.IP, error) {
i := 0
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipPrefix := h.cfg.IPPrefix

usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
}

// for _, ip := range usedIps {
// nextIP := ip.Next()

// if !containsIPs(usedIps, nextIP) && ipPrefix.Contains(nextIP) {
// return &nextIP, nil
// }
// }

// // If there are no IPs in use, we are starting fresh and
// // can issue IPs from the beginning of the prefix.
// ip := ipPrefix.IP()
// return &ip, nil

// return nil, fmt.Errorf("failed to find any available IP in %s", ipPrefix)

// Get the first IP in our prefix
ip := ipPrefix.IP()

for {
ip, err := getRandomIP(h.cfg.IPPrefix)
if err != nil {
return nil, err
if !ipPrefix.Contains(ip) {
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
}
m := Machine{}
if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return ip, nil

if ip.IsZero() &&
ip.IsLoopback() {
continue
}
i++
if i == 100 { // really random number
break

if !containsIPs(usedIps, ip) {
return &ip, nil
}

ip = ip.Next()
}
return nil, errors.New(fmt.Sprintf("Could not find an available IP address in %s", h.cfg.IPPrefix.String()))
}

func getRandomIP(ipPrefix netaddr.IPPrefix) (*net.IP, error) {
mathrand.Seed(time.Now().Unix())
ipo, ipnet, err := net.ParseCIDR(ipPrefix.String())
if err == nil {
ip := ipo.To4()
// fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet)
// fmt.Println("Final address is ", ip)
// fmt.Println("Broadcast address is ", ipb)
// fmt.Println("Network address is ", ipn)
r := mathrand.Uint32()
ipRaw := make([]byte, 4)
binary.LittleEndian.PutUint32(ipRaw, r)
// ipRaw[3] = 254
// fmt.Println("ipRaw is ", ipRaw)
for i, v := range ipRaw {
// fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i])
ip[i] = ip[i] + (v &^ ipnet.Mask[i])
// fmt.Println("IP After: ", ip[i])
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
var addresses []string
h.db.Model(&Machine{}).Pluck("ip_address", &addresses)

ips := make([]netaddr.IP, len(addresses))
for index, addr := range addresses {
ip, err := netaddr.ParseIP(addr)
if err != nil {
return nil, fmt.Errorf("failed to parse ip from database, %w", err)
}

ips[index] = ip
}

return ips, nil
}

func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
}
// fmt.Println("FINAL IP: ", ip.String())
return &ip, nil
}

return nil, err
return false
}
105 changes: 105 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package headscale

import (
"gopkg.in/check.v1"
"inet.af/netaddr"
)

func (s *Suite) TestGetAvailableIp(c *check.C) {
ip, err := h.getAvailableIP()

c.Assert(err, check.IsNil)

expected := netaddr.MustParseIP("10.27.0.0")

c.Assert(ip.String(), check.Equals, expected.String())
}

func (s *Suite) TestGetUsedIps(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

n, err := h.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)

pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)

_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)

m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)

ips, err := h.getUsedIPs()

c.Assert(err, check.IsNil)

expected := netaddr.MustParseIP("10.27.0.0")

c.Assert(ips[0], check.Equals, expected)
}

func (s *Suite) TestGetMultiIp(c *check.C) {
n, err := h.CreateNamespace("test-ip-multi")
c.Assert(err, check.IsNil)

for i := 1; i <= 350; i++ {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)

_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)

m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
}

ips, err := h.getUsedIPs()

c.Assert(err, check.IsNil)

c.Assert(len(ips), check.Equals, 350)

c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.0"))
c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.9"))
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.44"))

expectedNextIP := netaddr.MustParseIP("10.27.1.94")
nextIP, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())

// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
nextIP2, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
}

0 comments on commit b5841c8

Please sign in to comment.