diff --git a/api.go b/api.go index 92501a4e49..8629573d5e 100644 --- a/api.go +++ b/api.go @@ -75,15 +75,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - c.String(http.StatusInternalServerError, ":(") - return - } - var m Machine - if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { + if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { log.Println("New Machine!") m = Machine{ Expiry: &req.Expiry, @@ -91,14 +84,14 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Name: req.Hostinfo.Hostname, NodeKey: wgkey.Key(req.NodeKey).HexString(), } - if err := db.Create(&m).Error; err != nil { + if err := h.db.Create(&m).Error; err != nil { log.Printf("Could not create row: %s", err) return } } if !m.Registered && req.Auth.AuthKey != "" { - h.handleAuthKey(c, db, mKey, req, m) + h.handleAuthKey(c, h.db, mKey, req, m) return } @@ -138,7 +131,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() { log.Printf("[%s] We have the OldNodeKey in the database. This is a key refresh", m.Name) m.NodeKey = wgkey.Key(req.NodeKey).HexString() - db.Save(&m) + h.db.Save(&m) resp.AuthURL = "" resp.User = *m.Namespace.toUser() @@ -204,13 +197,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { return } - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return - } var m Machine - if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { + if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { log.Printf("Ignoring request, cannot find machine with key %s", mKey.HexString()) return } @@ -234,7 +222,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { m.Endpoints = datatypes.JSON(endpoints) m.LastSeen = &now } - db.Save(&m) + h.db.Save(&m) pollData := make(chan []byte, 1) update := make(chan []byte, 1) @@ -303,7 +291,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { } now := time.Now().UTC() m.LastSeen = &now - db.Save(&m) + h.db.Save(&m) return true case <-update: @@ -322,7 +310,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { log.Printf("[%s] The client has closed the connection", m.Name) now := time.Now().UTC() m.LastSeen = &now - db.Save(&m) + h.db.Save(&m) h.pollMu.Lock() cancelKeepAlive <- []byte{} delete(h.clientsPolling, m.ID) diff --git a/app.go b/app.go index 0cdc31051a..8ff602995f 100644 --- a/app.go +++ b/app.go @@ -12,6 +12,7 @@ import ( "github.com/gin-gonic/gin" "golang.org/x/crypto/acme/autocert" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/wgkey" ) @@ -43,6 +44,7 @@ type Config struct { // Headscale represents the base app of the service type Headscale struct { cfg Config + db *gorm.DB dbString string dbType string dbDebug bool @@ -87,6 +89,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) { if err != nil { return nil, err } + h.clientsPolling = make(map[uint64]chan []byte) return &h, nil } @@ -107,12 +110,6 @@ func (h *Headscale) ExpireEphemeralNodes(milliSeconds int64) { } func (h *Headscale) expireEphemeralNodesWorker() { - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return - } - namespaces, err := h.ListNamespaces() if err != nil { log.Printf("Error listing namespaces: %s", err) @@ -127,7 +124,7 @@ func (h *Headscale) expireEphemeralNodesWorker() { for _, m := range *machines { if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { log.Printf("[%s] Ephemeral client removed from database\n", m.Name) - err = db.Unscoped().Delete(m).Error + err = h.db.Unscoped().Delete(m).Error if err != nil { log.Printf("[%s] 🤮 Cannot delete ephemeral machine from the database: %s", m.Name, err) } diff --git a/app_test.go b/app_test.go index 6323655ed9..ad633334da 100644 --- a/app_test.go +++ b/app_test.go @@ -47,4 +47,9 @@ func (s *Suite) ResetDB(c *check.C) { if err != nil { c.Fatal(err) } + db, err := h.openDB() + if err != nil { + c.Fatal(err) + } + h.db = db } diff --git a/cli.go b/cli.go index 2ab1061f0b..9c5b66e5ad 100644 --- a/cli.go +++ b/cli.go @@ -2,7 +2,6 @@ package headscale import ( "errors" - "log" "gorm.io/gorm" "tailscale.com/types/wgkey" @@ -18,13 +17,9 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err if err != nil { return nil, err } - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return nil, err - } + m := Machine{} - if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { + if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, errors.New("Machine not found") } @@ -40,6 +35,6 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err m.NamespaceID = ns.ID m.Registered = true m.RegisterMethod = "cli" - db.Save(&m) + h.db.Save(&m) return &m, nil } diff --git a/cli_test.go b/cli_test.go index 3268e1afa6..9616b4a23c 100644 --- a/cli_test.go +++ b/cli_test.go @@ -8,11 +8,6 @@ func (s *Suite) TestRegisterMachine(c *check.C) { n, err := h.CreateNamespace("test") c.Assert(err, check.IsNil) - db, err := h.db() - if err != nil { - c.Fatal(err) - } - m := Machine{ ID: 0, MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", @@ -21,7 +16,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) { Name: "testmachine", NamespaceID: n.ID, } - db.Save(&m) + h.db.Save(&m) _, err = h.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) diff --git a/db.go b/db.go index d7ce66e109..99d45f8627 100644 --- a/db.go +++ b/db.go @@ -17,10 +17,12 @@ type KV struct { } func (h *Headscale) initDB() error { - db, err := h.db() + db, err := h.openDB() if err != nil { return err } + h.db = db + if h.dbType == "postgres" { db.Exec("create extension if not exists \"uuid-ossp\";") } @@ -45,7 +47,7 @@ func (h *Headscale) initDB() error { return err } -func (h *Headscale) db() (*gorm.DB, error) { +func (h *Headscale) openDB() (*gorm.DB, error) { var db *gorm.DB var err error switch h.dbType { @@ -69,12 +71,8 @@ func (h *Headscale) db() (*gorm.DB, error) { } func (h *Headscale) getValue(key string) (string, error) { - db, err := h.db() - if err != nil { - return "", err - } var row KV - if result := db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) { + if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", errors.New("not found") } return row.Value, nil @@ -85,16 +83,13 @@ func (h *Headscale) setValue(key string, value string) error { Key: key, Value: value, } - db, err := h.db() - if err != nil { - return err - } - _, err = h.getValue(key) + + _, err := h.getValue(key) if err == nil { - db.Model(&kv).Where("key = ?", key).Update("value", value) + h.db.Model(&kv).Where("key = ?", key).Update("value", value) return nil } - db.Create(kv) + h.db.Create(kv) return nil } diff --git a/machine.go b/machine.go index 59dbf16fd4..1892219dcc 100644 --- a/machine.go +++ b/machine.go @@ -154,14 +154,9 @@ func (m Machine) toNode() (*tailcfg.Node, error) { } func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) { - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return nil, err - } machines := []Machine{} - if err = db.Where("namespace_id = ? AND machine_key <> ? AND registered", + if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered", m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { log.Printf("Error accessing db: %s", err) return nil, err diff --git a/machine_test.go b/machine_test.go index 25c9e95162..d9a472cbf3 100644 --- a/machine_test.go +++ b/machine_test.go @@ -11,11 +11,6 @@ func (s *Suite) TestGetMachine(c *check.C) { pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) c.Assert(err, check.IsNil) - db, err := h.db() - if err != nil { - c.Fatal(err) - } - _, err = h.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) @@ -30,7 +25,7 @@ func (s *Suite) TestGetMachine(c *check.C) { RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - db.Save(&m) + h.db.Save(&m) m1, err := h.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) diff --git a/namespaces.go b/namespaces.go index afdbb9fbae..ddfbee99a4 100644 --- a/namespaces.go +++ b/namespaces.go @@ -25,18 +25,12 @@ type Namespace struct { // CreateNamespace creates a new Namespace. Returns error if could not be created // or another namespace already exists func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return nil, err - } - n := Namespace{} - if err := db.Where("name = ?", name).First(&n).Error; err == nil { + if err := h.db.Where("name = ?", name).First(&n).Error; err == nil { return nil, errorNamespaceExists } n.Name = name - if err := db.Create(&n).Error; err != nil { + if err := h.db.Create(&n).Error; err != nil { log.Printf("Could not create row: %s", err) return nil, err } @@ -46,12 +40,6 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { // DestroyNamespace destroys a Namespace. Returns error if the Namespace does // not exist or if there are machines associated with it. func (h *Headscale) DestroyNamespace(name string) error { - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return err - } - n, err := h.GetNamespace(name) if err != nil { return errorNamespaceNotFound @@ -65,7 +53,7 @@ func (h *Headscale) DestroyNamespace(name string) error { return errorNamespaceNotEmpty } - if result := db.Unscoped().Delete(&n); result.Error != nil { + if result := h.db.Unscoped().Delete(&n); result.Error != nil { return err } @@ -74,14 +62,8 @@ func (h *Headscale) DestroyNamespace(name string) error { // GetNamespace fetches a namespace by name func (h *Headscale) GetNamespace(name string) (*Namespace, error) { - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return nil, err - } - n := Namespace{} - if result := db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) { + if result := h.db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, errorNamespaceNotFound } return &n, nil @@ -89,13 +71,8 @@ func (h *Headscale) GetNamespace(name string) (*Namespace, error) { // ListNamespaces gets all the existing namespaces func (h *Headscale) ListNamespaces() (*[]Namespace, error) { - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return nil, err - } namespaces := []Namespace{} - if err := db.Find(&namespaces).Error; err != nil { + if err := h.db.Find(&namespaces).Error; err != nil { return nil, err } return &namespaces, nil @@ -107,14 +84,9 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) { if err != nil { return nil, err } - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return nil, err - } machines := []Machine{} - if err := db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { + if err := h.db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { return nil, err } return &machines, nil @@ -126,13 +98,8 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error if err != nil { return err } - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return err - } m.NamespaceID = n.ID - db.Save(&m) + h.db.Save(&m) return nil } diff --git a/namespaces_test.go b/namespaces_test.go index fd6045e3f4..9168b20ba1 100644 --- a/namespaces_test.go +++ b/namespaces_test.go @@ -30,10 +30,6 @@ func (s *Suite) TestDestroyNamespaceErrors(c *check.C) { pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) c.Assert(err, check.IsNil) - db, err := h.db() - if err != nil { - c.Fatal(err) - } m := Machine{ ID: 0, MachineKey: "foo", @@ -45,7 +41,7 @@ func (s *Suite) TestDestroyNamespaceErrors(c *check.C) { RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - db.Save(&m) + h.db.Save(&m) err = h.DestroyNamespace("test") c.Assert(err, check.Equals, errorNamespaceNotEmpty) diff --git a/preauth_keys.go b/preauth_keys.go index f0346e6885..7cffceae89 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -4,7 +4,6 @@ import ( "crypto/rand" "encoding/hex" "errors" - "log" "time" "gorm.io/gorm" @@ -34,12 +33,6 @@ func (h *Headscale) CreatePreAuthKey(namespaceName string, reusable bool, epheme return nil, err } - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return nil, err - } - now := time.Now().UTC() kstr, err := h.generateKey() if err != nil { @@ -55,7 +48,7 @@ func (h *Headscale) CreatePreAuthKey(namespaceName string, reusable bool, epheme CreatedAt: &now, Expiration: expiration, } - db.Save(&k) + h.db.Save(&k) return &k, nil } @@ -66,14 +59,9 @@ func (h *Headscale) GetPreAuthKeys(namespaceName string) (*[]PreAuthKey, error) if err != nil { return nil, err } - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return nil, err - } keys := []PreAuthKey{} - if err := db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { + if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { return nil, err } return &keys, nil @@ -82,13 +70,8 @@ func (h *Headscale) GetPreAuthKeys(namespaceName string) (*[]PreAuthKey, error) // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { - db, err := h.db() - if err != nil { - return nil, err - } - pak := PreAuthKey{} - if result := db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) { + if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, errorAuthKeyNotFound } @@ -101,7 +84,7 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { } machines := []Machine{} - if err := db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { + if err := h.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { return nil, err } diff --git a/preauth_keys_test.go b/preauth_keys_test.go index 471e4632c4..6f1369c5a7 100644 --- a/preauth_keys_test.go +++ b/preauth_keys_test.go @@ -73,10 +73,6 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) c.Assert(err, check.IsNil) - db, err := h.db() - if err != nil { - c.Fatal(err) - } m := Machine{ ID: 0, MachineKey: "foo", @@ -88,7 +84,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - db.Save(&m) + h.db.Save(&m) p, err := h.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, errorAuthKeyNotReusableAlreadyUsed) @@ -102,10 +98,6 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { pak, err := h.CreatePreAuthKey(n.Name, true, false, nil) c.Assert(err, check.IsNil) - db, err := h.db() - if err != nil { - c.Fatal(err) - } m := Machine{ ID: 1, MachineKey: "foo", @@ -117,7 +109,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - db.Save(&m) + h.db.Save(&m) p, err := h.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) @@ -143,10 +135,6 @@ func (*Suite) TestEphemeralKey(c *check.C) { pak, err := h.CreatePreAuthKey(n.Name, false, true, nil) c.Assert(err, check.IsNil) - db, err := h.db() - if err != nil { - c.Fatal(err) - } now := time.Now() m := Machine{ ID: 0, @@ -160,7 +148,7 @@ func (*Suite) TestEphemeralKey(c *check.C) { LastSeen: &now, AuthKeyID: uint(pak.ID), } - db.Save(&m) + h.db.Save(&m) _, err = h.checkKeyValidity(pak.Key) // Ephemeral keys are by definition reusable diff --git a/routes.go b/routes.go index 8b09e3f5e8..a02bed306a 100644 --- a/routes.go +++ b/routes.go @@ -3,7 +3,6 @@ package headscale import ( "encoding/json" "errors" - "log" "gorm.io/datatypes" "inet.af/netaddr" @@ -42,15 +41,9 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr for _, rIP := range hi.RoutableIPs { if rIP == route { - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return nil, err - } - routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest m.EnabledRoutes = datatypes.JSON(routes) - db.Save(&m) + h.db.Save(&m) // THIS IS COMPLETELY USELESS. // The peers map is stored in memory in the server process. diff --git a/routes_test.go b/routes_test.go index 19dab34e73..a05b7e16e6 100644 --- a/routes_test.go +++ b/routes_test.go @@ -16,11 +16,6 @@ func (s *Suite) TestGetRoutes(c *check.C) { pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) c.Assert(err, check.IsNil) - db, err := h.db() - if err != nil { - c.Fatal(err) - } - _, err = h.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) @@ -45,7 +40,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostinfo), } - db.Save(&m) + h.db.Save(&m) r, err := h.GetNodeRoutes("test", "testmachine") c.Assert(err, check.IsNil) diff --git a/utils.go b/utils.go index c2271aa9e1..f21063b02e 100644 --- a/utils.go +++ b/utils.go @@ -78,10 +78,6 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err } func (h *Headscale) getAvailableIP() (*net.IP, error) { - db, err := h.db() - if err != nil { - return nil, err - } i := 0 for { ip, err := getRandomIP() @@ -89,7 +85,7 @@ func (h *Headscale) getAvailableIP() (*net.IP, error) { return nil, err } m := Machine{} - if result := db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) { + if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) { return ip, nil } i++