diff --git a/pkg/config/api_server_config.go b/pkg/config/api_server_config.go index bfee1fc7c..29ca6e4d2 100644 --- a/pkg/config/api_server_config.go +++ b/pkg/config/api_server_config.go @@ -16,6 +16,7 @@ package config import ( "context" + "sync" "time" "github.com/google/exposure-notifications-verification-server/pkg/cache" @@ -42,7 +43,6 @@ type APIServerConfig struct { APIKeyCacheDuration time.Duration `env:"API_KEY_CACHE_DURATION,default=5m"` // Verification Token Config - // Currently this does not easily support rotation. TODO(mikehelmick) - add support. VerificationTokenDuration time.Duration `env:"VERIFICATION_TOKEN_DURATION,default=24h"` // Token signing @@ -53,6 +53,10 @@ type APIServerConfig struct { // Rate limiting configuration RateLimit ratelimit.Config + + // cached allowed public keys + allowedTokenPublicKeys map[string]string + mu sync.RWMutex } // NewAPIServerConfig returns the environment config for the API server. @@ -65,6 +69,34 @@ func NewAPIServerConfig(ctx context.Context) (*APIServerConfig, error) { return &config, nil } +// AllowedTokenPublicKeys returns a map of 'kid' to the KMS KeyID reference. +// This represents the keys that are allowed to be used to verify tokens, +// the TokenSigningKey/TokenSigningKeyID. +func (c *APIServerConfig) AllowedTokenPublicKeys() map[string]string { + { + c.mu.RLock() + if len(c.allowedTokenPublicKeys) != 0 { + c.mu.RUnlock() + return c.allowedTokenPublicKeys + } + c.mu.RUnlock() + } + + c.mu.Lock() + defer c.mu.Unlock() + // handle race condition that could occur between lock upgrade. + if len(c.allowedTokenPublicKeys) != 0 { + return c.allowedTokenPublicKeys + } + + c.allowedTokenPublicKeys = make(map[string]string) + + for i, kid := range c.TokenSigning.TokenSigningKeyID { + c.allowedTokenPublicKeys[kid] = c.TokenSigning.TokenSigningKey[i] + } + return c.allowedTokenPublicKeys +} + func (c *APIServerConfig) Validate() error { fields := []struct { Var time.Duration @@ -79,6 +111,10 @@ func (c *APIServerConfig) Validate() error { } } + if err := c.TokenSigning.Validate(); err != nil { + return err + } + return nil } diff --git a/pkg/config/token_signing_config.go b/pkg/config/token_signing_config.go index 016881283..b7b0edaf2 100644 --- a/pkg/config/token_signing_config.go +++ b/pkg/config/token_signing_config.go @@ -15,6 +15,8 @@ package config import ( + "fmt" + "github.com/google/exposure-notifications-server/pkg/keys" ) @@ -25,7 +27,22 @@ type TokenSigningConfig struct { // configuration. Keys keys.Config `env:",prefix=TOKEN_"` - TokenSigningKey string `env:"TOKEN_SIGNING_KEY, required"` - TokenSigningKeyID string `env:"TOKEN_SIGNING_KEY_ID, default=v1"` - TokenIssuer string `env:"TOKEN_ISSUER, default=diagnosis-verification-example"` + TokenSigningKey []string `env:"TOKEN_SIGNING_KEY, required"` + TokenSigningKeyID []string `env:"TOKEN_SIGNING_KEY_ID, default=v1"` + TokenIssuer string `env:"TOKEN_ISSUER, default=diagnosis-verification-example"` +} + +func (t *TokenSigningConfig) ActiveKey() string { + return t.TokenSigningKey[0] +} + +func (t *TokenSigningConfig) ActiveKeyID() string { + return t.TokenSigningKeyID[0] +} + +func (t *TokenSigningConfig) Validate() error { + if len(t.TokenSigningKey) != len(t.TokenSigningKeyID) { + return fmt.Errorf("TOKEN_SIGNING_KEY and TOKEN_SIGNING_KEY_ID must be lists of the same length") + } + return nil } diff --git a/pkg/controller/certapi/certapi.go b/pkg/controller/certapi/certapi.go index bdf1206f8..123bf508f 100644 --- a/pkg/controller/certapi/certapi.go +++ b/pkg/controller/certapi/certapi.go @@ -78,8 +78,9 @@ func New(ctx context.Context, config *config.APIServerConfig, db *database.Datab } // Parses and validates the token against the configured keyID and public key. -// If the token si valid the token id (`tid') and subject (`sub`) claims are returned. -func (c *Controller) validateToken(ctx context.Context, verToken string, publicKey crypto.PublicKey) (string, *database.Subject, error) { +// A map of valid 'kid' values is supported. +// If the token is valid the token id (`tid') and subject (`sub`) claims are returned. +func (c *Controller) validateToken(ctx context.Context, verToken string, publicKeys map[string]crypto.PublicKey) (string, *database.Subject, error) { // Parse and validate the verification token. token, err := jwt.ParseWithClaims(verToken, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) { kidHeader := token.Header[verifyapi.KeyIDHeader] @@ -87,10 +88,11 @@ func (c *Controller) validateToken(ctx context.Context, verToken string, publicK if !ok { return nil, fmt.Errorf("missing 'kid' header in token") } - if kid == c.config.TokenSigning.TokenSigningKeyID { - return publicKey, nil + publicKey, ok := publicKeys[kid] + if !ok { + return nil, fmt.Errorf("no public key for specified 'kid' not found: %v", kid) } - return nil, fmt.Errorf("no public key for specified 'kid' not found: %v", kid) + return publicKey, nil }) if err != nil { stats.Record(ctx, c.metrics.TokenInvalid.M(1), c.metrics.CertificateErrors.M(1)) diff --git a/pkg/controller/certapi/certificate.go b/pkg/controller/certapi/certificate.go index 05269438d..02240c01a 100644 --- a/pkg/controller/certapi/certificate.go +++ b/pkg/controller/certapi/certificate.go @@ -15,6 +15,7 @@ package certapi import ( + "crypto" "errors" "net/http" "time" @@ -58,12 +59,15 @@ func (c *Controller) HandleCertificate() http.Handler { stats.Record(ctx, c.metrics.Attempts.M(1)) // Get the public key for the token. - publicKey, err := c.pubKeyCache.GetPublicKey(ctx, c.config.TokenSigning.TokenSigningKey, c.kms) - if err != nil { - c.logger.Errorw("failed to get public key", "error", err) - stats.Record(ctx, c.metrics.CertificateErrors.M(1)) - c.h.RenderJSON(w, http.StatusInternalServerError, api.InternalError()) - return + allowedPublicKeys := make(map[string]crypto.PublicKey) + for kid, keyRef := range c.config.AllowedTokenPublicKeys() { + publicKey, err := c.pubKeyCache.GetPublicKey(ctx, keyRef, c.kms) + if err != nil { + c.logger.Errorw("failed to get public key", "error", err) + c.h.RenderJSON(w, http.StatusInternalServerError, api.InternalError()) + return + } + allowedPublicKeys[kid] = publicKey } var request api.VerificationCertificateRequest @@ -75,7 +79,7 @@ func (c *Controller) HandleCertificate() http.Handler { } // Parse and validate the verification token. - tokenID, subject, err := c.validateToken(ctx, request.VerificationToken, publicKey) + tokenID, subject, err := c.validateToken(ctx, request.VerificationToken, allowedPublicKeys) if err != nil { stats.Record(ctx, c.metrics.CertificateErrors.M(1)) c.h.RenderJSON(w, http.StatusBadRequest, api.Error(err).WithCode(api.ErrTokenInvalid)) diff --git a/pkg/controller/verifyapi/verify.go b/pkg/controller/verifyapi/verify.go index a7da0a0e2..41003327b 100644 --- a/pkg/controller/verifyapi/verify.go +++ b/pkg/controller/verifyapi/verify.go @@ -70,7 +70,7 @@ func (c *Controller) HandleVerify() http.Handler { } // Get the signer based on Key configuration. - signer, err := c.kms.NewSigner(ctx, c.config.TokenSigning.TokenSigningKey) + signer, err := c.kms.NewSigner(ctx, c.config.TokenSigning.ActiveKey()) if err != nil { c.logger.Errorw("failed to get signer", "error", err) stats.Record(ctx, c.metrics.CodeVerificationError.M(1)) @@ -123,7 +123,7 @@ func (c *Controller) HandleVerify() http.Handler { Subject: subject.String(), } token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) - token.Header[verifyapi.KeyIDHeader] = c.config.TokenSigning.TokenSigningKeyID + token.Header[verifyapi.KeyIDHeader] = c.config.TokenSigning.ActiveKeyID() signedJWT, err := jwthelper.SignJWT(token, signer) if err != nil { stats.Record(ctx, c.metrics.CodeVerificationError.M(1))