From 672988d2463ddf8abbade7cb9f0656d848682ae3 Mon Sep 17 00:00:00 2001 From: Dimitrij Drus Date: Mon, 16 Sep 2024 12:39:44 +0200 Subject: [PATCH] feat: New endpoint auth type to create http message signatures for outbound requests according to RFC 9421 (#1507) --- docs/content/docs/configuration/types.adoc | 39 ++ go.mod | 2 + go.sum | 4 + .../authstrategy/http_message_signatures.go | 248 +++++++++ .../http_message_signatures_test.go | 472 ++++++++++++++++++ .../authstrategy/mapstructure_decoder.go | 50 +- .../authstrategy/mapstructure_decoder_test.go | 411 ++++++++++++++- .../mock_creation_context_test.go | 180 +++++++ .../authenticators/anonymous_authenticator.go | 14 +- .../anonymous_authenticator_test.go | 4 +- .../authenticator_type_registry.go | 2 + .../basic_auth_authenticator.go | 14 +- .../basic_auth_authenticator_test.go | 6 +- .../authenticators/config_decoder.go | 4 +- .../authenticators/generic_authenticator.go | 10 +- .../generic_authenticator_test.go | 4 +- .../authenticators/jwt_authenticator.go | 15 +- .../authenticators/jwt_authenticator_test.go | 4 +- .../mock_creation_context_test.go | 49 ++ .../oauth2_introspection_authenticator.go | 15 +- ...oauth2_introspection_authenticator_test.go | 4 +- .../mechanisms/authorizers/cel_authorizer.go | 10 +- .../authorizers/cel_authorizer_test.go | 6 +- .../mechanisms/authorizers/config_decoder.go | 4 +- .../authorizers/remote_authorizer.go | 10 +- .../authorizers/remote_authorizer_test.go | 4 +- .../contextualizers/config_decoder.go | 4 +- .../contextualizers/generic_contextualizer.go | 14 +- .../generic_contextualizer_test.go | 4 +- .../provider/httpendpoint/config_decoder.go | 2 +- 30 files changed, 1515 insertions(+), 94 deletions(-) create mode 100644 internal/rules/endpoint/authstrategy/http_message_signatures.go create mode 100644 internal/rules/endpoint/authstrategy/http_message_signatures_test.go create mode 100644 internal/rules/endpoint/authstrategy/mock_creation_context_test.go diff --git a/docs/content/docs/configuration/types.adoc b/docs/content/docs/configuration/types.adoc index ebf92fa0b..fac8ac47c 100644 --- a/docs/content/docs/configuration/types.adoc +++ b/docs/content/docs/configuration/types.adoc @@ -341,6 +341,45 @@ config: ---- ==== +=== HTTP Message Signatures + +This strategy implements HTTP message signatures according to https://datatracker.ietf.org/doc/html/rfc9421[RFC 9421] to sign outbound requests. + +`type` must be set to `http_message_signatures`. `config` supports the following properties: + +* *`ttl`*: _link:{{< relref "#_duration" >}}[Duration]_ (optional) ++ +The TTL of the resulting signature. Defaults to 1m. Responsible for setting `created` and `expires` parameters in the resulting signature. + +* *`label`*: _string_ (optional) ++ +The label to use. Defaults to `sig`. + +* *`components`*: _string array_ (mandatory) ++ +The components to be covered by the signature. While the RFC allows for signatures that do not cover any components, this is considered a security risk. When using the `"content-digest"` component, Heimdall will compute hash values of the request body using `sha-256` and `sha-512` algorithms. It will then add a `Content-Digest` header with these hash values to the request, and this header will be included in the signature calculation. + +* *`signer`*: _link:{{< relref "/docs/configuration/types.adoc#_signer" >}}[Signer]_ (mandatory) ++ +The configuration of the key material used for signature creation purposes, as well as the name used for the `tag` parameter in the resulting signature. + +.Strategy configuration +==== + +[source, yaml] +---- +type: http_message_signatures +config: + ttl: 2m + label: foo + components: ["@method", "content-digest", "@authority", "x-my-fancy-header"] + signer: + name: bar + key-store: + path: /path/to/key.pem +---- +==== + === OAuth2 Client Credentials Grant Flow Strategy This strategy implements the https://datatracker.ietf.org/doc/html/rfc6749#section-4.4[OAuth2 Client Credentials Grant Flow] to obtain an access token expected by the endpoint. Heimdall caches the received access token. diff --git a/go.mod b/go.mod index 78c6f6916..f5049fa02 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.0 require ( github.com/Masterminds/sprig/v3 v3.2.3 github.com/alicebob/miniredis/v2 v2.33.0 + github.com/dadrus/httpsig v0.0.0-20240814203911-f6539fdef42a github.com/dlclark/regexp2 v1.11.4 github.com/drone/envsubst/v2 v2.0.0-20210730161058-179042472c46 github.com/elnormous/contenttype v1.0.4 @@ -130,6 +131,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dunglas/httpsfv v1.0.2 // indirect github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect github.com/fatih/structs v1.1.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect diff --git a/go.sum b/go.sum index 82c1a03a9..3a0fa337d 100644 --- a/go.sum +++ b/go.sum @@ -96,6 +96,8 @@ github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnTh github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/dadrus/httpsig v0.0.0-20240814203911-f6539fdef42a h1:eXhsbwb2ROng8D7DcMgMinL5FfY4Ao6N2K9DTwOXFfs= +github.com/dadrus/httpsig v0.0.0-20240814203911-f6539fdef42a/go.mod h1:P31eM5Rh3dqq9FLr1QASaZsk8/8qIiKKUYFKjBC/yYc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -104,6 +106,8 @@ github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yA github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/drone/envsubst/v2 v2.0.0-20210730161058-179042472c46 h1:7QPwrLT79GlD5sizHf27aoY2RTvw62mO6x7mxkScNk0= github.com/drone/envsubst/v2 v2.0.0-20210730161058-179042472c46/go.mod h1:esf2rsHFNlZlxsqsZDojNBcnNs5REqIvRrWRHqX0vEU= +github.com/dunglas/httpsfv v1.0.2 h1:iERDp/YAfnojSDJ7PW3dj1AReJz4MrwbECSSE59JWL0= +github.com/dunglas/httpsfv v1.0.2/go.mod h1:zID2mqw9mFsnt7YC3vYQ9/cjq30q41W+1AnDwH8TiMg= github.com/elnormous/contenttype v1.0.4 h1:FjmVNkvQOGqSX70yvocph7keC8DtmJaLzTTq6ZOQCI8= github.com/elnormous/contenttype v1.0.4/go.mod h1:5KTOW8m1kdX1dLMiUJeN9szzR2xkngiv2K+RVZwWBbI= github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= diff --git a/internal/rules/endpoint/authstrategy/http_message_signatures.go b/internal/rules/endpoint/authstrategy/http_message_signatures.go new file mode 100644 index 000000000..e2d366538 --- /dev/null +++ b/internal/rules/endpoint/authstrategy/http_message_signatures.go @@ -0,0 +1,248 @@ +// Copyright 2024 Dimitrij Drus +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package authstrategy + +import ( + "context" + "crypto/sha256" + "crypto/x509" + "encoding/binary" + "fmt" + "net/http" + "sync" + "time" + + "github.com/dadrus/httpsig" + "github.com/go-jose/go-jose/v4" + "github.com/rs/zerolog" + + "github.com/dadrus/heimdall/internal/heimdall" + "github.com/dadrus/heimdall/internal/keystore" + "github.com/dadrus/heimdall/internal/x" + "github.com/dadrus/heimdall/internal/x/errorchain" + "github.com/dadrus/heimdall/internal/x/pkix" + "github.com/dadrus/heimdall/internal/x/stringx" +) + +type KeyStore struct { + Path string `mapstructure:"path" validate:"required"` + Password string `mapstructure:"password"` +} + +type SignerConfig struct { + Name string `mapstructure:"name"` + KeyStore KeyStore `mapstructure:"key_store" validate:"required"` + KeyID string `mapstructure:"key_id"` +} + +type HTTPMessageSignatures struct { + Signer SignerConfig `mapstructure:"signer" validate:"required"` + Components []string `mapstructure:"components" validate:"gt=0,dive,required"` + TTL *time.Duration `mapstructure:"ttl"` + Label string `mapstructure:"label"` + + mut sync.RWMutex + // used to allow downloading the keys for signature verification purposes + // since the http message signatures rfc does not define a format for key transport + // JWK is used here. + pubKeys []jose.JSONWebKey + // used to monitor the expiration of configured certificates + certChain []*x509.Certificate + signer httpsig.Signer +} + +func (s *HTTPMessageSignatures) OnChanged(logger zerolog.Logger) { + err := s.init() + if err != nil { + logger.Warn().Err(err). + Str("_file", s.Signer.KeyStore.Path). + Msg("Signer key store reload failed") + } else { + logger.Info(). + Str("_file", s.Signer.KeyStore.Path). + Msg("Signer key store reloaded") + } +} + +func (s *HTTPMessageSignatures) init() error { + ks, err := keystore.NewKeyStoreFromPEMFile(s.Signer.KeyStore.Path, s.Signer.KeyStore.Password) + if err != nil { + return errorchain.NewWithMessage(heimdall.ErrConfiguration, + "failed loading keystore for http_message_signatures strategy").CausedBy(err) + } + + var kse *keystore.Entry + + if len(s.Signer.KeyID) == 0 { + kse, err = ks.Entries()[0], nil + } else { + kse, err = ks.GetKey(s.Signer.KeyID) + } + + if err != nil { + return errorchain.NewWithMessage(heimdall.ErrConfiguration, + "failed retrieving key from key store for http_message_signatures strategy").CausedBy(err) + } + + if len(kse.CertChain) != 0 { + opts := []pkix.ValidationOption{ + pkix.WithKeyUsage(x509.KeyUsageDigitalSignature), + pkix.WithRootCACertificates([]*x509.Certificate{kse.CertChain[len(kse.CertChain)-1]}), + pkix.WithCurrentTime(time.Now()), + } + + if len(kse.CertChain) > 2 { //nolint: mnd + opts = append(opts, pkix.WithIntermediateCACertificates(kse.CertChain[1:len(kse.CertChain)-1])) + } + + if err = pkix.ValidateCertificate(kse.CertChain[0], opts...); err != nil { + return errorchain.NewWithMessage(heimdall.ErrConfiguration, + "certificate for http_message_signatures strategy cannot be used for signing purposes"). + CausedBy(err) + } + } + + keys := make([]jose.JSONWebKey, len(ks.Entries())) + for idx, entry := range ks.Entries() { + keys[idx] = entry.JWK() + } + + signer, err := httpsig.NewSigner( + toHTTPSigKey(kse), + httpsig.WithComponents(s.Components...), + httpsig.WithTag(x.IfThenElse(len(s.Signer.Name) != 0, s.Signer.Name, "heimdall")), + httpsig.WithLabel(s.Label), + httpsig.WithTTL(x.IfThenElseExec(s.TTL != nil, + func() time.Duration { return *s.TTL }, + func() time.Duration { return 1 * time.Minute }, + )), + ) + if err != nil { + return errorchain.NewWithMessage(heimdall.ErrConfiguration, + "failed to configure http_message_signatures strategy").CausedBy(err) + } + + s.mut.Lock() + defer s.mut.Unlock() + + s.signer = signer + s.pubKeys = keys + s.certChain = kse.CertChain + + return nil +} + +func (s *HTTPMessageSignatures) Apply(ctx context.Context, req *http.Request) error { + logger := zerolog.Ctx(ctx) + logger.Debug().Msg("Applying http_message_signatures strategy to authenticate request") + + s.mut.RLock() + defer s.mut.RUnlock() + + header, err := s.signer.Sign(httpsig.MessageFromRequest(req)) + if err != nil { + return err + } + + // set the updated headers + req.Header = header + + return nil +} + +func (s *HTTPMessageSignatures) Keys() []jose.JSONWebKey { + s.mut.RLock() + defer s.mut.RUnlock() + + return s.pubKeys +} + +func (s *HTTPMessageSignatures) Hash() []byte { + const int64BytesCount = 8 + + hash := sha256.New() + hash.Write(stringx.ToBytes(s.Label)) + + for _, component := range s.Components { + hash.Write(stringx.ToBytes(component)) + } + + if s.TTL != nil { + ttlBytes := make([]byte, int64BytesCount) + binary.LittleEndian.PutUint64(ttlBytes, uint64(*s.TTL)) + + hash.Write(ttlBytes) + } + + hash.Write(stringx.ToBytes(s.Signer.Name)) + hash.Write(stringx.ToBytes(s.Signer.KeyID)) + + return hash.Sum(nil) +} + +func (s *HTTPMessageSignatures) Name() string { return "http message signer" } +func (s *HTTPMessageSignatures) Certificates() []*x509.Certificate { + s.mut.RLock() + defer s.mut.RUnlock() + + return s.certChain +} + +func toHTTPSigKey(entry *keystore.Entry) httpsig.Key { + var httpSigAlg httpsig.SignatureAlgorithm + + switch entry.Alg { + case keystore.AlgRSA: + httpSigAlg = getRSAAlgorithm(entry.KeySize) + case keystore.AlgECDSA: + httpSigAlg = getECDSAAlgorithm(entry.KeySize) + default: + panic("unsupported key algorithm: " + entry.Alg) + } + + return httpsig.Key{ + Algorithm: httpSigAlg, + KeyID: entry.KeyID, + Key: entry.PrivateKey, + } +} + +func getECDSAAlgorithm(keySize int) httpsig.SignatureAlgorithm { + switch keySize { + case 256: //nolint: mnd + return httpsig.EcdsaP256Sha256 + case 384: //nolint: mnd + return httpsig.EcdsaP384Sha384 + case 512: //nolint: mnd + return httpsig.EcdsaP521Sha512 + default: + panic(fmt.Sprintf("unsupported ECDSA key size: %d", keySize)) + } +} + +func getRSAAlgorithm(keySize int) httpsig.SignatureAlgorithm { + switch keySize { + case 2048: //nolint: mnd + return httpsig.RsaPssSha256 + case 3072: //nolint: mnd + return httpsig.RsaPssSha384 + case 4096: //nolint: mnd + return httpsig.RsaPssSha512 + default: + panic(fmt.Sprintf("unsupported RSA key size: %d", keySize)) + } +} diff --git a/internal/rules/endpoint/authstrategy/http_message_signatures_test.go b/internal/rules/endpoint/authstrategy/http_message_signatures_test.go new file mode 100644 index 000000000..d63135edf --- /dev/null +++ b/internal/rules/endpoint/authstrategy/http_message_signatures_test.go @@ -0,0 +1,472 @@ +// Copyright 2024 Dimitrij Drus +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package authstrategy + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/dadrus/httpsig" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dadrus/heimdall/internal/heimdall" + "github.com/dadrus/heimdall/internal/keystore" + "github.com/dadrus/heimdall/internal/x/pkix/pemx" + "github.com/dadrus/heimdall/internal/x/testsupport" +) + +func TestToHTTPSigKey(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + kse *keystore.Entry + expAlg httpsig.SignatureAlgorithm + }{ + { + expAlg: httpsig.RsaPssSha256, + kse: &keystore.Entry{KeyID: "foo", Alg: keystore.AlgRSA, KeySize: 2048, PrivateKey: &rsa.PrivateKey{}}, + }, + { + expAlg: httpsig.RsaPssSha384, + kse: &keystore.Entry{KeyID: "foo", Alg: keystore.AlgRSA, KeySize: 3072, PrivateKey: &rsa.PrivateKey{}}, + }, + { + expAlg: httpsig.RsaPssSha512, + kse: &keystore.Entry{KeyID: "foo", Alg: keystore.AlgRSA, KeySize: 4096, PrivateKey: &rsa.PrivateKey{}}, + }, + { + expAlg: httpsig.EcdsaP256Sha256, + kse: &keystore.Entry{KeyID: "foo", Alg: keystore.AlgECDSA, KeySize: 256, PrivateKey: &ecdsa.PrivateKey{}}, + }, + { + expAlg: httpsig.EcdsaP384Sha384, + kse: &keystore.Entry{KeyID: "foo", Alg: keystore.AlgECDSA, KeySize: 384, PrivateKey: &ecdsa.PrivateKey{}}, + }, + { + expAlg: httpsig.EcdsaP521Sha512, + kse: &keystore.Entry{KeyID: "foo", Alg: keystore.AlgECDSA, KeySize: 512, PrivateKey: &ecdsa.PrivateKey{}}, + }, + } { + t.Run(string(tc.expAlg), func(t *testing.T) { + key := toHTTPSigKey(tc.kse) + + assert.Equal(t, tc.expAlg, key.Algorithm) + assert.Equal(t, tc.kse.KeyID, key.KeyID) + assert.Equal(t, tc.kse.PrivateKey, key.Key) + }) + } +} + +func TestHTTPMessageSignaturesInit(t *testing.T) { + t.Parallel() + + rootCA, err := testsupport.NewRootCA("Test Root CA 1", time.Hour*24) + require.NoError(t, err) + + // INT CA + intCAPrivKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + + intCACert, err := rootCA.IssueCertificate( + testsupport.WithSubject(pkix.Name{ + CommonName: "Test Int CA 1", + Organization: []string{"Test"}, + Country: []string{"EU"}, + }), + testsupport.WithIsCA(), + testsupport.WithValidity(time.Now(), time.Hour*24), + testsupport.WithSubjectPubKey(&intCAPrivKey.PublicKey, x509.ECDSAWithSHA384)) + require.NoError(t, err) + + intCA := testsupport.NewCA(intCAPrivKey, intCACert) + + // EE CERTS + ee1PrivKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + ee1cert, err := intCA.IssueCertificate( + testsupport.WithSubject(pkix.Name{ + CommonName: "Test EE 1", + Organization: []string{"Test"}, + Country: []string{"EU"}, + }), + testsupport.WithValidity(time.Now(), time.Hour*24), + testsupport.WithSubjectPubKey(&ee1PrivKey.PublicKey, x509.ECDSAWithSHA384), + testsupport.WithKeyUsage(x509.KeyUsageDigitalSignature)) + require.NoError(t, err) + + ee2PrivKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + ee2cert, err := intCA.IssueCertificate( + testsupport.WithSubject(pkix.Name{ + CommonName: "Test EE 2", + Organization: []string{"Test"}, + Country: []string{"EU"}, + }), + testsupport.WithValidity(time.Now(), time.Hour*24), + testsupport.WithSubjectPubKey(&ee2PrivKey.PublicKey, x509.ECDSAWithSHA384)) + require.NoError(t, err) + + pemBytes, err := pemx.BuildPEM( + pemx.WithECDSAPrivateKey(ee1PrivKey, pemx.WithHeader("X-Key-ID", "key1")), + pemx.WithX509Certificate(ee1cert), + pemx.WithECDSAPrivateKey(ee2PrivKey, pemx.WithHeader("X-Key-ID", "key2")), + pemx.WithX509Certificate(ee2cert), + pemx.WithX509Certificate(intCACert), + pemx.WithX509Certificate(rootCA.Certificate), + ) + require.NoError(t, err) + + testDir := t.TempDir() + trustStorePath := filepath.Join(testDir, "keystore.pem") + + err = os.WriteFile(trustStorePath, pemBytes, 0o600) + require.NoError(t, err) + + for _, tc := range []struct { + uc string + conf *HTTPMessageSignatures + assert func(t *testing.T, err error, conf *HTTPMessageSignatures) + }{ + { + uc: "failed loading keystore", + conf: &HTTPMessageSignatures{}, + assert: func(t *testing.T, err error, _ *HTTPMessageSignatures) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrConfiguration) + require.ErrorContains(t, err, "failed loading keystore") + }, + }, + { + uc: "no key for given key id", + conf: &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: trustStorePath}, KeyID: "foo"}, + }, + assert: func(t *testing.T, err error, _ *HTTPMessageSignatures) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrConfiguration) + require.ErrorContains(t, err, "failed retrieving key from key store") + }, + }, + { + uc: "certificate cannot be used for signing", + conf: &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: trustStorePath}, KeyID: "key2"}, + }, + assert: func(t *testing.T, err error, _ *HTTPMessageSignatures) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrConfiguration) + require.ErrorContains(t, err, "cannot be used for signing purposes") + }, + }, + { + uc: "bad signer configuration", + conf: &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: trustStorePath}}, + Components: []string{"@foo"}, + }, + assert: func(t *testing.T, err error, _ *HTTPMessageSignatures) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrConfiguration) + require.ErrorContains(t, err, "failed to configure") + }, + }, + { + uc: "successful configuration with default ttl", + conf: &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: trustStorePath}, KeyID: "key1"}, + Components: []string{"@method"}, + }, + assert: func(t *testing.T, err error, conf *HTTPMessageSignatures) { + t.Helper() + + require.NoError(t, err) + + assert.NotNil(t, conf.signer) + assert.NotEmpty(t, conf.Certificates()) + assert.NotEmpty(t, conf.Keys()) + assert.Equal(t, "http message signer", conf.Name()) + }, + }, + { + uc: "successful configuration with custom ttl", + conf: &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: trustStorePath}, KeyID: "key1"}, + Components: []string{"@method"}, + TTL: func() *time.Duration { + ttl := 1 * time.Hour + + return &ttl + }(), + }, + assert: func(t *testing.T, err error, conf *HTTPMessageSignatures) { + t.Helper() + + require.NoError(t, err) + + assert.NotNil(t, conf.signer) + assert.NotEmpty(t, conf.Certificates()) + assert.NotEmpty(t, conf.Keys()) + assert.Equal(t, "http message signer", conf.Name()) + }, + }, + } { + t.Run(tc.uc, func(t *testing.T) { + err := tc.conf.init() + + tc.assert(t, err, tc.conf) + }) + } +} + +func TestHTTPMessageSignaturesHash(t *testing.T) { + t.Parallel() + + ttl := 1 * time.Hour + conf1 := &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: "/path/to/keystore.pem"}, KeyID: "key1"}, + Components: []string{"@method"}, + TTL: &ttl, + } + conf2 := &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: "/path/to/keystore.pem"}, KeyID: "key1", Name: "foo"}, + Components: []string{"@status"}, + TTL: &ttl, + Label: "test", + } + + hash1 := conf1.Hash() + hash2 := conf2.Hash() + + assert.NotEmpty(t, hash1) + assert.NotEmpty(t, hash2) + assert.NotEqual(t, hash1, hash2) + assert.Equal(t, hash1, conf1.Hash()) + assert.Equal(t, hash2, conf2.Hash()) +} + +func TestHTTPMessageSignaturesApply(t *testing.T) { + t.Parallel() + + privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + + cb := testsupport.NewCertificateBuilder( + testsupport.WithValidity(time.Now(), 15*time.Second), + testsupport.WithSerialNumber(big.NewInt(1)), + testsupport.WithSubject(pkix.Name{ + CommonName: "Test", + Organization: []string{"Test"}, + Country: []string{"EU"}, + }), + testsupport.WithSubjectPubKey(&privKey.PublicKey, x509.ECDSAWithSHA384), + testsupport.WithSelfSigned(), + testsupport.WithSignaturePrivKey(privKey), + testsupport.WithKeyUsage(x509.KeyUsageDigitalSignature), + ) + + cert, err := cb.Build() + require.NoError(t, err) + + pemBytes, err := pemx.BuildPEM( + pemx.WithECDSAPrivateKey(privKey, pemx.WithHeader("X-Key-ID", "test")), + pemx.WithX509Certificate(cert), + ) + require.NoError(t, err) + + testDir := t.TempDir() + trustStorePath := filepath.Join(testDir, "keystore.pem") + + err = os.WriteFile(trustStorePath, pemBytes, 0o600) + require.NoError(t, err) + + for _, tc := range []struct { + uc string + conf *HTTPMessageSignatures + assert func(t *testing.T, err error, req *http.Request) + }{ + { + uc: "fails", + conf: &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: trustStorePath}}, + Components: []string{"x-some-header"}, + }, + assert: func(t *testing.T, err error, req *http.Request) { + t.Helper() + + require.Error(t, err) + require.ErrorContains(t, err, "x-some-header") + assert.Empty(t, req.Header.Get("Signature")) + assert.Empty(t, req.Header.Get("Signature-Input")) + }, + }, + { + uc: "successful", + conf: &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: trustStorePath}}, + Components: []string{"@method", "content-digest"}, + }, + assert: func(t *testing.T, err error, req *http.Request) { + t.Helper() + + require.NoError(t, err) + assert.NotEmpty(t, req.Header.Get("Signature")) + sigInput := req.Header.Get("Signature-Input") + assert.Contains(t, sigInput, `("@method" "content-digest")`) + assert.Contains(t, sigInput, `created=`) + assert.Contains(t, sigInput, `expires=`) + assert.Contains(t, sigInput, `keyid="test"`) + assert.Contains(t, sigInput, `alg="ecdsa-p384-sha384"`) + assert.Contains(t, sigInput, `nonce=`) + assert.Contains(t, sigInput, `tag="heimdall"`) + contentDigest := req.Header.Get("Content-Digest") + assert.Contains(t, contentDigest, "sha-256=:X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=:") + assert.Contains(t, contentDigest, "sha-512=:WZDPaVn/7XgHaAy8pmojAkGWoRx2UFChF41A2svX+TaPm+AbwAgBWnrIiYllu7BNNyealdVLvRwEmTHWXvJwew==:") + }, + }, + } { + t.Run(tc.uc, func(t *testing.T) { + err := tc.conf.init() + require.NoError(t, err) + + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodGet, + "http//example.com/test", + strings.NewReader(`{"hello": "world"}`), + ) + require.NoError(t, err) + + err = tc.conf.Apply(context.Background(), req) + + tc.assert(t, err, req) + }) + } +} + +func TestHTTPMessageSignaturesOnChanged(t *testing.T) { + t.Parallel() + + // GIVEN + testDir := t.TempDir() + + privKey1, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + + privKey2, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + + cert1, err := testsupport.NewCertificateBuilder(testsupport.WithValidity(time.Now(), 10*time.Hour), + testsupport.WithSerialNumber(big.NewInt(1)), + testsupport.WithSubject(pkix.Name{ + CommonName: "test cert 1", + Organization: []string{"Test"}, + Country: []string{"EU"}, + }), + testsupport.WithSubjectPubKey(&privKey1.PublicKey, x509.ECDSAWithSHA384), + testsupport.WithSelfSigned(), + testsupport.WithKeyUsage(x509.KeyUsageDigitalSignature), + testsupport.WithSignaturePrivKey(privKey1)). + Build() + require.NoError(t, err) + + cert2, err := testsupport.NewCertificateBuilder(testsupport.WithValidity(time.Now(), 10*time.Hour), + testsupport.WithSerialNumber(big.NewInt(1)), + testsupport.WithSubject(pkix.Name{ + CommonName: "test cert 1", + Organization: []string{"Test"}, + Country: []string{"EU"}, + }), + testsupport.WithSubjectPubKey(&privKey2.PublicKey, x509.ECDSAWithSHA384), + testsupport.WithSelfSigned(), + testsupport.WithKeyUsage(x509.KeyUsageDigitalSignature), + testsupport.WithSignaturePrivKey(privKey2)). + Build() + require.NoError(t, err) + + pemBytes1, err := pemx.BuildPEM( + pemx.WithECDSAPrivateKey(privKey1, pemx.WithHeader("X-Key-ID", "key1")), + pemx.WithX509Certificate(cert1), + ) + require.NoError(t, err) + + pemBytes2, err := pemx.BuildPEM( + pemx.WithECDSAPrivateKey(privKey2, pemx.WithHeader("X-Key-ID", "key1")), + pemx.WithX509Certificate(cert2), + ) + require.NoError(t, err) + + pemFile, err := os.Create(filepath.Join(testDir, "keystore.pem")) + require.NoError(t, err) + + _, err = pemFile.Write(pemBytes1) + require.NoError(t, err) + + conf := &HTTPMessageSignatures{ + Signer: SignerConfig{KeyStore: KeyStore{Path: pemFile.Name()}, KeyID: "key1"}, + Components: []string{"@method"}, + } + err = conf.init() + require.NoError(t, err) + + require.Equal(t, cert1, conf.certChain[0]) + require.Equal(t, &privKey1.PublicKey, conf.pubKeys[0].Key) + + // WHEN + _, err = pemFile.Seek(0, 0) + require.NoError(t, err) + + _, err = pemFile.Write(pemBytes2) + require.NoError(t, err) + + conf.OnChanged(log.Logger) + + // THEN + require.Equal(t, cert2, conf.certChain[0]) + require.Equal(t, &privKey2.PublicKey, conf.pubKeys[0].Key) + + // WHEN + err = os.Truncate(pemFile.Name(), 0) + require.NoError(t, err) + + conf.OnChanged(log.Logger) + + // THEN + require.Equal(t, cert2, conf.certChain[0]) + require.Equal(t, &privKey2.PublicKey, conf.pubKeys[0].Key) +} diff --git a/internal/rules/endpoint/authstrategy/mapstructure_decoder.go b/internal/rules/endpoint/authstrategy/mapstructure_decoder.go index bed8d21b3..6766f290c 100644 --- a/internal/rules/endpoint/authstrategy/mapstructure_decoder.go +++ b/internal/rules/endpoint/authstrategy/mapstructure_decoder.go @@ -22,12 +22,23 @@ import ( "github.com/go-viper/mapstructure/v2" "github.com/dadrus/heimdall/internal/heimdall" + "github.com/dadrus/heimdall/internal/keyholder" + "github.com/dadrus/heimdall/internal/otel/metrics/certificate" "github.com/dadrus/heimdall/internal/rules/endpoint" "github.com/dadrus/heimdall/internal/validation" + "github.com/dadrus/heimdall/internal/watcher" "github.com/dadrus/heimdall/internal/x/errorchain" ) -func DecodeAuthenticationStrategyHookFunc() mapstructure.DecodeHookFunc { +//go:generate mockery --name CreationContext --structname CreationContextMock --inpackage --testonly + +type CreationContext interface { + Watcher() watcher.Watcher + KeyHolderRegistry() keyholder.Registry + CertificateObserver() certificate.Observer +} + +func DecodeAuthenticationStrategyHookFunc(ctx CreationContext) mapstructure.DecodeHookFunc { return func(from reflect.Type, to reflect.Type, data any) (any, error) { var as endpoint.AuthenticationStrategy @@ -61,6 +72,8 @@ func DecodeAuthenticationStrategyHookFunc() mapstructure.DecodeHookFunc { return decodeStrategy("api_key", &APIKey{}, typed["config"]) case "oauth2_client_credentials": return decodeStrategy("oauth2_client_credentials", &OAuth2ClientCredentials{}, typed["config"]) + case "http_message_signatures": + return decodeHTTPMessageSignaturesStrategy(ctx, typed["config"]) default: return nil, errorchain.NewWithMessagef(heimdall.ErrConfiguration, "unsupported authentication type: '%s'", typed["type"]) @@ -68,6 +81,27 @@ func DecodeAuthenticationStrategyHookFunc() mapstructure.DecodeHookFunc { } } +func decodeHTTPMessageSignaturesStrategy(ctx CreationContext, config any) (any, error) { + httpSig := &HTTPMessageSignatures{} + + if _, err := decodeStrategy("http_message_signatures", httpSig, config); err != nil { + return nil, err + } + + if err := httpSig.init(); err != nil { + return nil, err + } + + if err := ctx.Watcher().Add(httpSig.Signer.KeyStore.Path, httpSig); err != nil { + return nil, errorchain.NewWithMessage(heimdall.ErrInternal, + "failed registering http_message_signatures for updates").CausedBy(err) + } + + ctx.CertificateObserver().Add(httpSig) + + return httpSig, nil +} + func decodeStrategy[S endpoint.AuthenticationStrategy]( name string, strategy S, @@ -78,7 +112,19 @@ func decodeStrategy[S endpoint.AuthenticationStrategy]( "'%s' strategy requires 'config' property to be set", name) } - if err := mapstructure.Decode(config, strategy); err != nil { + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + ), + Result: strategy, + ErrorUnused: true, + }) + if err != nil { + return nil, errorchain.NewWithMessagef(heimdall.ErrConfiguration, + "failed to unmarshal '%s' strategy config", name).CausedBy(err) + } + + if err := dec.Decode(config); err != nil { return nil, errorchain.NewWithMessagef(heimdall.ErrConfiguration, "failed to unmarshal '%s' strategy config", name).CausedBy(err) } diff --git a/internal/rules/endpoint/authstrategy/mapstructure_decoder_test.go b/internal/rules/endpoint/authstrategy/mapstructure_decoder_test.go index 7a6d7bc6e..fcfebf92d 100644 --- a/internal/rules/endpoint/authstrategy/mapstructure_decoder_test.go +++ b/internal/rules/endpoint/authstrategy/mapstructure_decoder_test.go @@ -17,13 +17,29 @@ package authstrategy import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "math/big" + "os" + "path/filepath" "testing" + "time" "github.com/go-viper/mapstructure/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/dadrus/heimdall/internal/heimdall" + mocks3 "github.com/dadrus/heimdall/internal/otel/metrics/certificate/mocks" "github.com/dadrus/heimdall/internal/rules/endpoint" + "github.com/dadrus/heimdall/internal/watcher/mocks" + "github.com/dadrus/heimdall/internal/x" + "github.com/dadrus/heimdall/internal/x/pkix/pemx" "github.com/dadrus/heimdall/internal/x/testsupport" ) @@ -41,7 +57,7 @@ func TestDecodeAuthenticationStrategyHookFuncForBasicAuthStrategy(t *testing.T) assert func(t *testing.T, err error, as endpoint.AuthenticationStrategy) }{ { - uc: "basic auth with all required properties", + uc: "all required properties configured", config: []byte(` auth: type: basic_auth @@ -59,7 +75,25 @@ auth: }, }, { - uc: "basic auth without user property", + uc: "with unsupported properties", + config: []byte(` +auth: + type: basic_auth + config: + user: foo + password: bar + foo: bar +`), + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrConfiguration) + require.ErrorContains(t, err, "invalid keys: foo") + }, + }, + { + uc: "without user property", config: []byte(` auth: type: basic_auth @@ -74,7 +108,7 @@ auth: }, }, { - uc: "basic auth without password property", + uc: "without password property", config: []byte(` auth: type: basic_auth @@ -89,7 +123,7 @@ auth: }, }, { - uc: "basic auth without config property", + uc: "without config property", config: []byte(` auth: type: basic_auth @@ -102,13 +136,13 @@ auth: }, }, } { - t.Run("case="+tc.uc, func(t *testing.T) { + t.Run(tc.uc, func(t *testing.T) { // GIVEN var typ Type dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( - DecodeAuthenticationStrategyHookFunc(), + DecodeAuthenticationStrategyHookFunc(nil), ), Result: &typ, }) @@ -140,7 +174,7 @@ func TestDecodeAuthenticationStrategyHookFuncForAPIKeyStrategy(t *testing.T) { assert func(t *testing.T, err error, as endpoint.AuthenticationStrategy) }{ { - uc: "api key with all required properties, with in=header", + uc: "all required properties, with in=header", config: []byte(` auth: type: api_key @@ -161,7 +195,26 @@ auth: }, }, { - uc: "api key with all required properties, with in=cookie", + uc: "with unsupported properties", + config: []byte(` +auth: + type: api_key + config: + name: foo + value: bar + in: header + foo: bar +`), + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrConfiguration) + require.ErrorContains(t, err, "invalid keys: foo") + }, + }, + { + uc: "all required properties, with in=cookie", config: []byte(` auth: type: api_key @@ -182,7 +235,7 @@ auth: }, }, { - uc: "api key with all required properties, with in=query", + uc: "all required properties, with in=query", config: []byte(` auth: type: api_key @@ -203,7 +256,7 @@ auth: }, }, { - uc: "api key with all required properties, with in=foobar", + uc: "all required properties, with in=foobar", config: []byte(` auth: type: api_key @@ -220,7 +273,7 @@ auth: }, }, { - uc: "api key without in property", + uc: "without in property", config: []byte(` auth: type: api_key @@ -236,7 +289,7 @@ auth: }, }, { - uc: "api key without name property", + uc: "without name property", config: []byte(` auth: type: api_key @@ -252,7 +305,7 @@ auth: }, }, { - uc: "api key without value property", + uc: "without value property", config: []byte(` auth: type: api_key @@ -268,7 +321,7 @@ auth: }, }, { - uc: "api key without config property", + uc: "without config property", config: []byte(` auth: type: api_key @@ -281,13 +334,13 @@ auth: }, }, } { - t.Run("case="+tc.uc, func(t *testing.T) { + t.Run(tc.uc, func(t *testing.T) { // GIVEN var typ Type dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( - DecodeAuthenticationStrategyHookFunc(), + DecodeAuthenticationStrategyHookFunc(nil), ), Result: &typ, }) @@ -312,14 +365,13 @@ func TestDecodeAuthenticationStrategyHookFuncForClientCredentialsStrategy(t *tes AuthStrategy endpoint.AuthenticationStrategy `mapstructure:"auth"` } - // du to a bug in the linter for _, tc := range []struct { uc string config []byte assert func(t *testing.T, err error, as endpoint.AuthenticationStrategy) }{ { - uc: "client credentials with all required properties", + uc: "all required properties", config: []byte(` auth: type: oauth2_client_credentials @@ -340,7 +392,26 @@ auth: }, }, { - uc: "client credentials with all possible properties", + uc: "with unsupported properties", + config: []byte(` +auth: + type: oauth2_client_credentials + config: + client_id: foo + client_secret: bar + token_url: http://foobar.foo + foo: bar +`), + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrConfiguration) + require.ErrorContains(t, err, "invalid keys: foo") + }, + }, + { + uc: "all possible properties", config: []byte(` auth: type: oauth2_client_credentials @@ -365,7 +436,7 @@ auth: }, }, { - uc: "client credentials without client_id property", + uc: "without client_id property", config: []byte(` auth: type: oauth2_client_credentials @@ -381,7 +452,7 @@ auth: }, }, { - uc: "client credentials without client_secret property", + uc: "without client_secret property", config: []byte(` auth: type: oauth2_client_credentials @@ -397,7 +468,7 @@ auth: }, }, { - uc: "client credentials without token_url property", + uc: "without token_url property", config: []byte(` auth: type: oauth2_client_credentials @@ -413,7 +484,7 @@ auth: }, }, { - uc: "client credentials without config property", + uc: "without config property", config: []byte(` auth: type: oauth2_client_credentials @@ -426,13 +497,301 @@ auth: }, }, } { - t.Run("case="+tc.uc, func(t *testing.T) { + t.Run(tc.uc, func(t *testing.T) { // GIVEN var typ Type dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( - DecodeAuthenticationStrategyHookFunc(), + DecodeAuthenticationStrategyHookFunc(nil), + ), + Result: &typ, + }) + require.NoError(t, err) + + conf, err := testsupport.DecodeTestConfig(tc.config) + require.NoError(t, err) + + // WHEN + err = dec.Decode(conf) + + // THEN + tc.assert(t, err, typ.AuthStrategy) + }) + } +} + +func TestDecodeAuthenticationStrategyHookFuncForHTTPMessageSignatures(t *testing.T) { + t.Parallel() + + testDir := t.TempDir() + + privKey1, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + + cert1, err := testsupport.NewCertificateBuilder(testsupport.WithValidity(time.Now(), 10*time.Hour), + testsupport.WithSerialNumber(big.NewInt(1)), + testsupport.WithSubject(pkix.Name{ + CommonName: "test cert 1", + Organization: []string{"Test"}, + Country: []string{"EU"}, + }), + testsupport.WithSubjectPubKey(&privKey1.PublicKey, x509.ECDSAWithSHA384), + testsupport.WithSelfSigned(), + testsupport.WithKeyUsage(x509.KeyUsageDigitalSignature), + testsupport.WithSignaturePrivKey(privKey1)). + Build() + require.NoError(t, err) + + pemBytes1, err := pemx.BuildPEM( + pemx.WithECDSAPrivateKey(privKey1, pemx.WithHeader("X-Key-ID", "key1")), + pemx.WithX509Certificate(cert1), + ) + require.NoError(t, err) + + pemFile, err := os.Create(filepath.Join(testDir, "keystore.pem")) + require.NoError(t, err) + + _, err = pemFile.Write(pemBytes1) + require.NoError(t, err) + + type Type struct { + AuthStrategy endpoint.AuthenticationStrategy `mapstructure:"auth"` + } + + for _, tc := range []struct { + uc string + config []byte + configureContext func(t *testing.T, ccm *CreationContextMock) + assert func(t *testing.T, err error, as endpoint.AuthenticationStrategy) + }{ + { + uc: "without signer", + config: []byte(` +auth: + type: http_message_signatures + config: + components: ["@method"] +`), + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorContains(t, err, "'signer' is a required field") + }, + }, + { + uc: "without key store", + config: []byte(` +auth: + type: http_message_signatures + config: + signer: + name: foo + components: ["@method"] +`), + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorContains(t, err, "'signer'.'key_store' is a required field") + }, + }, + { + uc: "without key store path", + config: []byte(` +auth: + type: http_message_signatures + config: + signer: + key_store: + password: foo + components: ["@method"] +`), + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorContains(t, err, "'signer'.'key_store'.'path' is a required field") + }, + }, + { + uc: "without component identifiers", + config: []byte(` +auth: + type: http_message_signatures + config: + signer: + key_store: + path: /some/file.pem +`), + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorContains(t, err, "'components' must contain more than 0 items") + }, + }, + { + uc: "error while initializing strategy", + config: []byte(` +auth: + type: http_message_signatures + config: + components: ["@method"] + signer: + key_store: + path: /some/path.pem +`), + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrConfiguration) + require.ErrorContains(t, err, "/some/path.pem") + }, + }, + { + uc: "with unsupported properties", + config: []byte(` +auth: + type: http_message_signatures + config: + components: ["@method"] + foo: bar + signer: + key_store: + path: /some/path.pem +`), + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrConfiguration) + require.ErrorContains(t, err, "invalid keys: foo") + }, + }, + { + uc: "error while registering signer for updates watching", + config: []byte(` +auth: + type: http_message_signatures + config: + components: ["@method"] + signer: + key_store: + path: ` + pemFile.Name() + ` +`), + configureContext: func(t *testing.T, ccm *CreationContextMock) { + t.Helper() + + watcher := mocks.NewWatcherMock(t) + watcher.EXPECT().Add(pemFile.Name(), mock.Anything).Return(errors.New("test error")) + + ccm.EXPECT().Watcher().Return(watcher) + }, + assert: func(t *testing.T, err error, _ endpoint.AuthenticationStrategy) { + t.Helper() + + require.Error(t, err) + require.ErrorIs(t, err, heimdall.ErrInternal) + require.ErrorContains(t, err, "failed registering") + }, + }, + { + uc: "minimal possible configuration", + config: []byte(` +auth: + type: http_message_signatures + config: + components: ["@method"] + signer: + key_store: + path: ` + pemFile.Name() + ` +`), + configureContext: func(t *testing.T, ccm *CreationContextMock) { + t.Helper() + + watcher := mocks.NewWatcherMock(t) + watcher.EXPECT().Add(pemFile.Name(), mock.Anything).Return(nil) + + observer := mocks3.NewObserverMock(t) + observer.EXPECT().Add(mock.Anything) + + ccm.EXPECT().Watcher().Return(watcher) + ccm.EXPECT().CertificateObserver().Return(observer) + }, + assert: func(t *testing.T, err error, as endpoint.AuthenticationStrategy) { + t.Helper() + + require.NoError(t, err) + + httpSig, ok := as.(*HTTPMessageSignatures) + require.True(t, ok) + + assert.NotNil(t, httpSig.signer) + assert.NotEmpty(t, httpSig.Certificates()) + assert.NotEmpty(t, httpSig.Keys()) + assert.Equal(t, "http message signer", httpSig.Name()) + }, + }, + { + uc: "full possible configuration", + config: []byte(` +auth: + type: http_message_signatures + config: + ttl: 1m + label: bar + components: ["@method"] + signer: + name: foobar + key_id: key1 + key_store: + password: secret + path: ` + pemFile.Name() + ` +`), + configureContext: func(t *testing.T, ccm *CreationContextMock) { + t.Helper() + + watcher := mocks.NewWatcherMock(t) + watcher.EXPECT().Add(pemFile.Name(), mock.Anything).Return(nil) + + observer := mocks3.NewObserverMock(t) + observer.EXPECT().Add(mock.Anything) + + ccm.EXPECT().Watcher().Return(watcher) + ccm.EXPECT().CertificateObserver().Return(observer) + }, + assert: func(t *testing.T, err error, as endpoint.AuthenticationStrategy) { + t.Helper() + + require.NoError(t, err) + + httpSig, ok := as.(*HTTPMessageSignatures) + require.True(t, ok) + + assert.NotNil(t, httpSig.signer) + assert.NotEmpty(t, httpSig.Certificates()) + assert.NotEmpty(t, httpSig.Keys()) + assert.Equal(t, "http message signer", httpSig.Name()) + }, + }, + } { + t.Run(tc.uc, func(t *testing.T) { + // GIVEN + ccm := NewCreationContextMock(t) + configureContext := x.IfThenElse(tc.configureContext != nil, + tc.configureContext, + func(t *testing.T, _ *CreationContextMock) { t.Helper() }, + ) + configureContext(t, ccm) + + var typ Type + + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.ComposeDecodeHookFunc( + DecodeAuthenticationStrategyHookFunc(ccm), ), Result: &typ, }) @@ -462,7 +821,7 @@ func TestDecodeAuthenticationStrategyHookFuncForUnknownStrategy(t *testing.T) { dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( - DecodeAuthenticationStrategyHookFunc(), + DecodeAuthenticationStrategyHookFunc(nil), ), Result: &typ, }) diff --git a/internal/rules/endpoint/authstrategy/mock_creation_context_test.go b/internal/rules/endpoint/authstrategy/mock_creation_context_test.go new file mode 100644 index 000000000..c31b2fd8b --- /dev/null +++ b/internal/rules/endpoint/authstrategy/mock_creation_context_test.go @@ -0,0 +1,180 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package authstrategy + +import ( + keyholder "github.com/dadrus/heimdall/internal/keyholder" + certificate "github.com/dadrus/heimdall/internal/otel/metrics/certificate" + + mock "github.com/stretchr/testify/mock" + + watcher "github.com/dadrus/heimdall/internal/watcher" +) + +// CreationContextMock is an autogenerated mock type for the CreationContext type +type CreationContextMock struct { + mock.Mock +} + +type CreationContextMock_Expecter struct { + mock *mock.Mock +} + +func (_m *CreationContextMock) EXPECT() *CreationContextMock_Expecter { + return &CreationContextMock_Expecter{mock: &_m.Mock} +} + +// CertificateObserver provides a mock function with given fields: +func (_m *CreationContextMock) CertificateObserver() certificate.Observer { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for CertificateObserver") + } + + var r0 certificate.Observer + if rf, ok := ret.Get(0).(func() certificate.Observer); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(certificate.Observer) + } + } + + return r0 +} + +// CreationContextMock_CertificateObserver_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CertificateObserver' +type CreationContextMock_CertificateObserver_Call struct { + *mock.Call +} + +// CertificateObserver is a helper method to define mock.On call +func (_e *CreationContextMock_Expecter) CertificateObserver() *CreationContextMock_CertificateObserver_Call { + return &CreationContextMock_CertificateObserver_Call{Call: _e.mock.On("CertificateObserver")} +} + +func (_c *CreationContextMock_CertificateObserver_Call) Run(run func()) *CreationContextMock_CertificateObserver_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *CreationContextMock_CertificateObserver_Call) Return(_a0 certificate.Observer) *CreationContextMock_CertificateObserver_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *CreationContextMock_CertificateObserver_Call) RunAndReturn(run func() certificate.Observer) *CreationContextMock_CertificateObserver_Call { + _c.Call.Return(run) + return _c +} + +// KeyHolderRegistry provides a mock function with given fields: +func (_m *CreationContextMock) KeyHolderRegistry() keyholder.Registry { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for KeyHolderRegistry") + } + + var r0 keyholder.Registry + if rf, ok := ret.Get(0).(func() keyholder.Registry); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(keyholder.Registry) + } + } + + return r0 +} + +// CreationContextMock_KeyHolderRegistry_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'KeyHolderRegistry' +type CreationContextMock_KeyHolderRegistry_Call struct { + *mock.Call +} + +// KeyHolderRegistry is a helper method to define mock.On call +func (_e *CreationContextMock_Expecter) KeyHolderRegistry() *CreationContextMock_KeyHolderRegistry_Call { + return &CreationContextMock_KeyHolderRegistry_Call{Call: _e.mock.On("KeyHolderRegistry")} +} + +func (_c *CreationContextMock_KeyHolderRegistry_Call) Run(run func()) *CreationContextMock_KeyHolderRegistry_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *CreationContextMock_KeyHolderRegistry_Call) Return(_a0 keyholder.Registry) *CreationContextMock_KeyHolderRegistry_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *CreationContextMock_KeyHolderRegistry_Call) RunAndReturn(run func() keyholder.Registry) *CreationContextMock_KeyHolderRegistry_Call { + _c.Call.Return(run) + return _c +} + +// Watcher provides a mock function with given fields: +func (_m *CreationContextMock) Watcher() watcher.Watcher { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Watcher") + } + + var r0 watcher.Watcher + if rf, ok := ret.Get(0).(func() watcher.Watcher); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(watcher.Watcher) + } + } + + return r0 +} + +// CreationContextMock_Watcher_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Watcher' +type CreationContextMock_Watcher_Call struct { + *mock.Call +} + +// Watcher is a helper method to define mock.On call +func (_e *CreationContextMock_Expecter) Watcher() *CreationContextMock_Watcher_Call { + return &CreationContextMock_Watcher_Call{Call: _e.mock.On("Watcher")} +} + +func (_c *CreationContextMock_Watcher_Call) Run(run func()) *CreationContextMock_Watcher_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *CreationContextMock_Watcher_Call) Return(_a0 watcher.Watcher) *CreationContextMock_Watcher_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *CreationContextMock_Watcher_Call) RunAndReturn(run func() watcher.Watcher) *CreationContextMock_Watcher_Call { + _c.Call.Return(run) + return _c +} + +// NewCreationContextMock creates a new instance of CreationContextMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCreationContextMock(t interface { + mock.TestingT + Cleanup(func()) +}) *CreationContextMock { + mock := &CreationContextMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/rules/mechanisms/authenticators/anonymous_authenticator.go b/internal/rules/mechanisms/authenticators/anonymous_authenticator.go index aac77f2f9..f110d06f5 100644 --- a/internal/rules/mechanisms/authenticators/anonymous_authenticator.go +++ b/internal/rules/mechanisms/authenticators/anonymous_authenticator.go @@ -26,21 +26,25 @@ import ( // by intention. Used only during application bootstrap. func init() { // nolint: gochecknoinits registerTypeFactory( - func(_ CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { + func(ctx CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { if typ != AuthenticatorAnonymous { return false, nil, nil } - auth, err := newAnonymousAuthenticator(id, conf) + auth, err := newAnonymousAuthenticator(ctx, id, conf) return true, auth, err }) } -func newAnonymousAuthenticator(id string, rawConfig map[string]any) (*anonymousAuthenticator, error) { +func newAnonymousAuthenticator( + ctx CreationContext, + id string, + rawConfig map[string]any, +) (*anonymousAuthenticator, error) { var auth anonymousAuthenticator - if err := decodeConfig(AuthenticatorAnonymous, rawConfig, &auth); err != nil { + if err := decodeConfig(ctx, AuthenticatorAnonymous, rawConfig, &auth); err != nil { return nil, err } @@ -71,7 +75,7 @@ func (a *anonymousAuthenticator) WithConfig(config map[string]any) (Authenticato return a, nil } - return newAnonymousAuthenticator(a.id, config) + return newAnonymousAuthenticator(nil, a.id, config) } func (a *anonymousAuthenticator) IsFallbackOnErrorAllowed() bool { diff --git a/internal/rules/mechanisms/authenticators/anonymous_authenticator_test.go b/internal/rules/mechanisms/authenticators/anonymous_authenticator_test.go index 5ecdda5c4..77fc2f22b 100644 --- a/internal/rules/mechanisms/authenticators/anonymous_authenticator_test.go +++ b/internal/rules/mechanisms/authenticators/anonymous_authenticator_test.go @@ -82,7 +82,7 @@ func TestCreateAnonymousAuthenticator(t *testing.T) { require.NoError(t, err) // WHEN - auth, err := newAnonymousAuthenticator(tc.id, conf) + auth, err := newAnonymousAuthenticator(nil, tc.id, conf) // THEN tc.assert(t, err, auth) @@ -151,7 +151,7 @@ func TestCreateAnonymousAuthenticatorFromPrototype(t *testing.T) { conf, err := testsupport.DecodeTestConfig(tc.config) require.NoError(t, err) - prototype, err := newAnonymousAuthenticator(tc.id, pc) + prototype, err := newAnonymousAuthenticator(nil, tc.id, pc) require.NoError(t, err) // WHEN diff --git a/internal/rules/mechanisms/authenticators/authenticator_type_registry.go b/internal/rules/mechanisms/authenticators/authenticator_type_registry.go index 03989faea..0a14a5f66 100644 --- a/internal/rules/mechanisms/authenticators/authenticator_type_registry.go +++ b/internal/rules/mechanisms/authenticators/authenticator_type_registry.go @@ -21,6 +21,7 @@ import ( "sync" "github.com/dadrus/heimdall/internal/keyholder" + "github.com/dadrus/heimdall/internal/otel/metrics/certificate" "github.com/dadrus/heimdall/internal/watcher" "github.com/dadrus/heimdall/internal/x/errorchain" ) @@ -38,6 +39,7 @@ var ( type CreationContext interface { Watcher() watcher.Watcher KeyHolderRegistry() keyholder.Registry + CertificateObserver() certificate.Observer } type TypeFactory func(ctx CreationContext, id string, typ string, config map[string]any) (bool, Authenticator, error) diff --git a/internal/rules/mechanisms/authenticators/basic_auth_authenticator.go b/internal/rules/mechanisms/authenticators/basic_auth_authenticator.go index 5ad8d5673..d6e0afb64 100644 --- a/internal/rules/mechanisms/authenticators/basic_auth_authenticator.go +++ b/internal/rules/mechanisms/authenticators/basic_auth_authenticator.go @@ -41,12 +41,12 @@ const ( //nolint:gochecknoinits func init() { registerTypeFactory( - func(_ CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { + func(ctx CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { if typ != AuthenticatorBasicAuth { return false, nil, nil } - auth, err := newBasicAuthAuthenticator(id, conf) + auth, err := newBasicAuthAuthenticator(ctx, id, conf) return true, auth, err }) @@ -59,7 +59,11 @@ type basicAuthAuthenticator struct { allowFallbackOnError bool } -func newBasicAuthAuthenticator(id string, rawConfig map[string]any) (*basicAuthAuthenticator, error) { +func newBasicAuthAuthenticator( + ctx CreationContext, + id string, + rawConfig map[string]any, +) (*basicAuthAuthenticator, error) { type Config struct { UserID string `mapstructure:"user_id" validate:"required"` Password string `mapstructure:"password" validate:"required"` @@ -67,7 +71,7 @@ func newBasicAuthAuthenticator(id string, rawConfig map[string]any) (*basicAuthA } var conf Config - if err := decodeConfig(AuthenticatorBasicAuth, rawConfig, &conf); err != nil { + if err := decodeConfig(ctx, AuthenticatorBasicAuth, rawConfig, &conf); err != nil { return nil, err } @@ -150,7 +154,7 @@ func (a *basicAuthAuthenticator) WithConfig(rawConfig map[string]any) (Authentic } var conf Config - if err := decodeConfig(AuthenticatorBasicAuth, rawConfig, &conf); err != nil { + if err := decodeConfig(nil, AuthenticatorBasicAuth, rawConfig, &conf); err != nil { return nil, err } diff --git a/internal/rules/mechanisms/authenticators/basic_auth_authenticator_test.go b/internal/rules/mechanisms/authenticators/basic_auth_authenticator_test.go index 97afbbb31..d0abffd5a 100644 --- a/internal/rules/mechanisms/authenticators/basic_auth_authenticator_test.go +++ b/internal/rules/mechanisms/authenticators/basic_auth_authenticator_test.go @@ -140,7 +140,7 @@ foo: bar`), require.NoError(t, err) // WHEN - auth, err := newBasicAuthAuthenticator(tc.id, conf) + auth, err := newBasicAuthAuthenticator(nil, tc.id, conf) // THEN tc.assert(t, err, auth) @@ -310,7 +310,7 @@ password: baz`), conf, err := testsupport.DecodeTestConfig(tc.config) require.NoError(t, err) - prototype, err := newBasicAuthAuthenticator(tc.id, pc) + prototype, err := newBasicAuthAuthenticator(nil, tc.id, pc) require.NoError(t, err) // WHEN @@ -501,7 +501,7 @@ password: bar`)) } { t.Run("case="+tc.uc, func(t *testing.T) { // GIVEN - auth, err := newBasicAuthAuthenticator(tc.id, conf) + auth, err := newBasicAuthAuthenticator(nil, tc.id, conf) require.NoError(t, err) ctx := mocks.NewContextMock(t) diff --git a/internal/rules/mechanisms/authenticators/config_decoder.go b/internal/rules/mechanisms/authenticators/config_decoder.go index 96bc27d7f..39622b492 100644 --- a/internal/rules/mechanisms/authenticators/config_decoder.go +++ b/internal/rules/mechanisms/authenticators/config_decoder.go @@ -30,11 +30,11 @@ import ( "github.com/dadrus/heimdall/internal/x/errorchain" ) -func decodeConfig(authenticatorType string, input, output any) error { +func decodeConfig(ctx CreationContext, authenticatorType string, input, output any) error { dec, err := mapstructure.NewDecoder( &mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( - authstrategy.DecodeAuthenticationStrategyHookFunc(), + authstrategy.DecodeAuthenticationStrategyHookFunc(ctx), endpoint.DecodeEndpointHookFunc(), mapstructure.StringToTimeDurationHookFunc(), extractors.DecodeCompositeExtractStrategyHookFunc(), diff --git a/internal/rules/mechanisms/authenticators/generic_authenticator.go b/internal/rules/mechanisms/authenticators/generic_authenticator.go index d18cc5f86..8d27650ea 100644 --- a/internal/rules/mechanisms/authenticators/generic_authenticator.go +++ b/internal/rules/mechanisms/authenticators/generic_authenticator.go @@ -44,12 +44,12 @@ import ( //nolint:gochecknoinits func init() { registerTypeFactory( - func(_ CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { + func(ctx CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { if typ != AuthenticatorGeneric { return false, nil, nil } - auth, err := newGenericAuthenticator(id, conf) + auth, err := newGenericAuthenticator(ctx, id, conf) return true, auth, err }) @@ -68,7 +68,7 @@ type genericAuthenticator struct { allowFallbackOnError bool } -func newGenericAuthenticator(id string, rawConfig map[string]any) (*genericAuthenticator, error) { +func newGenericAuthenticator(ctx CreationContext, id string, rawConfig map[string]any) (*genericAuthenticator, error) { type Config struct { Endpoint endpoint.Endpoint `mapstructure:"identity_info_endpoint" validate:"required"` //nolint:lll SubjectInfo SubjectInfo `mapstructure:"subject" validate:"required"` //nolint:lll @@ -82,7 +82,7 @@ func newGenericAuthenticator(id string, rawConfig map[string]any) (*genericAuthe } var conf Config - if err := decodeConfig(AuthenticatorGeneric, rawConfig, &conf); err != nil { + if err := decodeConfig(ctx, AuthenticatorGeneric, rawConfig, &conf); err != nil { return nil, err } @@ -142,7 +142,7 @@ func (a *genericAuthenticator) WithConfig(config map[string]any) (Authenticator, } var conf Config - if err := decodeConfig(AuthenticatorGeneric, config, &conf); err != nil { + if err := decodeConfig(nil, AuthenticatorGeneric, config, &conf); err != nil { return nil, err } diff --git a/internal/rules/mechanisms/authenticators/generic_authenticator_test.go b/internal/rules/mechanisms/authenticators/generic_authenticator_test.go index c8d485170..2795453e8 100644 --- a/internal/rules/mechanisms/authenticators/generic_authenticator_test.go +++ b/internal/rules/mechanisms/authenticators/generic_authenticator_test.go @@ -308,7 +308,7 @@ session_lifespan: require.NoError(t, err) // WHEN - auth, err := newGenericAuthenticator(tc.id, conf) + auth, err := newGenericAuthenticator(nil, tc.id, conf) // THEN tc.assertError(t, err, auth) @@ -712,7 +712,7 @@ forward_cookies: conf, err := testsupport.DecodeTestConfig(tc.config) require.NoError(t, err) - prototype, err := newGenericAuthenticator(tc.id, pc) + prototype, err := newGenericAuthenticator(nil, tc.id, pc) require.NoError(t, err) // WHEN diff --git a/internal/rules/mechanisms/authenticators/jwt_authenticator.go b/internal/rules/mechanisms/authenticators/jwt_authenticator.go index c8eaedd6f..438099378 100644 --- a/internal/rules/mechanisms/authenticators/jwt_authenticator.go +++ b/internal/rules/mechanisms/authenticators/jwt_authenticator.go @@ -53,12 +53,12 @@ const defaultJWTAuthenticatorTTL = 10 * time.Minute //nolint:gochecknoinits func init() { registerTypeFactory( - func(_ CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { + func(ctx CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { if typ != AuthenticatorJwt { return false, nil, nil } - auth, err := newJwtAuthenticator(id, conf) + auth, err := newJwtAuthenticator(ctx, id, conf) return true, auth, err }) @@ -76,7 +76,12 @@ type jwtAuthenticator struct { validateJWKCert bool } -func newJwtAuthenticator(id string, rawConfig map[string]any) (*jwtAuthenticator, error) { // nolint: funlen +// nolint: funlen +func newJwtAuthenticator( + ctx CreationContext, + id string, + rawConfig map[string]any, +) (*jwtAuthenticator, error) { // nolint: funlen type Config struct { JWKSEndpoint *endpoint.Endpoint `mapstructure:"jwks_endpoint" validate:"required_without=MetadataEndpoint,excluded_with=MetadataEndpoint"` //nolint:lll,tagalign MetadataEndpoint *oauth2.MetadataEndpoint `mapstructure:"metadata_endpoint" validate:"required_without=JWKSEndpoint,excluded_with=JWKSEndpoint"` //nolint:lll,tagalign @@ -90,7 +95,7 @@ func newJwtAuthenticator(id string, rawConfig map[string]any) (*jwtAuthenticator } var conf Config - if err := decodeConfig(AuthenticatorJwt, rawConfig, &conf); err != nil { + if err := decodeConfig(ctx, AuthenticatorJwt, rawConfig, &conf); err != nil { return nil, err } @@ -214,7 +219,7 @@ func (a *jwtAuthenticator) WithConfig(config map[string]any) (Authenticator, err } var conf Config - if err := decodeConfig(AuthenticatorJwt, config, &conf); err != nil { + if err := decodeConfig(nil, AuthenticatorJwt, config, &conf); err != nil { return nil, err } diff --git a/internal/rules/mechanisms/authenticators/jwt_authenticator_test.go b/internal/rules/mechanisms/authenticators/jwt_authenticator_test.go index babfa0052..536ba2fe2 100644 --- a/internal/rules/mechanisms/authenticators/jwt_authenticator_test.go +++ b/internal/rules/mechanisms/authenticators/jwt_authenticator_test.go @@ -448,7 +448,7 @@ cache_ttl: 5s`), require.NoError(t, err) // WHEN - a, err := newJwtAuthenticator(tc.id, conf) + a, err := newJwtAuthenticator(nil, tc.id, conf) // THEN tc.assert(t, err, a) @@ -790,7 +790,7 @@ metadata_endpoint: conf, err := testsupport.DecodeTestConfig(tc.config) require.NoError(t, err) - prototype, err := newJwtAuthenticator(tc.id, pc) + prototype, err := newJwtAuthenticator(nil, tc.id, pc) require.NoError(t, err) // WHEN diff --git a/internal/rules/mechanisms/authenticators/mock_creation_context_test.go b/internal/rules/mechanisms/authenticators/mock_creation_context_test.go index dd4bd558b..d24493d32 100644 --- a/internal/rules/mechanisms/authenticators/mock_creation_context_test.go +++ b/internal/rules/mechanisms/authenticators/mock_creation_context_test.go @@ -4,6 +4,8 @@ package authenticators import ( keyholder "github.com/dadrus/heimdall/internal/keyholder" + certificate "github.com/dadrus/heimdall/internal/otel/metrics/certificate" + mock "github.com/stretchr/testify/mock" watcher "github.com/dadrus/heimdall/internal/watcher" @@ -22,6 +24,53 @@ func (_m *CreationContextMock) EXPECT() *CreationContextMock_Expecter { return &CreationContextMock_Expecter{mock: &_m.Mock} } +// CertificateObserver provides a mock function with given fields: +func (_m *CreationContextMock) CertificateObserver() certificate.Observer { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for CertificateObserver") + } + + var r0 certificate.Observer + if rf, ok := ret.Get(0).(func() certificate.Observer); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(certificate.Observer) + } + } + + return r0 +} + +// CreationContextMock_CertificateObserver_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CertificateObserver' +type CreationContextMock_CertificateObserver_Call struct { + *mock.Call +} + +// CertificateObserver is a helper method to define mock.On call +func (_e *CreationContextMock_Expecter) CertificateObserver() *CreationContextMock_CertificateObserver_Call { + return &CreationContextMock_CertificateObserver_Call{Call: _e.mock.On("CertificateObserver")} +} + +func (_c *CreationContextMock_CertificateObserver_Call) Run(run func()) *CreationContextMock_CertificateObserver_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *CreationContextMock_CertificateObserver_Call) Return(_a0 certificate.Observer) *CreationContextMock_CertificateObserver_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *CreationContextMock_CertificateObserver_Call) RunAndReturn(run func() certificate.Observer) *CreationContextMock_CertificateObserver_Call { + _c.Call.Return(run) + return _c +} + // KeyHolderRegistry provides a mock function with given fields: func (_m *CreationContextMock) KeyHolderRegistry() keyholder.Registry { ret := _m.Called() diff --git a/internal/rules/mechanisms/authenticators/oauth2_introspection_authenticator.go b/internal/rules/mechanisms/authenticators/oauth2_introspection_authenticator.go index b4dc43b12..d207b374f 100644 --- a/internal/rules/mechanisms/authenticators/oauth2_introspection_authenticator.go +++ b/internal/rules/mechanisms/authenticators/oauth2_introspection_authenticator.go @@ -49,12 +49,12 @@ import ( //nolint:gochecknoinits func init() { registerTypeFactory( - func(_ CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { + func(ctx CreationContext, id string, typ string, conf map[string]any) (bool, Authenticator, error) { if typ != AuthenticatorOAuth2Introspection { return false, nil, nil } - auth, err := newOAuth2IntrospectionAuthenticator(id, conf) + auth, err := newOAuth2IntrospectionAuthenticator(ctx, id, conf) return true, auth, err }) @@ -70,8 +70,11 @@ type oauth2IntrospectionAuthenticator struct { allowFallbackOnError bool } -func newOAuth2IntrospectionAuthenticator( // nolint: funlen - id string, rawConfig map[string]any, +// nolint: funlen +func newOAuth2IntrospectionAuthenticator( + ctx CreationContext, + id string, + rawConfig map[string]any, ) (*oauth2IntrospectionAuthenticator, error) { type Config struct { IntrospectionEndpoint *endpoint.Endpoint `mapstructure:"introspection_endpoint" validate:"required_without=MetadataEndpoint,excluded_with=MetadataEndpoint"` //nolint:lll,tagalign @@ -84,7 +87,7 @@ func newOAuth2IntrospectionAuthenticator( // nolint: funlen } var conf Config - if err := decodeConfig(AuthenticatorOAuth2Introspection, rawConfig, &conf); err != nil { + if err := decodeConfig(ctx, AuthenticatorOAuth2Introspection, rawConfig, &conf); err != nil { return nil, err } @@ -198,7 +201,7 @@ func (a *oauth2IntrospectionAuthenticator) WithConfig(rawConfig map[string]any) } var conf Config - if err := decodeConfig(AuthenticatorOAuth2Introspection, rawConfig, &conf); err != nil { + if err := decodeConfig(nil, AuthenticatorOAuth2Introspection, rawConfig, &conf); err != nil { return nil, err } diff --git a/internal/rules/mechanisms/authenticators/oauth2_introspection_authenticator_test.go b/internal/rules/mechanisms/authenticators/oauth2_introspection_authenticator_test.go index eb24652ad..d9d1fffe7 100644 --- a/internal/rules/mechanisms/authenticators/oauth2_introspection_authenticator_test.go +++ b/internal/rules/mechanisms/authenticators/oauth2_introspection_authenticator_test.go @@ -335,7 +335,7 @@ metadata_endpoint: require.NoError(t, err) // WHEN - a, err := newOAuth2IntrospectionAuthenticator(tc.id, conf) + a, err := newOAuth2IntrospectionAuthenticator(nil, tc.id, conf) // THEN tc.assert(t, err, a) @@ -588,7 +588,7 @@ subject: conf, err := testsupport.DecodeTestConfig(tc.config) require.NoError(t, err) - prototype, err := newOAuth2IntrospectionAuthenticator(tc.id, pc) + prototype, err := newOAuth2IntrospectionAuthenticator(nil, tc.id, pc) require.NoError(t, err) // WHEN diff --git a/internal/rules/mechanisms/authorizers/cel_authorizer.go b/internal/rules/mechanisms/authorizers/cel_authorizer.go index 43d90a76c..8325a017f 100644 --- a/internal/rules/mechanisms/authorizers/cel_authorizer.go +++ b/internal/rules/mechanisms/authorizers/cel_authorizer.go @@ -31,12 +31,12 @@ import ( //nolint:gochecknoinits func init() { registerTypeFactory( - func(_ CreationContext, id string, typ string, conf map[string]any) (bool, Authorizer, error) { + func(ctx CreationContext, id string, typ string, conf map[string]any) (bool, Authorizer, error) { if typ != AuthorizerCEL { return false, nil, nil } - auth, err := newCELAuthorizer(id, conf) + auth, err := newCELAuthorizer(ctx, id, conf) return true, auth, err }) @@ -47,13 +47,13 @@ type celAuthorizer struct { expressions compiledExpressions } -func newCELAuthorizer(id string, rawConfig map[string]any) (*celAuthorizer, error) { +func newCELAuthorizer(ctx CreationContext, id string, rawConfig map[string]any) (*celAuthorizer, error) { type Config struct { Expressions []Expression `mapstructure:"expressions" validate:"required,gt=0,dive"` } var conf Config - if err := decodeConfig(AuthorizerCEL, rawConfig, &conf); err != nil { + if err := decodeConfig(ctx, AuthorizerCEL, rawConfig, &conf); err != nil { return nil, err } @@ -83,7 +83,7 @@ func (a *celAuthorizer) WithConfig(rawConfig map[string]any) (Authorizer, error) return a, nil } - return newCELAuthorizer(a.id, rawConfig) + return newCELAuthorizer(nil, a.id, rawConfig) } func (a *celAuthorizer) ID() string { return a.id } diff --git a/internal/rules/mechanisms/authorizers/cel_authorizer_test.go b/internal/rules/mechanisms/authorizers/cel_authorizer_test.go index f52ce8744..d376b05cd 100644 --- a/internal/rules/mechanisms/authorizers/cel_authorizer_test.go +++ b/internal/rules/mechanisms/authorizers/cel_authorizer_test.go @@ -142,7 +142,7 @@ expressions: require.NoError(t, err) // WHEN - a, err := newCELAuthorizer(tc.id, conf) + a, err := newCELAuthorizer(nil, tc.id, conf) // THEN tc.assert(t, err, a) @@ -218,7 +218,7 @@ expressions: conf, err := testsupport.DecodeTestConfig(tc.config) require.NoError(t, err) - prototype, err := newCELAuthorizer(tc.id, pc) + prototype, err := newCELAuthorizer(nil, tc.id, pc) require.NoError(t, err) // WHEN @@ -344,7 +344,7 @@ expressions: tc.configureContextAndSubject(t, ctx, sub) - auth, err := newCELAuthorizer(tc.id, conf) + auth, err := newCELAuthorizer(nil, tc.id, conf) require.NoError(t, err) // WHEN diff --git a/internal/rules/mechanisms/authorizers/config_decoder.go b/internal/rules/mechanisms/authorizers/config_decoder.go index 8101f6626..809daa4d2 100644 --- a/internal/rules/mechanisms/authorizers/config_decoder.go +++ b/internal/rules/mechanisms/authorizers/config_decoder.go @@ -27,11 +27,11 @@ import ( "github.com/dadrus/heimdall/internal/x/errorchain" ) -func decodeConfig(authorizerType string, input, output any) error { +func decodeConfig(ctx CreationContext, authorizerType string, input, output any) error { dec, err := mapstructure.NewDecoder( &mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( - authstrategy.DecodeAuthenticationStrategyHookFunc(), + authstrategy.DecodeAuthenticationStrategyHookFunc(ctx), endpoint.DecodeEndpointHookFunc(), mapstructure.StringToTimeDurationHookFunc(), template.DecodeTemplateHookFunc(), diff --git a/internal/rules/mechanisms/authorizers/remote_authorizer.go b/internal/rules/mechanisms/authorizers/remote_authorizer.go index e9f7104d8..6d7314404 100644 --- a/internal/rules/mechanisms/authorizers/remote_authorizer.go +++ b/internal/rules/mechanisms/authorizers/remote_authorizer.go @@ -51,12 +51,12 @@ var errNoContent = errors.New("no payload received") //nolint:gochecknoinits func init() { registerTypeFactory( - func(_ CreationContext, id string, typ string, conf map[string]any) (bool, Authorizer, error) { + func(ctx CreationContext, id string, typ string, conf map[string]any) (bool, Authorizer, error) { if typ != AuthorizerRemote { return false, nil, nil } - auth, err := newRemoteAuthorizer(id, conf) + auth, err := newRemoteAuthorizer(ctx, id, conf) return true, auth, err }) @@ -93,7 +93,7 @@ func (ai *authorizationInformation) addResultsTo(key string, ctx heimdall.Contex } } -func newRemoteAuthorizer(id string, rawConfig map[string]any) (*remoteAuthorizer, error) { +func newRemoteAuthorizer(ctx CreationContext, id string, rawConfig map[string]any) (*remoteAuthorizer, error) { type Config struct { Endpoint endpoint.Endpoint `mapstructure:"endpoint" validate:"required"` //nolint:lll Expressions []Expression `mapstructure:"expressions" validate:"dive"` @@ -104,7 +104,7 @@ func newRemoteAuthorizer(id string, rawConfig map[string]any) (*remoteAuthorizer } var conf Config - if err := decodeConfig(AuthorizerRemote, rawConfig, &conf); err != nil { + if err := decodeConfig(ctx, AuthorizerRemote, rawConfig, &conf); err != nil { return nil, err } @@ -201,7 +201,7 @@ func (a *remoteAuthorizer) WithConfig(rawConfig map[string]any) (Authorizer, err } var conf Config - if err := decodeConfig(AuthorizerRemote, rawConfig, &conf); err != nil { + if err := decodeConfig(nil, AuthorizerRemote, rawConfig, &conf); err != nil { return nil, err } diff --git a/internal/rules/mechanisms/authorizers/remote_authorizer_test.go b/internal/rules/mechanisms/authorizers/remote_authorizer_test.go index f8b0008c0..aed075708 100644 --- a/internal/rules/mechanisms/authorizers/remote_authorizer_test.go +++ b/internal/rules/mechanisms/authorizers/remote_authorizer_test.go @@ -233,7 +233,7 @@ values: require.NoError(t, err) // WHEN - auth, err := newRemoteAuthorizer(tc.id, conf) + auth, err := newRemoteAuthorizer(nil, tc.id, conf) // THEN tc.assert(t, err, auth) @@ -475,7 +475,7 @@ cache_ttl: 15s conf, err := testsupport.DecodeTestConfig(tc.config) require.NoError(t, err) - prototype, err := newRemoteAuthorizer(tc.id, pc) + prototype, err := newRemoteAuthorizer(nil, tc.id, pc) require.NoError(t, err) // WHEN diff --git a/internal/rules/mechanisms/contextualizers/config_decoder.go b/internal/rules/mechanisms/contextualizers/config_decoder.go index 6079557c6..5141fed61 100644 --- a/internal/rules/mechanisms/contextualizers/config_decoder.go +++ b/internal/rules/mechanisms/contextualizers/config_decoder.go @@ -27,11 +27,11 @@ import ( "github.com/dadrus/heimdall/internal/x/errorchain" ) -func decodeConfig(contextualizerType string, input, output any) error { +func decodeConfig(ctx CreationContext, contextualizerType string, input, output any) error { dec, err := mapstructure.NewDecoder( &mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( - authstrategy.DecodeAuthenticationStrategyHookFunc(), + authstrategy.DecodeAuthenticationStrategyHookFunc(ctx), endpoint.DecodeEndpointHookFunc(), mapstructure.StringToTimeDurationHookFunc(), template.DecodeTemplateHookFunc(), diff --git a/internal/rules/mechanisms/contextualizers/generic_contextualizer.go b/internal/rules/mechanisms/contextualizers/generic_contextualizer.go index 5b4cc2dc8..16dca6434 100644 --- a/internal/rules/mechanisms/contextualizers/generic_contextualizer.go +++ b/internal/rules/mechanisms/contextualizers/generic_contextualizer.go @@ -53,12 +53,12 @@ var errNoContent = errors.New("no payload received") //nolint:gochecknoinits func init() { registerTypeFactory( - func(_ CreationContext, id string, typ string, conf map[string]any) (bool, Contextualizer, error) { + func(ctx CreationContext, id string, typ string, conf map[string]any) (bool, Contextualizer, error) { if typ != ContextualizerGeneric { return false, nil, nil } - eh, err := newGenericContextualizer(id, conf) + eh, err := newGenericContextualizer(ctx, id, conf) return true, eh, err }) @@ -79,7 +79,11 @@ type genericContextualizer struct { v values.Values } -func newGenericContextualizer(id string, rawConfig map[string]any) (*genericContextualizer, error) { +func newGenericContextualizer( + ctx CreationContext, + id string, + rawConfig map[string]any, +) (*genericContextualizer, error) { type Config struct { Endpoint endpoint.Endpoint `mapstructure:"endpoint" validate:"required"` ForwardHeaders []string `mapstructure:"forward_headers"` @@ -91,7 +95,7 @@ func newGenericContextualizer(id string, rawConfig map[string]any) (*genericCont } var conf Config - if err := decodeConfig(ContextualizerGeneric, rawConfig, &conf); err != nil { + if err := decodeConfig(ctx, ContextualizerGeneric, rawConfig, &conf); err != nil { return nil, err } @@ -185,7 +189,7 @@ func (h *genericContextualizer) WithConfig(rawConfig map[string]any) (Contextual } var conf Config - if err := decodeConfig(ContextualizerGeneric, rawConfig, &conf); err != nil { + if err := decodeConfig(nil, ContextualizerGeneric, rawConfig, &conf); err != nil { return nil, err } diff --git a/internal/rules/mechanisms/contextualizers/generic_contextualizer_test.go b/internal/rules/mechanisms/contextualizers/generic_contextualizer_test.go index 38c79119f..1724bc7b5 100644 --- a/internal/rules/mechanisms/contextualizers/generic_contextualizer_test.go +++ b/internal/rules/mechanisms/contextualizers/generic_contextualizer_test.go @@ -166,7 +166,7 @@ continue_pipeline_on_error: true require.NoError(t, err) // WHEN - contextualizer, err := newGenericContextualizer(tc.id, conf) + contextualizer, err := newGenericContextualizer(nil, tc.id, conf) // THEN tc.assert(t, err, contextualizer) @@ -473,7 +473,7 @@ continue_pipeline_on_error: false conf, err := testsupport.DecodeTestConfig(tc.config) require.NoError(t, err) - prototype, err := newGenericContextualizer(tc.id, pc) + prototype, err := newGenericContextualizer(nil, tc.id, pc) require.NoError(t, err) // WHEN diff --git a/internal/rules/provider/httpendpoint/config_decoder.go b/internal/rules/provider/httpendpoint/config_decoder.go index 86fd8ee5d..7aca5f5df 100644 --- a/internal/rules/provider/httpendpoint/config_decoder.go +++ b/internal/rules/provider/httpendpoint/config_decoder.go @@ -27,7 +27,7 @@ func decodeConfig(input any, output any) error { dec, err := mapstructure.NewDecoder( &mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( - authstrategy.DecodeAuthenticationStrategyHookFunc(), + authstrategy.DecodeAuthenticationStrategyHookFunc(nil), endpoint.DecodeEndpointHookFunc(), mapstructure.StringToTimeDurationHookFunc(), ),