From 26401ecde621ba77f724d09aea2507425dc97626 Mon Sep 17 00:00:00 2001 From: Yuhan Li Date: Thu, 24 Mar 2022 16:21:56 +0800 Subject: [PATCH 1/3] Add FetchJWTSVIDs function for workloadapi and jwtSource Signed-off-by: Yuhan Li --- v2/workloadapi/client.go | 50 ++++++++++++++++++++++++++++++++--- v2/workloadapi/client_test.go | 49 +++++++++++++++++++++++++++------- v2/workloadapi/convenience.go | 10 +++++++ v2/workloadapi/jwtsource.go | 9 +++++++ v2/workloadapi/watcher.go | 1 + 5 files changed, 107 insertions(+), 12 deletions(-) diff --git a/v2/workloadapi/client.go b/v2/workloadapi/client.go index 1e5fba34..af6ffd79 100644 --- a/v2/workloadapi/client.go +++ b/v2/workloadapi/client.go @@ -162,10 +162,29 @@ func (c *Client) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwts return nil, err } - if len(resp.Svids) == 0 { - return nil, errors.New("there were no SVIDs in the response") + svids, err := parseJWTSVIDs(resp, audience, true) + if err != nil { + return nil, err + } + + return svids[0], nil +} + +// FetchJWTSVIDs fetches all JWT-SVIDs. +func (c *Client) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) { + ctx, cancel := context.WithCancel(withHeader(ctx)) + defer cancel() + + audience := append([]string{params.Audience}, params.ExtraAudiences...) + resp, err := c.wlClient.FetchJWTSVID(ctx, &workload.JWTSVIDRequest{ + SpiffeId: params.Subject.String(), + Audience: audience, + }) + if err != nil { + return nil, err } - return jwtsvid.ParseInsecure(resp.Svids[0].Svid, audience) + + return parseJWTSVIDs(resp, audience, false) } // FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed @@ -425,6 +444,31 @@ func parseX509Bundle(spiffeID string, bundle []byte) (*x509bundle.Bundle, error) return x509bundle.FromX509Authorities(td, certs), nil } +// parseJWTSVIDs parses one or all of the SVIDs in the response. If firstOnly +// is true, then only the first SVID in the response is parsed and returned. +// Otherwise all SVIDs are parsed and returned. +func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly bool) ([]*jwtsvid.SVID, error) { + n := len(resp.Svids) + if firstOnly { + n = 1 + } + + svids := make([]*jwtsvid.SVID, 0, n) + for i := 0; i < n; i++ { + svid := resp.Svids[i] + s, err := jwtsvid.ParseInsecure(svid.Svid, audience) + if err != nil { + return nil, err + } + svids = append(svids, s) + } + + if len(svids) == 0 { + return nil, errors.New("there were no SVIDs in the response") + } + return svids, nil +} + func parseJWTSVIDBundles(resp *workload.JWTBundlesResponse) (*jwtbundle.Set, error) { bundles := []*jwtbundle.Bundle{} diff --git a/v2/workloadapi/client_test.go b/v2/workloadapi/client_test.go index 62443733..161380d1 100644 --- a/v2/workloadapi/client_test.go +++ b/v2/workloadapi/client_test.go @@ -232,8 +232,8 @@ func TestFetchJWTSVID(t *testing.T) { subjectID := spiffeid.RequireFromPath(td, "/subject") audienceID := spiffeid.RequireFromPath(td, "/audience") extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience") - token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}) - respJWT := makeJWTSVIDResponse(subjectID.String(), token) + token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() + respJWT := makeJWTSVIDResponse(ca, []string{token}, subjectID) wl.SetJWTSVIDResponse(respJWT) params := jwtsvid.Params{ @@ -245,7 +245,36 @@ func TestFetchJWTSVID(t *testing.T) { jwtSvid, err := c.FetchJWTSVID(context.Background(), params) require.NoError(t, err) - assertJWTSVID(t, jwtSvid, subjectID, token.Marshal(), audienceID.String(), extraAudienceID.String()) + assertJWTSVID(t, jwtSvid, subjectID, token, audienceID.String(), extraAudienceID.String()) +} + +func TestFetchJWTSVIDs(t *testing.T) { + ca := test.NewCA(t, td) + wl := fakeworkloadapi.New(t) + defer wl.Stop() + c, _ := New(context.Background(), WithAddr(wl.Addr())) + defer c.Close() + + subjectID := spiffeid.RequireFromPath(td, "/subject") + extraSubjectID := spiffeid.RequireFromPath(td, "/extra_subject") + audienceID := spiffeid.RequireFromPath(td, "/audience") + extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience") + subjectIDToken := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() + extraSubjectIDToken := ca.CreateJWTSVID(extraSubjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() + respJWT := makeJWTSVIDResponse(ca, []string{subjectIDToken, extraSubjectIDToken}, subjectID, extraSubjectID) + wl.SetJWTSVIDResponse(respJWT) + + params := jwtsvid.Params{ + Subject: subjectID, + Audience: audienceID.String(), + ExtraAudiences: []string{extraAudienceID.String()}, + } + + jwtSvid, err := c.FetchJWTSVIDs(context.Background(), params) + + require.NoError(t, err) + assertJWTSVID(t, jwtSvid[0], subjectID, subjectIDToken, audienceID.String(), extraAudienceID.String()) + assertJWTSVID(t, jwtSvid[1], extraSubjectID, extraSubjectIDToken, audienceID.String(), extraAudienceID.String()) } func TestFetchJWTBundles(t *testing.T) { @@ -357,12 +386,14 @@ func makeX509SVIDs(ca *test.CA, ids ...spiffeid.ID) []*x509svid.SVID { return svids } -func makeJWTSVIDResponse(spiffeID string, token *jwtsvid.SVID) *workload.JWTSVIDResponse { - svids := []*workload.JWTSVID{ - { - SpiffeId: spiffeID, - Svid: token.Marshal(), - }, +func makeJWTSVIDResponse(ca *test.CA, token []string, ids ...spiffeid.ID) *workload.JWTSVIDResponse { + svids := []*workload.JWTSVID{} + for i, id := range ids { + svid := &workload.JWTSVID{ + SpiffeId: id.String(), + Svid: token[i], + } + svids = append(svids, svid) } return &workload.JWTSVIDResponse{ Svids: svids, diff --git a/v2/workloadapi/convenience.go b/v2/workloadapi/convenience.go index c4018b3c..5aef54ec 100644 --- a/v2/workloadapi/convenience.go +++ b/v2/workloadapi/convenience.go @@ -71,6 +71,16 @@ func FetchJWTSVID(ctx context.Context, params jwtsvid.Params, options ...ClientO return c.FetchJWTSVID(ctx, params) } +// FetchJWTSVID fetches all JWT-SVIDs. +func FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params, options ...ClientOption) ([]*jwtsvid.SVID, error) { + c, err := New(ctx, options...) + if err != nil { + return nil, err + } + defer c.Close() + return c.FetchJWTSVIDs(ctx, params) +} + // FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed // by a SPIFFE ID of the trust domain to which they belong. func FetchJWTBundles(ctx context.Context, options ...ClientOption) (*jwtbundle.Set, error) { diff --git a/v2/workloadapi/jwtsource.go b/v2/workloadapi/jwtsource.go index cea67e97..6bfe06e4 100644 --- a/v2/workloadapi/jwtsource.go +++ b/v2/workloadapi/jwtsource.go @@ -63,6 +63,15 @@ func (s *JWTSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*j return s.watcher.client.FetchJWTSVID(ctx, params) } +// FetchJWTSVIDs fetches all JWT-SVIDs from the source with the given parameters. +// It implements the jwtsvid.Source interface. +func (s *JWTSource) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) { + if err := s.checkClosed(); err != nil { + return nil, err + } + return s.watcher.client.FetchJWTSVIDs(ctx, params) +} + // GetJWTBundleForTrustDomain returns the JWT bundle for the given trust // domain. It implements the jwtbundle.Source interface. func (s *JWTSource) GetJWTBundleForTrustDomain(trustDomain spiffeid.TrustDomain) (*jwtbundle.Bundle, error) { diff --git a/v2/workloadapi/watcher.go b/v2/workloadapi/watcher.go index a6329c63..f110e073 100644 --- a/v2/workloadapi/watcher.go +++ b/v2/workloadapi/watcher.go @@ -13,6 +13,7 @@ type sourceClient interface { WatchX509Context(context.Context, X509ContextWatcher) error WatchJWTBundles(context.Context, JWTBundleWatcher) error FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error) + FetchJWTSVIDs(context.Context, jwtsvid.Params) ([]*jwtsvid.SVID, error) Close() error } From 1282ba7bdc68edda1e721d188c132860bc64e07c Mon Sep 17 00:00:00 2001 From: Yuhan Li Date: Tue, 26 Apr 2022 09:27:34 +0800 Subject: [PATCH 2/3] Fix array out of bounds problem Signed-off-by: Yuhan Li --- v2/workloadapi/client.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/v2/workloadapi/client.go b/v2/workloadapi/client.go index 4e990518..b932d81c 100644 --- a/v2/workloadapi/client.go +++ b/v2/workloadapi/client.go @@ -375,6 +375,9 @@ func parseX509Context(resp *workload.X509SVIDResponse) (*X509Context, error) { // Otherwise all SVIDs are parsed and returned. func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svid.SVID, error) { n := len(resp.Svids) + if n == 0 { + return nil, errors.New("no SVIDs in response") + } if firstOnly { n = 1 } @@ -389,9 +392,6 @@ func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svi svids = append(svids, s) } - if len(svids) == 0 { - return nil, errors.New("no SVIDs in response") - } return svids, nil } @@ -436,6 +436,9 @@ func parseX509Bundle(spiffeID string, bundle []byte) (*x509bundle.Bundle, error) // Otherwise all SVIDs are parsed and returned. func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly bool) ([]*jwtsvid.SVID, error) { n := len(resp.Svids) + if n == 0 { + return nil, errors.New("there were no SVIDs in the response") + } if firstOnly { n = 1 } @@ -450,9 +453,6 @@ func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly svids = append(svids, s) } - if len(svids) == 0 { - return nil, errors.New("there were no SVIDs in the response") - } return svids, nil } From 9e22b552f2f921227d441da8777750780250f3b8 Mon Sep 17 00:00:00 2001 From: Andrew Harding Date: Wed, 27 Apr 2022 15:46:32 -0600 Subject: [PATCH 3/3] fix linting Signed-off-by: Andrew Harding --- v2/workloadapi/client_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/v2/workloadapi/client_test.go b/v2/workloadapi/client_test.go index 161380d1..b8633469 100644 --- a/v2/workloadapi/client_test.go +++ b/v2/workloadapi/client_test.go @@ -233,7 +233,7 @@ func TestFetchJWTSVID(t *testing.T) { audienceID := spiffeid.RequireFromPath(td, "/audience") extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience") token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() - respJWT := makeJWTSVIDResponse(ca, []string{token}, subjectID) + respJWT := makeJWTSVIDResponse([]string{token}, subjectID) wl.SetJWTSVIDResponse(respJWT) params := jwtsvid.Params{ @@ -261,7 +261,7 @@ func TestFetchJWTSVIDs(t *testing.T) { extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience") subjectIDToken := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() extraSubjectIDToken := ca.CreateJWTSVID(extraSubjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() - respJWT := makeJWTSVIDResponse(ca, []string{subjectIDToken, extraSubjectIDToken}, subjectID, extraSubjectID) + respJWT := makeJWTSVIDResponse([]string{subjectIDToken, extraSubjectIDToken}, subjectID, extraSubjectID) wl.SetJWTSVIDResponse(respJWT) params := jwtsvid.Params{ @@ -386,7 +386,7 @@ func makeX509SVIDs(ca *test.CA, ids ...spiffeid.ID) []*x509svid.SVID { return svids } -func makeJWTSVIDResponse(ca *test.CA, token []string, ids ...spiffeid.ID) *workload.JWTSVIDResponse { +func makeJWTSVIDResponse(token []string, ids ...spiffeid.ID) *workload.JWTSVIDResponse { svids := []*workload.JWTSVID{} for i, id := range ids { svid := &workload.JWTSVID{