Skip to content

Commit

Permalink
Add SVID hints on workload api client (#220)
Browse files Browse the repository at this point in the history
Signed-off-by: Guilherme Carvalho <guilhermbrsp@gmail.com>
Co-authored-by: Daniel Feldman <dfeldman.mn@gmail.com>
  • Loading branch information
guilhermocc and dfeldman authored Mar 31, 2023
1 parent acf23ce commit 4d7bbf4
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 94 deletions.
148 changes: 101 additions & 47 deletions v2/internal/test/ca.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,6 @@ type CA struct {
jwtKid string
}

type CertificateOption interface {
apply(*x509.Certificate)
}

type certificateOption func(*x509.Certificate)

func (co certificateOption) apply(c *x509.Certificate) {
co(c)
}

func NewCA(tb testing.TB, td spiffeid.TrustDomain) *CA {
cert, key := CreateCACertificate(tb, nil, nil)
return &CA{
Expand All @@ -62,7 +52,7 @@ func NewCA(tb testing.TB, td spiffeid.TrustDomain) *CA {
}
}

func (ca *CA) ChildCA(options ...CertificateOption) *CA {
func (ca *CA) ChildCA(options ...SVIDOption) *CA {
cert, key := CreateCACertificate(ca.tb, ca.cert, ca.key, options...)
return &CA{
tb: ca.tb,
Expand All @@ -74,21 +64,23 @@ func (ca *CA) ChildCA(options ...CertificateOption) *CA {
}
}

func (ca *CA) CreateX509SVID(id spiffeid.ID, options ...CertificateOption) *x509svid.SVID {
func (ca *CA) CreateX509SVID(id spiffeid.ID, options ...SVIDOption) *x509svid.SVID {
cert, key := CreateX509SVID(ca.tb, ca.cert, ca.key, id, options...)
return &x509svid.SVID{
svid := &x509svid.SVID{
ID: id,
Certificates: append([]*x509.Certificate{cert}, ca.chain(false)...),
PrivateKey: key,
}
applyX509SVIDOptions(svid, options...)
return svid
}

func (ca *CA) CreateX509Certificate(options ...CertificateOption) ([]*x509.Certificate, crypto.Signer) {
func (ca *CA) CreateX509Certificate(options ...SVIDOption) ([]*x509.Certificate, crypto.Signer) {
cert, key := CreateX509Certificate(ca.tb, ca.cert, ca.key, options...)
return append([]*x509.Certificate{cert}, ca.chain(false)...), key
}

func (ca *CA) CreateJWTSVID(id spiffeid.ID, audience []string) *jwtsvid.SVID {
func (ca *CA) CreateJWTSVID(id spiffeid.ID, audience []string, options ...SVIDOption) *jwtsvid.SVID {
claims := jwt.Claims{
Subject: id.String(),
Issuer: "FAKECA",
Expand All @@ -114,6 +106,9 @@ func (ca *CA) CreateJWTSVID(id spiffeid.ID, audience []string) *jwtsvid.SVID {

svid, err := jwtsvid.ParseInsecure(signedToken, audience)
require.NoError(ca.tb, err)

applyJWTSVIDOptions(svid, options...)

return svid
}

Expand Down Expand Up @@ -146,7 +141,7 @@ func (ca *CA) JWTBundle() *jwtbundle.Bundle {
return jwtbundle.FromJWTAuthorities(ca.td, ca.JWTAuthorities())
}

func CreateCACertificate(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, options ...CertificateOption) (*x509.Certificate, crypto.Signer) {
func CreateCACertificate(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, options ...SVIDOption) (*x509.Certificate, crypto.Signer) {
now := time.Now()
serial := NewSerial(tb)
key := NewEC256Key(tb)
Expand All @@ -161,7 +156,7 @@ func CreateCACertificate(tb testing.TB, parent *x509.Certificate, parentKey cryp
NotAfter: now.Add(time.Hour),
}

applyOptions(tmpl, options...)
applyCertOptions(tmpl, options...)

if parent == nil {
parent = tmpl
Expand All @@ -170,7 +165,7 @@ func CreateCACertificate(tb testing.TB, parent *x509.Certificate, parentKey cryp
return CreateCertificate(tb, tmpl, parent, key.Public(), parentKey), key
}

func CreateX509Certificate(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, options ...CertificateOption) (*x509.Certificate, crypto.Signer) {
func CreateX509Certificate(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, options ...SVIDOption) (*x509.Certificate, crypto.Signer) {
now := time.Now()
serial := NewSerial(tb)
key := NewEC256Key(tb)
Expand All @@ -184,12 +179,12 @@ func CreateX509Certificate(tb testing.TB, parent *x509.Certificate, parentKey cr
KeyUsage: x509.KeyUsageDigitalSignature,
}

applyOptions(tmpl, options...)
applyCertOptions(tmpl, options...)

return CreateCertificate(tb, tmpl, parent, key.Public(), parentKey), key
}

func CreateX509SVID(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, id spiffeid.ID, options ...CertificateOption) (*x509.Certificate, crypto.Signer) {
func CreateX509SVID(tb testing.TB, parent *x509.Certificate, parentKey crypto.Signer, id spiffeid.ID, options ...SVIDOption) (*x509.Certificate, crypto.Signer) {
serial := NewSerial(tb)
options = append(options,
WithSerial(serial),
Expand Down Expand Up @@ -230,46 +225,105 @@ func NewSerial(tb testing.TB) *big.Int {
return new(big.Int).SetBytes(b)
}

func WithSerial(serial *big.Int) CertificateOption {
return certificateOption(func(c *x509.Certificate) {
c.SerialNumber = serial
})
type SVIDOption struct {
certificateOption func(*x509.Certificate)
x509SvidOption func(*x509svid.SVID)
jwtSvidOption func(*jwtsvid.SVID)
}

func (s SVIDOption) applyJWTSVIDOption(svid *jwtsvid.SVID) {
if s.jwtSvidOption != nil {
s.jwtSvidOption(svid)
}
}

func (s SVIDOption) applyCertOption(certificate *x509.Certificate) {
if s.certificateOption != nil {
s.certificateOption(certificate)
}
}

func (s SVIDOption) applyX509SVIDOption(svid *x509svid.SVID) {
if s.x509SvidOption != nil {
s.x509SvidOption(svid)
}
}

func WithSerial(serial *big.Int) SVIDOption {
return SVIDOption{
certificateOption: func(c *x509.Certificate) {
c.SerialNumber = serial
},
}
}

func WithKeyUsage(keyUsage x509.KeyUsage) SVIDOption {
return SVIDOption{
certificateOption: func(c *x509.Certificate) {
c.KeyUsage = keyUsage
},
}
}

func WithLifetime(notBefore, notAfter time.Time) SVIDOption {
return SVIDOption{
certificateOption: func(c *x509.Certificate) {
c.NotBefore = notBefore
c.NotAfter = notAfter
},
}
}

func WithIPAddresses(ips ...net.IP) SVIDOption {
return SVIDOption{
certificateOption: func(c *x509.Certificate) {
c.IPAddresses = ips
},
}
}

func WithKeyUsage(keyUsage x509.KeyUsage) CertificateOption {
return certificateOption(func(c *x509.Certificate) {
c.KeyUsage = keyUsage
})
func WithURIs(uris ...*url.URL) SVIDOption {
return SVIDOption{
certificateOption: func(c *x509.Certificate) {
c.URIs = uris
},
}
}

func WithLifetime(notBefore, notAfter time.Time) CertificateOption {
return certificateOption(func(c *x509.Certificate) {
c.NotBefore = notBefore
c.NotAfter = notAfter
})
func WithSubject(subject pkix.Name) SVIDOption {
return SVIDOption{
certificateOption: func(c *x509.Certificate) {
c.Subject = subject
},
}
}

func WithIPAddresses(ips ...net.IP) CertificateOption {
return certificateOption(func(c *x509.Certificate) {
c.IPAddresses = ips
})
func WithHint(hint string) SVIDOption {
return SVIDOption{
x509SvidOption: func(svid *x509svid.SVID) {
svid.Hint = hint
},
jwtSvidOption: func(svid *jwtsvid.SVID) {
svid.Hint = hint
},
}
}

func WithURIs(uris ...*url.URL) CertificateOption {
return certificateOption(func(c *x509.Certificate) {
c.URIs = uris
})
func applyCertOptions(c *x509.Certificate, options ...SVIDOption) {
for _, opt := range options {
opt.applyCertOption(c)
}
}

func WithSubject(subject pkix.Name) CertificateOption {
return certificateOption(func(c *x509.Certificate) {
c.Subject = subject
})
func applyX509SVIDOptions(svid *x509svid.SVID, options ...SVIDOption) {
for _, opt := range options {
opt.applyX509SVIDOption(svid)
}
}

func applyOptions(c *x509.Certificate, options ...CertificateOption) {
func applyJWTSVIDOptions(svid *jwtsvid.SVID, options ...SVIDOption) {
for _, opt := range options {
opt.apply(c)
opt.applyJWTSVIDOption(svid)
}
}

Expand Down
1 change: 1 addition & 0 deletions v2/internal/test/fakeworkloadapi/workload_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ func (r *X509SVIDResponse) ToProto(tb testing.TB) *workload.X509SVIDResponse {
X509Svid: x509util.ConcatRawCertsFromCerts(svid.Certificates),
X509SvidKey: keyDER,
Bundle: bundle,
Hint: svid.Hint,
})
}
for _, v := range r.FederatedBundles {
Expand Down
3 changes: 3 additions & 0 deletions v2/svid/jwtsvid/svid.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type SVID struct {
Expiry time.Time
// Claims is the parsed claims from token
Claims map[string]interface{}
// Hint is an operator-specified string used to provide guidance on how this
// identity should be used by a workload when more than one SVID is returned.
Hint string

// token is the serialized JWT token
token string
Expand Down
4 changes: 4 additions & 0 deletions v2/svid/x509svid/svid.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ type SVID struct {

// PrivateKey is the private key for the X509-SVID.
PrivateKey crypto.Signer

// Hint is an operator-specified string used to provide guidance on how this
// identity should be used by a workload when more than one SVID is returned.
Hint string
}

// Load loads the X509-SVID from PEM encoded files on disk. certFile and
Expand Down
23 changes: 21 additions & 2 deletions v2/workloadapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ func parseX509Context(resp *workload.X509SVIDResponse) (*X509Context, error) {

// parseX509SVIDs parses one or all of the SVIDs in the response. If firstOnly
// is true, then only the first SVID in the response is parsed and returned.
// Otherwise all SVIDs are parsed and returned.
// Otherwise, all SVIDs are parsed and returned.
func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svid.SVID, error) {
n := len(resp.Svids)
if n == 0 {
Expand All @@ -436,10 +436,20 @@ func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svi
n = 1
}

hints := make(map[string]struct{}, n)
svids := make([]*x509svid.SVID, 0, n)
for i := 0; i < n; i++ {
svid := resp.Svids[i]
// In the event of more than one X509SVID message with the same hint value set, then the first message in the
// list SHOULD be selected.
if _, ok := hints[svid.Hint]; ok && svid.Hint != "" {
continue
}

hints[svid.Hint] = struct{}{}

s, err := x509svid.ParseRaw(svid.X509Svid, svid.X509SvidKey)
s.Hint = svid.Hint
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -506,7 +516,7 @@ func parseX509BundlesResponse(resp *workload.X509BundlesResponse) (*x509bundle.S

// parseJWTSVIDs parses one or all of the SVIDs in the response. If firstOnly
// is true, then only the first SVID in the response is parsed and returned.
// Otherwise all SVIDs are parsed and returned.
// Otherwise, all SVIDs are parsed and returned.
func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly bool) ([]*jwtsvid.SVID, error) {
n := len(resp.Svids)
if n == 0 {
Expand All @@ -516,10 +526,19 @@ func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly
n = 1
}

hints := make(map[string]struct{}, n)
svids := make([]*jwtsvid.SVID, 0, n)
for i := 0; i < n; i++ {
svid := resp.Svids[i]
// In the event of more than one X509SVID message with the same hint value set, then the first message in the
// list SHOULD be selected.
if _, ok := hints[svid.Hint]; ok && svid.Hint != "" {
continue
}
hints[svid.Hint] = struct{}{}

s, err := jwtsvid.ParseInsecure(svid.Svid, audience)
s.Hint = svid.Hint
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 4d7bbf4

Please sign in to comment.