From e82245309edc4a52947a668ce5afcdb879365909 Mon Sep 17 00:00:00 2001 From: Matt Holt Date: Thu, 17 Aug 2023 11:24:25 -0600 Subject: [PATCH] Use context from ClientHello during GetCertificate (#249) * Use context from ClientHello during GetCertificate (see #247) * Avoid recursive ops during on-demand issuance --- handshake.go | 65 ++++++++++++++++++++++++++-------------------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/handshake.go b/handshake.go index 1ac85f6e..1e9928ba 100644 --- a/handshake.go +++ b/handshake.go @@ -46,8 +46,7 @@ import ( // GetCertificate will run in a new context, use GetCertificateWithContext to provide // a context. func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - ctx := context.TODO() // TODO: get a proper context? from somewhere... - return cfg.GetCertificateWithContext(ctx, clientHello) + return cfg.GetCertificateWithContext(clientHello.Context(), clientHello) } func (cfg *Config) GetCertificateWithContext(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -276,7 +275,7 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client name := cfg.getNameFromClientHello(hello) // By this point, we need to load or obtain a certificate. If a swarm of requests comes in for the same - // domain, avoid pounding manager or storage thousands of times simultaneously. We do a similar sync + // domain, avoid pounding manager or storage thousands of times simultaneously. We use a similar sync // strategy for obtaining certificate during handshake. certLoadWaitChansMu.Lock() wait, ok := certLoadWaitChans[name] @@ -284,7 +283,7 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client // another goroutine is already loading the cert; just wait and we'll get it from the in-memory cache certLoadWaitChansMu.Unlock() - timeout := time.NewTimer(2 * time.Minute) // TODO: have Caddy use the context param to establish a timeout + timeout := time.NewTimer(2 * time.Minute) select { case <-timeout.C: return Certificate{}, fmt.Errorf("timed out waiting to load certificate for %s", name) @@ -480,6 +479,9 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli // wait for it to finish obtaining the cert and then we'll use it. obtainCertWaitChansMu.Unlock() + log.Debug("new certificate is needed, but is already being obtained; waiting for that issuance to complete", + zap.String("subject", name)) + // TODO: see if we can get a proper context in here, for true cancellation timeout := time.NewTimer(2 * time.Minute) select { @@ -489,7 +491,9 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli timeout.Stop() } - return cfg.loadCertFromStorage(ctx, log, hello) + // it should now be loaded in the cache, ready to go; if not, + // the goroutine in charge of that probably had an error + return cfg.getCertDuringHandshake(ctx, hello, false) } // looks like it's up to us to do all the work and obtain the cert. @@ -507,28 +511,28 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli log.Info("obtaining new certificate", zap.String("server_name", name)) - // TODO: we are only adding a timeout because we don't know if the context passed in is actually cancelable... + // set a timeout so we don't inadvertently hold a client handshake open too long // (timeout duration is based on https://caddy.community/t/zerossl-dns-challenge-failing-often-route53-plugin/13822/24?u=matt) var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, 180*time.Second) defer cancel() - // Obtain the certificate + // obtain the certificate (this puts it in storage) and if successful, + // load it from storage so we and any other waiting goroutine can use it + var cert Certificate err := cfg.ObtainCertAsync(ctx, name) + if err == nil { + // load from storage while others wait to make the op as atomic as possible + cert, err = cfg.loadCertFromStorage(ctx, log, hello) + if err != nil { + log.Error("loading newly-obtained certificate from storage", zap.String("server_name", name), zap.Error(err)) + } + } - // immediately unblock anyone waiting for it; doing this in - // a defer would risk deadlock because of the recursive call - // to getCertDuringHandshake below when we return! + // immediately unblock anyone waiting for it unblockWaiters() - if err != nil { - // shucks; failed to solve challenge on-demand - return Certificate{}, err - } - - // success; certificate was just placed on disk, so - // we need only restart serving the certificate - return cfg.loadCertFromStorage(ctx, log, hello) + return cert, err } // handshakeMaintenance performs a check on cert for expiration and OCSP validity. @@ -611,7 +615,7 @@ func (cfg *Config) handshakeMaintenance(ctx context.Context, hello *tls.ClientHe // // This function is safe for use by multiple concurrent goroutines. func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.ClientHelloInfo, currentCert Certificate) (Certificate, error) { - log := cfg.Logger.Named("on_demand") + log := logWithRemote(cfg.Logger.Named("on_demand"), hello) name := cfg.getNameFromClientHello(hello) timeLeft := time.Until(expiresAt(currentCert.Leaf)) @@ -651,7 +655,9 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien timeout.Stop() } - return cfg.loadCertFromStorage(ctx, log, hello) + // it should now be loaded in the cache, ready to go; if not, + // the goroutine in charge of that probably had an error + return cfg.getCertDuringHandshake(ctx, hello, false) } // looks like it's up to us to do all the work and renew the cert @@ -703,16 +709,8 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien } else { err = cfg.RenewCertAsync(ctx, name, false) if err == nil { - // even though the recursive nature of the dynamic cert loading - // would just call this function anyway, we do it here to - // make the replacement as atomic as possible. - newCert, err = cfg.CacheManagedCertificate(ctx, name) - if err != nil { - log.Error("loading renewed certificate", zap.String("server_name", name), zap.Error(err)) - } else { - // replace the old certificate with the new one - cfg.certCache.replaceCertificate(currentCert, newCert) - } + // load from storage while in lock to make the replacement as atomic as possible + newCert, err = cfg.reloadManagedCertificate(ctx, currentCert) } } @@ -722,11 +720,10 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien unblockWaiters() if err != nil { - log.Error("renewing and reloading certificate", zap.Error(err)) - return newCert, err + log.Error("renewing and reloading certificate", zap.String("server_name", name), zap.Error(err)) } - return cfg.loadCertFromStorage(ctx, log, hello) + return newCert, err } // if the certificate hasn't expired, we can serve what we have and renew in the background @@ -872,6 +869,8 @@ var ( obtainCertWaitChans = make(map[string]chan struct{}) obtainCertWaitChansMu sync.Mutex ) + +// TODO: this lockset should probably be per-cache var ( certLoadWaitChans = make(map[string]chan struct{}) certLoadWaitChansMu sync.Mutex