Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support tls cert rotation #831

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ delete-gatekeeper:
helm delete gatekeeper --namespace ${GATEKEEPER_NAMESPACE}

.PHONY: test-e2e
test-e2e:
test-e2e: generate-rotation-certs
bats -t ${BATS_TESTS_FILE}

.PHONY: test-e2e-cli
Expand All @@ -134,6 +134,13 @@ test-e2e-cli: e2e-dependencies e2e-create-local-registry e2e-notaryv2-setup e2e-
generate-certs:
./scripts/generate-tls-certs.sh ${CERT_DIR} ${GATEKEEPER_NAMESPACE}

generate-rotation-certs:
mkdir -p .staging/rotation
mkdir -p .staging/rotation/gatekeeper

./scripts/generate-gk-tls-certs.sh .staging/rotation/gatekeeper ${GATEKEEPER_NAMESPACE}
./scripts/generate-tls-certs.sh .staging/rotation ${GATEKEEPER_NAMESPACE}

install-bats:
# Download and install bats
curl -sSLO https://github.com/bats-core/bats-core/archive/v${BATS_VERSION}.tar.gz && tar -zxvf v${BATS_VERSION}.tar.gz && bash bats-core-${BATS_VERSION}/install.sh ${GITHUB_WORKSPACE}
Expand Down
30 changes: 13 additions & 17 deletions httpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package httpserver
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -145,23 +144,20 @@ func (server *Server) Run() error {

logrus.Info(fmt.Sprintf("%s: [%s:%s] [%s:%s]", "starting server using TLS", "certFile", certFile, "keyFile", keyFile))

if server.CaCertFile != "" {
caCert, err := os.ReadFile(server.CaCertFile)
if err != nil {
panic(err)
}

clientCAs := x509.NewCertPool()
clientCAs.AppendCertsFromPEM(caCert)

config := &tls.Config{
MinVersion: tls.VersionTLS13,
ClientCAs: clientCAs,
ClientAuth: tls.RequireAndVerifyClientCert,
}
svr.TLSConfig = config
logrus.Info(fmt.Sprintf("%s: [%s:%s] ", "loaded client CA certificate for mTLS", "CaFIle", server.CaCertFile))
tlsCertWatcher, err := NewTLSCertWatcher(certFile, keyFile, server.CaCertFile)
if err != nil {
return err
}
if err = tlsCertWatcher.Start(); err != nil {
return err
}
defer tlsCertWatcher.Stop()
binbin-li marked this conversation as resolved.
Show resolved Hide resolved

svr.TLSConfig = &tls.Config{
GetConfigForClient: tlsCertWatcher.GetConfigForClient,
MinVersion: tls.VersionTLS13,
}

if err := svr.ServeTLS(lsnr, certFile, keyFile); err != nil {
logrus.Errorf("failed to start server: %v", err)
return err
Expand Down
204 changes: 204 additions & 0 deletions httpserver/tlsManager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
Copyright The Ratify Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package httpserver

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"sync"
"time"

"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/wait"
)

// This implementation is based on K8s certwatcher: https://github.com/kubernetes-sigs/controller-runtime/blob/main/pkg/certwatcher/certwatcher.go
type TLSCertWatcher struct {
sync.RWMutex
ratifyServerCert *tls.Certificate
clientCACert *x509.CertPool
watcher *fsnotify.Watcher

ratifyServerCertPath string
ratifyServerKeyPath string
clientCACertPath string
}

// NewTLSCertWatcher creates a new TLSCertWatcher for ratify tls cert/key paths and client CA cert path
func NewTLSCertWatcher(ratifyServerCertPath, ratifyServerKeyPath, clientCACertPath string) (*TLSCertWatcher, error) {
var err error
certWatcher := &TLSCertWatcher{
ratifyServerCertPath: ratifyServerCertPath,
ratifyServerKeyPath: ratifyServerKeyPath,
clientCACertPath: clientCACertPath,
}

if err = certWatcher.ReadCertificates(); err != nil {
return nil, err
}

certWatcher.watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, err
}

return certWatcher, nil
}

// Start adds the files to watcher and starts the certificate watcher routine
func (t *TLSCertWatcher) Start() error {
files := map[string]struct{}{t.ratifyServerCertPath: {}, t.ratifyServerKeyPath: {}}
if t.clientCACertPath != "" {
files[t.clientCACertPath] = struct{}{}
}

{
var watchErr error
deadlineCtx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
if err := wait.PollUntilWithContext(deadlineCtx, 1*time.Second, func(ctx context.Context) (done bool, err error) {
for f := range files {
if err := t.watcher.Add(f); err != nil {
watchErr = err
return false, nil //nolint:nilerr // we want to keep trying.
}
// remove it from the set
delete(files, f)
}
return true, nil
}); err != nil {
return fmt.Errorf("failed to add watches: %w: %s", err, watchErr.Error())
}
}

logrus.Info("Starting TLS certificate watcher")
go t.Watch()

return nil
}

// Stop closes the watcher
func (t *TLSCertWatcher) Stop() {
if err := t.watcher.Close(); err != nil {
logrus.Errorf("error closing certificate watcher: %v", err)
}
}

// ReadCertificates reads the certificates from the cert/key paths
func (t *TLSCertWatcher) ReadCertificates() error {
if t.ratifyServerCertPath == "" || t.ratifyServerKeyPath == "" {
return fmt.Errorf("ratify server cert or key path is empty")
}

if t.clientCACertPath != "" {
caCert, err := os.ReadFile(t.clientCACertPath)
if err != nil {
return err
}

clientCAs := x509.NewCertPool()
clientCAs.AppendCertsFromPEM(caCert)
t.Lock()
t.clientCACert = clientCAs
t.Unlock()
}

ratifyServerCert, err := tls.LoadX509KeyPair(t.ratifyServerCertPath, t.ratifyServerKeyPath)
if err != nil {
return err
}
t.Lock()
t.ratifyServerCert = &ratifyServerCert
akashsinghal marked this conversation as resolved.
Show resolved Hide resolved
t.Unlock()
return nil
}

// GetConfigForClient returns the tls config for the client use in the TLS Config
func (t *TLSCertWatcher) GetConfigForClient(*tls.ClientHelloInfo) (*tls.Config, error) {
t.RLock()
defer t.RUnlock()

config := &tls.Config{
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{*t.ratifyServerCert},
GetConfigForClient: t.GetConfigForClient,
}

if t.clientCACert != nil {
config.ClientCAs = t.clientCACert
config.ClientAuth = tls.RequireAndVerifyClientCert
}
return config, nil
}

func (t *TLSCertWatcher) handleEvent(event fsnotify.Event) {
// Only care about events which may modify the contents of the file.
if !(isWrite(event) || isRemove(event) || isCreate(event)) {
return
}

logrus.Infof("tls certificate rotation event: %v", event)

// If the file was removed, re-add the watch.
if isRemove(event) {
if err := t.watcher.Add(event.Name); err != nil {
logrus.Errorf("error re-watching file: %v", err)
}
}

if err := t.ReadCertificates(); err != nil {
logrus.Errorf("error re-reading certificates: %v", err)
}
}

// Watch watches the certificate files for changes and terminates on error/stop
func (t *TLSCertWatcher) Watch() {
for {
select {
case event, ok := <-t.watcher.Events:
// Channel is closed.
if !ok {
return
}

t.handleEvent(event)

case err, ok := <-t.watcher.Errors:
// Channel is closed.
if !ok {
return
}

logrus.Errorf("certificate watch error: %v", err)
}
}
}

func isWrite(event fsnotify.Event) bool {
return event.Op == fsnotify.Write
}

func isCreate(event fsnotify.Event) bool {
return event.Op == fsnotify.Create
}

func isRemove(event fsnotify.Event) bool {
return event.Op == fsnotify.Remove
}
Loading