diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 3ea343f8a5..94b285e79a 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -12,8 +12,8 @@ on: # Stable release tags - v[0-9]+.[0-9]+.[0-9]+ paths: - - 'docs/**' - - 'mkdocs.yml' + - "docs/**" + - "mkdocs.yml" workflow_dispatch: jobs: diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index dbd3cb9779..f74dcac145 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -43,6 +43,7 @@ jobs: - TestPolicyBrokenConfigCommand - TestDERPVerifyEndpoint - TestResolveMagicDNS + - TestResolveMagicDNSExtraRecordsPath - TestValidateResolvConf - TestDERPServerScenario - TestDERPServerWebsocketScenario diff --git a/CHANGELOG.md b/CHANGELOG.md index c217355920..83fb142f1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,19 +33,19 @@ When automatic migration is enabled (`map_legacy_users: true`), Headscale will f - If `strip_email_domain: true` (the default): the Headscale username matches the "username" part of their email address. - If `strip_email_domain: false`: the Headscale username matches the _whole_ email address. -On migration, Headscale will change the account's username to their `preferred_username`. **This could break any ACLs or policies which are configured to match by username.** + On migration, Headscale will change the account's username to their `preferred_username`. **This could break any ACLs or policies which are configured to match by username.** -Like with Headscale v0.23.0 and earlier, this migration only works for users who haven't changed their email address since their last Headscale login. + Like with Headscale v0.23.0 and earlier, this migration only works for users who haven't changed their email address since their last Headscale login. -A _successful_ automated migration should otherwise be transparent to users. + A _successful_ automated migration should otherwise be transparent to users. -Once a Headscale account has been migrated, it will be _unavailable_ to be matched by the legacy process. An OIDC login with a matching username, but _non-matching_ `iss` and `sub` will instead get a _new_ Headscale account. + Once a Headscale account has been migrated, it will be _unavailable_ to be matched by the legacy process. An OIDC login with a matching username, but _non-matching_ `iss` and `sub` will instead get a _new_ Headscale account. -Because of the way OIDC works, Headscale's automated migration process can _only_ work when a user tries to log in after the update. Mass updates would require Headscale implement a protocol like SCIM, which is **extremely** complicated and not available in all identity providers. + Because of the way OIDC works, Headscale's automated migration process can _only_ work when a user tries to log in after the update. Mass updates would require Headscale implement a protocol like SCIM, which is **extremely** complicated and not available in all identity providers. -Administrators could also attempt to migrate users manually by editing the database, using their own mapping rules with known-good data sources. + Administrators could also attempt to migrate users manually by editing the database, using their own mapping rules with known-good data sources. -Legacy account migration should have no effect on new installations where all users have a recorded `sub` and `iss`. + Legacy account migration should have no effect on new installations where all users have a recorded `sub` and `iss`. ##### What happens when automatic migration is disabled? @@ -95,6 +95,7 @@ This will also affect the way you [reference users in policies](https://github.c - Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pr/2232) - Loosened up `server_url` and `base_domain` check. It was overly strict in some cases. [#2248](https://github.com/juanfont/headscale/pull/2248) - CLI for managing users now accepts `--identifier` in addition to `--name`, usage of `--identifier` is recommended [#2261](https://github.com/juanfont/headscale/pull/2261) +- Add `dns.extra_records_path` configuration option [#2262](https://github.com/juanfont/headscale/issues/2262) ## 0.23.0 (2024-09-18) diff --git a/Dockerfile.integration b/Dockerfile.integration index cf55bd7476..735cdba588 100644 --- a/Dockerfile.integration +++ b/Dockerfile.integration @@ -8,7 +8,7 @@ ENV GOPATH /go WORKDIR /go/src/headscale RUN apt-get update \ - && apt-get install --no-install-recommends --yes less jq sqlite3 \ + && apt-get install --no-install-recommends --yes less jq sqlite3 dnsutils \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean RUN mkdir -p /var/run/headscale diff --git a/Makefile b/Makefile index 96aff1fd96..fb22e7bb52 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,10 @@ fmt-prettier: prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}' fmt-go: - golines --max-len=88 --base-formatter=gofumpt -w $(GO_SOURCES) + # TODO(kradalby): Reeval if we want to use 88 in the future. + # golines --max-len=88 --base-formatter=gofumpt -w $(GO_SOURCES) + gofumpt -l -w . + golangci-lint run --fix fmt-proto: clang-format -i $(PROTO_SOURCES) diff --git a/flake.nix b/flake.nix index 853eb34b59..6e84031269 100644 --- a/flake.nix +++ b/flake.nix @@ -32,7 +32,7 @@ # When updating go.mod or go.sum, a new sha will need to be calculated, # update this if you have a mismatch after doing a change to thos files. - vendorHash = "sha256-OPgL2q13Hus6o9Npcp2bFiDiBZvbi/x8YVH6dU5q5fg="; + vendorHash = "sha256-NyXMSIVcmPlUhE3LmEsYZQxJdz+e435r+GZC8umQKqQ="; subPackages = ["cmd/headscale"]; diff --git a/go.mod b/go.mod index d880cfde8e..627804cdbe 100644 --- a/go.mod +++ b/go.mod @@ -117,7 +117,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/felixge/fgprof v0.9.5 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/fxamacker/cbor/v2 v2.6.0 // indirect github.com/gaissmai/bart v0.11.1 // indirect github.com/glebarez/go-sqlite v1.22.0 // indirect diff --git a/go.sum b/go.sum index 1149bab906..bc51d24057 100644 --- a/go.sum +++ b/go.sum @@ -157,6 +157,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= +github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= diff --git a/hscontrol/app.go b/hscontrol/app.go index 1651b8f211..629a2eb3b0 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -27,6 +27,7 @@ import ( "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/derp" derpServer "github.com/juanfont/headscale/hscontrol/derp/server" + "github.com/juanfont/headscale/hscontrol/dns" "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/policy" @@ -88,8 +89,9 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - polManOnce sync.Once - polMan policy.PolicyManager + polManOnce sync.Once + polMan policy.PolicyManager + extraRecordMan *dns.ExtraRecordsMan mapper *mapper.Mapper nodeNotifier *notifier.Notifier @@ -184,7 +186,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } app.authProvider = authProvider - if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS + if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS // TODO(kradalby): revisit why this takes a list. var magicDNSDomains []dnsname.FQDN @@ -196,11 +198,11 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } // we might have routes already from Split DNS - if app.cfg.DNSConfig.Routes == nil { - app.cfg.DNSConfig.Routes = make(map[string][]*dnstype.Resolver) + if app.cfg.TailcfgDNSConfig.Routes == nil { + app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver) } for _, d := range magicDNSDomains { - app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil + app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil } } @@ -237,23 +239,38 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { http.Redirect(w, req, target, http.StatusFound) } -// expireExpiredNodes expires nodes that have an explicit expiry set -// after that expiry time has passed. -func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) { - ticker := time.NewTicker(every) +func (h *Headscale) scheduledTasks(ctx context.Context) { + expireTicker := time.NewTicker(updateInterval) + defer expireTicker.Stop() - lastCheck := time.Unix(0, 0) - var update types.StateUpdate - var changed bool + lastExpiryCheck := time.Unix(0, 0) + + derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency) + defer derpTicker.Stop() + // If we dont want auto update, just stop the ticker + if !h.cfg.DERP.AutoUpdate { + derpTicker.Stop() + } + + var extraRecordsUpdate <-chan []tailcfg.DNSRecord + if h.extraRecordMan != nil { + extraRecordsUpdate = h.extraRecordMan.UpdateCh() + } else { + extraRecordsUpdate = make(chan []tailcfg.DNSRecord) + } for { select { case <-ctx.Done(): - ticker.Stop() + log.Info().Caller().Msg("scheduled task worker is shutting down.") return - case <-ticker.C: + + case <-expireTicker.C: + var update types.StateUpdate + var changed bool + if err := h.db.Write(func(tx *gorm.DB) error { - lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck) + lastExpiryCheck, update, changed = db.ExpireExpiredNodes(tx, lastExpiryCheck) return nil }); err != nil { @@ -267,24 +284,8 @@ func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) ctx := types.NotifyCtx(context.Background(), "expire-expired", "na") h.nodeNotifier.NotifyAll(ctx, update) } - } - } -} - -// scheduledDERPMapUpdateWorker refreshes the DERPMap stored on the global object -// at a set interval. -func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { - log.Info(). - Dur("frequency", h.cfg.DERP.UpdateFrequency). - Msg("Setting up a DERPMap update worker") - ticker := time.NewTicker(h.cfg.DERP.UpdateFrequency) - - for { - select { - case <-cancelChan: - return - case <-ticker.C: + case <-derpTicker.C: log.Info().Msg("Fetching DERPMap updates") h.DERPMap = derp.GetDERPMap(h.cfg.DERP) if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { @@ -297,6 +298,19 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { Type: types.StateDERPUpdated, DERPMap: h.DERPMap, }) + + case records, ok := <-extraRecordsUpdate: + if !ok { + continue + } + h.cfg.TailcfgDNSConfig.ExtraRecords = records + + ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all") + h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + // TODO(kradalby): We can probably do better than sending a full update here, + // but for now this will ensure that all of the nodes get the new records. + Type: types.StateFullUpdate, + }) } } } @@ -568,12 +582,6 @@ func (h *Headscale) Serve() error { go h.DERPServer.ServeSTUN() } - if h.cfg.DERP.AutoUpdate { - derpMapCancelChannel := make(chan struct{}) - defer func() { derpMapCancelChannel <- struct{}{} }() - go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) - } - if len(h.DERPMap.Regions) == 0 { return errEmptyInitialDERPMap } @@ -591,9 +599,21 @@ func (h *Headscale) Serve() error { h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout) } - expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background()) - defer expireNodeCancel() - go h.expireExpiredNodes(expireNodeCtx, updateInterval) + if h.cfg.DNSConfig.ExtraRecordsPath != "" { + h.extraRecordMan, err = dns.NewExtraRecordsManager(h.cfg.DNSConfig.ExtraRecordsPath) + if err != nil { + return fmt.Errorf("setting up extrarecord manager: %w", err) + } + h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records() + go h.extraRecordMan.Run() + defer h.extraRecordMan.Close() + } + + // Start all scheduled tasks, e.g. expiring nodes, derp updates and + // records updates + scheduleCtx, scheduleCancel := context.WithCancel(context.Background()) + defer scheduleCancel() + go h.scheduledTasks(scheduleCtx) if zl.GlobalLevel() == zl.TraceLevel { zerolog.RespLog = true @@ -847,7 +867,7 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received signal to stop, shutting down gracefully") - expireNodeCancel() + scheduleCancel() h.ephemeralGC.Close() // Gracefully shut down servers diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 2b23aad3d7..b4923ccb5c 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -390,7 +390,6 @@ func (h *Headscale) handleAuthKey( http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - } err = h.db.Write(func(tx *gorm.DB) error { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index bafe1e1b76..95c82160b7 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -373,6 +373,5 @@ func TestConstraints(t *testing.T) { tt.run(t, db.DB.Debug()) }) - } } diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go new file mode 100644 index 0000000000..73f646ba98 --- /dev/null +++ b/hscontrol/dns/extrarecords.go @@ -0,0 +1,155 @@ +package dns + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "os" + "sync" + + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" + "tailscale.com/util/set" +) + +type ExtraRecordsMan struct { + mu sync.RWMutex + records set.Set[tailcfg.DNSRecord] + watcher *fsnotify.Watcher + path string + + updateCh chan []tailcfg.DNSRecord + closeCh chan struct{} + hashes map[string][32]byte +} + +// NewExtraRecordsManager creates a new ExtraRecordsMan and starts watching the file at the given path. +func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("creating watcher: %w", err) + } + + fi, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("getting file info: %w", err) + } + + if fi.IsDir() { + return nil, fmt.Errorf("path is a directory, only file is supported: %s", path) + } + + records, hash, err := readExtraRecordsFromPath(path) + if err != nil { + return nil, fmt.Errorf("reading extra records from path: %w", err) + } + + er := &ExtraRecordsMan{ + watcher: watcher, + path: path, + records: set.SetOf(records), + hashes: map[string][32]byte{ + path: hash, + }, + closeCh: make(chan struct{}), + updateCh: make(chan []tailcfg.DNSRecord), + } + + err = watcher.Add(path) + if err != nil { + return nil, fmt.Errorf("adding path to watcher: %w", err) + } + + log.Trace().Caller().Strs("watching", watcher.WatchList()).Msg("started filewatcher") + + return er, nil +} + +func (e *ExtraRecordsMan) Records() []tailcfg.DNSRecord { + e.mu.RLock() + defer e.mu.RUnlock() + + return e.records.Slice() +} + +func (e *ExtraRecordsMan) Run() { + for { + select { + case <-e.closeCh: + return + case event, ok := <-e.watcher.Events: + if !ok { + log.Error().Caller().Msgf("file watcher event channel closing") + return + } + + log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event") + if event.Name != e.path { + continue + } + e.updateRecords() + + case err, ok := <-e.watcher.Errors: + if !ok { + log.Error().Caller().Msgf("file watcher error channel closing") + return + } + log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err) + } + } +} + +func (e *ExtraRecordsMan) Close() { + e.watcher.Close() + close(e.closeCh) +} + +func (e *ExtraRecordsMan) UpdateCh() <-chan []tailcfg.DNSRecord { + return e.updateCh +} + +func (e *ExtraRecordsMan) updateRecords() { + records, newHash, err := readExtraRecordsFromPath(e.path) + if err != nil { + log.Error().Caller().Err(err).Msgf("reading extra records from path: %s", e.path) + return + } + + e.mu.Lock() + defer e.mu.Unlock() + + // If there has not been any change, ignore the update. + if oldHash, ok := e.hashes[e.path]; ok { + if newHash == oldHash { + return + } + } + + oldCount := e.records.Len() + + e.records = set.SetOf(records) + e.hashes[e.path] = newHash + + log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len()) + e.updateCh <- e.records.Slice() +} + +// readExtraRecordsFromPath reads a JSON file of tailcfg.DNSRecord +// and returns the records and the hash of the file. +func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error) { + b, err := os.ReadFile(path) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("reading path: %s, err: %w", path, err) + } + + var records []tailcfg.DNSRecord + err = json.Unmarshal(b, &records) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err) + } + + hash := sha256.Sum256(b) + + return records, hash, nil +} diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 51c96f8c87..e18276ad6f 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -116,11 +116,11 @@ func generateDNSConfig( cfg *types.Config, node *types.Node, ) *tailcfg.DNSConfig { - if cfg.DNSConfig == nil { + if cfg.TailcfgDNSConfig == nil { return nil } - dnsConfig := cfg.DNSConfig.Clone() + dnsConfig := cfg.TailcfgDNSConfig.Clone() addNextDNSMetadata(dnsConfig.Resolvers, node) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 4ee8c6444e..55ab2ccbf7 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -117,7 +117,7 @@ func TestDNSConfigMapResponse(t *testing.T) { got := generateDNSConfig( &types.Config{ - DNSConfig: &dnsConfigOrig, + TailcfgDNSConfig: &dnsConfigOrig, }, nodeInShared1, ) @@ -349,7 +349,7 @@ func Test_fullMapResponse(t *testing.T) { derpMap: &tailcfg.DERPMap{}, cfg: &types.Config{ BaseDomain: "", - DNSConfig: &tailcfg.DNSConfig{}, + TailcfgDNSConfig: &tailcfg.DNSConfig{}, LogTail: types.LogTailConfig{Enabled: false}, RandomizeClientPort: false, }, @@ -381,7 +381,7 @@ func Test_fullMapResponse(t *testing.T) { derpMap: &tailcfg.DERPMap{}, cfg: &types.Config{ BaseDomain: "", - DNSConfig: &tailcfg.DNSConfig{}, + TailcfgDNSConfig: &tailcfg.DNSConfig{}, LogTail: types.LogTailConfig{Enabled: false}, RandomizeClientPort: false, }, @@ -424,7 +424,7 @@ func Test_fullMapResponse(t *testing.T) { derpMap: &tailcfg.DERPMap{}, cfg: &types.Config{ BaseDomain: "", - DNSConfig: &tailcfg.DNSConfig{}, + TailcfgDNSConfig: &tailcfg.DNSConfig{}, LogTail: types.LogTailConfig{Enabled: false}, RandomizeClientPort: false, }, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 9d7f1fedfb..96c008ab12 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -187,7 +187,7 @@ func TestTailNode(t *testing.T) { polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node}) cfg := &types.Config{ BaseDomain: tt.baseDomain, - DNSConfig: tt.dnsConfig, + TailcfgDNSConfig: tt.dnsConfig, RandomizeClientPort: false, } got, err := tailNode( diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 1db1ec079f..14191d23d0 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -447,7 +447,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( // This check is for legacy, if the user cannot be found by the OIDC identifier // look it up by username. This should only be needed once. - // This branch will presist for a number of versions after the OIDC migration and + // This branch will persist for a number of versions after the OIDC migration and // then be removed following a deprecation. // TODO(kradalby): Remove when strip_email_domain and migration is removed // after #2170 is cleaned up. @@ -536,7 +536,7 @@ func renderOIDCCallbackTemplate( // TODO(kradalby): Reintroduce when strip_email_domain is removed // after #2170 is cleaned up -// DEPRECATED: DO NOT USE +// DEPRECATED: DO NOT USE. func getUserName( claims *types.OIDCClaims, stripEmaildomain bool, diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 2af3989646..5c4b2c6ac2 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -72,7 +72,14 @@ type Config struct { ACMEURL string ACMEEmail string - DNSConfig *tailcfg.DNSConfig + // DNSConfig is the headscale representation of the DNS configuration. + // It is kept in the config update for some settings that are + // not directly converted into a tailcfg.DNSConfig. + DNSConfig DNSConfig + + // TailcfgDNSConfig is the tailcfg representation of the DNS configuration, + // it can be used directly when sending Netmaps to clients. + TailcfgDNSConfig *tailcfg.DNSConfig UnixSocket string UnixSocketPermission fs.FileMode @@ -90,11 +97,12 @@ type Config struct { } type DNSConfig struct { - MagicDNS bool `mapstructure:"magic_dns"` - BaseDomain string `mapstructure:"base_domain"` - Nameservers Nameservers - SearchDomains []string `mapstructure:"search_domains"` - ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"` + MagicDNS bool `mapstructure:"magic_dns"` + BaseDomain string `mapstructure:"base_domain"` + Nameservers Nameservers + SearchDomains []string `mapstructure:"search_domains"` + ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"` + ExtraRecordsPath string `mapstructure:"extra_records_path"` } type Nameservers struct { @@ -253,7 +261,6 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("dns.nameservers.global", []string{}) viper.SetDefault("dns.nameservers.split", map[string]string{}) viper.SetDefault("dns.search_domains", []string{}) - viper.SetDefault("dns.extra_records", []tailcfg.DNSRecord{}) viper.SetDefault("derp.server.enabled", false) viper.SetDefault("derp.server.stun.enabled", true) @@ -344,6 +351,10 @@ func validateServerConfig() error { } } + if viper.IsSet("dns.extra_records") && viper.IsSet("dns.extra_records_path") { + log.Fatal().Msg("Fatal config error: dns.extra_records and dns.extra_records_path are mutually exclusive. Please remove one of them from your config file") + } + // Collect any validation errors and return them all at once var errorText string if (viper.GetString("tls_letsencrypt_hostname") != "") && @@ -586,6 +597,7 @@ func dns() (DNSConfig, error) { dns.Nameservers.Global = viper.GetStringSlice("dns.nameservers.global") dns.Nameservers.Split = viper.GetStringMapStringSlice("dns.nameservers.split") dns.SearchDomains = viper.GetStringSlice("dns.search_domains") + dns.ExtraRecordsPath = viper.GetString("dns.extra_records_path") if viper.IsSet("dns.extra_records") { var extraRecords []tailcfg.DNSRecord @@ -871,7 +883,8 @@ func LoadServerConfig() (*Config, error) { TLS: tlsConfig(), - DNSConfig: dnsToTailcfgDNS(dnsConfig), + DNSConfig: dnsConfig, + TailcfgDNSConfig: dnsToTailcfgDNS(dnsConfig), ACMEEmail: viper.GetString("acme_email"), ACMEURL: viper.GetString("acme_url"), diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 58382ca5ab..511528df58 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -280,9 +280,9 @@ func TestReadConfigFromEnv(t *testing.T) { // "foo.bar.com": {"1.1.1.1"}, }, }, - ExtraRecords: []tailcfg.DNSRecord{ - // {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"}, - }, + // ExtraRecords: []tailcfg.DNSRecord{ + // {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"}, + // }, SearchDomains: []string{"test.com", "bar.com"}, }, }, diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index bf43eb507a..c6861c9e31 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -189,7 +189,6 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // NormalizeToFQDNRules will replace forbidden chars in user // it can also return an error if the user doesn't respect RFC 952 and 1123. func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { - name = strings.ToLower(name) name = strings.ReplaceAll(name, "'", "") atIdx := strings.Index(name, "@") diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 54aa05fbee..52d28054dc 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -154,7 +154,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { } sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].Id < listUsers[j].Id + return listUsers[i].GetId() < listUsers[j].GetId() }) if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { @@ -514,7 +514,7 @@ func TestOIDC024UserCreation(t *testing.T) { assertNoErr(t, err) sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].Id < listUsers[j].Id + return listUsers[i].GetId() < listUsers[j].GetId() }) if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { diff --git a/integration/dns_test.go b/integration/dns_test.go index efe702e9d9..7ae1c82bf6 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -1,6 +1,7 @@ package integration import ( + "encoding/json" "fmt" "strings" "testing" @@ -9,6 +10,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "tailscale.com/tailcfg" ) func TestResolveMagicDNS(t *testing.T) { @@ -81,6 +83,93 @@ func TestResolveMagicDNS(t *testing.T) { } } +func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + spec := map[string]int{ + "magicdns1": 1, + "magicdns2": 1, + } + + const erPath = "/tmp/extra_records.json" + + extraRecords := []tailcfg.DNSRecord{ + { + Name: "test.myvpn.example.com", + Type: "A", + Value: "6.6.6.6", + }, + } + b, _ := json.Marshal(extraRecords) + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{ + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", + "-c", + "/bin/sleep 3 ; apk add python3 curl bind-tools ; update-ca-certificates ; tailscaled --tun=tsdev", + }), + }, + hsic.WithTestName("extrarecords"), + hsic.WithConfigEnv(map[string]string{ + // Disable global nameservers to make the test run offline. + "HEADSCALE_DNS_NAMESERVERS_GLOBAL": "", + "HEADSCALE_DNS_EXTRA_RECORDS_PATH": erPath, + }), + hsic.WithFileInContainer(erPath, b), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + hsic.WithHostnameAsServerURL(), + ) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + // assertClientsState(t, allClients) + + // Poor mans cache + _, err = scenario.ListTailscaleClientsFQDNs() + assertNoErrListFQDN(t, err) + + _, err = scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + for _, client := range allClients { + assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6") + } + + extraRecords = append(extraRecords, tailcfg.DNSRecord{ + Name: "otherrecord.myvpn.example.com", + Type: "A", + Value: "7.7.7.7", + }) + b2, _ := json.Marshal(extraRecords) + + hs, err := scenario.Headscale() + assertNoErr(t, err) + + // Write it to a separate file to ensure Docker's API doesnt + // do anything unexpected and rather move it into place to trigger + // a reload. + err = hs.WriteFile(erPath+"2", b2) + assertNoErr(t, err) + _, err = hs.Execute([]string{"mv", erPath + "2", erPath}) + assertNoErr(t, err) + + for _, client := range allClients { + assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6") + assertCommandOutputContains(t, client, []string{"dig", "otherrecord.myvpn.example.com"}, "7.7.7.7") + } +} + // TestValidateResolvConf validates that the resolv.conf file // ends up as expected in our Tailscale containers. // All the containers are based on Alpine, meaning Tailscale diff --git a/integration/utils.go b/integration/utils.go index ec6aeecf79..0c151ae87c 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -3,6 +3,7 @@ package integration import ( "bufio" "bytes" + "fmt" "io" "os" "strings" @@ -10,6 +11,7 @@ import ( "testing" "time" + "github.com/cenkalti/backoff/v4" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" @@ -302,6 +304,30 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) { assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) } +// assertCommandOutputContains executes a command for a set time and asserts that the output +// reaches a desired state. +// It should be used instead of sleeping before executing. +func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { + t.Helper() + + err := backoff.Retry(func() error { + stdout, stderr, err := c.Execute(command) + if err != nil { + return fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err) + } + + if !strings.Contains(stdout, contains) { + return fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) + } + + return nil + }, backoff.NewExponentialBackOff( + backoff.WithMaxElapsedTime(10*time.Second)), + ) + + assert.NoError(t, err) +} + func isSelfClient(client TailscaleClient, addr string) bool { if addr == client.Hostname() { return true