Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use context from ClientHello during GetCertificate #249

Merged
merged 3 commits into from
Aug 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 32 additions & 33 deletions handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ import (
// GetCertificate will run in a new context, use GetCertificateWithContext to provide
// a context.
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
ctx := context.TODO() // TODO: get a proper context? from somewhere...
return cfg.GetCertificateWithContext(ctx, clientHello)
return cfg.GetCertificateWithContext(clientHello.Context(), clientHello)
}

func (cfg *Config) GetCertificateWithContext(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
Expand Down Expand Up @@ -276,15 +275,15 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client
name := cfg.getNameFromClientHello(hello)

// By this point, we need to load or obtain a certificate. If a swarm of requests comes in for the same
// domain, avoid pounding manager or storage thousands of times simultaneously. We do a similar sync
// domain, avoid pounding manager or storage thousands of times simultaneously. We use a similar sync
// strategy for obtaining certificate during handshake.
certLoadWaitChansMu.Lock()
wait, ok := certLoadWaitChans[name]
if ok {
// another goroutine is already loading the cert; just wait and we'll get it from the in-memory cache
certLoadWaitChansMu.Unlock()

timeout := time.NewTimer(2 * time.Minute) // TODO: have Caddy use the context param to establish a timeout
timeout := time.NewTimer(2 * time.Minute)
select {
case <-timeout.C:
return Certificate{}, fmt.Errorf("timed out waiting to load certificate for %s", name)
Expand Down Expand Up @@ -480,6 +479,9 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli
// wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitChansMu.Unlock()

log.Debug("new certificate is needed, but is already being obtained; waiting for that issuance to complete",
zap.String("subject", name))

// TODO: see if we can get a proper context in here, for true cancellation
timeout := time.NewTimer(2 * time.Minute)
select {
Expand All @@ -489,7 +491,9 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli
timeout.Stop()
}

return cfg.loadCertFromStorage(ctx, log, hello)
// it should now be loaded in the cache, ready to go; if not,
// the goroutine in charge of that probably had an error
return cfg.getCertDuringHandshake(ctx, hello, false)
}

// looks like it's up to us to do all the work and obtain the cert.
Expand All @@ -507,28 +511,28 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli

log.Info("obtaining new certificate", zap.String("server_name", name))

// TODO: we are only adding a timeout because we don't know if the context passed in is actually cancelable...
// set a timeout so we don't inadvertently hold a client handshake open too long
// (timeout duration is based on https://caddy.community/t/zerossl-dns-challenge-failing-often-route53-plugin/13822/24?u=matt)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 180*time.Second)
defer cancel()

// Obtain the certificate
// obtain the certificate (this puts it in storage) and if successful,
// load it from storage so we and any other waiting goroutine can use it
var cert Certificate
err := cfg.ObtainCertAsync(ctx, name)
if err == nil {
// load from storage while others wait to make the op as atomic as possible
cert, err = cfg.loadCertFromStorage(ctx, log, hello)
if err != nil {
log.Error("loading newly-obtained certificate from storage", zap.String("server_name", name), zap.Error(err))
}
}

// immediately unblock anyone waiting for it; doing this in
// a defer would risk deadlock because of the recursive call
// to getCertDuringHandshake below when we return!
// immediately unblock anyone waiting for it
unblockWaiters()

if err != nil {
// shucks; failed to solve challenge on-demand
return Certificate{}, err
}

// success; certificate was just placed on disk, so
// we need only restart serving the certificate
return cfg.loadCertFromStorage(ctx, log, hello)
return cert, err
}

// handshakeMaintenance performs a check on cert for expiration and OCSP validity.
Expand Down Expand Up @@ -611,7 +615,7 @@ func (cfg *Config) handshakeMaintenance(ctx context.Context, hello *tls.ClientHe
//
// This function is safe for use by multiple concurrent goroutines.
func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.ClientHelloInfo, currentCert Certificate) (Certificate, error) {
log := cfg.Logger.Named("on_demand")
log := logWithRemote(cfg.Logger.Named("on_demand"), hello)

name := cfg.getNameFromClientHello(hello)
timeLeft := time.Until(expiresAt(currentCert.Leaf))
Expand Down Expand Up @@ -651,7 +655,9 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
timeout.Stop()
}

return cfg.loadCertFromStorage(ctx, log, hello)
// it should now be loaded in the cache, ready to go; if not,
// the goroutine in charge of that probably had an error
return cfg.getCertDuringHandshake(ctx, hello, false)
}

// looks like it's up to us to do all the work and renew the cert
Expand Down Expand Up @@ -703,16 +709,8 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
} else {
err = cfg.RenewCertAsync(ctx, name, false)
if err == nil {
// even though the recursive nature of the dynamic cert loading
// would just call this function anyway, we do it here to
// make the replacement as atomic as possible.
newCert, err = cfg.CacheManagedCertificate(ctx, name)
if err != nil {
log.Error("loading renewed certificate", zap.String("server_name", name), zap.Error(err))
} else {
// replace the old certificate with the new one
cfg.certCache.replaceCertificate(currentCert, newCert)
}
// load from storage while in lock to make the replacement as atomic as possible
newCert, err = cfg.reloadManagedCertificate(ctx, currentCert)
}
}

Expand All @@ -722,11 +720,10 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
unblockWaiters()

if err != nil {
log.Error("renewing and reloading certificate", zap.Error(err))
return newCert, err
log.Error("renewing and reloading certificate", zap.String("server_name", name), zap.Error(err))
}

return cfg.loadCertFromStorage(ctx, log, hello)
return newCert, err
}

// if the certificate hasn't expired, we can serve what we have and renew in the background
Expand Down Expand Up @@ -872,6 +869,8 @@ var (
obtainCertWaitChans = make(map[string]chan struct{})
obtainCertWaitChansMu sync.Mutex
)

// TODO: this lockset should probably be per-cache
var (
certLoadWaitChans = make(map[string]chan struct{})
certLoadWaitChansMu sync.Mutex
Expand Down