Skip to content

Commit

Permalink
feat: reload TLS certs without restart
Browse files Browse the repository at this point in the history
Omni now watches gRPC/HTTP TLS cert files for changes and loads them
without restart.

Fixes: #508

Signed-off-by: Artem Chernyshev <artem.chernyshev@talos-systems.com>
  • Loading branch information
Unix4ever committed Aug 27, 2024
1 parent 00ae084 commit a32a6fa
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 14 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ require (
github.com/dustin/go-humanize v1.0.1
github.com/emicklei/dot v1.6.2
github.com/felixge/httpsnoop v1.0.4
github.com/fsnotify/fsnotify v1.7.0
github.com/gertd/go-pluralize v0.2.1
github.com/go-jose/go-jose/v4 v4.0.4
github.com/go-logr/zapr v1.3.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4=
github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/gertd/go-pluralize v0.2.1 h1:M3uASbVjMnTsPb0PNqg+E/24Vwigyo/tvyMTtAlLgiA=
Expand Down
2 changes: 1 addition & 1 deletion internal/backend/proxy_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (prx *httpProxy) Run(ctx context.Context, next http.Handler, logger *zap.Lo
}),
Addr: prx.bindAddr,
},
certData: certData{
certData: &certData{
certFile: prx.certFile,
keyFile: prx.keyFile,
},
Expand Down
150 changes: 137 additions & 13 deletions internal/backend/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package backend
import (
"compress/gzip"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
Expand All @@ -20,6 +21,7 @@ import (
"os"
"strconv"
"strings"
"sync"
"syscall"
"time"

Expand All @@ -32,6 +34,7 @@ import (
"github.com/cosi-project/runtime/pkg/state"
protobufserver "github.com/cosi-project/runtime/pkg/state/protobuf/server"
"github.com/crewjam/saml/samlsp"
"github.com/fsnotify/fsnotify"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
Expand All @@ -40,7 +43,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
service "github.com/siderolabs/discovery-service/pkg/service"
"github.com/siderolabs/gen/value"
"github.com/siderolabs/go-api-signature/pkg/pgp"
"github.com/siderolabs/go-retry/retry"
talosconstants "github.com/siderolabs/talos/pkg/machinery/constants"
Expand Down Expand Up @@ -269,7 +271,12 @@ func (s *Server) Run(ctx context.Context) error {
),
grpc.MaxRecvMsgSize(constants.GRPCMaxMessageSize),
)
crtData := certData{certFile: s.certFile, keyFile: s.keyFile}

var crtData *certData

if s.certFile != "" && s.keyFile != "" {
crtData = &certData{certFile: s.certFile, keyFile: s.keyFile}
}

workloadProxyHandler, err := s.workloadProxyHandler(mux)
if err != nil {
Expand Down Expand Up @@ -765,7 +772,7 @@ func runK8sProxyServer(
ctx context.Context,
bindAddress string,
oidcStorage oidcStore,
data certData,
data *certData,
runtimeState state.State,
wrapper k8sproxy.MiddlewareWrapper,
logger *zap.Logger,
Expand Down Expand Up @@ -813,7 +820,7 @@ func runK8sProxyServer(
}, logger)
}

func runAPIServer(ctx context.Context, handler http.Handler, bindAddress string, data certData, logger *zap.Logger) error {
func runAPIServer(ctx context.Context, handler http.Handler, bindAddress string, data *certData, logger *zap.Logger) error {
srv := &http.Server{
Addr: bindAddress,
Handler: handler,
Expand Down Expand Up @@ -847,21 +854,138 @@ func setRealIPRequest(req *http.Request) *http.Request {
}

type server struct {
server *http.Server
certData
server *http.Server
certData *certData
}

type certData struct {
cert tls.Certificate
certFile string
keyFile string
mu sync.Mutex
loaded bool
}

func (c *certData) load() error {
cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile)
if err != nil {
return err
}

c.mu.Lock()
defer c.mu.Unlock()

c.loaded = true
c.cert = cert

return nil
}

func (c *certData) getCert() (*tls.Certificate, error) {
c.mu.Lock()
defer c.mu.Unlock()

if !c.loaded {
return nil, fmt.Errorf("the cert wasn't loaded yet")
}

return &c.cert, nil
}

func (c *certData) runWatcher(ctx context.Context, logger *zap.Logger) error {
w, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("error creating fsnotify watcher: %w", err)
}
defer w.Close() //nolint:errcheck

if err = w.Add(c.certFile); err != nil {
return fmt.Errorf("error adding watch for file %s: %w", c.certFile, err)
}

if err = w.Add(c.keyFile); err != nil {
return fmt.Errorf("error adding watch for file %s: %w", c.keyFile, err)
}

handleEvent := func(e fsnotify.Event) error {
defer func() {
if err = c.load(); err != nil {
logger.Error("failed to load certs", zap.Error(err))

return
}

logger.Info("reloaded certs")
}()

if !e.Has(fsnotify.Remove) && !e.Has(fsnotify.Rename) {
return nil
}

if err = w.Remove(e.Name); err != nil {
logger.Error("failed to remove file watch, it may have been deleted", zap.String("file", e.Name), zap.Error(err))
}

if err = w.Add(e.Name); err != nil {
return fmt.Errorf("error adding watch for file %s: %w", e.Name, err)
}

return nil
}

for {
select {
case e := <-w.Events:
if err = handleEvent(e); err != nil {
return err
}
case err = <-w.Errors:
return fmt.Errorf("received fsnotify error: %w", err)
case <-ctx.Done():
return nil
}
}
}

func (s *server) ListenAndServe() error {
if s.certFile != "" || s.keyFile != "" {
return s.server.ListenAndServeTLS(s.certFile, s.keyFile)
func (s *server) ListenAndServe(ctx context.Context, logger *zap.Logger) error {
if s.certData == nil {
return s.server.ListenAndServe()
}

if err := s.certData.load(); err != nil {
return err
}

s.server.TLSConfig = &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return s.certData.getCert()
},
}

return s.server.ListenAndServe()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

eg := panichandler.NewErrGroup()

eg.Go(func() error {
for {
err := s.certData.runWatcher(ctx, logger)

if err == nil {
return nil
}

logger.Error("cert watcher crashed, restarting in 5 seconds", zap.Error(err))

time.Sleep(time.Second * 5)
}
})

eg.Go(func() error {
return s.server.ListenAndServeTLS("", "")
})

return eg.Wait()
}

func (s *server) Shutdown(ctx context.Context) error {
Expand All @@ -883,7 +1007,7 @@ func runServer(ctx context.Context, srv *server, logger *zap.Logger) error {
errCh := make(chan error, 1)

panichandler.Go(func() {
errCh <- srv.ListenAndServe()
errCh <- srv.ListenAndServe(ctx, logger)
}, logger)

select {
Expand Down Expand Up @@ -1017,7 +1141,7 @@ func runGRPCServer(ctx context.Context, server *grpc.Server, transport *memconn.
return nil
}

func unifyHandler(handler http.Handler, grpcServer *grpc.Server, data certData) http.Handler {
func unifyHandler(handler http.Handler, grpcServer *grpc.Server, data *certData) http.Handler {
h := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.ProtoMajor == 2 && strings.HasPrefix(
req.Header.Get("Content-Type"), "application/grpc") {
Expand All @@ -1031,7 +1155,7 @@ func unifyHandler(handler http.Handler, grpcServer *grpc.Server, data certData)
handler.ServeHTTP(w, req)
}))

if value.IsZero(data) {
if data == nil {
// If we don't have TLS data, wrap the handler in http2.Server
h = h2c.NewHandler(h, &http2.Server{})
}
Expand Down

0 comments on commit a32a6fa

Please sign in to comment.