diff --git a/go.mod b/go.mod index cb93ac5ce..1b2253f3c 100644 --- a/go.mod +++ b/go.mod @@ -55,6 +55,7 @@ require ( go.opencensus.io v0.22.6 go.uber.org/zap v1.16.0 golang.org/x/oauth2 v0.0.0-20210216194517-16ff1888fd2e // indirect + golang.org/x/sync v0.0.0-20201207232520-09787c993a3a golang.org/x/sys v0.0.0-20210216224549-f992740a1bac // indirect golang.org/x/text v0.3.5 golang.org/x/tools v0.1.0 diff --git a/pkg/config/stats_puller_config.go b/pkg/config/stats_puller_config.go index cc9c30b3c..526e4f50a 100644 --- a/pkg/config/stats_puller_config.go +++ b/pkg/config/stats_puller_config.go @@ -59,6 +59,10 @@ type StatsPullerConfig struct { // StatsPullerMinPeriod defines the period for which the stats puller will hold a lock // which prevents other calls from entering. StatsPullerMinPeriod time.Duration `env:"STATS_PULLER_MIN_PERIOD, default=5m"` + + // MaxWorkers is the maximum number of parallel workers to use when pulling + // statistics. The value must be greater than 0. + MaxWorkers int64 `env:"STATS_PULLER_MAX_WORKERS, default=5"` } // NewStatsPullerConfig returns the config for the stats-puller service. diff --git a/pkg/controller/statspuller/handle_pull.go b/pkg/controller/statspuller/handle_pull.go index e7c0463a2..20c1cd116 100644 --- a/pkg/controller/statspuller/handle_pull.go +++ b/pkg/controller/statspuller/handle_pull.go @@ -15,17 +15,25 @@ package statspuller import ( + "context" + "errors" + "fmt" "net/http" + "sync" "time" "github.com/dgrijalva/jwt-go" v1 "github.com/google/exposure-notifications-server/pkg/api/v1" "github.com/google/exposure-notifications-server/pkg/logging" "github.com/google/exposure-notifications-verification-server/internal/clients" + "github.com/google/exposure-notifications-verification-server/internal/project" "github.com/google/exposure-notifications-verification-server/pkg/controller" "github.com/google/exposure-notifications-verification-server/pkg/controller/certapi" "github.com/google/exposure-notifications-verification-server/pkg/database" "github.com/google/exposure-notifications-verification-server/pkg/jwthelper" + "github.com/hashicorp/go-multierror" + "github.com/sethvargo/go-retry" + "golang.org/x/sync/semaphore" ) const ( @@ -69,64 +77,36 @@ func (c *Controller) HandlePullStats() http.Handler { return } + var merr *multierror.Error + var merrLock sync.Mutex + sem := semaphore.NewWeighted(c.config.MaxWorkers) + var wg sync.WaitGroup for _, realmStat := range statsConfigs { - realmID := realmStat.RealmID - - var err error - client := c.defaultKeyServerClient - if realmStat.KeyServerURLOverride != "" { - client, err = clients.NewKeyServerClient( - realmStat.KeyServerURLOverride, - clients.WithTimeout(c.config.DownloadTimeout), - clients.WithMaxBodySize(c.config.FileSizeLimitBytes)) - if err != nil { - logger.Errorw("failed to create key server client", "error", err) - continue - } - } - - s, err := certapi.GetSignerForRealm(ctx, realmID, c.config.CertificateSigning, c.signerCache, c.db, c.kms) - if err != nil { - logger.Errorw("failed to retrieve signer for realm", "realmID", realmID, "error", err) - continue - } - - audience := c.config.KeyServerStatsAudience - if realmStat.KeyServerAudienceOverride != "" { - audience = realmStat.KeyServerAudienceOverride - } - - now := time.Now().UTC() - claims := &jwt.StandardClaims{ - Audience: audience, - ExpiresAt: now.Add(5 * time.Minute).UTC().Unix(), - IssuedAt: now.Unix(), - Issuer: s.Issuer, - } - token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) - token.Header["kid"] = s.KeyID - - signedJWT, err := jwthelper.SignJWT(token, s.Signer) - if err != nil { - logger.Errorw("failed to stat-pull token", "error", err) - continue - } - - resp, err := client.Stats(ctx, &v1.StatsRequest{}, signedJWT) - if err != nil { - logger.Errorw("failed make stats call", "error", err) - continue + if err := sem.Acquire(ctx, 1); err != nil { + controller.InternalError(w, r, c.h, fmt.Errorf("failed to acquire semaphore: %w", err)) + return } - for _, d := range resp.Days { - if d == nil { - continue + wg.Add(1) + go func(ctx context.Context, realmStat *database.KeyServerStats) { + defer sem.Release(1) + defer wg.Done() + if err := c.pullOneStat(ctx, realmStat); err != nil { + merrLock.Lock() + defer merrLock.Unlock() + merr = multierror.Append(merr, fmt.Errorf("failed to pull stats for realm %d: %w", realmStat.RealmID, err)) } - day := database.MakeKeyServerStatsDay(realmID, d) - if err = c.db.SaveKeyServerStatsDay(day); err != nil { - logger.Errorw("failed saving stats day", "error", err) - } - } + }(ctx, realmStat) + } + wg.Wait() + + if errs := merr.WrappedErrors(); len(errs) > 0 { + logger.Errorw("failed to pull stats", "errors", errs) + c.h.RenderJSON(w, http.StatusInternalServerError, &Result{ + OK: false, + Errors: project.ErrorsToStrings(errs), + }) + return } c.h.RenderJSON(w, http.StatusOK, &Result{ @@ -134,3 +114,73 @@ func (c *Controller) HandlePullStats() http.Handler { }) }) } + +func (c *Controller) pullOneStat(ctx context.Context, realmStat *database.KeyServerStats) error { + realmID := realmStat.RealmID + + client := c.defaultKeyServerClient + if realmStat.KeyServerURLOverride != "" { + var err error + client, err = clients.NewKeyServerClient( + realmStat.KeyServerURLOverride, + clients.WithTimeout(c.config.DownloadTimeout), + clients.WithMaxBodySize(c.config.FileSizeLimitBytes)) + if err != nil { + return fmt.Errorf("failed to create key server client: %w", err) + } + } + + s, err := certapi.GetSignerForRealm(ctx, realmID, c.config.CertificateSigning, c.signerCache, c.db, c.kms) + if err != nil { + return fmt.Errorf("failed to retrieve signer for realm %d: %w", realmID, err) + } + + audience := c.config.KeyServerStatsAudience + if realmStat.KeyServerAudienceOverride != "" { + audience = realmStat.KeyServerAudienceOverride + } + + now := time.Now().UTC() + claims := &jwt.StandardClaims{ + Audience: audience, + ExpiresAt: now.Add(5 * time.Minute).UTC().Unix(), + IssuedAt: now.Unix(), + Issuer: s.Issuer, + } + token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + token.Header["kid"] = s.KeyID + + signedJWT, err := jwthelper.SignJWT(token, s.Signer) + if err != nil { + return fmt.Errorf("failed to stat-pull token: %w", err) + } + + // Attempt to download the stats with retries. We intentionally re-use the + // same JWT because it's valid for 5min and don't want the overhead of + // reconstructing and signing it. + var resp *v1.StatsResponse + b, _ := retry.NewConstant(500 * time.Millisecond) + b = retry.WithMaxRetries(3, b) + if err := retry.Do(ctx, b, func(ctx context.Context) error { + var err error + resp, err = client.Stats(ctx, &v1.StatsRequest{}, signedJWT) + if err != nil { + return retry.RetryableError(fmt.Errorf("failed to make stats call: %w", err)) + } + return nil + }); err != nil { + return errors.Unwrap(err) + } + + for _, d := range resp.Days { + if d == nil { + continue + } + day := database.MakeKeyServerStatsDay(realmID, d) + if err = c.db.SaveKeyServerStatsDay(day); err != nil { + return fmt.Errorf("failed to save stats day: %w", err) + } + } + + return nil +}