diff --git a/docs/content/docs/configuration/reference/reference.adoc b/docs/content/docs/configuration/reference/reference.adoc index a9367510c..b7b26d40a 100644 --- a/docs/content/docs/configuration/reference/reference.adoc +++ b/docs/content/docs/configuration/reference/reference.adoc @@ -300,6 +300,19 @@ rules: config: cookies: foo-bar: '{{ .Subject.ID }}' + - id: get_token + type: oauth2_client_credentials + config: + header: + name: X-Token + token_url: https://my-oauth-provider.com/token + client_id: my_client + client_secret: VerySecret! + auth_method: basic_auth + cache_ttl: 5m + scopes: + - foo + - bar error_handlers: - id: default diff --git a/docs/content/docs/configuration/rules/pipeline_mechanisms/finalizers.adoc b/docs/content/docs/configuration/rules/pipeline_mechanisms/finalizers.adoc index 1995de849..f24d23bbb 100644 --- a/docs/content/docs/configuration/rules/pipeline_mechanisms/finalizers.adoc +++ b/docs/content/docs/configuration/rules/pipeline_mechanisms/finalizers.adoc @@ -9,7 +9,7 @@ menu: parent: "Pipeline Mechanisms" --- -Finalizers, as the name implies, finalize the successful execution of the pipeline and make the available information about the link:{{< relref "overview.adoc#_subject" >}}[`Subject`] and the link:{{< relref "overview.adoc#_request" >}}[`Request`] to the upstream service in a format expected, respectively required by it. This ranges from adding a simple header to a structured JWT in a specific header. +Finalizers, as the name implies, finalize the execution of the pipeline and enrich the request with data such as subject information or authentication tokens required by the upstream service. The available options range from adding a simple header over a structured JWT in a specific header, to driving specific protocols, e.g. to obtain a token required by the upstream service. == Finalizer Types @@ -83,9 +83,11 @@ config: === JWT -This finalizer enables transformation of the link:{{< relref "overview.adoc#_subject" >}}[`Subject`] and the link:{{< relref "overview.adoc#_request" >}}[`Request`] object into a token in a https://www.rfc-editor.org/rfc/rfc7519[JWT] format, which is then made available to your upstream service in either the HTTP `Authorization` header with `Bearer` scheme set, or in a custom header. In addition to setting the JWT specific claims, it allows setting custom claims as well. Your upstream service can then verify the signature of the JWT by making use of heimdall's JWKS endpoint to retrieve the required public keys/certificates from. +This finalizer enables transformation of the link:{{< relref "overview.adoc#_subject" >}}[`Subject`] object into a token in a https://www.rfc-editor.org/rfc/rfc7519[JWT] format, which is then made available to your upstream service in either the HTTP `Authorization` header with `Bearer` scheme set, or in a custom header. In addition to setting the JWT specific claims, it allows setting custom claims as well. Your upstream service can then verify the signature of the JWT by making use of heimdall's JWKS endpoint to retrieve the required public keys/certificates from. -NOTE: To enable the usage of this finalizer, you have to set the `type` property to `jwt`. The usage of this finalizer type requires a configured link:{{< relref "/docs/configuration/cryptographic_material.adoc" >}}[Signer] as well. At least it is a must in production environments. +To enable the usage of this finalizer, you have to set the `type` property to `jwt`. + +NOTE: The usage of this finalizer type requires a configured link:{{< relref "/docs/configuration/cryptographic_material.adoc" >}}[Signer] as well. At least it is a must in production environments. Configuration using the `config` property is optional. Following properties are available: @@ -122,3 +124,66 @@ config: } ---- ==== + +=== OAuth2 Client Credentials + +This finalizer drives the https://www.rfc-editor.org/rfc/rfc6749#section-4.4[OAuth2 Client Credentials Grant] flow to obtain a token, which should be used for communication with the upstream service. By default, as long as not otherwise configured (see the options below), the obtained token is made available to your upstream service in the HTTP `Authorization` header with `Bearer` scheme set. Unlike the other finalizers, it does not have access to any objects created by the rule execution pipeline. + +To enable the usage of this finalizer, you have to set the `type` property to `oauth2_client_credentials`. + +Configuration using the `config` property is mandatory. Following properties are available: + +* *`token_url`*: _string_ (mandatory, not overridable) ++ +The token endpoint of the authorization server. + +* *`client_id`*: _string_ (mandatory, not overridable) ++ +The client identifier for heimdall. + +* *`client_secret`*: _string_ (mandatory, not overridable) ++ +The client secret for heimdall. + +* *`auth_method`*: _string_ (optional, not overridable) ++ +The authentication method to be used according to https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1[RFC 6749, Client Password]. Can one of + +** `basic_auth` (default if `auth_method` is not set): With that authentication method, the `"application/x-www-form-urlencoded"` encoded values of `client_id` and `client_secret` are sent to the authorization server via the `Authorization` header using the `Basic` scheme. + +** `request_body`: With that authentication method the `client_id` and `client_secret` are sent in the request body together with the other parameters (e.g. `scopes`) defined by the flow. ++ +WARNING: Usage of `request_body` authentication method is not recommended and should be avoided. + +* *`scopes`*: _string array_ (optional, overridable) ++ +The scopes required for the access token. + +* *`cache_ttl`*: _link:{{< relref "/docs/configuration/reference/types.adoc#_duration" >}}[Duration]_ (optional, overridable) ++ +How long to cache the token received from the token endpoint. Defaults to the token expiration information from the token endpoint (the value of the `expires_in` field) if present. If the token expiration inforation is not present and `cache_ttl` is not configured, the received token is not cached. If the token expiration information is present in the response and `cache_ttl` is configured the shorter value is taken. If caching is enabled, the token is cached until 5 seconds before its expiration. To disable caching, set it to `0s`. The cache key calculation is based on the entire `oauth2_client_credentials` configuration without considering the `header` property. + +* *`header`*: _object_ (optional, overridable) ++ +Defines the `name` and `scheme` to be used for the header. Defaults to `Authorization` with scheme `Bearer`. If defined, the `name` property must be set. If `scheme` is not defined, no scheme will be prepended to the resulting JWT. + +.OAuth2 Client Credentials finalizer configuration +==== +[source, yaml] +---- +id: get_token +type: oauth2_client_credentials +config: + cache_ttl: 5m + header: + name: X-Token + scheme: MyScheme + token_url: https://my-oauth-provider.com/token + client_id: my_client + client_secret: VerySecret! + auth_method: basic_auth + scopes: + - foo + - bar +---- +==== diff --git a/docs/content/docs/configuration/rules/pipeline_mechanisms/overview.adoc b/docs/content/docs/configuration/rules/pipeline_mechanisms/overview.adoc index 31973c572..261406db2 100644 --- a/docs/content/docs/configuration/rules/pipeline_mechanisms/overview.adoc +++ b/docs/content/docs/configuration/rules/pipeline_mechanisms/overview.adoc @@ -14,7 +14,7 @@ All mechanisms supported by heimdall fall into following categories: * link:{{< relref "authenticators.adoc">}}[Authenticators], which inspect HTTP requests for presence of authentication objects, like e.g. the presence of a specific cookie. If such objects exist, authenticators verify the related authentication status and obtain information about the corresponding subject. A subject, could be a user who tries to use particular functionality of the upstream service, a machine (if you have machine-2-machine interaction), or something different. Authenticators ensure the subject is authenticated and the information available about it is valid. * link:{{< relref "authorizers.adoc">}}[Authorizers], which ensure that the subject obtained via an authenticator has the required permissions to submit the given HTTP request and thus to execute the corresponding logic in the upstream service. E.g. a specific endpoint of the upstream service might only be accessible to a "user" from the "admin" group, or to an HTTP request if a specific HTTP header is set. * link:{{< relref "contextualizers.adoc">}}[Contextualizers], which enrich the information about the subject obtained via an authenticator with further contextual information, required either by the upstream service itself or an authorizer. This can be handy if the actual authentication system doesn't have all information about the subject (which is usually the case in microservice architectures), or if dynamic information about the subject, like the current location based on the IP address, is required. -* link:{{< relref "finalizers.adoc">}}[Finalizers], which, as the name implies, finalize the successful execution of the pipeline and make the gathered information about the subject and the request available to the upstream service in a format expected, respectively required by it. This ranges from adding a simple header or cookie, to a structured JWT. +* link:{{< relref "finalizers.adoc">}}[Finalizers], which, as the name implies, finalize the execution of the pipeline and enrich the request with data such as subject information or authentication tokens required by the upstream service. The available options range from adding a simple header over a structured JWT, to driving specific protocols, e.g. to obtain a token required by the upstream service. * link:{{< relref "error_handlers.adoc">}}[Error Handlers], which are responsible for execution of logic if any of the mechanisms described above fail. These range from a simple error response to the client, which sent the request, to sophisticated ones, supporting complex logic and redirects. == General Configuration diff --git a/docs/content/docs/getting_started/concepts.adoc b/docs/content/docs/getting_started/concepts.adoc index 611817d24..7ad8ad2eb 100644 --- a/docs/content/docs/getting_started/concepts.adoc +++ b/docs/content/docs/getting_started/concepts.adoc @@ -60,7 +60,7 @@ Here, heimdall communicates with other systems as well, either to get further in Here, heimdall performs authorization checks, either locally, or by communicating with yet again further systems, like Open Policy Agent, Ory Keto and alike. ** *finalization* mechanisms, so-called link:{{< relref "/docs/configuration/rules/pipeline_mechanisms/finalizers.adoc" >}}[Finalizers], to be executed (if multiple are defined, they are executed in the order of their definition) - step 4 in the figure above. + -This step finalizes the execution of the pipeline and transform the information collected so far about the subject and the request into objects expected by the upstream service. That reaches from a simple custom header, carrying e.g. the id of the subject, to a JWT carried in the `Authorization` header. +This step finalizes the execution of the pipeline and enriches the request with data such as subject information or authentication tokens required by the upstream service. The available options range from adding a simple header over a structured JWT, to driving specific protocols, e.g. to obtain a token required by the upstream service. * an error pipeline, consisting of link:{{< relref "/docs/configuration/rules/pipeline_mechanisms/error_handlers.adoc" >}}[error handler] mechanisms (if multiple are defined, they are executed as fallbacks), which are executed if any of the regular pipeline mechanisms fail. These mechanisms range from a simple error response to the client (which sent the request), to sophisticated ones supporting complex logic and redirects. The diagram below sketches the related execution logic diff --git a/go.sum b/go.sum index 677bd7f04..2d44339eb 100644 --- a/go.sum +++ b/go.sum @@ -132,8 +132,6 @@ github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbS github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-co-op/gocron v1.35.1 h1:xi0tfAhxeAmGUKkjiA7bTIjh2VdBJpUYDJ+lPx/EPcM= -github.com/go-co-op/gocron v1.35.1/go.mod h1:NLi+bkm4rRSy1F8U7iacZOz0xPseMoIOnvabGoSe/no= github.com/go-co-op/gocron v1.35.2 h1:lG3rdA9TqBBC/PtT2ukQqgLm6jEepnAzz3+OQetvPTE= github.com/go-co-op/gocron v1.35.2/go.mod h1:NLi+bkm4rRSy1F8U7iacZOz0xPseMoIOnvabGoSe/no= github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1 h1:zga7zaRE8HCbWjcXMDlfvmQtH0/kMVLo7cQ48dy6kWg= @@ -644,8 +642,6 @@ google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTp google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ= -google.golang.org/grpc v1.58.2 h1:SXUpjxeVF3FKrTYQI4f4KvbGD5u2xccdYdurwowix5I= -google.golang.org/grpc v1.58.2/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= google.golang.org/grpc v1.58.3 h1:BjnpXut1btbtgN/6sp+brB2Kbm2LjNXnidYujAVbSoQ= google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/internal/config/test_data/test_config.yaml b/internal/config/test_data/test_config.yaml index 6760c546f..51deb2231 100644 --- a/internal/config/test_data/test_config.yaml +++ b/internal/config/test_data/test_config.yaml @@ -355,6 +355,20 @@ rules: config: cookies: foo-bar: '{{ .Subject.ID }}' + - id: client_cred_grant + type: oauth2_client_credentials + config: + token_url: https://my-auth-provider/token + client_id: foo + client_secret: bar + auth_method: basic_auth + cache_ttl: 5m + scopes: + - foo + - bar + header: + name: My-Header + scheme: Foo error_handlers: - id: default type: default diff --git a/internal/rules/endpoint/endpoint.go b/internal/rules/endpoint/endpoint.go index 033d5ed00..3549c2664 100644 --- a/internal/rules/endpoint/endpoint.go +++ b/internal/rules/endpoint/endpoint.go @@ -125,7 +125,14 @@ func (e Endpoint) CreateRequest(ctx context.Context, body io.Reader, rndr Render return req, nil } -func (e Endpoint) SendRequest(ctx context.Context, body io.Reader, renderer Renderer) ([]byte, error) { +type ResponseReader func(resp *http.Response) ([]byte, error) + +func (e Endpoint) SendRequest( + ctx context.Context, + body io.Reader, + renderer Renderer, + reader ...ResponseReader, +) ([]byte, error) { req, err := e.CreateRequest(ctx, body, renderer) if err != nil { return nil, err @@ -143,6 +150,10 @@ func (e Endpoint) SendRequest(ctx context.Context, body io.Reader, renderer Rend defer resp.Body.Close() + if len(reader) != 0 { + return reader[0](resp) + } + return e.readResponse(resp) } diff --git a/internal/rules/mechanisms/finalizers/constants.go b/internal/rules/mechanisms/finalizers/constants.go index 6fa27a607..9db8aac37 100644 --- a/internal/rules/mechanisms/finalizers/constants.go +++ b/internal/rules/mechanisms/finalizers/constants.go @@ -17,8 +17,9 @@ package finalizers const ( - FinalizerNoop = "noop" - FinalizerJwt = "jwt" - FinalizerHeader = "header" - FinalizerCookie = "cookie" + FinalizerNoop = "noop" + FinalizerJwt = "jwt" + FinalizerHeader = "header" + FinalizerCookie = "cookie" + FinalizerOAuth2ClientCredentials = "oauth2_client_credentials" // nolint: gosec ) diff --git a/internal/rules/mechanisms/finalizers/finalizer_type_registry_test.go b/internal/rules/mechanisms/finalizers/finalizer_type_registry_test.go index 26feb3f39..cc05e69b0 100644 --- a/internal/rules/mechanisms/finalizers/finalizer_type_registry_test.go +++ b/internal/rules/mechanisms/finalizers/finalizer_type_registry_test.go @@ -27,7 +27,7 @@ func TestCreateFinalizerPrototype(t *testing.T) { t.Parallel() // there are 4 finalizers implemented, which should have been registered - require.Len(t, typeFactories, 4) + require.Len(t, typeFactories, 5) for _, tc := range []struct { uc string diff --git a/internal/rules/mechanisms/finalizers/oauth2_client_credentials_finalizer.go b/internal/rules/mechanisms/finalizers/oauth2_client_credentials_finalizer.go new file mode 100644 index 000000000..87fa3fb2a --- /dev/null +++ b/internal/rules/mechanisms/finalizers/oauth2_client_credentials_finalizer.go @@ -0,0 +1,390 @@ +package finalizers + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/goccy/go-json" + "github.com/rs/zerolog" + + "github.com/dadrus/heimdall/internal/cache" + "github.com/dadrus/heimdall/internal/heimdall" + "github.com/dadrus/heimdall/internal/rules/endpoint" + "github.com/dadrus/heimdall/internal/rules/mechanisms/subject" + "github.com/dadrus/heimdall/internal/validation" + "github.com/dadrus/heimdall/internal/x" + "github.com/dadrus/heimdall/internal/x/errorchain" + "github.com/dadrus/heimdall/internal/x/stringx" +) + +// by intention. Used only during application bootstrap +// +//nolint:gochecknoinits +func init() { + registerTypeFactory( + func(id string, typ string, conf map[string]any) (bool, Finalizer, error) { + if typ != FinalizerOAuth2ClientCredentials { + return false, nil, nil + } + + finalizer, err := newOAuth2ClientCredentialsFinalizer(id, conf) + + return true, finalizer, err + }) +} + +type AuthMethod string + +const ( + authMethodBasicAuth AuthMethod = "basic_auth" + authMethodRequestBody AuthMethod = "request_body" +) + +type TokenSuccessfulResponse struct { + AccessToken string `json:"access_token,omitempty"` + TokenType string `json:"token_type,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` +} + +type TokenErrorResponse struct { //nolint:errname + ErrorType string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` +} + +func (e *TokenErrorResponse) Error() string { + builder := strings.Builder{} + builder.WriteString("error from oauth2 server: ") + builder.WriteString("error: ") + builder.WriteString(e.ErrorType) + + if len(e.ErrorDescription) != 0 { + builder.WriteString(", error_description: ") + builder.WriteString(e.ErrorDescription) + } + + if len(e.ErrorURI) != 0 { + builder.WriteString(", error_uri: ") + builder.WriteString(e.ErrorURI) + } + + return builder.String() +} + +type TokenEndpointResponse struct { + *TokenSuccessfulResponse + *TokenErrorResponse +} + +func (r TokenEndpointResponse) Error() error { + // weird go behavior + if r.TokenErrorResponse != nil { + return r.TokenErrorResponse + } + + return nil +} + +type oauth2ClientCredentialsFinalizer struct { + id string + tokenURL string + clientID string + clientSecret string + authMethod AuthMethod + scopes []string + ttl *time.Duration + headerName string + headerScheme string +} + +func newOAuth2ClientCredentialsFinalizer( + id string, + rawConfig map[string]any, +) (*oauth2ClientCredentialsFinalizer, error) { + type HeaderConfig struct { + Name string `mapstructure:"name" validate:"required"` + Scheme string `mapstructure:"scheme"` + } + + type Config struct { + TokenURL string `mapstructure:"token_url" validate:"required,http_url"` + ClientID string `mapstructure:"client_id" validate:"required"` + ClientSecret string `mapstructure:"client_secret" validate:"required"` + AuthMethod AuthMethod `mapstructure:"auth_method" validate:"omitempty,oneof=basic_auth request_body"` + Scopes []string `mapstructure:"scopes"` + TTL *time.Duration `mapstructure:"cache_ttl"` + Header *HeaderConfig `mapstructure:"header"` + } + + var conf Config + if err := decodeConfig(rawConfig, &conf); err != nil { + return nil, errorchain.NewWithMessage(heimdall.ErrConfiguration, + "failed to unmarshal oauth2_client_credentials finalizer config").CausedBy(err) + } + + if err := validation.ValidateStruct(conf); err != nil { + return nil, errorchain.NewWithMessage(heimdall.ErrConfiguration, + "failed validating oauth2_client_credentials finalizer config").CausedBy(err) + } + + return &oauth2ClientCredentialsFinalizer{ + id: id, + tokenURL: conf.TokenURL, + clientID: conf.ClientID, + clientSecret: conf.ClientSecret, + scopes: conf.Scopes, + ttl: conf.TTL, + authMethod: x.IfThenElse(len(conf.AuthMethod) == 0, authMethodBasicAuth, conf.AuthMethod), + headerName: x.IfThenElseExec(conf.Header != nil, + func() string { return conf.Header.Name }, + func() string { return "Authorization" }), + headerScheme: x.IfThenElseExec(conf.Header != nil, + func() string { return conf.Header.Scheme }, + func() string { return "Bearer" }), + }, nil +} + +func (f *oauth2ClientCredentialsFinalizer) ContinueOnError() bool { return false } +func (f *oauth2ClientCredentialsFinalizer) ID() string { return f.id } + +func (f *oauth2ClientCredentialsFinalizer) WithConfig(rawConfig map[string]any) (Finalizer, error) { + type HeaderConfig struct { + Name string `mapstructure:"name" validate:"required"` + Scheme string `mapstructure:"scheme"` + } + + type Config struct { + Scopes []string `mapstructure:"scopes"` + TTL *time.Duration `mapstructure:"cache_ttl"` + Header *HeaderConfig `mapstructure:"header"` + } + + var conf Config + if err := decodeConfig(rawConfig, &conf); err != nil { + return nil, errorchain.NewWithMessage(heimdall.ErrConfiguration, + "failed to unmarshal oauth2_client_credentials finalizer config").CausedBy(err) + } + + if err := validation.ValidateStruct(conf); err != nil { + return nil, errorchain.NewWithMessage(heimdall.ErrConfiguration, + "failed validating oauth2_client_credentials finalizer config").CausedBy(err) + } + + return &oauth2ClientCredentialsFinalizer{ + id: f.id, + tokenURL: f.tokenURL, + clientID: f.clientID, + clientSecret: f.clientSecret, + authMethod: f.authMethod, + scopes: x.IfThenElse(conf.Scopes != nil, conf.Scopes, f.scopes), + ttl: x.IfThenElse(conf.TTL != nil, conf.TTL, f.ttl), + headerName: x.IfThenElseExec(conf.Header != nil, + func() string { return conf.Header.Name }, + func() string { return f.headerName }), + headerScheme: x.IfThenElseExec(conf.Header != nil, + func() string { return conf.Header.Scheme }, + func() string { return f.headerScheme }), + }, nil +} + +func (f *oauth2ClientCredentialsFinalizer) Execute(ctx heimdall.Context, _ *subject.Subject) error { + logger := zerolog.Ctx(ctx.AppContext()) + logger.Debug().Msg("Finalizing using oauth2_client_credentials finalizer") + + cch := cache.Ctx(ctx.AppContext()) + + var ( + ok bool + cacheKey string + cacheEntry any + token string + ) + + if f.isCacheEnabled() { + cacheKey = f.calculateCacheKey() + cacheEntry = cch.Get(cacheKey) + } + + if cacheEntry != nil { + if token, ok = cacheEntry.(string); !ok { + logger.Warn().Msg("Wrong object type from cache") + cch.Delete(cacheKey) + } else { + logger.Debug().Msg("Reusing access token from cache") + } + } + + if len(token) == 0 { + logger.Debug().Msg("Retrieving new access token") + + tokenInfo, err := f.getAccessToken(ctx.AppContext()) + if err != nil { + return err + } + + token = tokenInfo.AccessToken + + if cacheTTL := f.getCacheTTL(tokenInfo); cacheTTL > 0 { + cch.Set(cacheKey, token, cacheTTL) + } + } + + ctx.AddHeaderForUpstream(f.headerName, fmt.Sprintf("%s %s", f.headerScheme, token)) + + return nil +} + +func (f *oauth2ClientCredentialsFinalizer) calculateCacheKey() string { + const int64BytesCount = 8 + + ttlBytes := make([]byte, int64BytesCount) + if f.ttl != nil { + binary.LittleEndian.PutUint64(ttlBytes, uint64(*f.ttl)) + } else { + binary.LittleEndian.PutUint64(ttlBytes, 0) + } + + digest := sha256.New() + digest.Write(stringx.ToBytes(FinalizerOAuth2ClientCredentials)) + digest.Write(stringx.ToBytes(f.clientID)) + digest.Write(stringx.ToBytes(f.clientSecret)) + digest.Write(stringx.ToBytes(f.tokenURL)) + digest.Write(stringx.ToBytes(strings.Join(f.scopes, ""))) + digest.Write(ttlBytes) + + return hex.EncodeToString(digest.Sum(nil)) +} + +func (f *oauth2ClientCredentialsFinalizer) getAccessToken(ctx context.Context) (*TokenSuccessfulResponse, error) { + ept := endpoint.Endpoint{ + URL: f.tokenURL, + Method: http.MethodPost, + AuthStrategy: f.authStrategy(), + Headers: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + "Accept-Type": "application/json", + }, + } + + data := url.Values{"grant_type": []string{"client_credentials"}} + if len(f.scopes) != 0 { + data.Add("scope", strings.Join(f.scopes, " ")) + } + + // This is not recommended, but there are non-compliant servers out there + // which do not support the Basic Auth authentication method required by + // the spec. See also https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1 + if ept.AuthStrategy == nil { + data.Add("client_id", f.clientID) + data.Add("client_secret", f.clientSecret) + } + + rawData, err := ept.SendRequest( + ctx, + strings.NewReader(data.Encode()), + nil, + func(resp *http.Response) ([]byte, error) { + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusBadRequest { + return nil, errorchain.NewWithMessagef(heimdall.ErrCommunication, + "unexpected response code: %v", resp.StatusCode) + } + + rawData, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errorchain.NewWithMessage(heimdall.ErrInternal, + "failed to read response").CausedBy(err) + } + + if resp.StatusCode == http.StatusBadRequest { + var ter TokenErrorResponse + if err := json.Unmarshal(rawData, &ter); err != nil { + return nil, errorchain.NewWithMessage(heimdall.ErrInternal, + "failed to unmarshal response").CausedBy(err) + } + + return nil, errorchain.New(heimdall.ErrCommunication).CausedBy(&ter) + } + + return rawData, nil + }, + ) + if err != nil { + return nil, err + } + + // some oauth2 provider are not compliant and return errors via HTTP 200 instead of 400 + // this is the reason for using a union struct here (see the error check below) + var resp TokenEndpointResponse + if err := json.Unmarshal(rawData, &resp); err != nil { + return nil, errorchain.NewWithMessage(heimdall.ErrInternal, + "failed to unmarshal response").CausedBy(err) + } + + if resp.Error() != nil { + return nil, errorchain.New(heimdall.ErrCommunication).CausedBy(resp.Error()) + } + + return resp.TokenSuccessfulResponse, nil +} + +func (f *oauth2ClientCredentialsFinalizer) authStrategy() endpoint.AuthenticationStrategy { + if f.authMethod == authMethodRequestBody { + return nil + } + + return &endpoint.BasicAuthStrategy{ + User: url.QueryEscape(f.clientID), + Password: url.QueryEscape(f.clientSecret), + } +} + +func (f *oauth2ClientCredentialsFinalizer) getCacheTTL(resp *TokenSuccessfulResponse) time.Duration { + // timeLeeway defines the default time deviation to ensure the token is still valid + // when used from cache + const timeLeeway = 5 + + if !f.isCacheEnabled() { + return 0 + } + + // we cache by default using the settings in the token endpoint response (if available) + // or if ttl has been configured. Latter overwrites the settings in the token endpoint response + // if it is shorter than the ttl in the token endpoint response + tokenEndpointResponseTTL := x.IfThenElseExec(resp.ExpiresIn != 0, + func() time.Duration { + expiresIn := resp.ExpiresIn - timeLeeway + + return x.IfThenElse(expiresIn > 0, time.Duration(expiresIn)*time.Second, 0) + }, + func() time.Duration { return 0 }) + + configuredTTL := x.IfThenElseExec(f.ttl != nil, + func() time.Duration { return *f.ttl }, + func() time.Duration { return 0 }) + + switch { + case configuredTTL == 0 && tokenEndpointResponseTTL == 0: + return 0 + case configuredTTL == 0 && tokenEndpointResponseTTL != 0: + return tokenEndpointResponseTTL + case configuredTTL != 0 && tokenEndpointResponseTTL == 0: + return configuredTTL + default: + return min(configuredTTL, tokenEndpointResponseTTL) + } +} + +func (f *oauth2ClientCredentialsFinalizer) isCacheEnabled() bool { + // cache is enabled if it is not configured (in that case the ttl value from the + // token response if used), or if it is configured and the value > 0 + return f.ttl == nil || (f.ttl != nil && *f.ttl > 0) +} diff --git a/internal/rules/mechanisms/finalizers/oauth2_client_credentials_finalizer_test.go b/internal/rules/mechanisms/finalizers/oauth2_client_credentials_finalizer_test.go new file mode 100644 index 000000000..0123a6d6a --- /dev/null +++ b/internal/rules/mechanisms/finalizers/oauth2_client_credentials_finalizer_test.go @@ -0,0 +1,1086 @@ +package finalizers + +import ( + "context" + "encoding/base64" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/goccy/go-json" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/dadrus/heimdall/internal/cache" + mocks2 "github.com/dadrus/heimdall/internal/cache/mocks" + "github.com/dadrus/heimdall/internal/heimdall" + "github.com/dadrus/heimdall/internal/heimdall/mocks" + "github.com/dadrus/heimdall/internal/x" + "github.com/dadrus/heimdall/internal/x/testsupport" +) + +func TestNewClientCredentialsFinalizer(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + uc string + id string + config []byte + assert func(t *testing.T, err error, finalizer *oauth2ClientCredentialsFinalizer) + }{ + { + uc: "without configuration", + assert: func(t *testing.T, err error, _ *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.Error(t, err) + assert.ErrorIs(t, err, heimdall.ErrConfiguration) + assert.Contains(t, err.Error(), "failed validating") + assert.Contains(t, err.Error(), "token_url") + assert.Contains(t, err.Error(), "client_id") + assert.Contains(t, err.Error(), "client_secret") + }, + }, + { + uc: "with empty configuration", + config: []byte(``), + assert: func(t *testing.T, err error, _ *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.Error(t, err) + assert.ErrorIs(t, err, heimdall.ErrConfiguration) + assert.Contains(t, err.Error(), "failed validating") + assert.Contains(t, err.Error(), "token_url") + assert.Contains(t, err.Error(), "client_id") + assert.Contains(t, err.Error(), "client_secret") + }, + }, + { + uc: "with unsupported attributes", + config: []byte(` +token_url: https://foo.bar +foo: bar +`), + assert: func(t *testing.T, err error, _ *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.Error(t, err) + assert.ErrorIs(t, err, heimdall.ErrConfiguration) + assert.Contains(t, err.Error(), "invalid keys") + }, + }, + { + uc: "with bad auth method attributes", + config: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +auth_method: bar +`), + assert: func(t *testing.T, err error, _ *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.Error(t, err) + assert.ErrorIs(t, err, heimdall.ErrConfiguration) + assert.Contains(t, err.Error(), "'auth_method' must be one of [basic_auth request_body]") + }, + }, + { + uc: "with minimal valid config", + id: "minimal", + config: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +`), + assert: func(t *testing.T, err error, finalizer *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.NoError(t, err) + require.NotNil(t, finalizer) + + assert.Equal(t, "minimal", finalizer.ID()) + assert.Equal(t, "https://foo.bar", finalizer.tokenURL) + assert.Equal(t, "foo", finalizer.clientID) + assert.Equal(t, "bar", finalizer.clientSecret) + assert.Equal(t, "Authorization", finalizer.headerName) + assert.Equal(t, "Bearer", finalizer.headerScheme) + assert.Equal(t, authMethodBasicAuth, finalizer.authMethod) + assert.Nil(t, finalizer.ttl) + assert.Empty(t, finalizer.scopes) + assert.False(t, finalizer.ContinueOnError()) + }, + }, + { + uc: "with full valid config", + id: "full", + config: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +auth_method: request_body +cache_ttl: 11s +scopes: + - foo + - baz +header: + name: "X-My-Header" + scheme: Foo +`), + assert: func(t *testing.T, err error, finalizer *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.NoError(t, err) + require.NotNil(t, finalizer) + + assert.Equal(t, "full", finalizer.ID()) + assert.Equal(t, "https://foo.bar", finalizer.tokenURL) + assert.Equal(t, "foo", finalizer.clientID) + assert.Equal(t, "bar", finalizer.clientSecret) + assert.Equal(t, "X-My-Header", finalizer.headerName) + assert.Equal(t, "Foo", finalizer.headerScheme) + assert.Equal(t, authMethodRequestBody, finalizer.authMethod) + assert.Equal(t, 11*time.Second, *finalizer.ttl) + assert.Len(t, finalizer.scopes, 2) + assert.Contains(t, finalizer.scopes, "foo") + assert.Contains(t, finalizer.scopes, "baz") + assert.False(t, finalizer.ContinueOnError()) + }, + }, + } { + t.Run(tc.uc, func(t *testing.T) { + conf, err := testsupport.DecodeTestConfig(tc.config) + require.NoError(t, err) + + // WHEN + finalizer, err := newOAuth2ClientCredentialsFinalizer(tc.id, conf) + + // THEN + tc.assert(t, err, finalizer) + }) + } +} + +func TestCreateClientCredentialsFinalizerFromPrototype(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + uc string + id string + prototypeConfig []byte + config []byte + assert func(t *testing.T, err error, prototype *oauth2ClientCredentialsFinalizer, configured *oauth2ClientCredentialsFinalizer) + }{ + { + uc: "no new configuration provided", + id: "1", + prototypeConfig: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +cache_ttl: 11s +scopes: + - foo + - baz +header: + name: "X-My-Header" + scheme: Foo +`), + assert: func(t *testing.T, err error, prototype *oauth2ClientCredentialsFinalizer, configured *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.NoError(t, err) + assert.Equal(t, prototype, configured) + }, + }, + { + uc: "empty configuration provided", + id: "2", + prototypeConfig: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +cache_ttl: 11s +scopes: + - foo + - baz +header: + name: "X-My-Header" + scheme: Foo +`), + config: []byte(``), + assert: func(t *testing.T, err error, prototype *oauth2ClientCredentialsFinalizer, configured *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.NoError(t, err) + assert.Equal(t, prototype, configured) + assert.Equal(t, "2", configured.ID()) + }, + }, + { + uc: "scopes reconfigured", + id: "3", + prototypeConfig: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +cache_ttl: 11s +`), + config: []byte(` +scopes: + - foo + - baz +`), + assert: func(t *testing.T, err error, prototype *oauth2ClientCredentialsFinalizer, configured *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.NoError(t, err) + + assert.NotEqual(t, prototype, configured) + assert.Equal(t, prototype.ID(), configured.ID()) + assert.Equal(t, "https://foo.bar", prototype.tokenURL) + assert.Equal(t, prototype.tokenURL, configured.tokenURL) + assert.Equal(t, "foo", prototype.clientID) + assert.Equal(t, prototype.clientID, configured.clientID) + assert.Equal(t, "bar", prototype.clientSecret) + assert.Equal(t, prototype.clientSecret, configured.clientSecret) + assert.Equal(t, 11*time.Second, *prototype.ttl) + assert.Equal(t, prototype.ttl, configured.ttl) + assert.Equal(t, "Authorization", prototype.headerName) + assert.Equal(t, prototype.headerName, configured.headerName) + assert.Equal(t, "Bearer", prototype.headerScheme) + assert.Equal(t, prototype.headerScheme, configured.headerScheme) + assert.Empty(t, prototype.scopes) + assert.Len(t, configured.scopes, 2) + assert.Contains(t, configured.scopes, "foo") + assert.Contains(t, configured.scopes, "baz") + assert.Equal(t, prototype.authMethod, configured.authMethod) + }, + }, + { + uc: "ttl reconfigured", + id: "3", + prototypeConfig: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +cache_ttl: 11s +`), + config: []byte(` +cache_ttl: 12s +`), + assert: func(t *testing.T, err error, prototype *oauth2ClientCredentialsFinalizer, configured *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.NoError(t, err) + + assert.NotEqual(t, prototype, configured) + assert.Equal(t, prototype.ID(), configured.ID()) + assert.Equal(t, "https://foo.bar", prototype.tokenURL) + assert.Equal(t, prototype.tokenURL, configured.tokenURL) + assert.Equal(t, "foo", prototype.clientID) + assert.Equal(t, prototype.clientID, configured.clientID) + assert.Equal(t, "bar", prototype.clientSecret) + assert.Equal(t, prototype.clientSecret, configured.clientSecret) + assert.Equal(t, 11*time.Second, *prototype.ttl) + assert.Equal(t, 12*time.Second, *configured.ttl) + assert.Equal(t, "Authorization", prototype.headerName) + assert.Equal(t, prototype.headerName, configured.headerName) + assert.Equal(t, "Bearer", prototype.headerScheme) + assert.Equal(t, prototype.headerScheme, configured.headerScheme) + assert.Empty(t, prototype.scopes) + assert.Equal(t, prototype.scopes, configured.scopes) + assert.Equal(t, prototype.authMethod, configured.authMethod) + }, + }, + { + uc: "unsupported attributes while reconfiguring", + prototypeConfig: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +cache_ttl: 11s +`), + config: []byte(` +foo: 10s +`), + assert: func(t *testing.T, err error, prototype *oauth2ClientCredentialsFinalizer, configured *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.Error(t, err) + assert.ErrorIs(t, err, heimdall.ErrConfiguration) + assert.Contains(t, err.Error(), "failed to unmarshal") + + require.NotNil(t, prototype) + require.Nil(t, configured) + }, + }, + { + uc: "header name reconfigured", + id: "3", + prototypeConfig: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +cache_ttl: 11s +`), + config: []byte(` +header: + name: X-Foo-Bar +`), + assert: func(t *testing.T, err error, prototype *oauth2ClientCredentialsFinalizer, configured *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.NoError(t, err) + + assert.NotEqual(t, prototype, configured) + assert.Equal(t, prototype.ID(), configured.ID()) + assert.Equal(t, "https://foo.bar", prototype.tokenURL) + assert.Equal(t, prototype.tokenURL, configured.tokenURL) + assert.Equal(t, "foo", prototype.clientID) + assert.Equal(t, prototype.clientID, configured.clientID) + assert.Equal(t, "bar", prototype.clientSecret) + assert.Equal(t, prototype.clientSecret, configured.clientSecret) + assert.Equal(t, 11*time.Second, *prototype.ttl) + assert.Equal(t, prototype.ttl, configured.ttl) + assert.Equal(t, "Authorization", prototype.headerName) + assert.Equal(t, "X-Foo-Bar", configured.headerName) + assert.Equal(t, "Bearer", prototype.headerScheme) + assert.Empty(t, configured.headerScheme) + assert.Empty(t, prototype.scopes) + assert.Equal(t, prototype.scopes, configured.scopes) + assert.Equal(t, prototype.authMethod, configured.authMethod) + }, + }, + { + uc: "header name and scheme reconfigured", + id: "3", + prototypeConfig: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +cache_ttl: 11s +`), + config: []byte(` +header: + name: X-Foo-Bar + scheme: Foo +`), + assert: func(t *testing.T, err error, prototype *oauth2ClientCredentialsFinalizer, configured *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.NoError(t, err) + + assert.NotEqual(t, prototype, configured) + assert.Equal(t, prototype.ID(), configured.ID()) + assert.Equal(t, "https://foo.bar", prototype.tokenURL) + assert.Equal(t, prototype.tokenURL, configured.tokenURL) + assert.Equal(t, "foo", prototype.clientID) + assert.Equal(t, prototype.clientID, configured.clientID) + assert.Equal(t, "bar", prototype.clientSecret) + assert.Equal(t, prototype.clientSecret, configured.clientSecret) + assert.Equal(t, 11*time.Second, *prototype.ttl) + assert.Equal(t, prototype.ttl, configured.ttl) + assert.Equal(t, "Authorization", prototype.headerName) + assert.Equal(t, "X-Foo-Bar", configured.headerName) + assert.Equal(t, "Bearer", prototype.headerScheme) + assert.Equal(t, "Foo", configured.headerScheme) + assert.Empty(t, prototype.scopes) + assert.Equal(t, prototype.scopes, configured.scopes) + assert.Equal(t, prototype.authMethod, configured.authMethod) + }, + }, + { + uc: "only header scheme reconfigured", + id: "3", + prototypeConfig: []byte(` +token_url: https://foo.bar +client_id: foo +client_secret: bar +cache_ttl: 11s +`), + config: []byte(` +header: + scheme: Foo +`), + assert: func(t *testing.T, err error, prototype *oauth2ClientCredentialsFinalizer, configured *oauth2ClientCredentialsFinalizer) { + t.Helper() + + require.Error(t, err) + assert.ErrorIs(t, err, heimdall.ErrConfiguration) + assert.Contains(t, err.Error(), "failed validating") + }, + }, + } { + t.Run(tc.uc, func(t *testing.T) { + pc, err := testsupport.DecodeTestConfig(tc.prototypeConfig) + require.NoError(t, err) + + conf, err := testsupport.DecodeTestConfig(tc.config) + require.NoError(t, err) + + prototype, err := newOAuth2ClientCredentialsFinalizer(tc.id, pc) + require.NoError(t, err) + + // WHEN + finalizer, err := prototype.WithConfig(conf) + + // THEN + var ( + ok bool + realFinalizer *oauth2ClientCredentialsFinalizer + ) + + if err == nil { + realFinalizer, ok = finalizer.(*oauth2ClientCredentialsFinalizer) + require.True(t, ok) + } + + tc.assert(t, err, prototype, realFinalizer) + }) + } +} + +func TestClientCredentialsFinalizerExecute(t *testing.T) { + t.Parallel() + + type ( + RequestAsserter func(t *testing.T, req *http.Request) + ResponseBuilder func(t *testing.T) (any, int) + ) + + var ( + endpointCalled bool + assertRequest RequestAsserter + buildResponse ResponseBuilder + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + endpointCalled = true + + if req.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + + return + } + if err := req.ParseForm(); err != nil { + w.WriteHeader(http.StatusInternalServerError) + + return + } + + assertRequest(t, req) + + resp, code := buildResponse(t) + + rawResp, err := json.MarshalContext(req.Context(), resp) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + + return + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Length", strconv.Itoa(len(rawResp))) + + w.WriteHeader(code) + _, err = w.Write(rawResp) + assert.NoError(t, err) + })) + defer srv.Close() + + for _, tc := range []struct { + uc string + finalizer *oauth2ClientCredentialsFinalizer + configureMocks func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) + assertRequest RequestAsserter + buildResponse ResponseBuilder + assert func(t *testing.T, err error, tokenEndpointCalled bool) + }{ + { + uc: "reusing response from cache", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "Authorization", + headerScheme: "Bearer", + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return("foobar") + ctx.EXPECT().AddHeaderForUpstream("Authorization", "Bearer foobar") + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + require.NoError(t, err) + assert.False(t, tokenEndpointCalled) + }, + }, + { + uc: "cache entry of wrong type and no ttl in issued token", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "Authorization", + headerScheme: "Bearer", + tokenURL: srv.URL, + clientID: "foo", + clientSecret: "bar", + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return(10) + cch.EXPECT().Delete(mock.Anything) + ctx.EXPECT().AddHeaderForUpstream("Authorization", "Bearer barfoo") + }, + assertRequest: func(t *testing.T, req *http.Request) { + t.Helper() + + val, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic ")) + assert.NoError(t, err) + + clientIDAndSecret := strings.Split(string(val), ":") + assert.Equal(t, "foo", clientIDAndSecret[0]) + assert.Equal(t, "bar", clientIDAndSecret[1]) + + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + assert.Equal(t, "application/json", req.Header.Get("Accept-Type")) + assert.Equal(t, "client_credentials", req.FormValue("grant_type")) + assert.Empty(t, req.FormValue("scope")) + }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return &TokenSuccessfulResponse{ + AccessToken: "barfoo", + TokenType: "Foo", + }, http.StatusOK + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + require.NoError(t, err) + assert.True(t, tokenEndpointCalled) + }, + }, + { + uc: "ttl not configured, no cache entry and token has expires_in claim", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "Authorization", + headerScheme: "Bar", + tokenURL: srv.URL, + clientID: "bar", + clientSecret: "foo", + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return(nil) + cch.EXPECT().Set(mock.Anything, "barfoo", 5*time.Minute-5*time.Second).Return() + ctx.EXPECT().AddHeaderForUpstream("Authorization", "Bar barfoo").Return() + }, + assertRequest: func(t *testing.T, req *http.Request) { + t.Helper() + + val, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic ")) + assert.NoError(t, err) + + clientIDAndSecret := strings.Split(string(val), ":") + assert.Equal(t, "bar", clientIDAndSecret[0]) + assert.Equal(t, "foo", clientIDAndSecret[1]) + + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + assert.Equal(t, "application/json", req.Header.Get("Accept-Type")) + assert.Equal(t, "client_credentials", req.FormValue("grant_type")) + assert.Empty(t, req.FormValue("scope")) + }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return &TokenSuccessfulResponse{ + AccessToken: "barfoo", + TokenType: "Foo", + ExpiresIn: int64((5 * time.Minute).Seconds()), + }, http.StatusOK + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + require.NoError(t, err) + assert.True(t, tokenEndpointCalled) + }, + }, + { + uc: "error while unmarshalling successful response", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + tokenURL: srv.URL, + clientID: "bar", + clientSecret: "foo", + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return(nil) + }, + assertRequest: func(t *testing.T, req *http.Request) { t.Helper() }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return "foo", http.StatusOK + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + assert.True(t, tokenEndpointCalled) + require.Error(t, err) + assert.ErrorIs(t, err, heimdall.ErrInternal) + }, + }, + { + uc: "error while unmarshalling error response", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + tokenURL: srv.URL, + clientID: "bar", + clientSecret: "foo", + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return(nil) + }, + assertRequest: func(t *testing.T, req *http.Request) { t.Helper() }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return "foo", http.StatusBadRequest + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + assert.True(t, tokenEndpointCalled) + require.Error(t, err) + assert.ErrorIs(t, err, heimdall.ErrInternal) + }, + }, + { + uc: "error while sending request", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + tokenURL: "http://127.0.0.1:11111", + clientID: "bar", + clientSecret: "foo", + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return(nil) + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + assert.False(t, tokenEndpointCalled) + require.Error(t, err) + assert.ErrorIs(t, err, heimdall.ErrCommunication) + }, + }, + { + uc: "full configuration, no cache hit and token has expires_in claim", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "X-My-Header", + headerScheme: "Foo", + tokenURL: srv.URL, + clientID: "bar", + clientSecret: "foo", + ttl: func() *time.Duration { + ttl := 3 * time.Minute + + return &ttl + }(), + scopes: []string{"baz", "zab"}, + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return(nil) + cch.EXPECT().Set(mock.Anything, "foobar", 3*time.Minute).Return() + ctx.EXPECT().AddHeaderForUpstream("X-My-Header", "Foo foobar").Return() + }, + assertRequest: func(t *testing.T, req *http.Request) { + t.Helper() + + val, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic ")) + assert.NoError(t, err) + + clientIDAndSecret := strings.Split(string(val), ":") + assert.Equal(t, "bar", clientIDAndSecret[0]) + assert.Equal(t, "foo", clientIDAndSecret[1]) + + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + assert.Equal(t, "application/json", req.Header.Get("Accept-Type")) + assert.Equal(t, "client_credentials", req.FormValue("grant_type")) + scopes := strings.Split(req.FormValue("scope"), " ") + assert.Len(t, scopes, 2) + assert.Contains(t, scopes, "baz") + assert.Contains(t, scopes, "zab") + }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return &TokenSuccessfulResponse{ + AccessToken: "foobar", + TokenType: "Foo", + ExpiresIn: int64((5 * time.Minute).Seconds()), + }, http.StatusOK + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + require.NoError(t, err) + assert.True(t, tokenEndpointCalled) + }, + }, + { + uc: "disabled cache", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "X-My-Header", + headerScheme: "Foo", + tokenURL: srv.URL, + clientID: "bar", + clientSecret: "foo", + ttl: func() *time.Duration { + ttl := 0 * time.Second + + return &ttl + }(), + scopes: []string{"baz", "zab"}, + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + ctx.EXPECT().AddHeaderForUpstream("X-My-Header", "Foo foobar").Return() + }, + assertRequest: func(t *testing.T, req *http.Request) { + t.Helper() + + val, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic ")) + assert.NoError(t, err) + + clientIDAndSecret := strings.Split(string(val), ":") + assert.Equal(t, "bar", clientIDAndSecret[0]) + assert.Equal(t, "foo", clientIDAndSecret[1]) + + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + assert.Equal(t, "application/json", req.Header.Get("Accept-Type")) + assert.Equal(t, "client_credentials", req.FormValue("grant_type")) + scopes := strings.Split(req.FormValue("scope"), " ") + assert.Len(t, scopes, 2) + assert.Contains(t, scopes, "baz") + assert.Contains(t, scopes, "zab") + }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return &TokenSuccessfulResponse{ + AccessToken: "foobar", + TokenType: "Foo", + ExpiresIn: int64((5 * time.Minute).Seconds()), + }, http.StatusOK + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + require.NoError(t, err) + assert.True(t, tokenEndpointCalled) + }, + }, + { + uc: "custom cache ttl and no expires_in in token", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "X-My-Header", + headerScheme: "Foo", + tokenURL: srv.URL, + clientID: "bar", + clientSecret: "foo", + ttl: func() *time.Duration { + ttl := 3 * time.Minute + + return &ttl + }(), + scopes: []string{"baz", "zab"}, + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return(nil) + cch.EXPECT().Set(mock.Anything, "foobar", 3*time.Minute).Return() + ctx.EXPECT().AddHeaderForUpstream("X-My-Header", "Foo foobar").Return() + }, + assertRequest: func(t *testing.T, req *http.Request) { + t.Helper() + + val, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic ")) + assert.NoError(t, err) + + clientIDAndSecret := strings.Split(string(val), ":") + assert.Equal(t, "bar", clientIDAndSecret[0]) + assert.Equal(t, "foo", clientIDAndSecret[1]) + + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + assert.Equal(t, "application/json", req.Header.Get("Accept-Type")) + assert.Equal(t, "client_credentials", req.FormValue("grant_type")) + scopes := strings.Split(req.FormValue("scope"), " ") + assert.Len(t, scopes, 2) + assert.Contains(t, scopes, "baz") + assert.Contains(t, scopes, "zab") + }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return &TokenSuccessfulResponse{ + AccessToken: "foobar", + TokenType: "Foo", + }, http.StatusOK + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + require.NoError(t, err) + assert.True(t, tokenEndpointCalled) + }, + }, + { + uc: "using request_body authentication strategy", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "X-My-Header", + headerScheme: "Foo", + tokenURL: srv.URL, + clientID: "bar foo", + clientSecret: "foo bar", + authMethod: authMethodRequestBody, + ttl: func() *time.Duration { + ttl := 3 * time.Minute + + return &ttl + }(), + scopes: []string{"baz", "zab"}, + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return(nil) + cch.EXPECT().Set(mock.Anything, "foobar", 3*time.Minute).Return() + ctx.EXPECT().AddHeaderForUpstream("X-My-Header", "Foo foobar").Return() + }, + assertRequest: func(t *testing.T, req *http.Request) { + t.Helper() + + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + assert.Equal(t, "application/json", req.Header.Get("Accept-Type")) + assert.Equal(t, "bar foo", req.FormValue("client_id")) + assert.Equal(t, "foo bar", req.FormValue("client_secret")) + assert.Equal(t, "client_credentials", req.FormValue("grant_type")) + scopes := strings.Split(req.FormValue("scope"), " ") + assert.Len(t, scopes, 2) + assert.Contains(t, scopes, "baz") + assert.Contains(t, scopes, "zab") + }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return &TokenSuccessfulResponse{ + AccessToken: "foobar", + TokenType: "Foo", + }, http.StatusOK + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + require.NoError(t, err) + assert.True(t, tokenEndpointCalled) + }, + }, + { + uc: "misbehaving server on error", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "X-My-Header", + headerScheme: "Foo", + tokenURL: srv.URL, + clientID: "bar", + clientSecret: "foo", + ttl: func() *time.Duration { + ttl := 0 * time.Minute + + return &ttl + }(), + scopes: []string{"baz", "zab"}, + }, + assertRequest: func(t *testing.T, req *http.Request) { + t.Helper() + + val, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic ")) + assert.NoError(t, err) + + clientIDAndSecret := strings.Split(string(val), ":") + assert.Equal(t, "bar", clientIDAndSecret[0]) + assert.Equal(t, "foo", clientIDAndSecret[1]) + + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + assert.Equal(t, "application/json", req.Header.Get("Accept-Type")) + assert.Equal(t, "client_credentials", req.FormValue("grant_type")) + scopes := strings.Split(req.FormValue("scope"), " ") + assert.Len(t, scopes, 2) + assert.Contains(t, scopes, "baz") + assert.Contains(t, scopes, "zab") + }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + // the following is not compliant as error is defined otherwise + // in https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + res, err := json.Marshal(map[string]any{ + "error": "invalid_request", + "error_description": "whatever", + }) + require.NoError(t, err) + + return &TokenErrorResponse{ + ErrorType: string(res), + }, http.StatusOK + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + assert.True(t, tokenEndpointCalled) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid_request") + }, + }, + { + uc: "misbehaving server on error, response code unexpected", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "X-My-Header", + headerScheme: "Foo", + tokenURL: srv.URL, + clientID: "bar", + clientSecret: "foo", + ttl: func() *time.Duration { + ttl := 0 * time.Minute + + return &ttl + }(), + scopes: []string{"baz", "zab"}, + }, + assertRequest: func(t *testing.T, req *http.Request) { + t.Helper() + }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return &TokenErrorResponse{ + ErrorType: "invalid_request", + ErrorDescription: "whatever", + }, http.StatusForbidden + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + assert.True(t, tokenEndpointCalled) + require.Error(t, err) + assert.Contains(t, err.Error(), "unexpected response code: 403") + }, + }, + { + uc: "compliant server on error", + finalizer: &oauth2ClientCredentialsFinalizer{ + id: "test", + headerName: "X-My-Header", + headerScheme: "Foo", + tokenURL: srv.URL, + clientID: "bar", + clientSecret: "foo", + ttl: func() *time.Duration { + ttl := 3 * time.Minute + + return &ttl + }(), + scopes: []string{"baz", "zab"}, + }, + configureMocks: func(t *testing.T, ctx *mocks.ContextMock, cch *mocks2.CacheMock) { + t.Helper() + + cch.EXPECT().Get(mock.Anything).Return(nil) + }, + assertRequest: func(t *testing.T, req *http.Request) { + t.Helper() + + val, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(req.Header.Get("Authorization"), "Basic ")) + assert.NoError(t, err) + + clientIDAndSecret := strings.Split(string(val), ":") + assert.Equal(t, "bar", clientIDAndSecret[0]) + assert.Equal(t, "foo", clientIDAndSecret[1]) + + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + assert.Equal(t, "application/json", req.Header.Get("Accept-Type")) + assert.Equal(t, "client_credentials", req.FormValue("grant_type")) + scopes := strings.Split(req.FormValue("scope"), " ") + assert.Len(t, scopes, 2) + assert.Contains(t, scopes, "baz") + assert.Contains(t, scopes, "zab") + }, + buildResponse: func(t *testing.T) (any, int) { + t.Helper() + + return &TokenErrorResponse{ + ErrorType: "invalid_request", + ErrorDescription: "whatever", + ErrorURI: "https://www.rfc-editor.org/rfc/rfc6749#section-5.1", + }, http.StatusBadRequest + }, + assert: func(t *testing.T, err error, tokenEndpointCalled bool) { + t.Helper() + + assert.True(t, tokenEndpointCalled) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid_request") + }, + }, + } { + t.Run(tc.uc, func(t *testing.T) { + endpointCalled = false + configureMocks := x.IfThenElse(tc.configureMocks != nil, + tc.configureMocks, + func(t *testing.T, _ *mocks.ContextMock, _ *mocks2.CacheMock) { t.Helper() }, + ) + + cch := mocks2.NewCacheMock(t) + ctx := mocks.NewContextMock(t) + + ctx.EXPECT().AppContext().Return(cache.WithContext(context.Background(), cch)) + configureMocks(t, ctx, cch) + + assertRequest = tc.assertRequest + buildResponse = tc.buildResponse + + // WHEN + err := tc.finalizer.Execute(ctx, nil) + + // THEN + tc.assert(t, err, endpointCalled) + }) + } +} diff --git a/schema/config.schema.json b/schema/config.schema.json index 65b859ea9..eef1c36bb 100644 --- a/schema/config.schema.json +++ b/schema/config.schema.json @@ -1417,6 +1417,92 @@ } } }, + "finalizerClientCredentials": { + "description": "Drives the OAuth2 client credentials flow and adds the corresponding token to the headers for the upstream", + "type": "object", + "additionalProperties": false, + "required": [ + "id", + "type", + "config" + ], + "properties": { + "type": { + "const": "oauth2_client_credentials" + }, + "id": { + "description": "The unique id of the finalizer to be used in the rule definition", + "type": "string" + }, + "config": { + "type": "object", + "additionalProperties": false, + "required": [ + "client_id", + "client_secret", + "token_url" + ], + "properties": { + "client_id": { + "description": "The OAuth 2.0 Client ID to be used for the OAuth 2.0 Client Credentials Grant", + "type": "string" + }, + "client_secret": { + "description": "The OAuth 2.0 Client Secret to be used for the OAuth 2.0 Client Credentials Grant", + "type": "string" + }, + "auth_method": { + "description": "How to transfer the client_id and client_secret to the oauth provider", + "type": "string", + "default": "basic_auth", + "enum": [ + "basic_auth", + "request_body" + ] + }, + "token_url": { + "description": "The OAuth 2.0 Token Endpoint where the OAuth 2.0 Client Credentials Grant will be performed", + "type": "string" + }, + "scopes": { + "description": "The OAuth 2.0 Scopes to be requested during the OAuth 2.0 Client Credentials Grant", + "type": "array", + "items": { + "type": "string" + } + }, + "cache_ttl": { + "type": "string", + "description": "How long to cache the issued token. Defaults to the value of the `expires_in` of the issued token. If `expires_in` is not present in the response, the token is not cached until this property is not explicitly configured. If `expires_in` is present in the response and this property is configured the shorter value is taken. 0 or negative value will disable caching. ", + "pattern": "^[0-9]+(ns|us|ms|s|m|h)$", + "examples": [ + "1h", + "1m", + "30s" + ] + }, + "header": { + "type": "object", + "description": "Header and scheme to use to transport the issued token to the upstream service", + "additionalProperties": false, + "required": ["name"], + "properties": { + "name": { + "description": "The header name to use", + "type": "string", + "default": "Authorization" + }, + "scheme": { + "description": "The scheme to use", + "type": "string", + "default": "Bearer" + } + } + } + } + } + } + }, "errorType": { "description": "Error type", "type": "string", @@ -1812,6 +1898,9 @@ }, { "$ref": "#/definitions/finalizerCookie" + }, + { + "$ref": "#/definitions/finalizerClientCredentials" } ] }