diff --git a/v2/workloadapi/client.go b/v2/workloadapi/client.go index 092b0ed5..228285ce 100644 --- a/v2/workloadapi/client.go +++ b/v2/workloadapi/client.go @@ -163,10 +163,29 @@ func (c *Client) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwts return nil, err } - if len(resp.Svids) == 0 { - return nil, errors.New("there were no SVIDs in the response") + svids, err := parseJWTSVIDs(resp, audience, true) + if err != nil { + return nil, err + } + + return svids[0], nil +} + +// FetchJWTSVIDs fetches all JWT-SVIDs. +func (c *Client) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) { + ctx, cancel := context.WithCancel(withHeader(ctx)) + defer cancel() + + audience := append([]string{params.Audience}, params.ExtraAudiences...) + resp, err := c.wlClient.FetchJWTSVID(ctx, &workload.JWTSVIDRequest{ + SpiffeId: params.Subject.String(), + Audience: audience, + }) + if err != nil { + return nil, err } - return jwtsvid.ParseInsecure(resp.Svids[0].Svid, audience) + + return parseJWTSVIDs(resp, audience, false) } // FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed @@ -357,6 +376,9 @@ func parseX509Context(resp *workload.X509SVIDResponse) (*X509Context, error) { // Otherwise all SVIDs are parsed and returned. func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svid.SVID, error) { n := len(resp.Svids) + if n == 0 { + return nil, errors.New("no SVIDs in response") + } if firstOnly { n = 1 } @@ -371,9 +393,6 @@ func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svi svids = append(svids, s) } - if len(svids) == 0 { - return nil, errors.New("no SVIDs in response") - } return svids, nil } @@ -413,6 +432,31 @@ func parseX509Bundle(spiffeID string, bundle []byte) (*x509bundle.Bundle, error) return x509bundle.FromX509Authorities(td, certs), nil } +// parseJWTSVIDs parses one or all of the SVIDs in the response. If firstOnly +// is true, then only the first SVID in the response is parsed and returned. +// Otherwise all SVIDs are parsed and returned. +func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly bool) ([]*jwtsvid.SVID, error) { + n := len(resp.Svids) + if n == 0 { + return nil, errors.New("there were no SVIDs in the response") + } + if firstOnly { + n = 1 + } + + svids := make([]*jwtsvid.SVID, 0, n) + for i := 0; i < n; i++ { + svid := resp.Svids[i] + s, err := jwtsvid.ParseInsecure(svid.Svid, audience) + if err != nil { + return nil, err + } + svids = append(svids, s) + } + + return svids, nil +} + func parseJWTSVIDBundles(resp *workload.JWTBundlesResponse) (*jwtbundle.Set, error) { bundles := []*jwtbundle.Bundle{} diff --git a/v2/workloadapi/client_test.go b/v2/workloadapi/client_test.go index 62443733..b8633469 100644 --- a/v2/workloadapi/client_test.go +++ b/v2/workloadapi/client_test.go @@ -232,8 +232,8 @@ func TestFetchJWTSVID(t *testing.T) { subjectID := spiffeid.RequireFromPath(td, "/subject") audienceID := spiffeid.RequireFromPath(td, "/audience") extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience") - token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}) - respJWT := makeJWTSVIDResponse(subjectID.String(), token) + token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() + respJWT := makeJWTSVIDResponse([]string{token}, subjectID) wl.SetJWTSVIDResponse(respJWT) params := jwtsvid.Params{ @@ -245,7 +245,36 @@ func TestFetchJWTSVID(t *testing.T) { jwtSvid, err := c.FetchJWTSVID(context.Background(), params) require.NoError(t, err) - assertJWTSVID(t, jwtSvid, subjectID, token.Marshal(), audienceID.String(), extraAudienceID.String()) + assertJWTSVID(t, jwtSvid, subjectID, token, audienceID.String(), extraAudienceID.String()) +} + +func TestFetchJWTSVIDs(t *testing.T) { + ca := test.NewCA(t, td) + wl := fakeworkloadapi.New(t) + defer wl.Stop() + c, _ := New(context.Background(), WithAddr(wl.Addr())) + defer c.Close() + + subjectID := spiffeid.RequireFromPath(td, "/subject") + extraSubjectID := spiffeid.RequireFromPath(td, "/extra_subject") + audienceID := spiffeid.RequireFromPath(td, "/audience") + extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience") + subjectIDToken := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() + extraSubjectIDToken := ca.CreateJWTSVID(extraSubjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() + respJWT := makeJWTSVIDResponse([]string{subjectIDToken, extraSubjectIDToken}, subjectID, extraSubjectID) + wl.SetJWTSVIDResponse(respJWT) + + params := jwtsvid.Params{ + Subject: subjectID, + Audience: audienceID.String(), + ExtraAudiences: []string{extraAudienceID.String()}, + } + + jwtSvid, err := c.FetchJWTSVIDs(context.Background(), params) + + require.NoError(t, err) + assertJWTSVID(t, jwtSvid[0], subjectID, subjectIDToken, audienceID.String(), extraAudienceID.String()) + assertJWTSVID(t, jwtSvid[1], extraSubjectID, extraSubjectIDToken, audienceID.String(), extraAudienceID.String()) } func TestFetchJWTBundles(t *testing.T) { @@ -357,12 +386,14 @@ func makeX509SVIDs(ca *test.CA, ids ...spiffeid.ID) []*x509svid.SVID { return svids } -func makeJWTSVIDResponse(spiffeID string, token *jwtsvid.SVID) *workload.JWTSVIDResponse { - svids := []*workload.JWTSVID{ - { - SpiffeId: spiffeID, - Svid: token.Marshal(), - }, +func makeJWTSVIDResponse(token []string, ids ...spiffeid.ID) *workload.JWTSVIDResponse { + svids := []*workload.JWTSVID{} + for i, id := range ids { + svid := &workload.JWTSVID{ + SpiffeId: id.String(), + Svid: token[i], + } + svids = append(svids, svid) } return &workload.JWTSVIDResponse{ Svids: svids, diff --git a/v2/workloadapi/convenience.go b/v2/workloadapi/convenience.go index c4018b3c..5aef54ec 100644 --- a/v2/workloadapi/convenience.go +++ b/v2/workloadapi/convenience.go @@ -71,6 +71,16 @@ func FetchJWTSVID(ctx context.Context, params jwtsvid.Params, options ...ClientO return c.FetchJWTSVID(ctx, params) } +// FetchJWTSVID fetches all JWT-SVIDs. +func FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params, options ...ClientOption) ([]*jwtsvid.SVID, error) { + c, err := New(ctx, options...) + if err != nil { + return nil, err + } + defer c.Close() + return c.FetchJWTSVIDs(ctx, params) +} + // FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed // by a SPIFFE ID of the trust domain to which they belong. func FetchJWTBundles(ctx context.Context, options ...ClientOption) (*jwtbundle.Set, error) { diff --git a/v2/workloadapi/jwtsource.go b/v2/workloadapi/jwtsource.go index cea67e97..6bfe06e4 100644 --- a/v2/workloadapi/jwtsource.go +++ b/v2/workloadapi/jwtsource.go @@ -63,6 +63,15 @@ func (s *JWTSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*j return s.watcher.client.FetchJWTSVID(ctx, params) } +// FetchJWTSVIDs fetches all JWT-SVIDs from the source with the given parameters. +// It implements the jwtsvid.Source interface. +func (s *JWTSource) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) { + if err := s.checkClosed(); err != nil { + return nil, err + } + return s.watcher.client.FetchJWTSVIDs(ctx, params) +} + // GetJWTBundleForTrustDomain returns the JWT bundle for the given trust // domain. It implements the jwtbundle.Source interface. func (s *JWTSource) GetJWTBundleForTrustDomain(trustDomain spiffeid.TrustDomain) (*jwtbundle.Bundle, error) { diff --git a/v2/workloadapi/watcher.go b/v2/workloadapi/watcher.go index a6329c63..f110e073 100644 --- a/v2/workloadapi/watcher.go +++ b/v2/workloadapi/watcher.go @@ -13,6 +13,7 @@ type sourceClient interface { WatchX509Context(context.Context, X509ContextWatcher) error WatchJWTBundles(context.Context, JWTBundleWatcher) error FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error) + FetchJWTSVIDs(context.Context, jwtsvid.Params) ([]*jwtsvid.SVID, error) Close() error }