Skip to content

Commit

Permalink
update cert validation to take a slice, clean up dead code
Browse files Browse the repository at this point in the history
  • Loading branch information
tgross committed Oct 12, 2023
1 parent 9701aca commit b07a322
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 107 deletions.
17 changes: 8 additions & 9 deletions nomad/alloc_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,22 +451,21 @@ func (a *Alloc) GetServiceRegistrations(
// This is an internal-only RPC and not exposed via the HTTP API.
func (a *Alloc) SignIdentities(args *structs.AllocIdentitiesRequest, reply *structs.AllocIdentitiesResponse) error {

authErr := a.srv.Authenticate(a.ctx, args)

// Ensure the connection was initiated by a client if TLS is used.
if err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelClient); err != nil {
return err
aclObj, err := a.srv.AuthenticateClientOnly(a.ctx, args)
a.srv.MeasureRPCRate("alloc", structs.RateMetricRead, args)
if err != nil {
return structs.ErrPermissionDenied
}

if done, err := a.srv.forward("Alloc.SignIdentities", args, args, reply); done {
return err
}
a.srv.MeasureRPCRate("alloc", structs.RateMetricRead, args)
if authErr != nil {
defer metrics.MeasureSince([]string{"nomad", "alloc", "sign_identities"}, time.Now())

if !aclObj.AllowClientOp() {
return structs.ErrPermissionDenied
}

defer metrics.MeasureSince([]string{"nomad", "alloc", "sign_identities"}, time.Now())

if len(args.Identities) == 0 {
// Client bug. Fail loudly instead of letting clients waste time with
// noops.
Expand Down
66 changes: 32 additions & 34 deletions nomad/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
"golang.org/x/exp/slices"
)

// aclCacheSize is the number of ACL objects to keep cached. ACLs have a parsing
Expand Down Expand Up @@ -46,6 +47,9 @@ type Authenticator struct {
getLeaderACL LeaderACLGetter
region string

validServerCertNames []string
validClientCertNames []string

// aclCache is used to maintain the parsed ACL objects
aclCache *structs.ACLCache[*acl.ACL]

Expand All @@ -66,14 +70,19 @@ type AuthenticatorConfig struct {

func NewAuthenticator(cfg *AuthenticatorConfig) *Authenticator {
return &Authenticator{
aclsEnabled: cfg.AclsEnabled,
tlsEnabled: cfg.TLSEnabled,
logger: cfg.Logger.With("auth"),
getState: cfg.StateFn,
getLeaderACL: cfg.GetLeaderACLFn,
region: cfg.Region,
aclCache: structs.NewACLCache[*acl.ACL](aclCacheSize),
encrypter: cfg.Encrypter,
aclsEnabled: cfg.AclsEnabled,
tlsEnabled: cfg.TLSEnabled,
logger: cfg.Logger.With("auth"),
getState: cfg.StateFn,
getLeaderACL: cfg.GetLeaderACLFn,
region: cfg.Region,
aclCache: structs.NewACLCache[*acl.ACL](aclCacheSize),
encrypter: cfg.Encrypter,
validServerCertNames: []string{"server." + cfg.Region + ".nomad"},
validClientCertNames: []string{
"client." + cfg.Region + ".nomad",
"server." + cfg.Region + ".nomad",
},
}
}

Expand Down Expand Up @@ -232,9 +241,7 @@ func (s *Authenticator) AuthenticateServerOnly(ctx RPCContext, args structs.Requ
// set on the identity whether or not its valid for server RPC, so we
// can capture it for metrics
identity.TLSName = tlsCert.Subject.CommonName

expected := "server." + s.region + ".nomad"
_, err := validateCertificateForName(tlsCert, expected)
_, err := validateCertificateForNames(tlsCert, s.validServerCertNames)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -277,15 +284,9 @@ func (s *Authenticator) AuthenticateClientOnly(ctx RPCContext, args structs.Requ
// set on the identity whether or not its valid for server RPC, so we
// can capture it for metrics
identity.TLSName = tlsCert.Subject.CommonName

expected := fmt.Sprintf("client.%s.nomad", s.region)
_, err := validateCertificateForName(tlsCert, expected)
_, err := validateCertificateForNames(tlsCert, s.validClientCertNames)
if err != nil {
expected := fmt.Sprintf("server.%s.nomad", s.region)
_, err := validateCertificateForName(tlsCert, expected)
if err != nil {
return nil, err
}
return nil, err
}
}

Expand Down Expand Up @@ -331,38 +332,35 @@ func (s *Authenticator) AuthenticateClientOnlyLegacy(ctx RPCContext, args struct
// set on the identity whether or not its valid for server RPC, so we
// can capture it for metrics
identity.TLSName = tlsCert.Subject.CommonName

expected := fmt.Sprintf("client.%s.nomad", s.region)
_, err := validateCertificateForName(tlsCert, expected)
_, err := validateCertificateForNames(tlsCert, s.validClientCertNames)
if err != nil {
expected := fmt.Sprintf("server.%s.nomad", s.region)
_, err := validateCertificateForName(tlsCert, expected)
if err != nil {
return nil, err
}
return nil, err
}
}

return acl.ClientACL, nil
}

// validateCertificateForName returns true if the certificate is valid
// for the given domain name.
func validateCertificateForName(cert *x509.Certificate, expectedName string) (bool, error) {
// validateCertificateForNames returns true if the certificate is valid for any
// of the given domain names.
func validateCertificateForNames(cert *x509.Certificate, expectedNames []string) (bool, error) {
if cert == nil {
return false, nil
}

validNames := []string{cert.Subject.CommonName}
validNames = append(validNames, cert.DNSNames...)
for _, valid := range validNames {
if expectedName == valid {

for _, expectedName := range expectedNames {
if slices.Contains(validNames, expectedName) {
return true, nil
}
}

return false, fmt.Errorf("invalid certificate, %s not in %s",
expectedName, strings.Join(validNames, ","))
return false, fmt.Errorf("invalid certificate: %s not in expected %s",
strings.Join(validNames, ", "),
strings.Join(expectedNames, ", "))

}

// ResolveACLForToken resolves an ACL from a token only. It should be used only
Expand Down
4 changes: 2 additions & 2 deletions nomad/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ func TestAuthenticateServerOnly(t *testing.T) {

aclObj, err := auth.AuthenticateServerOnly(ctx, args)
must.EqError(t, err,
"invalid certificate, server.global.nomad not in client.global.nomad")
"invalid certificate: client.global.nomad not in expected server.global.nomad")
must.Eq(t, "client.global.nomad:192.168.1.1", args.GetIdentity().String())
must.Nil(t, aclObj)
},
Expand Down Expand Up @@ -507,7 +507,7 @@ func TestAuthenticateClientOnly(t *testing.T) {

aclObj, err := auth.AuthenticateClientOnly(ctx, args)
must.EqError(t, err,
"invalid certificate, server.global.nomad not in cli.global.nomad")
"invalid certificate: cli.global.nomad not in expected client.global.nomad, server.global.nomad")
must.Eq(t, "cli.global.nomad:192.168.1.1", args.GetIdentity().String())
must.Nil(t, aclObj)
},
Expand Down
62 changes: 0 additions & 62 deletions nomad/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,65 +282,3 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) {

return alloc, nil
}

// tlsCertificateLevel represents a role level for mTLS certificates.
type tlsCertificateLevel int8

const (
tlsCertificateLevelServer tlsCertificateLevel = iota
tlsCertificateLevelClient
)

// validateTLSCertificateLevel checks if the provided RPC connection was
// initiated with a certificate that matches the given TLS role level.
//
// - tlsCertificateLevelServer requires a server certificate.
// - tlsCertificateLevelServer requires a client or server certificate.
func validateTLSCertificateLevel(srv *Server, ctx *RPCContext, lvl tlsCertificateLevel) error {
switch lvl {
case tlsCertificateLevelClient:
err := validateLocalClientTLSCertificate(srv, ctx)
if err != nil {
return validateLocalServerTLSCertificate(srv, ctx)
}
return nil
case tlsCertificateLevelServer:
return validateLocalServerTLSCertificate(srv, ctx)
}

return fmt.Errorf("invalid TLS certificate level %v", lvl)
}

// validateLocalClientTLSCertificate checks if the provided RPC connection was
// initiated by a client in the same region as the target server.
func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error {
expected := fmt.Sprintf("client.%s.nomad", srv.Region())

err := validateTLSCertificate(srv, ctx, expected)
if err != nil {
return fmt.Errorf("invalid client connection in region %s: %v", srv.Region(), err)
}
return nil
}

// validateLocalServerTLSCertificate checks if the provided RPC connection was
// initiated by a server in the same region as the target server.
func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error {
expected := fmt.Sprintf("server.%s.nomad", srv.Region())

err := validateTLSCertificate(srv, ctx, expected)
if err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", srv.Region(), err)
}
return nil
}

// validateTLSCertificate checks if the RPC connection mTLS certificates are
// valid for the given name.
func validateTLSCertificate(srv *Server, ctx *RPCContext, name string) error {
if srv.config.TLSConfig == nil || !srv.config.TLSConfig.VerifyServerHostname {
return nil
}

return ctx.ValidateCertificateForName(name)
}

0 comments on commit b07a322

Please sign in to comment.