Skip to content

Commit

Permalink
create DB struct
Browse files Browse the repository at this point in the history
This is step one in detaching the Database layer from Headscale (h). The
ultimate goal is to have all function that does database operations in
its own package, and keep the business logic and writing separate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
  • Loading branch information
kradalby committed May 26, 2023
1 parent b01f1f1 commit 14e29a7
Show file tree
Hide file tree
Showing 48 changed files with 1,739 additions and 1,580 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@ integration_test/etc/config.dump.yaml
# MkDocs
.cache
/site

__debug_bin
4 changes: 2 additions & 2 deletions cmd/headscale/cli/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"

v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"github.com/pterm/pterm"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -83,7 +83,7 @@ var listAPIKeys = &cobra.Command{
}

tableData = append(tableData, []string{
strconv.FormatUint(key.GetId(), hscontrol.Base10),
strconv.FormatUint(key.GetId(), util.Base10),
key.GetPrefix(),
expiration,
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
Expand Down
4 changes: 2 additions & 2 deletions cmd/headscale/cli/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"fmt"

v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -93,7 +93,7 @@ var createNodeCmd = &cobra.Command{

return
}
if !hscontrol.NodePublicKeyRegex.Match([]byte(machineKey)) {
if !util.NodePublicKeyRegex.Match([]byte(machineKey)) {
err = errPreAuthKeyMalformed
ErrorOutput(
err,
Expand Down
8 changes: 4 additions & 4 deletions cmd/headscale/cli/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

survey "github.com/AlecAivazis/survey/v2"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -529,15 +529,15 @@ func nodesToPtables(

var machineKey key.MachinePublic
err := machineKey.UnmarshalText(
[]byte(hscontrol.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil {
machineKey = key.MachinePublic{}
}

var nodeKey key.NodePublic
err = nodeKey.UnmarshalText(
[]byte(hscontrol.NodePublicKeyEnsurePrefix(machine.NodeKey)),
[]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey)),
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -596,7 +596,7 @@ func nodesToPtables(
}

nodeData := []string{
strconv.FormatUint(machine.Id, hscontrol.Base10),
strconv.FormatUint(machine.Id, util.Base10),
machine.Name,
machine.GetGivenName(),
machineKey.ShortString(),
Expand Down
6 changes: 2 additions & 4 deletions cmd/headscale/cli/users.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package cli

import (
"errors"
"fmt"

survey "github.com/AlecAivazis/survey/v2"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol"
"github.com/pterm/pterm"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
Expand All @@ -20,9 +20,7 @@ func init() {
userCmd.AddCommand(renameUserCmd)
}

const (
errMissingParameter = hscontrol.Error("missing parameters")
)
var errMissingParameter = errors.New("missing parameters")

var userCmd = &cobra.Command{
Use: "users",
Expand Down
5 changes: 3 additions & 2 deletions cmd/headscale/cli/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -39,7 +40,7 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
// We are doing this here, as in the future could be cool to have it also hot-reload

if cfg.ACL.PolicyPath != "" {
aclPath := hscontrol.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
err = app.LoadACLPolicyFromPath(aclPath)
if err != nil {
log.Fatal().
Expand Down Expand Up @@ -98,7 +99,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
grpcOptions = append(
grpcOptions,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(hscontrol.GrpcSocketDialer),
grpc.WithContextDialer(util.GrpcSocketDialer),
)
} else {
// If we are not connecting to a local server, require an API key for authentication
Expand Down
5 changes: 3 additions & 2 deletions cmd/headscale/headscale_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"

"github.com/juanfont/headscale/hscontrol"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/spf13/viper"
"gopkg.in/check.v1"
)
Expand Down Expand Up @@ -64,7 +65,7 @@ func (*Suite) TestConfigFileLoading(c *check.C) {
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
c.Assert(
hscontrol.GetFileMode("unix_socket_permission"),
util.GetFileMode("unix_socket_permission"),
check.Equals,
fs.FileMode(0o770),
)
Expand Down Expand Up @@ -107,7 +108,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
c.Assert(
hscontrol.GetFileMode("unix_socket_permission"),
util.GetFileMode("unix_socket_permission"),
check.Equals,
fs.FileMode(0o770),
)
Expand Down
38 changes: 17 additions & 21 deletions hscontrol/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strings"
"time"

"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/tailscale/hujson"
"go4.org/netipx"
Expand All @@ -20,21 +21,16 @@ import (
"tailscale.com/tailcfg"
)

const (
errEmptyPolicy = Error("empty policy")
errInvalidAction = Error("invalid action")
errInvalidGroup = Error("invalid group")
errInvalidTag = Error("invalid tag")
errInvalidPortFormat = Error("invalid port format")
errWildcardIsNeeded = Error("wildcard as port is required for the protocol")
var (
errEmptyPolicy = errors.New("empty policy")
errInvalidAction = errors.New("invalid action")
errInvalidGroup = errors.New("invalid group")
errInvalidTag = errors.New("invalid tag")
errInvalidPortFormat = errors.New("invalid port format")
errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol")
)

const (
Base8 = 8
Base10 = 10
BitSize16 = 16
BitSize32 = 32
BitSize64 = 64
portRangeBegin = 0
portRangeEnd = 65535
expectedTokenItems = 2
Expand Down Expand Up @@ -123,7 +119,7 @@ func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error {
}

func (h *Headscale) UpdateACLRules() error {
machines, err := h.ListMachines()
machines, err := h.db.ListMachines()
if err != nil {
return err
}
Expand Down Expand Up @@ -230,7 +226,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
return nil, errEmptyPolicy
}

machines, err := h.ListMachines()
machines, err := h.db.ListMachines()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -570,7 +566,7 @@ func excludeCorrectlyTaggedNodes(
for tag := range aclPolicy.TagOwners {
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
ns := append(owners, user)
if contains(ns, user) {
if util.StringOrPrefixListContains(ns, user) {
tags = append(tags, tag)
}
}
Expand All @@ -580,7 +576,7 @@ func excludeCorrectlyTaggedNodes(

found := false
for _, t := range hi.RequestTags {
if contains(tags, t) {
if util.StringOrPrefixListContains(tags, t) {
found = true

break
Expand Down Expand Up @@ -614,7 +610,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
rang := strings.Split(portStr, "-")
switch len(rang) {
case 1:
port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
Expand All @@ -624,11 +620,11 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
})

case expectedTokenItems:
start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -754,7 +750,7 @@ func (pol *ACLPolicy) getIPsFromTag(

// check for forced tags
for _, machine := range machines {
if contains(machine.ForcedTags, alias) {
if util.StringOrPrefixListContains(machine.ForcedTags, alias) {
machine.IPAddresses.AppendToIPSet(&build)
}
}
Expand Down Expand Up @@ -783,7 +779,7 @@ func (pol *ACLPolicy) getIPsFromTag(
machines := filterMachinesByUser(machines, user)
for _, machine := range machines {
hi := machine.GetHostInfo()
if contains(hi.RequestTags, alias) {
if util.StringOrPrefixListContains(hi.RequestTags, alias) {
machine.IPAddresses.AppendToIPSet(&build)
}
}
Expand Down
Loading

0 comments on commit 14e29a7

Please sign in to comment.