diff --git a/v2/internal/test/ca.go b/v2/internal/test/ca.go index 7e69bfe8..f4f80929 100644 --- a/v2/internal/test/ca.go +++ b/v2/internal/test/ca.go @@ -40,16 +40,6 @@ type CA struct { jwtKid string } -type CertificateOption interface { - apply(*x509.Certificate) -} - -type certificateOption func(*x509.Certificate) - -func (co certificateOption) apply(c *x509.Certificate) { - co(c) -} - func NewCA(tb testing.TB, td spiffeid.TrustDomain) *CA { cert, key := CreateCACertificate(tb, nil, nil) return &CA{ @@ -62,7 +52,7 @@ func NewCA(tb testing.TB, td spiffeid.TrustDomain) *CA { } } -func (ca *CA) ChildCA(options ...CertificateOption) *CA { +func (ca *CA) ChildCA(options ...SVIDOption) *CA { cert, key := CreateCACertificate(ca.tb, ca.cert, ca.key, options...) return &CA{ tb: ca.tb, @@ -74,21 +64,23 @@ func (ca *CA) ChildCA(options ...CertificateOption) *CA { } } -func (ca *CA) CreateX509SVID(id spiffeid.ID, options ...CertificateOption) *x509svid.SVID { +func (ca *CA) CreateX509SVID(id spiffeid.ID, options ...SVIDOption) *x509svid.SVID { cert, key := CreateX509SVID(ca.tb, ca.cert, ca.key, id, options...) - return &x509svid.SVID{ + svid := &x509svid.SVID{ ID: id, Certificates: append([]*x509.Certificate{cert}, ca.chain(false)...), PrivateKey: key, } + applyX509SVIDOptions(svid, options...) + return svid } -func (ca *CA) CreateX509Certificate(options ...CertificateOption) ([]*x509.Certificate, crypto.Signer) { +func (ca *CA) CreateX509Certificate(options ...SVIDOption) ([]*x509.Certificate, crypto.Signer) { cert, key := CreateX509Certificate(ca.tb, ca.cert, ca.key, options...) return append([]*x509.Certificate{cert}, ca.chain(false)...), key } -func (ca *CA) CreateJWTSVID(id spiffeid.ID, audience []string) *jwtsvid.SVID { +func (ca *CA) CreateJWTSVID(id spiffeid.ID, audience []string, options ...SVIDOption) *jwtsvid.SVID { claims := jwt.Claims{ Subject: id.String(), Issuer: "FAKECA", @@ -114,6 +106,9 @@ func (ca *CA) CreateJWTSVID(id spiffeid.ID, audience []string) *jwtsvid.SVID { svid, err := jwtsvid.ParseInsecure(signedToken, audience) require.NoError(ca.tb, err) + + applyJWTSVIDOptions(svid, options...) + return svid } @@ -146,7 +141,7 @@ func (ca *CA) JWTBundle() *jwtbundle.Bundle { return jwtbundle.FromJWTAuthorities(ca.td, ca.JWTAuthorities()) } -func CreateCACertificate(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, options ...CertificateOption) (*x509.Certificate, crypto.Signer) { +func CreateCACertificate(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, options ...SVIDOption) (*x509.Certificate, crypto.Signer) { now := time.Now() serial := NewSerial(tb) key := NewEC256Key(tb) @@ -161,7 +156,7 @@ func CreateCACertificate(tb testing.TB, parent *x509.Certificate, parentKey cryp NotAfter: now.Add(time.Hour), } - applyOptions(tmpl, options...) + applyCertOptions(tmpl, options...) if parent == nil { parent = tmpl @@ -170,7 +165,7 @@ func CreateCACertificate(tb testing.TB, parent *x509.Certificate, parentKey cryp return CreateCertificate(tb, tmpl, parent, key.Public(), parentKey), key } -func CreateX509Certificate(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, options ...CertificateOption) (*x509.Certificate, crypto.Signer) { +func CreateX509Certificate(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, options ...SVIDOption) (*x509.Certificate, crypto.Signer) { now := time.Now() serial := NewSerial(tb) key := NewEC256Key(tb) @@ -184,12 +179,12 @@ func CreateX509Certificate(tb testing.TB, parent *x509.Certificate, parentKey cr KeyUsage: x509.KeyUsageDigitalSignature, } - applyOptions(tmpl, options...) + applyCertOptions(tmpl, options...) return CreateCertificate(tb, tmpl, parent, key.Public(), parentKey), key } -func CreateX509SVID(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, id spiffeid.ID, options ...CertificateOption) (*x509.Certificate, crypto.Signer) { +func CreateX509SVID(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, id spiffeid.ID, options ...SVIDOption) (*x509.Certificate, crypto.Signer) { serial := NewSerial(tb) options = append(options, WithSerial(serial), @@ -230,46 +225,105 @@ func NewSerial(tb testing.TB) *big.Int { return new(big.Int).SetBytes(b) } -func WithSerial(serial *big.Int) CertificateOption { - return certificateOption(func(c *x509.Certificate) { - c.SerialNumber = serial - }) +type SVIDOption struct { + certificateOption func(*x509.Certificate) + x509SvidOption func(*x509svid.SVID) + jwtSvidOption func(*jwtsvid.SVID) +} + +func (s SVIDOption) applyJWTSVIDOption(svid *jwtsvid.SVID) { + if s.jwtSvidOption != nil { + s.jwtSvidOption(svid) + } +} + +func (s SVIDOption) applyCertOption(certificate *x509.Certificate) { + if s.certificateOption != nil { + s.certificateOption(certificate) + } +} + +func (s SVIDOption) applyX509SVIDOption(svid *x509svid.SVID) { + if s.x509SvidOption != nil { + s.x509SvidOption(svid) + } +} + +func WithSerial(serial *big.Int) SVIDOption { + return SVIDOption{ + certificateOption: func(c *x509.Certificate) { + c.SerialNumber = serial + }, + } +} + +func WithKeyUsage(keyUsage x509.KeyUsage) SVIDOption { + return SVIDOption{ + certificateOption: func(c *x509.Certificate) { + c.KeyUsage = keyUsage + }, + } +} + +func WithLifetime(notBefore, notAfter time.Time) SVIDOption { + return SVIDOption{ + certificateOption: func(c *x509.Certificate) { + c.NotBefore = notBefore + c.NotAfter = notAfter + }, + } +} + +func WithIPAddresses(ips ...net.IP) SVIDOption { + return SVIDOption{ + certificateOption: func(c *x509.Certificate) { + c.IPAddresses = ips + }, + } } -func WithKeyUsage(keyUsage x509.KeyUsage) CertificateOption { - return certificateOption(func(c *x509.Certificate) { - c.KeyUsage = keyUsage - }) +func WithURIs(uris ...*url.URL) SVIDOption { + return SVIDOption{ + certificateOption: func(c *x509.Certificate) { + c.URIs = uris + }, + } } -func WithLifetime(notBefore, notAfter time.Time) CertificateOption { - return certificateOption(func(c *x509.Certificate) { - c.NotBefore = notBefore - c.NotAfter = notAfter - }) +func WithSubject(subject pkix.Name) SVIDOption { + return SVIDOption{ + certificateOption: func(c *x509.Certificate) { + c.Subject = subject + }, + } } -func WithIPAddresses(ips ...net.IP) CertificateOption { - return certificateOption(func(c *x509.Certificate) { - c.IPAddresses = ips - }) +func WithHint(hint string) SVIDOption { + return SVIDOption{ + x509SvidOption: func(svid *x509svid.SVID) { + svid.Hint = hint + }, + jwtSvidOption: func(svid *jwtsvid.SVID) { + svid.Hint = hint + }, + } } -func WithURIs(uris ...*url.URL) CertificateOption { - return certificateOption(func(c *x509.Certificate) { - c.URIs = uris - }) +func applyCertOptions(c *x509.Certificate, options ...SVIDOption) { + for _, opt := range options { + opt.applyCertOption(c) + } } -func WithSubject(subject pkix.Name) CertificateOption { - return certificateOption(func(c *x509.Certificate) { - c.Subject = subject - }) +func applyX509SVIDOptions(svid *x509svid.SVID, options ...SVIDOption) { + for _, opt := range options { + opt.applyX509SVIDOption(svid) + } } -func applyOptions(c *x509.Certificate, options ...CertificateOption) { +func applyJWTSVIDOptions(svid *jwtsvid.SVID, options ...SVIDOption) { for _, opt := range options { - opt.apply(c) + opt.applyJWTSVIDOption(svid) } } diff --git a/v2/internal/test/fakeworkloadapi/workload_api.go b/v2/internal/test/fakeworkloadapi/workload_api.go index 45c79950..0668545d 100644 --- a/v2/internal/test/fakeworkloadapi/workload_api.go +++ b/v2/internal/test/fakeworkloadapi/workload_api.go @@ -213,6 +213,7 @@ func (r *X509SVIDResponse) ToProto(tb testing.TB) *workload.X509SVIDResponse { X509Svid: x509util.ConcatRawCertsFromCerts(svid.Certificates), X509SvidKey: keyDER, Bundle: bundle, + Hint: svid.Hint, }) } for _, v := range r.FederatedBundles { diff --git a/v2/svid/jwtsvid/svid.go b/v2/svid/jwtsvid/svid.go index a9c5e6e4..ddbfac34 100644 --- a/v2/svid/jwtsvid/svid.go +++ b/v2/svid/jwtsvid/svid.go @@ -28,6 +28,9 @@ type SVID struct { Expiry time.Time // Claims is the parsed claims from token Claims map[string]interface{} + // Hint is an operator-specified string used to provide guidance on how this + // identity should be used by a workload when more than one SVID is returned. + Hint string // token is the serialized JWT token token string diff --git a/v2/svid/x509svid/svid.go b/v2/svid/x509svid/svid.go index 5fecffe8..4ac51dae 100644 --- a/v2/svid/x509svid/svid.go +++ b/v2/svid/x509svid/svid.go @@ -26,6 +26,10 @@ type SVID struct { // PrivateKey is the private key for the X509-SVID. PrivateKey crypto.Signer + + // Hint is an operator-specified string used to provide guidance on how this + // identity should be used by a workload when more than one SVID is returned. + Hint string } // Load loads the X509-SVID from PEM encoded files on disk. certFile and diff --git a/v2/workloadapi/client.go b/v2/workloadapi/client.go index 3328a98f..7a9685cf 100644 --- a/v2/workloadapi/client.go +++ b/v2/workloadapi/client.go @@ -426,7 +426,7 @@ func parseX509Context(resp *workload.X509SVIDResponse) (*X509Context, error) { // parseX509SVIDs parses one or all of the SVIDs in the response. If firstOnly // is true, then only the first SVID in the response is parsed and returned. -// Otherwise all SVIDs are parsed and returned. +// Otherwise, all SVIDs are parsed and returned. func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svid.SVID, error) { n := len(resp.Svids) if n == 0 { @@ -436,10 +436,20 @@ func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svi n = 1 } + hints := make(map[string]struct{}, n) svids := make([]*x509svid.SVID, 0, n) for i := 0; i < n; i++ { svid := resp.Svids[i] + // In the event of more than one X509SVID message with the same hint value set, then the first message in the + // list SHOULD be selected. + if _, ok := hints[svid.Hint]; ok && svid.Hint != "" { + continue + } + + hints[svid.Hint] = struct{}{} + s, err := x509svid.ParseRaw(svid.X509Svid, svid.X509SvidKey) + s.Hint = svid.Hint if err != nil { return nil, err } @@ -506,7 +516,7 @@ func parseX509BundlesResponse(resp *workload.X509BundlesResponse) (*x509bundle.S // parseJWTSVIDs parses one or all of the SVIDs in the response. If firstOnly // is true, then only the first SVID in the response is parsed and returned. -// Otherwise all SVIDs are parsed and returned. +// Otherwise, all SVIDs are parsed and returned. func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly bool) ([]*jwtsvid.SVID, error) { n := len(resp.Svids) if n == 0 { @@ -516,10 +526,19 @@ func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly n = 1 } + hints := make(map[string]struct{}, n) svids := make([]*jwtsvid.SVID, 0, n) for i := 0; i < n; i++ { svid := resp.Svids[i] + // In the event of more than one X509SVID message with the same hint value set, then the first message in the + // list SHOULD be selected. + if _, ok := hints[svid.Hint]; ok && svid.Hint != "" { + continue + } + hints[svid.Hint] = struct{}{} + s, err := jwtsvid.ParseInsecure(svid.Svid, audience) + s.Hint = svid.Hint if err != nil { return nil, err } diff --git a/v2/workloadapi/client_test.go b/v2/workloadapi/client_test.go index d7951982..13da9025 100644 --- a/v2/workloadapi/client_test.go +++ b/v2/workloadapi/client_test.go @@ -20,11 +20,13 @@ import ( ) var ( - td = spiffeid.RequireTrustDomainFromString("example.org") - federatedTD = spiffeid.RequireTrustDomainFromString("federated.test") - fooID = spiffeid.RequireFromPath(td, "/foo") - barID = spiffeid.RequireFromPath(td, "/bar") - bazID = spiffeid.RequireFromPath(td, "/baz") + td = spiffeid.RequireTrustDomainFromString("example.org") + federatedTD = spiffeid.RequireTrustDomainFromString("federated.test") + fooID = spiffeid.RequireFromPath(td, "/foo") + barID = spiffeid.RequireFromPath(td, "/bar") + bazID = spiffeid.RequireFromPath(td, "/baz") + hintInternal = "internal usage" + hintExternal = "external usage" ) func TestFetchX509SVID(t *testing.T) { @@ -34,17 +36,18 @@ func TestFetchX509SVID(t *testing.T) { c, err := New(context.Background(), WithAddr(wl.Addr())) require.NoError(t, err) defer c.Close() + hint := "internal usage" resp := &fakeworkloadapi.X509SVIDResponse{ Bundle: ca.X509Bundle(), - SVIDs: makeX509SVIDs(ca, fooID, barID), + SVIDs: makeX509SVIDs(ca, hint, fooID, barID), } wl.SetX509SVIDResponse(resp) svid, err := c.FetchX509SVID(context.Background()) require.NoError(t, err) - assertX509SVID(t, svid, fooID, resp.SVIDs[0].Certificates) + assertX509SVID(t, svid, fooID, resp.SVIDs[0].Certificates, hint) } func TestFetchX509SVIDs(t *testing.T) { @@ -55,18 +58,27 @@ func TestFetchX509SVIDs(t *testing.T) { require.NoError(t, err) defer c.Close() + fooSVID := ca.CreateX509SVID(fooID, test.WithHint(hintInternal)) + barSVID := ca.CreateX509SVID(barID, test.WithHint(hintExternal)) + duplicatedHintSVID := ca.CreateX509SVID(bazID, test.WithHint(hintInternal)) + emptyHintSVID1 := ca.CreateX509SVID(spiffeid.RequireFromPath(td, "/empty1"), test.WithHint("")) + emptyHintSVID2 := ca.CreateX509SVID(spiffeid.RequireFromPath(td, "/empty2"), test.WithHint("")) + resp := &fakeworkloadapi.X509SVIDResponse{ Bundle: ca.X509Bundle(), - SVIDs: makeX509SVIDs(ca, fooID, barID), + SVIDs: []*x509svid.SVID{fooSVID, barSVID, duplicatedHintSVID, emptyHintSVID1, emptyHintSVID2}, } wl.SetX509SVIDResponse(resp) svids, err := c.FetchX509SVIDs(context.Background()) require.NoError(t, err) - require.Len(t, svids, 2) - assertX509SVID(t, svids[0], fooID, resp.SVIDs[0].Certificates) - assertX509SVID(t, svids[1], barID, resp.SVIDs[1].Certificates) + // Assert that the response contains the expected SVIDs, and does not contain the SVID with duplicated hint + require.Len(t, svids, 4) + assertX509SVID(t, svids[0], fooID, resp.SVIDs[0].Certificates, hintInternal) + assertX509SVID(t, svids[1], barID, resp.SVIDs[1].Certificates, hintExternal) + assertX509SVID(t, svids[2], emptyHintSVID1.ID, resp.SVIDs[3].Certificates, "") + assertX509SVID(t, svids[3], emptyHintSVID2.ID, resp.SVIDs[4].Certificates, "") } func TestFetchX509Bundles(t *testing.T) { @@ -148,7 +160,13 @@ func TestFetchX509Context(t *testing.T) { require.NoError(t, err) defer c.Close() - svids := makeX509SVIDs(ca, fooID, barID) + fooSVID := ca.CreateX509SVID(fooID, test.WithHint(hintInternal)) + barSVID := ca.CreateX509SVID(barID, test.WithHint(hintExternal)) + duplicatedHintSVID := ca.CreateX509SVID(bazID, test.WithHint(hintInternal)) + emptyHintSVID1 := ca.CreateX509SVID(spiffeid.RequireFromPath(td, "/empty1"), test.WithHint("")) + emptyHintSVID2 := ca.CreateX509SVID(spiffeid.RequireFromPath(td, "/empty2"), test.WithHint("")) + + svids := []*x509svid.SVID{fooSVID, barSVID, duplicatedHintSVID, emptyHintSVID1, emptyHintSVID2} resp := &fakeworkloadapi.X509SVIDResponse{ Bundle: ca.X509Bundle(), @@ -161,9 +179,11 @@ func TestFetchX509Context(t *testing.T) { require.NoError(t, err) // inspect svids - require.Len(t, x509Ctx.SVIDs, 2) - assertX509SVID(t, x509Ctx.SVIDs[0], fooID, resp.SVIDs[0].Certificates) - assertX509SVID(t, x509Ctx.SVIDs[1], barID, resp.SVIDs[1].Certificates) + require.Len(t, x509Ctx.SVIDs, 4) + assertX509SVID(t, x509Ctx.SVIDs[0], fooID, resp.SVIDs[0].Certificates, hintInternal) + assertX509SVID(t, x509Ctx.SVIDs[1], barID, resp.SVIDs[1].Certificates, hintExternal) + assertX509SVID(t, x509Ctx.SVIDs[2], emptyHintSVID1.ID, resp.SVIDs[3].Certificates, "") + assertX509SVID(t, x509Ctx.SVIDs[3], emptyHintSVID2.ID, resp.SVIDs[4].Certificates, "") // inspect bundles assert.Equal(t, 2, x509Ctx.Bundles.Len()) @@ -205,10 +225,18 @@ func TestWatchX509Context(t *testing.T) { require.Len(t, tw.Errors(), 1) require.Len(t, tw.X509Contexts(), 0) + fooSVID := ca.CreateX509SVID(fooID, test.WithHint(hintInternal)) + barSVID := ca.CreateX509SVID(barID, test.WithHint(hintExternal)) + duplicatedHintSVID := ca.CreateX509SVID(bazID, test.WithHint(hintInternal)) + emptyHintSVID1 := ca.CreateX509SVID(spiffeid.RequireFromPath(td, "/empty1"), test.WithHint("")) + emptyHintSVID2 := ca.CreateX509SVID(spiffeid.RequireFromPath(td, "/empty2"), test.WithHint("")) + + svids := []*x509svid.SVID{fooSVID, barSVID, duplicatedHintSVID, emptyHintSVID1, emptyHintSVID2} + // test first update resp := &fakeworkloadapi.X509SVIDResponse{ Bundle: ca.X509Bundle(), - SVIDs: makeX509SVIDs(ca, fooID, barID), + SVIDs: svids, FederatedBundles: []*x509bundle.Bundle{federatedCA.X509Bundle()}, } wl.SetX509SVIDResponse(resp) @@ -219,18 +247,21 @@ func TestWatchX509Context(t *testing.T) { require.Len(t, tw.X509Contexts(), 1) update := tw.X509Contexts()[len(tw.X509Contexts())-1] // inspect svids - require.Len(t, update.SVIDs, 2) - assertX509SVID(t, update.SVIDs[0], fooID, resp.SVIDs[0].Certificates) - assertX509SVID(t, update.SVIDs[1], barID, resp.SVIDs[1].Certificates) + require.Len(t, update.SVIDs, 4) + assertX509SVID(t, update.SVIDs[0], fooID, resp.SVIDs[0].Certificates, hintInternal) + assertX509SVID(t, update.SVIDs[1], barID, resp.SVIDs[1].Certificates, hintExternal) + assertX509SVID(t, update.SVIDs[2], emptyHintSVID1.ID, resp.SVIDs[3].Certificates, "") + assertX509SVID(t, update.SVIDs[3], emptyHintSVID2.ID, resp.SVIDs[4].Certificates, "") // inspect bundles assert.Equal(t, 2, update.Bundles.Len()) assertX509Bundle(t, update.Bundles, td, ca.X509Bundle()) assertX509Bundle(t, update.Bundles, federatedTD, federatedCA.X509Bundle()) + bazSVID := ca.CreateX509SVID(bazID, test.WithHint(hintExternal)) // test second update resp = &fakeworkloadapi.X509SVIDResponse{ Bundle: ca.X509Bundle(), - SVIDs: makeX509SVIDs(ca, bazID), + SVIDs: []*x509svid.SVID{bazSVID}, } wl.SetX509SVIDResponse(resp) @@ -242,7 +273,7 @@ func TestWatchX509Context(t *testing.T) { update = tw.X509Contexts()[len(tw.X509Contexts())-1] // inspect svids require.Len(t, update.SVIDs, 1) - assertX509SVID(t, update.SVIDs[0], bazID, resp.SVIDs[0].Certificates) + assertX509SVID(t, update.SVIDs[0], bazID, resp.SVIDs[0].Certificates, hintExternal) // inspect bundles assert.Equal(t, 1, update.Bundles.Len()) assertX509Bundle(t, update.Bundles, td, ca.X509Bundle()) @@ -266,8 +297,8 @@ func TestFetchJWTSVID(t *testing.T) { subjectID := spiffeid.RequireFromPath(td, "/subject") audienceID := spiffeid.RequireFromPath(td, "/audience") extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience") - token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() - respJWT := makeJWTSVIDResponse([]string{token}, subjectID) + svid := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}, test.WithHint("internal usage")) + respJWT := makeJWTSVIDResponse(svid) wl.SetJWTSVIDResponse(respJWT) params := jwtsvid.Params{ @@ -279,7 +310,7 @@ func TestFetchJWTSVID(t *testing.T) { jwtSvid, err := c.FetchJWTSVID(context.Background(), params) require.NoError(t, err) - assertJWTSVID(t, jwtSvid, subjectID, token, audienceID.String(), extraAudienceID.String()) + assertJWTSVID(t, jwtSvid, subjectID, svid.Marshal(), svid.Hint, audienceID.String(), extraAudienceID.String()) } func TestFetchJWTSVIDs(t *testing.T) { @@ -291,11 +322,15 @@ func TestFetchJWTSVIDs(t *testing.T) { subjectID := spiffeid.RequireFromPath(td, "/subject") extraSubjectID := spiffeid.RequireFromPath(td, "/extra_subject") + duplicatedHintID := spiffeid.RequireFromPath(td, "/somePath") audienceID := spiffeid.RequireFromPath(td, "/audience") extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience") - subjectIDToken := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() - extraSubjectIDToken := ca.CreateJWTSVID(extraSubjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal() - respJWT := makeJWTSVIDResponse([]string{subjectIDToken, extraSubjectIDToken}, subjectID, extraSubjectID) + subjectSVID := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}, test.WithHint("internal usage")) + extraSubjectSVID := ca.CreateJWTSVID(extraSubjectID, []string{audienceID.String(), extraAudienceID.String()}, test.WithHint("external usage")) + duplicatedHintSVID := ca.CreateJWTSVID(duplicatedHintID, []string{audienceID.String(), extraAudienceID.String()}, test.WithHint("internal usage")) + emptyHintSVID1 := ca.CreateJWTSVID(extraSubjectID, []string{audienceID.String(), extraAudienceID.String()}, test.WithHint("")) + emptyHintSVID2 := ca.CreateJWTSVID(duplicatedHintID, []string{audienceID.String(), extraAudienceID.String()}, test.WithHint("")) + respJWT := makeJWTSVIDResponse(subjectSVID, extraSubjectSVID, duplicatedHintSVID, emptyHintSVID1, emptyHintSVID2) wl.SetJWTSVIDResponse(respJWT) params := jwtsvid.Params{ @@ -307,8 +342,12 @@ func TestFetchJWTSVIDs(t *testing.T) { jwtSvid, err := c.FetchJWTSVIDs(context.Background(), params) require.NoError(t, err) - assertJWTSVID(t, jwtSvid[0], subjectID, subjectIDToken, audienceID.String(), extraAudienceID.String()) - assertJWTSVID(t, jwtSvid[1], extraSubjectID, extraSubjectIDToken, audienceID.String(), extraAudienceID.String()) + // Assert that the response contains the expected SVIDs, and does not contain the SVID with duplicated hint + require.Len(t, jwtSvid, 4) + assertJWTSVID(t, jwtSvid[0], subjectID, subjectSVID.Marshal(), subjectSVID.Hint, audienceID.String(), extraAudienceID.String()) + assertJWTSVID(t, jwtSvid[1], extraSubjectID, extraSubjectSVID.Marshal(), extraSubjectSVID.Hint, audienceID.String(), extraAudienceID.String()) + assertJWTSVID(t, jwtSvid[2], emptyHintSVID1.ID, emptyHintSVID1.Marshal(), emptyHintSVID1.Hint, audienceID.String(), extraAudienceID.String()) + assertJWTSVID(t, jwtSvid[3], emptyHintSVID2.ID, emptyHintSVID2.Marshal(), emptyHintSVID2.Hint, audienceID.String(), extraAudienceID.String()) } func TestFetchJWTBundles(t *testing.T) { @@ -394,14 +433,14 @@ func TestValidateJWTSVID(t *testing.T) { jwtSvid, err := c.ValidateJWTSVID(context.Background(), token.Marshal(), audience[0]) assert.NoError(t, err) - assertJWTSVID(t, jwtSvid, workloadID, token.Marshal(), audience...) + assertJWTSVID(t, jwtSvid, workloadID, token.Marshal(), "", audience...) }) t.Run("second audience is valid", func(t *testing.T) { jwtSvid, err := c.ValidateJWTSVID(context.Background(), token.Marshal(), audience[1]) assert.NoError(t, err) - assertJWTSVID(t, jwtSvid, workloadID, token.Marshal(), audience...) + assertJWTSVID(t, jwtSvid, workloadID, token.Marshal(), "", audience...) }) t.Run("invalid audience returns error", func(t *testing.T) { @@ -412,31 +451,33 @@ func TestValidateJWTSVID(t *testing.T) { }) } -func makeX509SVIDs(ca *test.CA, ids ...spiffeid.ID) []*x509svid.SVID { +func makeX509SVIDs(ca *test.CA, hint string, ids ...spiffeid.ID) []*x509svid.SVID { svids := []*x509svid.SVID{} for _, id := range ids { - svids = append(svids, ca.CreateX509SVID(id)) + svids = append(svids, ca.CreateX509SVID(id, test.WithHint(hint))) } return svids } -func makeJWTSVIDResponse(token []string, ids ...spiffeid.ID) *workload.JWTSVIDResponse { - svids := []*workload.JWTSVID{} - for i, id := range ids { - svid := &workload.JWTSVID{ - SpiffeId: id.String(), - Svid: token[i], +func makeJWTSVIDResponse(svids ...*jwtsvid.SVID) *workload.JWTSVIDResponse { + respSVIDS := []*workload.JWTSVID{} + for _, svid := range svids { + respSVID := &workload.JWTSVID{ + SpiffeId: svid.ID.String(), + Svid: svid.Marshal(), + Hint: svid.Hint, } - svids = append(svids, svid) + respSVIDS = append(respSVIDS, respSVID) } return &workload.JWTSVIDResponse{ - Svids: svids, + Svids: respSVIDS, } } -func assertX509SVID(tb testing.TB, svid *x509svid.SVID, spiffeID spiffeid.ID, certificates []*x509.Certificate) { +func assertX509SVID(tb testing.TB, svid *x509svid.SVID, spiffeID spiffeid.ID, certificates []*x509.Certificate, hint string) { assert.Equal(tb, spiffeID, svid.ID) assert.Equal(tb, certificates, svid.Certificates) + assert.Equal(tb, hint, svid.Hint) assert.NotEmpty(tb, svid.PrivateKey) } @@ -452,12 +493,13 @@ func assertJWTBundle(tb testing.TB, bundleSet *jwtbundle.Set, trustDomain spiffe assert.Equal(tb, b, expectedBundle) } -func assertJWTSVID(t testing.TB, jwtSvid *jwtsvid.SVID, subjectID spiffeid.ID, token string, audience ...string) { +func assertJWTSVID(t testing.TB, jwtSvid *jwtsvid.SVID, subjectID spiffeid.ID, token, hint string, audience ...string) { assert.Equal(t, subjectID.String(), jwtSvid.ID.String()) assert.Equal(t, audience, jwtSvid.Audience) assert.NotNil(t, jwtSvid.Claims) assert.NotEmpty(t, jwtSvid.Expiry) assert.Equal(t, token, jwtSvid.Marshal()) + assert.Equal(t, hint, jwtSvid.Hint) } type testWatcher struct { diff --git a/v2/workloadapi/client_windows_test.go b/v2/workloadapi/client_windows_test.go index c73b4e38..2d503c55 100644 --- a/v2/workloadapi/client_windows_test.go +++ b/v2/workloadapi/client_windows_test.go @@ -26,12 +26,12 @@ func TestWithNamedPipeName(t *testing.T) { resp := &fakeworkloadapi.X509SVIDResponse{ Bundle: ca.X509Bundle(), - SVIDs: makeX509SVIDs(ca, fooID, barID), + SVIDs: makeX509SVIDs(ca, "internal", fooID, barID), } wl.SetX509SVIDResponse(resp) svid, err := c.FetchX509SVID(context.Background()) require.NoError(t, err) - assertX509SVID(t, svid, fooID, resp.SVIDs[0].Certificates) + assertX509SVID(t, svid, fooID, resp.SVIDs[0].Certificates, "internal") } func TestWithNamedPipeNameError(t *testing.T) {