From d9a74e04b102413526a39810533855c01decfa22 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Fri, 13 Jan 2023 09:17:51 -0800 Subject: [PATCH] oidc: restrict use of context.Background() --- oidc/jwks.go | 2 +- oidc/oidc.go | 66 ++++++++++++++++++++++++++++++----------------- oidc/oidc_test.go | 6 ++--- oidc/verify.go | 16 +++++++++++- 4 files changed, 62 insertions(+), 28 deletions(-) diff --git a/oidc/jwks.go b/oidc/jwks.go index 50dad7e0..524f4e67 100644 --- a/oidc/jwks.go +++ b/oidc/jwks.go @@ -60,7 +60,7 @@ func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) if now == nil { now = time.Now } - return &RemoteKeySet{jwksURL: jwksURL, ctx: cloneContext(ctx), now: now} + return &RemoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now} } // RemoteKeySet is a KeySet implementation that validates JSON web tokens against diff --git a/oidc/oidc.go b/oidc/oidc.go index ae73eb02..aa267daa 100644 --- a/oidc/oidc.go +++ b/oidc/oidc.go @@ -14,6 +14,7 @@ import ( "mime" "net/http" "strings" + "sync" "time" "golang.org/x/oauth2" @@ -58,15 +59,11 @@ func ClientContext(ctx context.Context, client *http.Client) context.Context { return context.WithValue(ctx, oauth2.HTTPClient, client) } -// cloneContext copies a context's bag-of-values into a new context that isn't -// associated with its cancellation. This is used to initialize remote keys sets -// which run in the background and aren't associated with the initial context. -func cloneContext(ctx context.Context) context.Context { - cp := context.Background() +func getClient(ctx context.Context) *http.Client { if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { - cp = ClientContext(cp, c) + return c } - return cp + return nil } // InsecureIssuerURLContext allows discovery to work when the issuer_url reported @@ -90,7 +87,7 @@ func InsecureIssuerURLContext(ctx context.Context, issuerURL string) context.Con func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) { client := http.DefaultClient - if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { + if c := getClient(ctx); c != nil { client = c } return client.Do(req.WithContext(ctx)) @@ -102,12 +99,33 @@ type Provider struct { authURL string tokenURL string userInfoURL string + jwksURL string algorithms []string // Raw claims returned by the server. rawClaims []byte - remoteKeySet KeySet + // Guards all of the following fields. + mu sync.Mutex + // HTTP client specified from the initial NewProvider request. This is used + // when creating the common key set. + client *http.Client + // A key set that uses context.Background() and is shared between all code paths + // that don't have a convinent way of supplying a unique context. + commonRemoteKeySet KeySet +} + +func (p *Provider) remoteKeySet() KeySet { + p.mu.Lock() + defer p.mu.Unlock() + if p.commonRemoteKeySet == nil { + ctx := context.Background() + if p.client != nil { + ctx = ClientContext(ctx, p.client) + } + p.commonRemoteKeySet = NewRemoteKeySet(ctx, p.jwksURL) + } + return p.commonRemoteKeySet } type providerJSON struct { @@ -167,12 +185,13 @@ type ProviderConfig struct { // through discovery. func (p *ProviderConfig) NewProvider(ctx context.Context) *Provider { return &Provider{ - issuer: p.IssuerURL, - authURL: p.AuthURL, - tokenURL: p.TokenURL, - userInfoURL: p.UserInfoURL, - algorithms: p.Algorithms, - remoteKeySet: NewRemoteKeySet(cloneContext(ctx), p.JWKSURL), + issuer: p.IssuerURL, + authURL: p.AuthURL, + tokenURL: p.TokenURL, + userInfoURL: p.UserInfoURL, + jwksURL: p.JWKSURL, + algorithms: p.Algorithms, + client: getClient(ctx), } } @@ -221,13 +240,14 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) { } } return &Provider{ - issuer: issuerURL, - authURL: p.AuthURL, - tokenURL: p.TokenURL, - userInfoURL: p.UserInfoURL, - algorithms: algs, - rawClaims: body, - remoteKeySet: NewRemoteKeySet(cloneContext(ctx), p.JWKSURL), + issuer: issuerURL, + authURL: p.AuthURL, + tokenURL: p.TokenURL, + userInfoURL: p.UserInfoURL, + jwksURL: p.JWKSURL, + algorithms: algs, + rawClaims: body, + client: getClient(ctx), }, nil } @@ -317,7 +337,7 @@ func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) ct := resp.Header.Get("Content-Type") mediaType, _, parseErr := mime.ParseMediaType(ct) if parseErr == nil && mediaType == "application/jwt" { - payload, err := p.remoteKeySet.VerifySignature(ctx, string(body)) + payload, err := p.remoteKeySet().VerifySignature(ctx, string(body)) if err != nil { return nil, fmt.Errorf("oidc: invalid userinfo jwt signature %v", err) } diff --git a/oidc/oidc_test.go b/oidc/oidc_test.go index 26352082..bbbbda42 100644 --- a/oidc/oidc_test.go +++ b/oidc/oidc_test.go @@ -326,15 +326,15 @@ func TestNewProvider(t *testing.T) { } } -func TestCloneContext(t *testing.T) { +func TestGetClient(t *testing.T) { ctx := context.Background() - if _, ok := cloneContext(ctx).Value(oauth2.HTTPClient).(*http.Client); ok { + if c := getClient(ctx); c != nil { t.Errorf("cloneContext(): expected no *http.Client from empty context") } c := &http.Client{} ctx = ClientContext(ctx, c) - if got, ok := cloneContext(ctx).Value(oauth2.HTTPClient).(*http.Client); !ok || c != got { + if got := getClient(ctx); got == nil || c != got { t.Errorf("cloneContext(): expected *http.Client from context") } } diff --git a/oidc/verify.go b/oidc/verify.go index ade86157..fee9ee89 100644 --- a/oidc/verify.go +++ b/oidc/verify.go @@ -120,8 +120,22 @@ type Config struct { InsecureSkipSignatureCheck bool } +// VerifierContext returns an IDTokenVerifier that uses the provider's key set to +// verify JWTs. As opposed to Verifier, the context is used for all requests to +// the upstream JWKs endpoint. +func (p *Provider) VerifierContext(ctx context.Context, config *Config) *IDTokenVerifier { + return p.newVerifier(NewRemoteKeySet(ctx, p.jwksURL), config) +} + // Verifier returns an IDTokenVerifier that uses the provider's key set to verify JWTs. +// +// The returned verifier uses a background context for all requests to the upstream +// JWKs endpoint. To control that context, use VerifierContext instead. func (p *Provider) Verifier(config *Config) *IDTokenVerifier { + return p.newVerifier(p.remoteKeySet(), config) +} + +func (p *Provider) newVerifier(keySet KeySet, config *Config) *IDTokenVerifier { if len(config.SupportedSigningAlgs) == 0 && len(p.algorithms) > 0 { // Make a copy so we don't modify the config values. cp := &Config{} @@ -129,7 +143,7 @@ func (p *Provider) Verifier(config *Config) *IDTokenVerifier { cp.SupportedSigningAlgs = p.algorithms config = cp } - return NewVerifier(p.issuer, p.remoteKeySet, config) + return NewVerifier(p.issuer, keySet, config) } func parseJWT(p string) ([]byte, error) {