diff --git a/changelog/17598.txt b/changelog/17598.txt new file mode 100644 index 000000000000..8171255f8e95 --- /dev/null +++ b/changelog/17598.txt @@ -0,0 +1,3 @@ +```release-note:improvement +core/config: reload service registration configuration on SIGHUP +``` diff --git a/command/server.go b/command/server.go index d8c5654a3717..fb2c04e62085 100644 --- a/command/server.go +++ b/command/server.go @@ -1685,6 +1685,15 @@ func (c *ServerCommand) Run(args []string) int { } } + // notify ServiceRegistration that a configuration reload has occurred + if sr := coreConfig.GetServiceRegistration(); sr != nil { + var srConfig *map[string]string + if config.ServiceRegistration != nil { + srConfig = &config.ServiceRegistration.Config + } + sr.NotifyConfigurationReload(srConfig) + } + if err := core.ReloadCensus(); err != nil { c.UI.Error(err.Error()) } diff --git a/serviceregistration/consul/consul_service_registration.go b/serviceregistration/consul/consul_service_registration.go index a3534e4ff37b..c59ed775cc99 100644 --- a/serviceregistration/consul/consul_service_registration.go +++ b/serviceregistration/consul/consul_service_registration.go @@ -51,7 +51,7 @@ const ( // reconcileTimeout is how often Vault should query Consul to detect // and fix any state drift. - reconcileTimeout = 60 * time.Second + DefaultReconcileTimeout = 60 * time.Second // metaExternalSource is a metadata value for external-source that can be // used by the Consul UI. @@ -64,9 +64,11 @@ var hostnameRegex = regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]* // Vault to Consul. type serviceRegistration struct { Client *api.Client + config *api.Config logger log.Logger serviceLock sync.RWMutex + registeredServiceID string redirectHost string redirectPort int64 serviceName string @@ -74,6 +76,7 @@ type serviceRegistration struct { serviceAddress *string disableRegistration bool checkTimeout time.Duration + reconcileTimeout time.Duration notifyActiveCh chan struct{} notifySealedCh chan struct{} @@ -92,90 +95,9 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr. return nil, errors.New("logger is required") } - // Allow admins to disable consul integration - disableReg, ok := conf["disable_registration"] - var disableRegistration bool - if ok && disableReg != "" { - b, err := parseutil.ParseBool(disableReg) - if err != nil { - return nil, fmt.Errorf("failed parsing disable_registration parameter: %w", err) - } - disableRegistration = b - } - if logger.IsDebug() { - logger.Debug("config disable_registration set", "disable_registration", disableRegistration) - } - - // Get the service name to advertise in Consul - service, ok := conf["service"] - if !ok { - service = DefaultServiceName - } - if !hostnameRegex.MatchString(service) { - return nil, errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes") - } - if logger.IsDebug() { - logger.Debug("config service set", "service", service) - } - - // Get the additional tags to attach to the registered service name - tags := conf["service_tags"] - if logger.IsDebug() { - logger.Debug("config service_tags set", "service_tags", tags) - } - - // Get the service-specific address to override the use of the HA redirect address - var serviceAddr *string - serviceAddrStr, ok := conf["service_address"] - if ok { - serviceAddr = &serviceAddrStr - } - if logger.IsDebug() { - logger.Debug("config service_address set", "service_address", serviceAddrStr) - } - - checkTimeout := defaultCheckTimeout - checkTimeoutStr, ok := conf["check_timeout"] - if ok { - d, err := parseutil.ParseDurationSecond(checkTimeoutStr) - if err != nil { - return nil, err - } - - min, _ := durationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor) - if min < checkMinBuffer { - return nil, fmt.Errorf("consul check_timeout must be greater than %v", min) - } - - checkTimeout = d - if logger.IsDebug() { - logger.Debug("config check_timeout set", "check_timeout", d) - } - } - - // Configure the client - consulConf := api.DefaultConfig() - // Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore - consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount - - SetupSecureTLS(context.Background(), consulConf, conf, logger, false) - - consulConf.HttpClient = &http.Client{Transport: consulConf.Transport} - client, err := api.NewClient(consulConf) - if err != nil { - return nil, fmt.Errorf("client setup failed: %w", err) - } - // Setup the backend c := &serviceRegistration{ - Client: client, - - logger: logger, - serviceName: service, - serviceTags: strutil.ParseDedupAndSortStrings(tags, ","), - serviceAddress: serviceAddr, - checkTimeout: checkTimeout, - disableRegistration: disableRegistration, + logger: logger, notifyActiveCh: make(chan struct{}), notifySealedCh: make(chan struct{}), @@ -187,7 +109,11 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr. isPerfStandby: atomicB.NewBool(state.IsPerformanceStandby), isInitialized: atomicB.NewBool(state.IsInitialized), } - return c, nil + + c.serviceLock.Lock() + defer c.serviceLock.Unlock() + err := c.merge(conf) + return c, err } func SetupSecureTLS(ctx context.Context, consulConf *api.Config, conf map[string]string, logger log.Logger, isDiagnose bool) error { @@ -270,6 +196,112 @@ func (c *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGro return nil } +func (c *serviceRegistration) merge(conf map[string]string) error { + // Allow admins to disable consul integration + disableReg, ok := conf["disable_registration"] + var disableRegistration bool + if ok && disableReg != "" { + b, err := parseutil.ParseBool(disableReg) + if err != nil { + return fmt.Errorf("failed parsing disable_registration parameter: %w", err) + } + disableRegistration = b + } + if c.logger.IsDebug() { + c.logger.Debug("config disable_registration set", "disable_registration", disableRegistration) + } + + // Get the service name to advertise in Consul + service, ok := conf["service"] + if !ok { + service = DefaultServiceName + } + if !hostnameRegex.MatchString(service) { + return errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes") + } + if c.logger.IsDebug() { + c.logger.Debug("config service set", "service", service) + } + + // Get the additional tags to attach to the registered service name + tags := conf["service_tags"] + if c.logger.IsDebug() { + c.logger.Debug("config service_tags set", "service_tags", tags) + } + + // Get the service-specific address to override the use of the HA redirect address + var serviceAddr *string + serviceAddrStr, ok := conf["service_address"] + if ok { + serviceAddr = &serviceAddrStr + } + if c.logger.IsDebug() { + c.logger.Debug("config service_address set", "service_address", serviceAddrStr) + } + + checkTimeout := defaultCheckTimeout + checkTimeoutStr, ok := conf["check_timeout"] + if ok { + d, err := parseutil.ParseDurationSecond(checkTimeoutStr) + if err != nil { + return err + } + + min, _ := durationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor) + if min < checkMinBuffer { + return fmt.Errorf("consul check_timeout must be greater than %v", min) + } + + checkTimeout = d + if c.logger.IsDebug() { + c.logger.Debug("config check_timeout set", "check_timeout", d) + } + } + + reconcileTimeout := DefaultReconcileTimeout + reconcileTimeoutStr, ok := conf["reconcile_timeout"] + if ok { + d, err := parseutil.ParseDurationSecond(reconcileTimeoutStr) + if err != nil { + return err + } + + min, _ := durationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor) + if min < checkMinBuffer { + return fmt.Errorf("consul reconcile_timeout must be greater than %v", min) + } + + reconcileTimeout = d + if c.logger.IsDebug() { + c.logger.Debug("config reconcile_timeout set", "reconcile_timeout", d) + } + } + + // Configure the client + consulConf := api.DefaultConfig() + // Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore + consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount + + SetupSecureTLS(context.Background(), consulConf, conf, c.logger, false) + + consulConf.HttpClient = &http.Client{Transport: consulConf.Transport} + client, err := api.NewClient(consulConf) + if err != nil { + return fmt.Errorf("client setup failed: %w", err) + } + + c.Client = client + c.config = consulConf + c.serviceName = service + c.serviceTags = strutil.ParseDedupAndSortStrings(tags, ",") + c.serviceAddress = serviceAddr + c.checkTimeout = checkTimeout + c.disableRegistration = disableRegistration + c.reconcileTimeout = reconcileTimeout + + return nil +} + func (c *serviceRegistration) NotifyActiveStateChange(isActive bool) error { c.isActive.Store(isActive) select { @@ -322,6 +354,25 @@ func (c *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) e return nil } +func (c *serviceRegistration) NotifyConfigurationReload(conf *map[string]string) error { + c.serviceLock.Lock() + defer c.serviceLock.Unlock() + if conf == nil { + if c.logger.IsDebug() { + c.logger.Debug("registration is now empty, deregistering service from consul") + } + c.disableRegistration = true + err := c.deregisterService() + c.Client = nil + return err + } else { + if c.logger.IsDebug() { + c.logger.Debug("service registration configuration received, merging with existing configuation") + } + return c.merge(*conf) + } +} + func (c *serviceRegistration) checkDuration() time.Duration { return durationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor) } @@ -363,7 +414,6 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow // and end of a handler's life (or after a handler wakes up from // sleeping during a back-off/retry). var shutdown atomicB.Bool - var registeredServiceID string checkLock := new(int32) serviceRegLock := new(int32) @@ -383,16 +433,19 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow checkTimer.Reset(0) case <-reconcileTimer.C: // Unconditionally rearm the reconcileTimer - reconcileTimer.Reset(reconcileTimeout - randomStagger(reconcileTimeout/checkJitterFactor)) + c.serviceLock.RLock() + reconcileTimer.Reset(c.reconcileTimeout - randomStagger(c.reconcileTimeout/checkJitterFactor)) + disableRegistration := c.disableRegistration + c.serviceLock.RUnlock() // Abort if service discovery is disabled or a // reconcile handler is already active - if !c.disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) { + if !disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) { // Enter handler with serviceRegLock held go func() { defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0) for !shutdown.Load() { - serviceID, err := c.reconcileConsul(registeredServiceID) + serviceID, err := c.reconcileConsul() if err != nil { if c.logger.IsWarn() { c.logger.Warn("reconcile unable to talk with Consul backend", "error", err) @@ -402,7 +455,7 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow } c.serviceLock.Lock() - registeredServiceID = serviceID + c.registeredServiceID = serviceID c.serviceLock.Unlock() return @@ -411,19 +464,29 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow } case <-checkTimer.C: checkTimer.Reset(c.checkDuration()) + c.serviceLock.RLock() + disableRegistration := c.disableRegistration + c.serviceLock.RUnlock() + // Abort if service discovery is disabled or a // reconcile handler is active - if !c.disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) { + if !disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) { // Enter handler with checkLock held go func() { defer atomic.CompareAndSwapInt32(checkLock, 1, 0) for !shutdown.Load() { - if err := c.runCheck(c.isSealed.Load()); err != nil { - if c.logger.IsWarn() { - c.logger.Warn("check unable to talk with Consul backend", "error", err) + c.serviceLock.RLock() + registeredServiceID := c.registeredServiceID + c.serviceLock.RUnlock() + + if registeredServiceID != "" { + if err := c.runCheck(c.isSealed.Load()); err != nil { + if c.logger.IsWarn() { + c.logger.Warn("check unable to talk with Consul backend", "error", err) + } + time.Sleep(consulRetryInterval) + continue } - time.Sleep(consulRetryInterval) - continue } return } @@ -435,13 +498,23 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow } } - c.serviceLock.RLock() - defer c.serviceLock.RUnlock() - if err := c.Client.Agent().ServiceDeregister(registeredServiceID); err != nil { - if c.logger.IsWarn() { - c.logger.Warn("service deregistration failed", "error", err) + c.serviceLock.Lock() + defer c.serviceLock.Unlock() + c.deregisterService() +} + +func (c *serviceRegistration) deregisterService() error { + if c.registeredServiceID != "" { + if err := c.Client.Agent().ServiceDeregister(c.registeredServiceID); err != nil { + if c.logger.IsWarn() { + c.logger.Warn("service deregistration failed", "error", err) + } + return err } + c.registeredServiceID = "" } + + return nil } // checkID returns the ID used for a Consul Check. Assume at least a read @@ -458,10 +531,12 @@ func (c *serviceRegistration) serviceID() string { // reconcileConsul queries the state of Vault Core and Consul and fixes up // Consul's state according to what's in Vault. reconcileConsul is called -// without any locks held and can be run concurrently, therefore no changes +// with a read lock and can be run concurrently, therefore no changes // to serviceRegistration can be made in this method (i.e. wtb const receiver for // compiler enforced safety). -func (c *serviceRegistration) reconcileConsul(registeredServiceID string) (serviceID string, err error) { +func (c *serviceRegistration) reconcileConsul() (serviceID string, err error) { + c.serviceLock.RLock() + defer c.serviceLock.RUnlock() agent := c.Client.Agent() catalog := c.Client.Catalog() @@ -483,7 +558,7 @@ func (c *serviceRegistration) reconcileConsul(registeredServiceID string) (servi var reregister bool switch { - case currentVaultService == nil, registeredServiceID == "": + case currentVaultService == nil, c.registeredServiceID == "": reregister = true default: switch { diff --git a/serviceregistration/consul/consul_service_registration_test.go b/serviceregistration/consul/consul_service_registration_test.go index bd41890be8f1..8dcc3d66952b 100644 --- a/serviceregistration/consul/consul_service_registration_test.go +++ b/serviceregistration/consul/consul_service_registration_test.go @@ -63,6 +63,17 @@ func TestConsul_ServiceRegistration(t *testing.T) { t.Fatal(err) } + // update the agent's ACL token so that we can successfully deregister the + // service later in the test + _, err = client.Agent().UpdateAgentACLToken(config.Token, nil) + if err != nil { + t.Fatal(err) + } + _, err = client.Agent().UpdateDefaultACLToken(config.Token, nil) + if err != nil { + t.Fatal(err) + } + // waitForServices waits for the services in the Consul catalog to // reach an expected value, returning the delta if that doesn't happen in time. waitForServices := func(t *testing.T, expected map[string][]string) map[string][]string { @@ -92,10 +103,13 @@ func TestConsul_ServiceRegistration(t *testing.T) { // Create a ServiceRegistration that points to our consul instance logger := logging.NewVaultLogger(log.Trace) - sd, err := NewServiceRegistration(map[string]string{ + srConfig := map[string]string{ "address": config.Address(), "token": config.Token, - }, logger, sr.State{}) + // decrease reconcile timeout to make test run faster + "reconcile_timeout": "1s", + } + sd, err := NewServiceRegistration(srConfig, logger, sr.State{}) if err != nil { t.Fatal(err) } @@ -147,6 +161,58 @@ func TestConsul_ServiceRegistration(t *testing.T) { "consul": {}, "vault": {"active", "initialized"}, }) + + // change the token and trigger reload + if sd.(*serviceRegistration).config.Token == "" { + t.Fatal("expected service registration token to not be '' before configuration reload") + } + + srConfigWithoutToken := make(map[string]string) + for k, v := range srConfig { + srConfigWithoutToken[k] = v + } + srConfigWithoutToken["token"] = "" + err = sd.NotifyConfigurationReload(&srConfigWithoutToken) + if err != nil { + t.Fatal(err) + } + + if sd.(*serviceRegistration).config.Token != "" { + t.Fatal("expected service registration token to be '' after configuration reload") + } + + // reconfigure the configuration back to its original state and verify vault is registered + err = sd.NotifyConfigurationReload(&srConfig) + if err != nil { + t.Fatal(err) + } + + waitForServices(t, map[string][]string{ + "consul": {}, + "vault": {"active", "initialized"}, + }) + + // send 'nil' configuration to verify that the service is deregistered + err = sd.NotifyConfigurationReload(nil) + if err != nil { + t.Fatal(err) + } + + waitForServices(t, map[string][]string{ + "consul": {}, + }) + + // reconfigure the configuration back to its original state and verify vault + // is re-registered + err = sd.NotifyConfigurationReload(&srConfig) + if err != nil { + t.Fatal(err) + } + + waitForServices(t, map[string][]string{ + "consul": {}, + "vault": {"active", "initialized"}, + }) } func TestConsul_ServiceAddress(t *testing.T) { diff --git a/serviceregistration/kubernetes/service_registration.go b/serviceregistration/kubernetes/service_registration.go index 1c22888016a3..8b52023b001d 100644 --- a/serviceregistration/kubernetes/service_registration.go +++ b/serviceregistration/kubernetes/service_registration.go @@ -106,6 +106,10 @@ func (r *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) e return nil } +func (c *serviceRegistration) NotifyConfigurationReload(conf *map[string]string) error { + return nil +} + func getRequiredField(logger hclog.Logger, config map[string]string, envVar, configParam string) (string, error) { value := "" switch { diff --git a/serviceregistration/service_registration.go b/serviceregistration/service_registration.go index 4eb560793d42..394892e84768 100644 --- a/serviceregistration/service_registration.go +++ b/serviceregistration/service_registration.go @@ -96,4 +96,14 @@ type ServiceRegistration interface { // the implementation's responsibility to retry updating state // in the face of errors. NotifyInitializedStateChange(isInitialized bool) error + + // NotifyConfigurationReload is used by Core to notify that the Vault + // configuration has been reloaded. + // If errors are returned, Vault only logs a warning, so it is + // the implementation's responsibility to retry updating state + // in the face of errors. + // + // If the passed in conf is nil, it is assumed that the service registration + // configuration no longer exits and should be deregistered. + NotifyConfigurationReload(conf *map[string]string) error } diff --git a/vault/core_test.go b/vault/core_test.go index 7493a25fbcee..8b34dd3340d7 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -3286,11 +3286,12 @@ func TestCore_HandleRequest_TokenCreate_RegisterAuthFailure(t *testing.T) { // mockServiceRegistration helps test whether standalone ServiceRegistration works type mockServiceRegistration struct { - notifyActiveCount int - notifySealedCount int - notifyPerfCount int - notifyInitCount int - runDiscoveryCount int + notifyActiveCount int + notifySealedCount int + notifyPerfCount int + notifyInitCount int + notifyConfigurationReload int + runDiscoveryCount int } func (m *mockServiceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, redirectAddr string) error { @@ -3318,6 +3319,11 @@ func (m *mockServiceRegistration) NotifyInitializedStateChange(isInitialized boo return nil } +func (m *mockServiceRegistration) NotifyConfigurationReload(config *map[string]string) error { + m.notifyConfigurationReload++ + return nil +} + // TestCore_ServiceRegistration tests whether standalone ServiceRegistration works func TestCore_ServiceRegistration(t *testing.T) { // Make a mock service discovery @@ -3374,10 +3380,11 @@ func TestCore_ServiceRegistration(t *testing.T) { // Vault should be registered, unsealed, and active if diff := deep.Equal(sr, &mockServiceRegistration{ - runDiscoveryCount: 1, - notifyActiveCount: 1, - notifySealedCount: 1, - notifyInitCount: 1, + runDiscoveryCount: 1, + notifyActiveCount: 1, + notifySealedCount: 1, + notifyInitCount: 1, + notifyConfigurationReload: 1, }); diff != nil { t.Fatal(diff) }