Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Watch directories for certificate hot-reload #4159

Merged
merged 3 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 152 additions & 71 deletions pkg/config/tlscfg/cert_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
package tlscfg

import (
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"os"
"path"
"path/filepath"
"sync"

Expand All @@ -31,11 +34,15 @@ import (
// The certificate and key can be obtained via certWatcher.certificate.
// The consumers of this API should use GetCertificate or GetClientCertificate from tls.Config to supply the certificate to the config.
type certWatcher struct {
opts Options
watcher *fsnotify.Watcher
cert *tls.Certificate
logger *zap.Logger
mu *sync.RWMutex
mu sync.RWMutex
opts Options
logger *zap.Logger
watcher *fsnotify.Watcher
cert *tls.Certificate
caHash string
clientCAHash string
certHash string
keyHash string
}

var _ io.Closer = (*certWatcher)(nil)
Expand All @@ -50,21 +57,24 @@ func newCertWatcher(opts Options, logger *zap.Logger) (*certWatcher, error) {
}
cert = &c
}

watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
if err := addCertsToWatch(watcher, opts); err != nil {

w := &certWatcher{
opts: opts,
logger: logger,
cert: cert,
watcher: watcher,
}

if err := w.setupWatchedPaths(); err != nil {
watcher.Close()
return nil, err
}
return &certWatcher{
cert: cert,
opts: opts,
watcher: watcher,
logger: logger,
mu: &sync.RWMutex{},
}, nil
return w, nil
}

func (w *certWatcher) Close() error {
Expand All @@ -77,81 +87,152 @@ func (w *certWatcher) certificate() *tls.Certificate {
return w.cert
}

// setupWatchedPaths retrieves hashes of all configured certificates
// and adds their parent directories to the watcher.
func (w *certWatcher) setupWatchedPaths() error {
uniqueDirs := make(map[string]bool)
addPath := func(certPath string, hashPtr *string) error {
if certPath == "" {
return nil
}
if h, err := hashFile(certPath); err == nil {
*hashPtr = h
} else {
return err
}
dir := path.Dir(certPath)
if _, ok := uniqueDirs[dir]; !ok {
w.watcher.Add(dir)
uniqueDirs[dir] = true
}
return nil
}

if err := addPath(w.opts.CAPath, &w.caHash); err != nil {
return err
}
if err := addPath(w.opts.ClientCAPath, &w.clientCAHash); err != nil {
return err
}
if err := addPath(w.opts.CertPath, &w.certHash); err != nil {
return err
}
if err := addPath(w.opts.KeyPath, &w.keyHash); err != nil {
return err
}
return nil
}

// watchChangesLoop waits for notifications of changes in the watched directories
// and attempts to reload all certificates that changed.
//
// Write and Rename events indicate that some files might have changed and reload might be necessary.
// Remove event indicates that the file was deleted and we should write an error to log.
//
// Reasoning:
//
// Write event is sent if the file content is rewritten.
//
// Usually files are not rewritten, but they are updated by swapping them with new
// ones by calling Rename. That avoids files being read while they are not yet
// completely written but it also means that inotify on file level will not work:
// watch is invalidated when the old file is deleted.
//
// If reading from Kubernetes Secret volumes the target files are symbolic links
// to files in a different directory. That directory is swapped with a new one,
// while the symbolic links remain the same. This guarantees atomic swap for all
// files at once, but it also means any Rename event in the directory might
// indicate that the files were replaced, even if event.Name is not any of the
// files we are monitoring. We check the hashes of the files to detect if they
// were really changed.
func (w *certWatcher) watchChangesLoop(rootCAs, clientCAs *x509.CertPool) {
for {
select {
case event, ok := <-w.watcher.Events:
if !ok {
return
}
// ignore if the event is a chmod event (permission or owner changes)
if event.Op&fsnotify.Chmod == fsnotify.Chmod {
continue
return // channel closed means the watcher is closed
}
if event.Op&fsnotify.Remove == fsnotify.Remove {
w.logger.Warn("Certificate has been removed, using the last known version",
zap.String("certificate", event.Name))
continue
}

w.logger.Info("Loading modified certificate",
zap.String("certificate", event.Name),
zap.String("event", event.Op.String()))
var err error
switch event.Name {
case w.opts.CAPath:
err = addCertToPool(w.opts.CAPath, rootCAs)
case w.opts.ClientCAPath:
err = addCertToPool(w.opts.ClientCAPath, clientCAs)
case w.opts.CertPath, w.opts.KeyPath:
w.mu.Lock()
c, e := tls.LoadX509KeyPair(filepath.Clean(w.opts.CertPath), filepath.Clean(w.opts.KeyPath))
if e == nil {
w.cert = &c
}
w.mu.Unlock()
err = e
w.logger.Debug("Received event", zap.String("event", event.String()))
if event.Op&fsnotify.Write == fsnotify.Write ||
event.Op&fsnotify.Rename == fsnotify.Rename ||
event.Op&fsnotify.Remove == fsnotify.Remove {
w.attemptReload(rootCAs, clientCAs)
}
if err == nil {
w.logger.Info("Loaded modified certificate",
zap.String("certificate", event.Name),
zap.String("event", event.Op.String()))
} else {
w.logger.Error("Failed to load certificate",
zap.String("certificate", event.Name),
zap.String("event", event.Op.String()),
zap.Error(err))
case err, ok := <-w.watcher.Errors:
if !ok {
return // channel closed means the watcher is closed
}
case err := <-w.watcher.Errors:
w.logger.Error("Watcher got error", zap.Error(err))
}
}
}

func addCertsToWatch(watcher *fsnotify.Watcher, opts Options) error {
if len(opts.CAPath) != 0 {
err := watcher.Add(opts.CAPath)
if err != nil {
return err
// attemptReload checks if the watched files have been modified and reloads them if necessary.
func (w *certWatcher) attemptReload(rootCAs, clientCAs *x509.CertPool) {
w.reloadIfModified(w.opts.CAPath, &w.caHash, rootCAs)
w.reloadIfModified(w.opts.ClientCAPath, &w.clientCAHash, clientCAs)

isCertModified, newCertHash := w.isModified(w.opts.CertPath, w.certHash)
isKeyModified, newKeyHash := w.isModified(w.opts.KeyPath, w.keyHash)
if isCertModified || isKeyModified {
c, err := tls.LoadX509KeyPair(filepath.Clean(w.opts.CertPath), filepath.Clean(w.opts.KeyPath))
if err == nil {
w.mu.Lock()
w.cert = &c
w.certHash = newCertHash
w.keyHash = newKeyHash
w.mu.Unlock()
w.logger.Info("Loaded modified certificate", zap.String("certificate", w.opts.CertPath))
w.logger.Info("Loaded modified certificate", zap.String("certificate", w.opts.KeyPath))
} else {
w.logger.Error(
"Failed to load certificate pair",
zap.String("certificate", w.opts.CertPath),
zap.String("key", w.opts.KeyPath),
zap.Error(err),
)
}
}
if len(opts.ClientCAPath) != 0 {
err := watcher.Add(opts.ClientCAPath)
if err != nil {
return err
}

func (w *certWatcher) reloadIfModified(certPath string, certHash *string, certPool *x509.CertPool) {
if mod, newHash := w.isModified(certPath, *certHash); mod {
if err := addCertToPool(certPath, certPool); err == nil {
w.mu.Lock()
*certHash = newHash
w.mu.Unlock()
w.logger.Info("Loaded modified certificate", zap.String("certificate", certPath))
} else {
w.logger.Error("Failed to load certificate", zap.String("certificate", certPath), zap.Error(err))
}
}
if len(opts.CertPath) != 0 {
err := watcher.Add(opts.CertPath)
if err != nil {
return err
}
}

// isModified returns true if the file has been modified since the last check.
func (w *certWatcher) isModified(file string, previousHash string) (bool, string) {
if file == "" {
return false, ""
}
if len(opts.KeyPath) != 0 {
err := watcher.Add(opts.KeyPath)
if err != nil {
return err
}
hash, err := hashFile(file)
if err != nil {
w.logger.Warn("Certificate has been removed, using the last known version", zap.String("certificate", file))
return false, ""
}
return nil
return previousHash != hash, hash
}

// hashFile returns the SHA256 hash of the file.
func hashFile(file string) (string, error) {
f, err := os.Open(filepath.Clean(file))
if err != nil {
return "", err
}
defer f.Close()

h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}

return fmt.Sprintf("%x", h.Sum(nil)), nil
}
Loading