diff --git a/nomad/alloc_endpoint.go b/nomad/alloc_endpoint.go index bdd1f5801e68..8a83229fd9dd 100644 --- a/nomad/alloc_endpoint.go +++ b/nomad/alloc_endpoint.go @@ -451,22 +451,21 @@ func (a *Alloc) GetServiceRegistrations( // This is an internal-only RPC and not exposed via the HTTP API. func (a *Alloc) SignIdentities(args *structs.AllocIdentitiesRequest, reply *structs.AllocIdentitiesResponse) error { - authErr := a.srv.Authenticate(a.ctx, args) - - // Ensure the connection was initiated by a client if TLS is used. - if err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelClient); err != nil { - return err + aclObj, err := a.srv.AuthenticateClientOnly(a.ctx, args) + a.srv.MeasureRPCRate("alloc", structs.RateMetricRead, args) + if err != nil { + return structs.ErrPermissionDenied } + if done, err := a.srv.forward("Alloc.SignIdentities", args, args, reply); done { return err } - a.srv.MeasureRPCRate("alloc", structs.RateMetricRead, args) - if authErr != nil { + defer metrics.MeasureSince([]string{"nomad", "alloc", "sign_identities"}, time.Now()) + + if !aclObj.AllowClientOp() { return structs.ErrPermissionDenied } - defer metrics.MeasureSince([]string{"nomad", "alloc", "sign_identities"}, time.Now()) - if len(args.Identities) == 0 { // Client bug. Fail loudly instead of letting clients waste time with // noops. diff --git a/nomad/auth/auth.go b/nomad/auth/auth.go index 7602086c5005..696a23fe707b 100644 --- a/nomad/auth/auth.go +++ b/nomad/auth/auth.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" + "golang.org/x/exp/slices" ) // aclCacheSize is the number of ACL objects to keep cached. ACLs have a parsing @@ -46,6 +47,9 @@ type Authenticator struct { getLeaderACL LeaderACLGetter region string + validServerCertNames []string + validClientCertNames []string + // aclCache is used to maintain the parsed ACL objects aclCache *structs.ACLCache[*acl.ACL] @@ -66,14 +70,19 @@ type AuthenticatorConfig struct { func NewAuthenticator(cfg *AuthenticatorConfig) *Authenticator { return &Authenticator{ - aclsEnabled: cfg.AclsEnabled, - tlsEnabled: cfg.TLSEnabled, - logger: cfg.Logger.With("auth"), - getState: cfg.StateFn, - getLeaderACL: cfg.GetLeaderACLFn, - region: cfg.Region, - aclCache: structs.NewACLCache[*acl.ACL](aclCacheSize), - encrypter: cfg.Encrypter, + aclsEnabled: cfg.AclsEnabled, + tlsEnabled: cfg.TLSEnabled, + logger: cfg.Logger.With("auth"), + getState: cfg.StateFn, + getLeaderACL: cfg.GetLeaderACLFn, + region: cfg.Region, + aclCache: structs.NewACLCache[*acl.ACL](aclCacheSize), + encrypter: cfg.Encrypter, + validServerCertNames: []string{"server." + cfg.Region + ".nomad"}, + validClientCertNames: []string{ + "client." + cfg.Region + ".nomad", + "server." + cfg.Region + ".nomad", + }, } } @@ -232,9 +241,7 @@ func (s *Authenticator) AuthenticateServerOnly(ctx RPCContext, args structs.Requ // set on the identity whether or not its valid for server RPC, so we // can capture it for metrics identity.TLSName = tlsCert.Subject.CommonName - - expected := "server." + s.region + ".nomad" - _, err := validateCertificateForName(tlsCert, expected) + _, err := validateCertificateForNames(tlsCert, s.validServerCertNames) if err != nil { return nil, err } @@ -277,15 +284,9 @@ func (s *Authenticator) AuthenticateClientOnly(ctx RPCContext, args structs.Requ // set on the identity whether or not its valid for server RPC, so we // can capture it for metrics identity.TLSName = tlsCert.Subject.CommonName - - expected := fmt.Sprintf("client.%s.nomad", s.region) - _, err := validateCertificateForName(tlsCert, expected) + _, err := validateCertificateForNames(tlsCert, s.validClientCertNames) if err != nil { - expected := fmt.Sprintf("server.%s.nomad", s.region) - _, err := validateCertificateForName(tlsCert, expected) - if err != nil { - return nil, err - } + return nil, err } } @@ -331,38 +332,35 @@ func (s *Authenticator) AuthenticateClientOnlyLegacy(ctx RPCContext, args struct // set on the identity whether or not its valid for server RPC, so we // can capture it for metrics identity.TLSName = tlsCert.Subject.CommonName - - expected := fmt.Sprintf("client.%s.nomad", s.region) - _, err := validateCertificateForName(tlsCert, expected) + _, err := validateCertificateForNames(tlsCert, s.validClientCertNames) if err != nil { - expected := fmt.Sprintf("server.%s.nomad", s.region) - _, err := validateCertificateForName(tlsCert, expected) - if err != nil { - return nil, err - } + return nil, err } } return acl.ClientACL, nil } -// validateCertificateForName returns true if the certificate is valid -// for the given domain name. -func validateCertificateForName(cert *x509.Certificate, expectedName string) (bool, error) { +// validateCertificateForNames returns true if the certificate is valid for any +// of the given domain names. +func validateCertificateForNames(cert *x509.Certificate, expectedNames []string) (bool, error) { if cert == nil { return false, nil } validNames := []string{cert.Subject.CommonName} validNames = append(validNames, cert.DNSNames...) - for _, valid := range validNames { - if expectedName == valid { + + for _, expectedName := range expectedNames { + if slices.Contains(validNames, expectedName) { return true, nil } } - return false, fmt.Errorf("invalid certificate, %s not in %s", - expectedName, strings.Join(validNames, ",")) + return false, fmt.Errorf("invalid certificate: %s not in expected %s", + strings.Join(validNames, ", "), + strings.Join(expectedNames, ", ")) + } // ResolveACLForToken resolves an ACL from a token only. It should be used only diff --git a/nomad/auth/auth_test.go b/nomad/auth/auth_test.go index bcac38976fe7..376d2f39a459 100644 --- a/nomad/auth/auth_test.go +++ b/nomad/auth/auth_test.go @@ -383,7 +383,7 @@ func TestAuthenticateServerOnly(t *testing.T) { aclObj, err := auth.AuthenticateServerOnly(ctx, args) must.EqError(t, err, - "invalid certificate, server.global.nomad not in client.global.nomad") + "invalid certificate: client.global.nomad not in expected server.global.nomad") must.Eq(t, "client.global.nomad:192.168.1.1", args.GetIdentity().String()) must.Nil(t, aclObj) }, @@ -507,7 +507,7 @@ func TestAuthenticateClientOnly(t *testing.T) { aclObj, err := auth.AuthenticateClientOnly(ctx, args) must.EqError(t, err, - "invalid certificate, server.global.nomad not in cli.global.nomad") + "invalid certificate: cli.global.nomad not in expected client.global.nomad, server.global.nomad") must.Eq(t, "cli.global.nomad:192.168.1.1", args.GetIdentity().String()) must.Nil(t, aclObj) }, diff --git a/nomad/util.go b/nomad/util.go index 6effadb98abc..bc99bd757d14 100644 --- a/nomad/util.go +++ b/nomad/util.go @@ -282,65 +282,3 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) { return alloc, nil } - -// tlsCertificateLevel represents a role level for mTLS certificates. -type tlsCertificateLevel int8 - -const ( - tlsCertificateLevelServer tlsCertificateLevel = iota - tlsCertificateLevelClient -) - -// validateTLSCertificateLevel checks if the provided RPC connection was -// initiated with a certificate that matches the given TLS role level. -// -// - tlsCertificateLevelServer requires a server certificate. -// - tlsCertificateLevelServer requires a client or server certificate. -func validateTLSCertificateLevel(srv *Server, ctx *RPCContext, lvl tlsCertificateLevel) error { - switch lvl { - case tlsCertificateLevelClient: - err := validateLocalClientTLSCertificate(srv, ctx) - if err != nil { - return validateLocalServerTLSCertificate(srv, ctx) - } - return nil - case tlsCertificateLevelServer: - return validateLocalServerTLSCertificate(srv, ctx) - } - - return fmt.Errorf("invalid TLS certificate level %v", lvl) -} - -// validateLocalClientTLSCertificate checks if the provided RPC connection was -// initiated by a client in the same region as the target server. -func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error { - expected := fmt.Sprintf("client.%s.nomad", srv.Region()) - - err := validateTLSCertificate(srv, ctx, expected) - if err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", srv.Region(), err) - } - return nil -} - -// validateLocalServerTLSCertificate checks if the provided RPC connection was -// initiated by a server in the same region as the target server. -func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error { - expected := fmt.Sprintf("server.%s.nomad", srv.Region()) - - err := validateTLSCertificate(srv, ctx, expected) - if err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", srv.Region(), err) - } - return nil -} - -// validateTLSCertificate checks if the RPC connection mTLS certificates are -// valid for the given name. -func validateTLSCertificate(srv *Server, ctx *RPCContext, name string) error { - if srv.config.TLSConfig == nil || !srv.config.TLSConfig.VerifyServerHostname { - return nil - } - - return ctx.ValidateCertificateForName(name) -}