diff --git a/pkg/config/tlscfg/cert_watcher.go b/pkg/config/tlscfg/cert_watcher.go index 82d0872c03f..5383e153e34 100644 --- a/pkg/config/tlscfg/cert_watcher.go +++ b/pkg/config/tlscfg/cert_watcher.go @@ -15,10 +15,13 @@ package tlscfg import ( + "crypto/sha256" "crypto/tls" "crypto/x509" "fmt" "io" + "os" + "path" "path/filepath" "sync" @@ -31,11 +34,15 @@ import ( // The certificate and key can be obtained via certWatcher.certificate. // The consumers of this API should use GetCertificate or GetClientCertificate from tls.Config to supply the certificate to the config. type certWatcher struct { - opts Options - watcher *fsnotify.Watcher - cert *tls.Certificate - logger *zap.Logger - mu *sync.RWMutex + mu sync.RWMutex + opts Options + logger *zap.Logger + watcher *fsnotify.Watcher + cert *tls.Certificate + caHash string + clientCAHash string + certHash string + keyHash string } var _ io.Closer = (*certWatcher)(nil) @@ -50,21 +57,24 @@ func newCertWatcher(opts Options, logger *zap.Logger) (*certWatcher, error) { } cert = &c } + watcher, err := fsnotify.NewWatcher() if err != nil { return nil, err } - if err := addCertsToWatch(watcher, opts); err != nil { + + w := &certWatcher{ + opts: opts, + logger: logger, + cert: cert, + watcher: watcher, + } + + if err := w.setupWatchedPaths(); err != nil { watcher.Close() return nil, err } - return &certWatcher{ - cert: cert, - opts: opts, - watcher: watcher, - logger: logger, - mu: &sync.RWMutex{}, - }, nil + return w, nil } func (w *certWatcher) Close() error { @@ -77,81 +87,152 @@ func (w *certWatcher) certificate() *tls.Certificate { return w.cert } +// setupWatchedPaths retrieves hashes of all configured certificates +// and adds their parent directories to the watcher. +func (w *certWatcher) setupWatchedPaths() error { + uniqueDirs := make(map[string]bool) + addPath := func(certPath string, hashPtr *string) error { + if certPath == "" { + return nil + } + if h, err := hashFile(certPath); err == nil { + *hashPtr = h + } else { + return err + } + dir := path.Dir(certPath) + if _, ok := uniqueDirs[dir]; !ok { + w.watcher.Add(dir) + uniqueDirs[dir] = true + } + return nil + } + + if err := addPath(w.opts.CAPath, &w.caHash); err != nil { + return err + } + if err := addPath(w.opts.ClientCAPath, &w.clientCAHash); err != nil { + return err + } + if err := addPath(w.opts.CertPath, &w.certHash); err != nil { + return err + } + if err := addPath(w.opts.KeyPath, &w.keyHash); err != nil { + return err + } + return nil +} + +// watchChangesLoop waits for notifications of changes in the watched directories +// and attempts to reload all certificates that changed. +// +// Write and Rename events indicate that some files might have changed and reload might be necessary. +// Remove event indicates that the file was deleted and we should write an error to log. +// +// Reasoning: +// +// Write event is sent if the file content is rewritten. +// +// Usually files are not rewritten, but they are updated by swapping them with new +// ones by calling Rename. That avoids files being read while they are not yet +// completely written but it also means that inotify on file level will not work: +// watch is invalidated when the old file is deleted. +// +// If reading from Kubernetes Secret volumes the target files are symbolic links +// to files in a different directory. That directory is swapped with a new one, +// while the symbolic links remain the same. This guarantees atomic swap for all +// files at once, but it also means any Rename event in the directory might +// indicate that the files were replaced, even if event.Name is not any of the +// files we are monitoring. We check the hashes of the files to detect if they +// were really changed. func (w *certWatcher) watchChangesLoop(rootCAs, clientCAs *x509.CertPool) { for { select { case event, ok := <-w.watcher.Events: if !ok { - return - } - // ignore if the event is a chmod event (permission or owner changes) - if event.Op&fsnotify.Chmod == fsnotify.Chmod { - continue + return // channel closed means the watcher is closed } - if event.Op&fsnotify.Remove == fsnotify.Remove { - w.logger.Warn("Certificate has been removed, using the last known version", - zap.String("certificate", event.Name)) - continue - } - - w.logger.Info("Loading modified certificate", - zap.String("certificate", event.Name), - zap.String("event", event.Op.String())) - var err error - switch event.Name { - case w.opts.CAPath: - err = addCertToPool(w.opts.CAPath, rootCAs) - case w.opts.ClientCAPath: - err = addCertToPool(w.opts.ClientCAPath, clientCAs) - case w.opts.CertPath, w.opts.KeyPath: - w.mu.Lock() - c, e := tls.LoadX509KeyPair(filepath.Clean(w.opts.CertPath), filepath.Clean(w.opts.KeyPath)) - if e == nil { - w.cert = &c - } - w.mu.Unlock() - err = e + w.logger.Debug("Received event", zap.String("event", event.String())) + if event.Op&fsnotify.Write == fsnotify.Write || + event.Op&fsnotify.Rename == fsnotify.Rename || + event.Op&fsnotify.Remove == fsnotify.Remove { + w.attemptReload(rootCAs, clientCAs) } - if err == nil { - w.logger.Info("Loaded modified certificate", - zap.String("certificate", event.Name), - zap.String("event", event.Op.String())) - } else { - w.logger.Error("Failed to load certificate", - zap.String("certificate", event.Name), - zap.String("event", event.Op.String()), - zap.Error(err)) + case err, ok := <-w.watcher.Errors: + if !ok { + return // channel closed means the watcher is closed } - case err := <-w.watcher.Errors: w.logger.Error("Watcher got error", zap.Error(err)) } } } -func addCertsToWatch(watcher *fsnotify.Watcher, opts Options) error { - if len(opts.CAPath) != 0 { - err := watcher.Add(opts.CAPath) - if err != nil { - return err +// attemptReload checks if the watched files have been modified and reloads them if necessary. +func (w *certWatcher) attemptReload(rootCAs, clientCAs *x509.CertPool) { + w.reloadIfModified(w.opts.CAPath, &w.caHash, rootCAs) + w.reloadIfModified(w.opts.ClientCAPath, &w.clientCAHash, clientCAs) + + isCertModified, newCertHash := w.isModified(w.opts.CertPath, w.certHash) + isKeyModified, newKeyHash := w.isModified(w.opts.KeyPath, w.keyHash) + if isCertModified || isKeyModified { + c, err := tls.LoadX509KeyPair(filepath.Clean(w.opts.CertPath), filepath.Clean(w.opts.KeyPath)) + if err == nil { + w.mu.Lock() + w.cert = &c + w.certHash = newCertHash + w.keyHash = newKeyHash + w.mu.Unlock() + w.logger.Info("Loaded modified certificate", zap.String("certificate", w.opts.CertPath)) + w.logger.Info("Loaded modified certificate", zap.String("certificate", w.opts.KeyPath)) + } else { + w.logger.Error( + "Failed to load certificate pair", + zap.String("certificate", w.opts.CertPath), + zap.String("key", w.opts.KeyPath), + zap.Error(err), + ) } } - if len(opts.ClientCAPath) != 0 { - err := watcher.Add(opts.ClientCAPath) - if err != nil { - return err +} + +func (w *certWatcher) reloadIfModified(certPath string, certHash *string, certPool *x509.CertPool) { + if mod, newHash := w.isModified(certPath, *certHash); mod { + if err := addCertToPool(certPath, certPool); err == nil { + w.mu.Lock() + *certHash = newHash + w.mu.Unlock() + w.logger.Info("Loaded modified certificate", zap.String("certificate", certPath)) + } else { + w.logger.Error("Failed to load certificate", zap.String("certificate", certPath), zap.Error(err)) } } - if len(opts.CertPath) != 0 { - err := watcher.Add(opts.CertPath) - if err != nil { - return err - } +} + +// isModified returns true if the file has been modified since the last check. +func (w *certWatcher) isModified(file string, previousHash string) (bool, string) { + if file == "" { + return false, "" } - if len(opts.KeyPath) != 0 { - err := watcher.Add(opts.KeyPath) - if err != nil { - return err - } + hash, err := hashFile(file) + if err != nil { + w.logger.Warn("Certificate has been removed, using the last known version", zap.String("certificate", file)) + return false, "" } - return nil + return previousHash != hash, hash +} + +// hashFile returns the SHA256 hash of the file. +func hashFile(file string) (string, error) { + f, err := os.Open(filepath.Clean(file)) + if err != nil { + return "", err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + + return fmt.Sprintf("%x", h.Sum(nil)), nil } diff --git a/pkg/config/tlscfg/cert_watcher_test.go b/pkg/config/tlscfg/cert_watcher_test.go index 1c81b9f50e2..0030ec7db18 100644 --- a/pkg/config/tlscfg/cert_watcher_test.go +++ b/pkg/config/tlscfg/cert_watcher_test.go @@ -17,6 +17,7 @@ package tlscfg import ( "crypto/tls" "crypto/x509" + "fmt" "os" "path/filepath" "testing" @@ -36,29 +37,42 @@ const ( clientCert = "./testdata/example-client-cert.pem" clientKey = "./testdata/example-client-key.pem" - caCert = "./testdata/example-CA-cert.pem" - badCaCert = "./testdata/bad-CA-cert.txt" + caCert = "./testdata/example-CA-cert.pem" + wrongCaCert = "./testdata/wrong-CA-cert.pem" + badCaCert = "./testdata/bad-CA-cert.txt" ) -func TestReload(t *testing.T) { - // copy certs to temp so we can modify them - certFile, err := os.CreateTemp("", "cert.crt") - require.NoError(t, err) - defer os.Remove(certFile.Name()) - certData, err := os.ReadFile(serverCert) +func copyToTempFile(t *testing.T, pattern string, filename string) (file *os.File, closeFn func()) { + tempFile, err := os.CreateTemp("", pattern) require.NoError(t, err) - _, err = certFile.Write(certData) + + data, err := os.ReadFile(filename) require.NoError(t, err) - certFile.Close() - keyFile, err := os.CreateTemp("", "key.crt") + _, err = tempFile.Write(data) require.NoError(t, err) - defer os.Remove(keyFile.Name()) - keyData, err := os.ReadFile(serverKey) + require.NoError(t, tempFile.Close()) + + return tempFile, func() { + // ignore error because some tests may remove the files earlier + _ = os.Remove(tempFile.Name()) + } +} + +func copyFile(t *testing.T, dest string, src string) { + certData, err := os.ReadFile(src) require.NoError(t, err) - _, err = keyFile.Write(keyData) + err = syncWrite(dest, certData, 0o644) require.NoError(t, err) - keyFile.Close() +} + +func TestReload(t *testing.T) { + // copy certs to temp so we can modify them + certFile, certFileCloseFn := copyToTempFile(t, "cert.crt", serverCert) + defer certFileCloseFn() + + keyFile, keyFileCloseFn := copyToTempFile(t, "key.crt", serverKey) + defer keyFileCloseFn() zcore, logObserver := observer.New(zapcore.InfoLevel) logger := zap.New(zcore) @@ -81,39 +95,27 @@ func TestReload(t *testing.T) { assert.Equal(t, &cert, watcher.certificate()) // Write the client's public key. - certData, err = os.ReadFile(clientCert) - require.NoError(t, err) - err = syncWrite(certFile.Name(), certData, 0o644) - require.NoError(t, err) - - waitUntil(func() bool { - // Logged when the cert is reloaded with mismatching client public key and existing server private key. - return logObserver.FilterMessage("Failed to load certificate"). - FilterField(zap.String("certificate", certFile.Name())).Len() > 0 - }, 2000, time.Millisecond*10) + copyFile(t, certFile.Name(), clientCert) - assert.True(t, logObserver. - FilterMessage("Failed to load certificate"). - FilterField(zap.String("certificate", certFile.Name())).Len() > 0, - "Unable to locate 'Failed to load certificate' in log. All logs: %v", logObserver.All()) + assertLogs(t, + func() bool { + // Logged when the cert is reloaded with mismatching client public key and existing server private key. + return logObserver.FilterMessage("Failed to load certificate pair"). + FilterField(zap.String("certificate", certFile.Name())).Len() > 0 + }, + "Unable to locate 'Failed to load certificate pair' in log. All logs: %v", logObserver) // Write the client's private key. - keyData, err = os.ReadFile(clientKey) - require.NoError(t, err) - err = syncWrite(keyFile.Name(), keyData, 0o644) - require.NoError(t, err) - - waitUntil(func() bool { - // Logged when the client private key is modified in the cert which enables successful reloading of - // the cert as both private and public keys now match. - return logObserver.FilterMessage("Loaded modified certificate"). - FilterField(zap.String("certificate", keyFile.Name())).Len() > 0 - }, 2000, time.Millisecond*10) - - assert.True(t, logObserver. - FilterMessage("Loaded modified certificate"). - FilterField(zap.String("certificate", keyFile.Name())).Len() > 0, - "Unable to locate 'Loaded modified certificate' in log. All logs: %v", logObserver.All()) + copyFile(t, keyFile.Name(), clientKey) + + assertLogs(t, + func() bool { + // Logged when the client private key is modified in the cert which enables successful reloading of + // the cert as both private and public keys now match. + return logObserver.FilterMessage("Loaded modified certificate"). + FilterField(zap.String("certificate", keyFile.Name())).Len() > 0 + }, + "Unable to locate 'Loaded modified certificate' in log. All logs: %v", logObserver) cert, err = tls.LoadX509KeyPair(filepath.Clean(clientCert), clientKey) require.NoError(t, err) @@ -122,23 +124,10 @@ func TestReload(t *testing.T) { func TestReload_ca_certs(t *testing.T) { // copy certs to temp so we can modify them - caFile, err := os.CreateTemp("", "cert.crt") - require.NoError(t, err) - defer os.Remove(caFile.Name()) - caData, err := os.ReadFile(caCert) - require.NoError(t, err) - _, err = caFile.Write(caData) - require.NoError(t, err) - caFile.Close() - - clientCaFile, err := os.CreateTemp("", "key.crt") - require.NoError(t, err) - defer os.Remove(clientCaFile.Name()) - clientCaData, err := os.ReadFile(caCert) - require.NoError(t, err) - _, err = clientCaFile.Write(clientCaData) - require.NoError(t, err) - clientCaFile.Close() + caFile, caFileCloseFn := copyToTempFile(t, "cert.crt", caCert) + defer caFileCloseFn() + clientCaFile, clientCaFileClostFn := copyToTempFile(t, "key.crt", caCert) + defer clientCaFileClostFn() zcore, logObserver := observer.New(zapcore.InfoLevel) logger := zap.New(zcore) @@ -154,46 +143,29 @@ func TestReload_ca_certs(t *testing.T) { require.NoError(t, err) go watcher.watchChangesLoop(certPool, certPool) - // update the content with client certs - caData, err = os.ReadFile(caCert) - require.NoError(t, err) - err = syncWrite(caFile.Name(), caData, 0o644) - require.NoError(t, err) - clientCaData, err = os.ReadFile(caCert) - require.NoError(t, err) - err = syncWrite(clientCaFile.Name(), clientCaData, 0o644) - require.NoError(t, err) + // update the content with different certs to trigger reload. + copyFile(t, caFile.Name(), wrongCaCert) + copyFile(t, clientCaFile.Name(), wrongCaCert) - waitUntil(func() bool { - return logObserver.FilterField(zap.String("certificate", caFile.Name())).Len() > 0 - }, 100, time.Millisecond*200) - assert.True(t, logObserver.FilterField(zap.String("certificate", caFile.Name())).Len() > 0) + assertLogs(t, + func() bool { + return logObserver.FilterField(zap.String("certificate", caFile.Name())).Len() > 0 + }, + "Unable to locate 'certificate' in log. All logs: %v", logObserver) - waitUntil(func() bool { - return logObserver.FilterField(zap.String("certificate", clientCaFile.Name())).Len() > 0 - }, 100, time.Millisecond*200) - assert.True(t, logObserver.FilterField(zap.String("certificate", clientCaFile.Name())).Len() > 0) + assertLogs(t, + func() bool { + return logObserver.FilterField(zap.String("certificate", clientCaFile.Name())).Len() > 0 + }, + "Unable to locate 'certificate' in log. All logs: %v", logObserver) } func TestReload_err_cert_update(t *testing.T) { // copy certs to temp so we can modify them - certFile, err := os.CreateTemp("", "cert.crt") - require.NoError(t, err) - defer os.Remove(certFile.Name()) - certData, err := os.ReadFile(serverCert) - require.NoError(t, err) - _, err = certFile.Write(certData) - require.NoError(t, err) - certFile.Close() - - keyFile, err := os.CreateTemp("", "key.crt") - require.NoError(t, err) - defer os.Remove(keyFile.Name()) - keyData, err := os.ReadFile(serverKey) - require.NoError(t, err) - _, err = keyFile.Write(keyData) - require.NoError(t, err) - keyFile.Close() + certFile, certFileCloseFn := copyToTempFile(t, "cert.crt", serverCert) + defer certFileCloseFn() + keyFile, keyFileCloseFn := copyToTempFile(t, "cert.crt", serverKey) + defer keyFileCloseFn() zcore, logObserver := observer.New(zapcore.InfoLevel) logger := zap.New(zcore) @@ -215,20 +187,15 @@ func TestReload_err_cert_update(t *testing.T) { require.NoError(t, err) assert.Equal(t, &serverCert, watcher.certificate()) - // update the content with client certs - certData, err = os.ReadFile(badCaCert) - require.NoError(t, err) - err = syncWrite(certFile.Name(), certData, 0o644) - require.NoError(t, err) - keyData, err = os.ReadFile(clientKey) - require.NoError(t, err) - err = syncWrite(keyFile.Name(), keyData, 0o644) - require.NoError(t, err) + // update the content with bad client certs + copyFile(t, certFile.Name(), badCaCert) + copyFile(t, keyFile.Name(), clientKey) - waitUntil(func() bool { - return logObserver.FilterMessage("Failed to load certificate").Len() > 0 - }, 100, time.Millisecond*200) - assert.True(t, logObserver.FilterField(zap.String("certificate", certFile.Name())).Len() > 0) + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Failed to load certificate pair"). + FilterField(zap.String("certificate", certFile.Name())).Len() > 0 + }, "Unable to locate 'Failed to load certificate pair' in log. All logs: %v", logObserver) assert.Equal(t, &serverCert, watcher.certificate()) } @@ -242,10 +209,143 @@ func TestReload_err_watch(t *testing.T) { assert.Nil(t, watcher) } +func TestReload_kubernetes_secret_update(t *testing.T) { + mountDir, err := os.MkdirTemp("", "secret-mountpoint") + require.NoError(t, err) + defer os.RemoveAll(mountDir) + + // Create directory layout before update: + // + // /secret-mountpoint/ca.crt # symbolic link to ..data/ca.crt + // /secret-mountpoint/tls.crt # symbolic link to ..data/tls.crt + // /secret-mountpoint/tls.key # symbolic link to ..data/tls.key + // /secret-mountpoint/..data # symbolic link to ..timestamp-1 + // /secret-mountpoint/..timestamp-1 # directory + // /secret-mountpoint/..timestamp-1/ca.crt # initial version of ca.crt + // /secret-mountpoint/..timestamp-1/tls.crt # initial version of tls.crt + // /secret-mountpoint/..timestamp-1/tls.key # initial version of tls.key + + err = os.Symlink("..timestamp-1", filepath.Join(mountDir, "..data")) + require.NoError(t, err) + err = os.Symlink(filepath.Join("..data", "ca.crt"), filepath.Join(mountDir, "ca.crt")) + require.NoError(t, err) + err = os.Symlink(filepath.Join("..data", "tls.crt"), filepath.Join(mountDir, "tls.crt")) + require.NoError(t, err) + err = os.Symlink(filepath.Join("..data", "tls.key"), filepath.Join(mountDir, "tls.key")) + require.NoError(t, err) + + timestamp1Dir := filepath.Join(mountDir, "..timestamp-1") + createTimestampDir(t, timestamp1Dir, caCert, serverCert, serverKey) + + opts := Options{ + CAPath: filepath.Join(mountDir, "ca.crt"), + ClientCAPath: filepath.Join(mountDir, "ca.crt"), + CertPath: filepath.Join(mountDir, "tls.crt"), + KeyPath: filepath.Join(mountDir, "tls.key"), + } + + zcore, logObserver := observer.New(zapcore.InfoLevel) + logger := zap.New(zcore) + watcher, err := newCertWatcher(opts, logger) + require.NoError(t, err) + defer watcher.Close() + + certPool := x509.NewCertPool() + require.NoError(t, err) + go watcher.watchChangesLoop(certPool, certPool) + + expectedCert, err := tls.LoadX509KeyPair(serverCert, serverKey) + require.NoError(t, err) + + assert.Equal(t, expectedCert.Certificate, watcher.certificate().Certificate, + "certificate should be updated: %v", logObserver.All()) + + // After the update, the directory looks like following: + // + // /secret-mountpoint/ca.crt # symbolic link to ..data/ca.crt + // /secret-mountpoint/tls.crt # symbolic link to ..data/tls.crt + // /secret-mountpoint/tls.key # symbolic link to ..data/tls.key + // /secret-mountpoint/..data # symbolic link to ..timestamp-2 + // /secret-mountpoint/..timestamp-2 # new directory + // /secret-mountpoint/..timestamp-2/ca.crt # new version of ca.crt + // /secret-mountpoint/..timestamp-2/tls.crt # new version of tls.crt + // /secret-mountpoint/..timestamp-2/tls.key # new version of tls.key + logObserver.TakeAll() + + timestamp2Dir := filepath.Join(mountDir, "..timestamp-2") + createTimestampDir(t, timestamp2Dir, caCert, clientCert, clientKey) + + err = os.Symlink("..timestamp-2", filepath.Join(mountDir, "..data_tmp")) + require.NoError(t, err) + + os.Rename(filepath.Join(mountDir, "..data_tmp"), filepath.Join(mountDir, "..data")) + require.NoError(t, err) + err = os.RemoveAll(timestamp1Dir) + require.NoError(t, err) + + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Loaded modified certificate"). + FilterField(zap.String("certificate", opts.CertPath)).Len() > 0 + }, + "Unable to locate 'Loaded modified certificate' in log. All logs: %v", logObserver) + + expectedCert, err = tls.LoadX509KeyPair(clientCert, clientKey) + require.NoError(t, err) + assert.Equal(t, expectedCert.Certificate, watcher.certificate().Certificate, + "certificate should be updated: %v", logObserver.All()) + + // Make third update to make sure that the watcher is still working. + logObserver.TakeAll() + + timestamp3Dir := filepath.Join(mountDir, "..timestamp-3") + createTimestampDir(t, timestamp3Dir, caCert, serverCert, serverKey) + err = os.Symlink("..timestamp-3", filepath.Join(mountDir, "..data_tmp")) + require.NoError(t, err) + os.Rename(filepath.Join(mountDir, "..data_tmp"), filepath.Join(mountDir, "..data")) + require.NoError(t, err) + err = os.RemoveAll(timestamp2Dir) + require.NoError(t, err) + + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Loaded modified certificate"). + FilterField(zap.String("certificate", opts.CertPath)).Len() > 0 + }, + "Unable to locate 'Loaded modified certificate' in log. All logs: %v", logObserver) + + expectedCert, err = tls.LoadX509KeyPair(serverCert, serverKey) + require.NoError(t, err) + assert.Equal(t, expectedCert.Certificate, watcher.certificate().Certificate, + "certificate should be updated: %v", logObserver.All()) +} + +func createTimestampDir(t *testing.T, dir string, ca, cert, key string) { + t.Helper() + err := os.MkdirAll(dir, 0o700) + require.NoError(t, err) + + data, err := os.ReadFile(ca) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(dir, "ca.crt"), data, 0o600) + require.NoError(t, err) + data, err = os.ReadFile(cert) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(dir, "tls.crt"), data, 0o600) + require.NoError(t, err) + data, err = os.ReadFile(key) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(dir, "tls.key"), data, 0o600) + require.NoError(t, err) +} + func TestAddCertsToWatch_err(t *testing.T) { watcher, err := fsnotify.NewWatcher() require.NoError(t, err) defer watcher.Close() + w := &certWatcher{ + watcher: watcher, + } tests := []struct { opts Options @@ -278,30 +378,18 @@ func TestAddCertsToWatch_err(t *testing.T) { }, } for _, test := range tests { - err := addCertsToWatch(watcher, test.opts) + w.opts = test.opts + err := w.setupWatchedPaths() require.Error(t, err) assert.Contains(t, err.Error(), "no such file or directory") } } func TestAddCertsToWatch_remove_ca(t *testing.T) { - caFile, err := os.CreateTemp("", "ca.crt") - require.NoError(t, err) - defer os.Remove(caFile.Name()) - caData, err := os.ReadFile(caCert) - require.NoError(t, err) - _, err = caFile.Write(caData) - require.NoError(t, err) - caFile.Close() - - clientCaFile, err := os.CreateTemp("", "clientCa.crt") - require.NoError(t, err) - defer os.Remove(clientCaFile.Name()) - clientCaData, err := os.ReadFile(caCert) - require.NoError(t, err) - _, err = clientCaFile.Write(clientCaData) - require.NoError(t, err) - clientCaFile.Close() + caFile, caFileCloseFn := copyToTempFile(t, "cert.crt", caCert) + defer caFileCloseFn() + clientCaFile, clientCaFileClostFn := copyToTempFile(t, "key.crt", caCert) + defer clientCaFileClostFn() zcore, logObserver := observer.New(zapcore.InfoLevel) logger := zap.New(zcore) @@ -319,20 +407,31 @@ func TestAddCertsToWatch_remove_ca(t *testing.T) { require.NoError(t, os.Remove(caFile.Name())) require.NoError(t, os.Remove(clientCaFile.Name())) - waitUntil(func() bool { - return logObserver.FilterMessage("Certificate has been removed, using the last known version").Len() >= 2 - }, 100, time.Millisecond*100) + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Certificate has been removed, using the last known version").Len() >= 2 + }, + "Unable to locate 'Certificate has been removed' in log. All logs: %v", logObserver) assert.True(t, logObserver.FilterMessage("Certificate has been removed, using the last known version").FilterField(zap.String("certificate", caFile.Name())).Len() > 0) assert.True(t, logObserver.FilterMessage("Certificate has been removed, using the last known version").FilterField(zap.String("certificate", clientCaFile.Name())).Len() > 0) } -func waitUntil(f func() bool, iterations int, sleepInterval time.Duration) { - for i := 0; i < iterations; i++ { - if f() { - return - } - time.Sleep(sleepInterval) - } +type delayedFormat struct { + fn func() interface{} +} + +func (df delayedFormat) String() string { + return fmt.Sprintf("%v", df.fn()) +} + +func assertLogs(t *testing.T, f func() bool, errorMsg string, logObserver *observer.ObservedLogs) { + assert.Eventuallyf(t, f, + 10*time.Second, 10*time.Millisecond, + errorMsg, + delayedFormat{ + fn: func() interface{} { return logObserver.All() }, + }, + ) } // syncWrite ensures data is written to the given filename and flushed to disk. @@ -348,3 +447,42 @@ func syncWrite(filename string, data []byte, perm os.FileMode) error { } return f.Sync() } + +func TestReload_err_ca_cert_update(t *testing.T) { + // copy certs to temp so we can modify them + caFile, caFileCloseFn := copyToTempFile(t, "cert.crt", caCert) + defer caFileCloseFn() + clientCaFile, clientCaFileClostFn := copyToTempFile(t, "key.crt", caCert) + defer clientCaFileClostFn() + + zcore, logObserver := observer.New(zapcore.InfoLevel) + logger := zap.New(zcore) + opts := Options{ + CAPath: caFile.Name(), + ClientCAPath: clientCaFile.Name(), + } + watcher, err := newCertWatcher(opts, logger) + require.NoError(t, err) + defer watcher.Close() + + certPool := x509.NewCertPool() + require.NoError(t, err) + go watcher.watchChangesLoop(certPool, certPool) + + // update the content with bad certs. + copyFile(t, caFile.Name(), badCaCert) + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Failed to load certificate"). + FilterField(zap.String("certificate", caFile.Name())).Len() > 0 + }, + "Unable to locate 'certificate' in log. All logs: %v", logObserver) + + copyFile(t, clientCaFile.Name(), badCaCert) + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Failed to load certificate"). + FilterField(zap.String("certificate", clientCaFile.Name())).Len() > 0 + }, + "Unable to locate 'Failed to load certificate' in log. All logs: %v", logObserver) +}