From d1efce031f2e33c8f135e563d765fd28595682d1 Mon Sep 17 00:00:00 2001 From: Windz Date: Thu, 5 Sep 2019 15:55:14 +0900 Subject: [PATCH] [minor] fix multi domain fetch (#38) * fix multi domain fetch (add fetcher implementation) - add Init() function to authorizerd --- README.md | 11 +- authorizerd.go | 42 +- authorizerd_mock_test.go | 27 + authorizerd_test.go | 172 +++- client/exp_rt.go | 1 + go.mod | 2 +- go.sum | 5 +- jwk/daemon.go | 4 - jwk/daemon_test.go | 38 + model.go | 41 - option.go | 8 - option_test.go | 35 - policy/assertion.go | 1 - policy/assertion_test.go | 9 + policy/daemon.go | 155 +-- policy/daemon_test.go | 1920 +++++++++++------------------------ policy/fetcher.go | 198 ++++ policy/fetcher_mock_test.go | 57 ++ policy/fetcher_test.go | 1337 ++++++++++++++++++++++++ policy/option.go | 18 +- policy/option_test.go | 132 +-- pubkey/daemon.go | 4 - pubkey/daemon_test.go | 36 +- role/claim.go | 6 +- role/processor_test.go | 3 + 25 files changed, 2606 insertions(+), 1656 deletions(-) delete mode 100644 model.go create mode 100644 policy/fetcher.go create mode 100644 policy/fetcher_mock_test.go create mode 100644 policy/fetcher_test.go diff --git a/README.md b/README.md index 6da5aacb..6fb77bfc 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![License: Apache](https://img.shields.io/badge/License-Apache%202.0-blue.svg?style=flat-square)](https://opensource.org/licenses/Apache-2.0) [![release](https://img.shields.io/github/release/yahoojapan/athenz-authorizer.svg?style=flat-square)](https://github.com/yahoojapan/athenz-authorizer/releases/latest) [![CircleCI](https://circleci.com/gh/yahoojapan/athenz-authorizer.svg)](https://circleci.com/gh/yahoojapan/athenz-authorizer) [![codecov](https://codecov.io/gh/yahoojapan/athenz-authorizer/branch/master/graph/badge.svg?token=2CzooNJtUu&style=flat-square)](https://codecov.io/gh/yahoojapan/athenz-authorizer) [![Go Report Card](https://goreportcard.com/badge/github.com/yahoojapan/athenz-authorizer)](https://goreportcard.com/report/github.com/yahoojapan/athenz-authorizer) [![GolangCI](https://golangci.com/badges/github.com/yahoojapan/athenz-authorizer.svg?style=flat-square)](https://golangci.com/r/github.com/yahoojapan/athenz-authorizer) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/828220605c43419e92fb0667876dd2d0)](https://www.codacy.com/app/i.can.feel.gravity/athenz-authorizer?utm_source=github.com&utm_medium=referral&utm_content=yahoojapan/athenz-authorizer&utm_campaign=Badge_Grade) [![GoDoc](http://godoc.org/github.com/yahoojapan/athenz-authorizer?status.svg)](http://godoc.org/github.com/yahoojapan/athenz-authorizer) ## What is Athenz authorizer -Athenz authorizer is a library to cache the policies of [Athenz](https://github.com/yahoo/athenz) to authorizer authenication and authorization check of user request. +Athenz authorizer is a library to cache the policies of [Athenz](https://github.com/yahoo/athenz) to authorizer authentication and authorization check of user request. ![Overview](./doc/policy_updater_overview.png) @@ -33,7 +33,7 @@ go func() { // Verify role token if err := daemon.VerifyRoleToken(ctx, roleTok, act, res); err != nil { - // token not authorizated + // token not authorized } ``` @@ -49,9 +49,9 @@ Athenz pubkey daemon (pubkeyd) is responsible for periodically update the Athenz ### Athenz policy daemon -Athenz policy daemon (policyd) is responsible for periodically update the policy data of specified Athenz domain from Athenz server. The received policy data will be verified using the public key got from pubkeyd, and cache into memory. Whenever user requesting for the access check, the verification check will be used instead of asking Athenz server everytime. +Athenz policy daemon (policyd) is responsible for periodically update the policy data of specified Athenz domain from Athenz server. The received policy data will be verified using the public key got from pubkeyd, and cache into memory. Whenever user requesting for the access check, the verification check will be used instead of asking Athenz server every time. -## Configuratrion +## Configuration The authorizer uses functional options pattern to initialize the instance. All the options are defined [here](./option.go). @@ -66,8 +66,7 @@ The authorizer uses functional options pattern to initialize the instance. All t | PubkeyEtagExpTime | ETag cache TTL of Athenz public key data | 168 Hours (1 Week) | No | | | PubkeyEtagFlushDur | ETag cache purge duration | 84 Hours | No | | | PolicyRefreshDuration | The refresh duration to update Athenz policy data | 30 Minutes | No | | -| PolicyExpireMargin | The expire margin to update the policy data. It forces update the policy data before the policy expiration margin. | 3 Hours | No | | -| PolicyEtagFlushDur | Policy data cache purge duration | 12 Hours | No | | +| PolicyExpireMargin | The expire margin to update the policy data. It forces update the policy data before the policy expiration margin. | 3 Hours | No | | ## License diff --git a/authorizerd.go b/authorizerd.go index 76a0a36d..9eddcf5a 100644 --- a/authorizerd.go +++ b/authorizerd.go @@ -25,6 +25,7 @@ import ( "github.com/kpango/gache" "github.com/kpango/glg" + "golang.org/x/sync/errgroup" "github.com/pkg/errors" "github.com/yahoojapan/athenz-authorizer/v2/jwk" @@ -35,6 +36,7 @@ import ( // Authorizerd represents a daemon for user to verify the role token type Authorizerd interface { + Init(ctx context.Context) error Start(ctx context.Context) <-chan error VerifyRoleToken(ctx context.Context, tok, act, res string) error VerifyRoleJWT(ctx context.Context, tok, act, res string) error @@ -74,7 +76,6 @@ type authorizer struct { athenzDomains []string policyRefreshDuration string policyErrRetryInterval string - policyEtagFlushDur string // jwkd parameters disableJwkd bool @@ -127,13 +128,12 @@ func New(opts ...Option) (Authorizerd, error) { if !prov.disablePolicyd { if prov.policyd, err = policy.New( policy.WithExpireMargin(prov.policyExpireMargin), - policy.WithEtagFlushDuration(prov.policyEtagFlushDur), policy.WithAthenzURL(prov.athenzURL), policy.WithAthenzDomains(prov.athenzDomains...), policy.WithRefreshDuration(prov.policyRefreshDuration), policy.WithErrRetryInterval(prov.policyErrRetryInterval), policy.WithHTTPClient(prov.client), - policy.WithPubKeyProvider(prov.pubkeyd.GetProvider()), + policy.WithPubKeyProvider(pubkeyProvider), ); err != nil { return nil, errors.Wrap(err, "error create policyd") } @@ -159,6 +159,40 @@ func New(opts ...Option) (Authorizerd, error) { return prov, nil } +// Init initializes child daemons synchronously. +func (a *authorizer) Init(ctx context.Context) error { + eg, egCtx := errgroup.WithContext(ctx) + eg.Go(func() error { + select { + case <-egCtx.Done(): + return egCtx.Err() + default: + if !a.disablePubkeyd { + err := a.pubkeyd.Update(egCtx) + if err != nil { + return err + } + } + if !a.disablePolicyd { + return a.policyd.Update(egCtx) + } + return nil + } + }) + if !a.disableJwkd { + eg.Go(func() error { + select { + case <-egCtx.Done(): + return egCtx.Err() + default: + return a.jwkd.Update(egCtx) + } + }) + } + + return eg.Wait() +} + // Start starts authorizer daemon. func (a *authorizer) Start(ctx context.Context) <-chan error { var ( @@ -260,7 +294,7 @@ func (a *authorizer) verify(ctx context.Context, m mode, tok, act, res string) e } func (a *authorizer) VerifyRoleCert(ctx context.Context, peerCerts []*x509.Certificate, act, res string) error { - dr := make([]string, 0, 2) + var dr []string drcheck := make(map[string]struct{}) domainRoles := make(map[string][]string) for _, cert := range peerCerts { diff --git a/authorizerd_mock_test.go b/authorizerd_mock_test.go index 28f8096b..6364bb11 100644 --- a/authorizerd_mock_test.go +++ b/authorizerd_mock_test.go @@ -39,6 +39,33 @@ func (cm *ConfdMock) Start(ctx context.Context) <-chan error { return ech } +type PubkeydMock struct { + StartFunc func(context.Context) <-chan error + UpdateFunc func(context.Context) error + GetProviderFunc func() pubkey.Provider +} + +func (pm *PubkeydMock) Start(ctx context.Context) <-chan error { + if pm.StartFunc != nil { + return pm.StartFunc(ctx) + } + return nil +} + +func (pm *PubkeydMock) Update(ctx context.Context) error { + if pm.UpdateFunc != nil { + return pm.UpdateFunc(ctx) + } + return nil +} + +func (pm *PubkeydMock) GetProvider() pubkey.Provider { + if pm.GetProviderFunc != nil { + return pm.GetProviderFunc() + } + return nil +} + type PolicydMock struct { UpdateFunc func(context.Context) error CheckPolicyFunc func(ctx context.Context, domain string, roles []string, action, resource string) error diff --git a/authorizerd_test.go b/authorizerd_test.go index 44dea32f..7d465bd7 100644 --- a/authorizerd_test.go +++ b/authorizerd_test.go @@ -102,6 +102,170 @@ func TestNew(t *testing.T) { } } +func Test_authorizer_Init(t *testing.T) { + type fields struct { + pubkeyd pubkey.Daemon + policyd policy.Daemon + jwkd jwk.Daemon + disablePubkeyd bool + disablePolicyd bool + disableJwkd bool + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + wantErrStr string + }{ + { + name: "cancelled context, no waiting", + fields: fields{ + pubkeyd: &PubkeydMock{ + UpdateFunc: func(context.Context) error { + time.Sleep(10 * time.Millisecond) + return errors.New("pubkeyd error") + }, + }, + policyd: nil, + jwkd: &JwkdMock{ + UpdateFunc: func(context.Context) error { + time.Sleep(10 * time.Millisecond) + return errors.New("jwkd error") + }, + }, + disablePubkeyd: false, + disablePolicyd: true, + disableJwkd: false, + }, + args: args{ + ctx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }(), + }, + wantErrStr: context.Canceled.Error(), + }, + { + name: "all disable", + fields: fields{ + pubkeyd: nil, + policyd: nil, + jwkd: nil, + disablePubkeyd: true, + disablePolicyd: true, + disableJwkd: true, + }, + args: args{ + ctx: context.Background(), + }, + wantErrStr: "", + }, + { + name: "jwkd is not blocked", + fields: fields{ + pubkeyd: &PubkeydMock{ + UpdateFunc: func(context.Context) error { + time.Sleep(10 * time.Millisecond) + return errors.New("pubkeyd error") + }, + }, + policyd: nil, + jwkd: &JwkdMock{ + UpdateFunc: func(context.Context) error { + return errors.New("jwkd done") + }, + }, + disablePubkeyd: false, + disablePolicyd: true, + disableJwkd: false, + }, + args: args{ + ctx: context.Background(), + }, + wantErrStr: "jwkd done", + }, + { + name: "policyd is blocked by pubkeyd", + fields: *(func() *fields { + pubkeydDone := false + return &fields{ + pubkeyd: &PubkeydMock{ + UpdateFunc: func(context.Context) error { + time.Sleep(10 * time.Millisecond) + pubkeydDone = true + return nil + }, + }, + policyd: &PolicydMock{ + UpdateFunc: func(context.Context) error { + if pubkeydDone { + return nil + } + return errors.New("policyd error") + }, + }, + jwkd: nil, + disablePubkeyd: false, + disablePolicyd: true, + disableJwkd: true, + } + }()), + args: args{ + ctx: context.Background(), + }, + wantErrStr: "", + }, + { + name: "all daemons init success", + fields: fields{ + pubkeyd: &PubkeydMock{ + UpdateFunc: func(context.Context) error { + return nil + }, + }, + policyd: &PolicydMock{ + UpdateFunc: func(context.Context) error { + return nil + }, + }, + jwkd: &JwkdMock{ + UpdateFunc: func(context.Context) error { + return nil + }, + }, + disablePubkeyd: false, + disablePolicyd: false, + disableJwkd: false, + }, + args: args{ + ctx: context.Background(), + }, + wantErrStr: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &authorizer{ + pubkeyd: tt.fields.pubkeyd, + policyd: tt.fields.policyd, + jwkd: tt.fields.jwkd, + disablePubkeyd: tt.fields.disablePubkeyd, + disablePolicyd: tt.fields.disablePolicyd, + disableJwkd: tt.fields.disableJwkd, + } + err := a.Init(tt.args.ctx) + if (err == nil && tt.wantErrStr != "") || (err != nil && err.Error() != tt.wantErrStr) { + t.Errorf("authorizer.Init() error = %v, wantErr %v", err, tt.wantErrStr) + return + } + }) + } +} + func Test_authorizer_Start(t *testing.T) { type fields struct { pubkeyd pubkey.Daemon @@ -503,7 +667,6 @@ func Test_authorizer_VerifyRoleJWT(t *testing.T) { policyExpireMargin string athenzDomains []string policyRefreshDuration string - policyEtagFlushDur string } type args struct { ctx context.Context @@ -695,7 +858,6 @@ func Test_authorizer_VerifyRoleJWT(t *testing.T) { policyExpireMargin: tt.fields.policyExpireMargin, athenzDomains: tt.fields.athenzDomains, policyRefreshDuration: tt.fields.policyRefreshDuration, - policyEtagFlushDur: tt.fields.policyEtagFlushDur, } err := p.VerifyRoleJWT(tt.args.ctx, tt.args.tok, tt.args.act, tt.args.res) if err != nil { @@ -736,7 +898,6 @@ func Test_authorizer_verify(t *testing.T) { policyExpireMargin string athenzDomains []string policyRefreshDuration string - policyEtagFlushDur string } type args struct { ctx context.Context @@ -772,7 +933,6 @@ func Test_authorizer_verify(t *testing.T) { policyExpireMargin: tt.fields.policyExpireMargin, athenzDomains: tt.fields.athenzDomains, policyRefreshDuration: tt.fields.policyRefreshDuration, - policyEtagFlushDur: tt.fields.policyEtagFlushDur, } if err := p.verify(tt.args.ctx, tt.args.m, tt.args.tok, tt.args.act, tt.args.res); (err != nil) != tt.wantErr { t.Errorf("authorizer.verify() error = %v, wantErr %v", err, tt.wantErr) @@ -799,7 +959,6 @@ func Test_authorizer_VerifyRoleCert(t *testing.T) { policyExpireMargin string athenzDomains []string policyRefreshDuration string - policyEtagFlushDur string } type args struct { ctx context.Context @@ -971,7 +1130,6 @@ bu80CwTnWhmdBo36Ig== policyExpireMargin: tt.fields.policyExpireMargin, athenzDomains: tt.fields.athenzDomains, policyRefreshDuration: tt.fields.policyRefreshDuration, - policyEtagFlushDur: tt.fields.policyEtagFlushDur, } if err := p.VerifyRoleCert(tt.args.ctx, tt.args.peerCerts, tt.args.act, tt.args.res); (err != nil) != tt.wantErr { t.Errorf("authorizer.VerifyRoleCert() error = %v, wantErr %v", err, tt.wantErr) @@ -998,7 +1156,6 @@ func Test_authorizer_GetPolicyCache(t *testing.T) { policyExpireMargin string athenzDomains []string policyRefreshDuration string - policyEtagFlushDur string } type args struct { ctx context.Context @@ -1039,7 +1196,6 @@ func Test_authorizer_GetPolicyCache(t *testing.T) { policyExpireMargin: tt.fields.policyExpireMargin, athenzDomains: tt.fields.athenzDomains, policyRefreshDuration: tt.fields.policyRefreshDuration, - policyEtagFlushDur: tt.fields.policyEtagFlushDur, } if got := a.GetPolicyCache(tt.args.ctx); !reflect.DeepEqual(got, tt.want) { t.Errorf("authorizer.GetPolicyCache() = %v, want %v", got, tt.want) diff --git a/client/exp_rt.go b/client/exp_rt.go index 95f1194a..0eba7db4 100644 --- a/client/exp_rt.go +++ b/client/exp_rt.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package client import ( diff --git a/go.mod b/go.mod index ac132b10..ef549c81 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,6 @@ require ( github.com/kpango/glg v1.4.6 github.com/lestrrat-go/jwx v0.9.0 github.com/pkg/errors v0.8.1 - github.com/yahoo/athenz v1.8.29 + github.com/yahoo/athenz v1.8.31 golang.org/x/sync v0.0.0-20190423024810-112230192c58 ) diff --git a/go.sum b/go.sum index 3b547078..df5e1395 100644 --- a/go.sum +++ b/go.sum @@ -55,9 +55,10 @@ github.com/spaolacci/murmur3 v1.0.1-0.20190317074736-539464a789e9/go.mod h1:JwIa github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/vmihailenco/msgpack v4.0.1+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= -github.com/yahoo/athenz v1.8.29 h1:9L3bexB9ZtiZIoGi4qarWoO4hPpWLscmkvwtseNqR8o= -github.com/yahoo/athenz v1.8.29/go.mod h1:Y+ZECUuQYgqNJ1WgR3jiNP6GxCbiMChPO0Ew7l1GNaE= +github.com/yahoo/athenz v1.8.31 h1:tfAS1P0yo4WtgRcoGwlVUaqYxrCGWJFFpDprbmmOFZ4= +github.com/yahoo/athenz v1.8.31/go.mod h1:5p1/T3n45DT6IICiDgsG//9Sf0DR3iGESoMq5jVybrQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= diff --git a/jwk/daemon.go b/jwk/daemon.go index f721c991..044097a4 100644 --- a/jwk/daemon.go +++ b/jwk/daemon.go @@ -65,10 +65,6 @@ func (j *jwkd) Start(ctx context.Context) <-chan error { glg.Info("Starting jwk updater") ech := make(chan error, 100) fch := make(chan struct{}, 1) - if err := j.Update(ctx); err != nil { - ech <- errors.Wrap(err, "error update athenz json web key") - fch <- struct{}{} - } go func() { defer close(fch) diff --git a/jwk/daemon_test.go b/jwk/daemon_test.go index 0a45a49c..b079ac10 100644 --- a/jwk/daemon_test.go +++ b/jwk/daemon_test.go @@ -99,6 +99,44 @@ func Test_jwkd_Start(t *testing.T) { afterFunc func() } tests := []test{ + func() test { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + return test{ + name: "canceled context", + fields: fields{ + athenzURL: strings.Replace(srv.URL, "https://", "", 1), + refreshDuration: time.Millisecond * 10, + errRetryInterval: time.Millisecond, + client: srv.Client(), + }, + args: args{ + ctx: ctx, + }, + checkFunc: func(j *jwkd, ch <-chan error) error { + err := <-ch + wantErr := context.Canceled + if err != wantErr { + return fmt.Errorf("got: %v, want: %v", err, wantErr) + } + for err = range ch { + if err != nil { + return err + } + } + + if k := j.keys.Load(); k != nil { + return errors.New("keys updated") + } + + return nil + }, + } + }(), func() test { k := `{ "e":"AQAB", diff --git a/model.go b/model.go deleted file mode 100644 index 42cb1d0b..00000000 --- a/model.go +++ /dev/null @@ -1,41 +0,0 @@ -/* -Copyright (C) 2018 Yahoo Japan Corporation Athenz team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package authorizerd - -type signedPolicy struct { - KeyID string `json:"keyId"` - Signature string `json:"signature"` - SignedPolicyData struct { - Expires string `json:"expires"` - Modified string `json:"modified"` - PolicyData struct { - Domain string `json:"domain"` - Policies []struct { - Assertions []struct { - Action string `json:"action"` - Effect string `json:"effect"` - Resource string `json:"resource"` - Role string `json:"role"` - } `json:"assertions"` - Modified string `json:"modified"` - Name string `json:"name"` - } `json:"policies"` - } `json:"policyData"` - ZmsKeyID string `json:"zmsKeyId"` - ZmsSignature string `json:"zmsSignature"` - } `json:"signedPolicyData"` -} diff --git a/option.go b/option.go index fc2079db..99b36208 100644 --- a/option.go +++ b/option.go @@ -198,14 +198,6 @@ func WithPolicyExpireMargin(t string) Option { } } -// WithPolicyEtagFlushDuration returns a PolicyEtagFlushDur functional option -func WithPolicyEtagFlushDuration(t string) Option { - return func(authz *authorizer) error { - authz.policyEtagFlushDur = t - return nil - } -} - /* jwkd parameters */ diff --git a/option_test.go b/option_test.go index 78d73224..1286b291 100644 --- a/option_test.go +++ b/option_test.go @@ -527,41 +527,6 @@ func TestWithPolicyExpireMargin(t *testing.T) { }) } } -func TestWithPolicyEtagFlushDuration(t *testing.T) { - type args struct { - t string - } - tests := []struct { - name string - args args - checkFunc func(Option) error - }{ - { - name: "set success", - args: args{ - t: "dummy", - }, - checkFunc: func(opt Option) error { - authz := &authorizer{} - if err := opt(authz); err != nil { - return err - } - if authz.policyEtagFlushDur != "dummy" { - return fmt.Errorf("invalid param was set") - } - return nil - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := WithPolicyEtagFlushDuration(tt.args.t) - if err := tt.checkFunc(got); err != nil { - t.Errorf("WithPolicyEtagFlushDuration() error = %v", err) - } - }) - } -} func TestWithCacheExp(t *testing.T) { type args struct { diff --git a/policy/assertion.go b/policy/assertion.go index f8fb1590..f086f696 100644 --- a/policy/assertion.go +++ b/policy/assertion.go @@ -48,7 +48,6 @@ func NewAssertion(action, resource, effect string) (*Assertion, error) { res := domres[1] reg, err := regexp.Compile("^" + replacer.Replace(strings.ToLower(action+"-"+res)) + "$") - if err != nil { return nil, errors.Wrap(err, "assertion format not correct") } diff --git a/policy/assertion_test.go b/policy/assertion_test.go index ba7cfa14..9edbed40 100644 --- a/policy/assertion_test.go +++ b/policy/assertion_test.go @@ -100,6 +100,15 @@ func TestNewAssertion(t *testing.T) { }, wantErr: errors.New("assertion format not correct: Access denied due to invalid/empty policy resources"), }, + { + name: "invalid regex", + args: args{ + resource: "dom:res(", + action: "act", + effect: "deny", + }, + wantErr: errors.New("assertion format not correct: error parsing regexp: missing closing ): `^act-res($`"), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/policy/daemon.go b/policy/daemon.go index 0aafb72f..4db058fe 100644 --- a/policy/daemon.go +++ b/policy/daemon.go @@ -20,8 +20,6 @@ import ( "context" "encoding/json" "fmt" - "io" - "io/ioutil" "net/http" "strings" "sync" @@ -57,26 +55,18 @@ type policyd struct { refreshDuration time.Duration errRetryInterval time.Duration - etagCache gache.Gache - etagFlushDur time.Duration - athenzURL string athenzDomains []string - client *http.Client - pkp pubkey.Provider -} - -type etagCache struct { - etag string - sp *SignedPolicy + client *http.Client + pkp pubkey.Provider + fetchers map[string]Fetcher // used for concurrent read, should never be updated } // New represent the constructor of Policyd func New(opts ...Option) (Daemon, error) { p := &policyd{ rolePolicies: gache.New(), - etagCache: gache.New(), } for _, opt := range append(defaultOptions, opts...) { @@ -85,6 +75,23 @@ func New(opts ...Option) (Daemon, error) { } } + // create fetchers + p.fetchers = make(map[string]Fetcher, len(p.athenzDomains)) + for _, domain := range p.athenzDomains { + f := fetcher{ + domain: domain, + expireMargin: p.expireMargin, + retryInterval: p.errRetryInterval, + retryMaxCount: 3, + athenzURL: p.athenzURL, + spVerifier: func(sp *SignedPolicy) error { + return sp.Verify(p.pkp) + }, + client: p.client, + } + p.fetchers[domain] = &f + } + return p, nil } @@ -93,17 +100,11 @@ func (p *policyd) Start(ctx context.Context) <-chan error { glg.Info("Starting policyd updater") ech := make(chan error, 100) fch := make(chan struct{}, 1) - if err := p.Update(ctx); err != nil { - glg.Debugf("Error initialize policy data, err: %v", err) - ech <- errors.Wrap(err, "error update policy") - fch <- struct{}{} - } go func() { defer close(fch) defer close(ech) - p.etagCache.StartExpired(ctx, p.etagFlushDur) ticker := time.NewTicker(p.refreshDuration) for { select { @@ -143,45 +144,48 @@ func (p *policyd) Start(ctx context.Context) <-chan error { // Update updates and cache policy data func (p *policyd) Update(ctx context.Context) error { - glg.Info("Updating policy") - defer glg.Info("Updated policy") + jobID := fastime.Now().Unix() + glg.Infof("[%d] will update policy", jobID) eg := errgroup.Group{} rp := gache.New() - for _, domain := range p.athenzDomains { + for _, fetcher := range p.fetchers { + f := fetcher // for closure select { case <-ctx.Done(): glg.Info("Update policy interrupted") return ctx.Err() default: - dom := domain eg.Go(func() error { select { case <-ctx.Done(): glg.Info("Update policy interrupted") return ctx.Err() default: - return p.fetchAndCachePolicy(ctx, rp, dom) + return fetchAndCachePolicy(ctx, rp, f) } }) } } if err := eg.Wait(); err != nil { + glg.Errorf("[%d] update policy fail", jobID) return err } rp.StartExpired(ctx, p.policyExpiredDuration). EnableExpiredHook(). SetExpiredHook(func(ctx context.Context, key string) { - //key = :role. - p.fetchAndCachePolicy(ctx, p.rolePolicies, strings.Split(key, ":role.")[0]) + // key = :role. + fetchAndCachePolicy(ctx, p.rolePolicies, p.fetchers[strings.Split(key, ":role.")[0]]) }) p.rolePolicies, rp = rp, p.rolePolicies + glg.Debugf("tmp cache becomes effective") rp.Stop() rp.Clear() + glg.Infof("[%d] update policy done", jobID) return nil } @@ -256,101 +260,30 @@ func (p *policyd) GetPolicyCache(ctx context.Context) map[string]interface{} { return p.rolePolicies.ToRawMap(ctx) } -func (p *policyd) fetchAndCachePolicy(ctx context.Context, g gache.Gache, dom string) error { - spd, upd, err := p.fetchPolicy(ctx, dom) +func fetchAndCachePolicy(ctx context.Context, g gache.Gache, f Fetcher) error { + sp, err := f.FetchWithRetry(ctx) if err != nil { - glg.Debugf("fetch policy failed, err: %v", err) - return errors.Wrap(err, "error fetch policy") + errMsg := "fetch policy fail" + glg.Errorf("%s, error: %v", errMsg, err) + if sp == nil { + return errors.Wrap(err, errMsg) + } } glg.DebugFunc(func() string { - rawpol, _ := json.Marshal(spd) - return fmt.Sprintf("fetched policy data, domain: %s,updated: %v, body: %s", dom, upd, (string)(rawpol)) + rawpol, _ := json.Marshal(sp) + return fmt.Sprintf("will merge policy, domain: %s, body: %s", f.Domain(), (string)(rawpol)) }) - if err = simplifyAndCachePolicy(ctx, g, spd); err != nil { - glg.Debugf("simplify and cache error: %v", err) - return errors.Wrap(err, "error simplify and cache") + if err := simplifyAndCachePolicy(ctx, g, sp); err != nil { + errMsg := "simplify and cache policy fail" + glg.Debugf("%s, error: %v", errMsg, err) + return errors.Wrap(err, errMsg) } return nil } -func (p *policyd) fetchPolicy(ctx context.Context, domain string) (*SignedPolicy, bool, error) { - glg.Infof("Fetching policy for domain %s", domain) - // https://{www.athenz.com/zts/v1}/domain/{athenz domain}/signed_policy_data - url := fmt.Sprintf("https://%s/domain/%s/signed_policy_data", p.athenzURL, domain) - - glg.Debugf("fetching policy, url: %v", url) - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - glg.Errorf("fetch policy error, domain: %s, error: %v", domain, err) - return nil, false, errors.Wrap(err, "error creating fetch policy request") - } - - // etag header - t, ok := p.etagCache.Get(domain) - if ok { - ec := t.(*etagCache) - glg.Debugf("request on domain: %s, using etag: %s", domain, ec.etag) - req.Header.Set("If-None-Match", ec.etag) - } - - res, err := p.client.Do(req.WithContext(ctx)) - if err != nil { - glg.Errorf("Error making HTTP request, domain: %s, error: %v", domain, err) - return nil, false, errors.Wrap(err, "error making request") - } - - // if server return NotModified, return policy from cache - if res.StatusCode == http.StatusNotModified { - cache := t.(*etagCache) - glg.Debugf("Server return not modified, keep using domain: %s, etag: %v", domain, cache.etag) - return cache.sp, false, nil - } - - if res.StatusCode != http.StatusOK { - glg.Errorf("Domain %s: Server return not OK", domain) - return nil, false, errors.Wrap(ErrFetchPolicy, "error fetching policy data") - } - - // read and decode - sp := new(SignedPolicy) - if err = json.NewDecoder(res.Body).Decode(&sp); err != nil { - glg.Errorf("Error decoding policy, domain: %s, err: %v", domain, err) - return nil, false, errors.Wrap(err, "error decode response") - } - - // verify policy data - if err = sp.Verify(p.pkp); err != nil { - glg.Errorf("Error verifying policy, domain: %s, err: %v", domain, err) - return nil, false, errors.Wrap(err, "error verify policy data") - } - - if _, err = io.Copy(ioutil.Discard, res.Body); err != nil { - glg.Warn(errors.Wrap(err, "error io.copy")) - } - if err = res.Body.Close(); err != nil { - glg.Warn(errors.Wrap(err, "error body.close")) - } - - // set etag cache - etag := res.Header.Get("ETag") - if etag != "" { - etagValidDur := sp.SignedPolicyData.Expires.Time.Sub(fastime.Now()) - p.expireMargin - glg.Debugf("Set domain %s with etag %v, duration: %s", domain, etag, etagValidDur) - if etagValidDur > 0 { - p.etagCache.SetWithExpire(domain, &etagCache{etag, sp}, etagValidDur) - } else { - // this triggers only if the new policies from server have expiry time < expiry margin - // hence, will not use ETag on next fetch request - p.etagCache.Delete(domain) - } - } - - return sp, true, nil -} - func simplifyAndCachePolicy(ctx context.Context, rp gache.Gache, sp *SignedPolicy) error { eg := errgroup.Group{} assm := new(sync.Map) // assertion map @@ -409,7 +342,7 @@ func simplifyAndCachePolicy(ctx context.Context, rp gache.Gache, sp *SignedPolic } rp.SetWithExpire(ass.Role, asss, time.Duration(sp.DomainSignedPolicyData.SignedPolicyData.Expires.Sub(now))) - glg.Debugf("added assertion to the cache: %+v", ass) + glg.Debugf("added assertion to the tmp cache: %+v", ass) return true }) if retErr != nil { diff --git a/policy/daemon_test.go b/policy/daemon_test.go index 59ba7f83..e693b93a 100644 --- a/policy/daemon_test.go +++ b/policy/daemon_test.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package policy import ( @@ -20,73 +21,109 @@ import ( "fmt" "math" "net/http" - "net/http/httptest" "reflect" "runtime" - "strings" "testing" "time" "github.com/ardielle/ardielle-go/rdl" - cmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/kpango/fastime" "github.com/kpango/gache" "github.com/pkg/errors" - authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" "github.com/yahoo/athenz/utils/zpe-updater/util" "github.com/yahoojapan/athenz-authorizer/v2/pubkey" ) func TestNew(t *testing.T) { + gacheCmp := cmp.Comparer(func(x, y gache.Gache) bool { + ctx := context.Background() + return cmp.Equal(x.ToRawMap(ctx), y.ToRawMap(ctx), cmpopts.EquateEmpty()) + }) + fetcherCmp := cmp.Comparer(func(x, y Fetcher) bool { + return x.Domain() == y.Domain() + }) type args struct { opts []Option } tests := []struct { - name string - args args - want Daemon - checkFunc func(got Daemon) error - wantErr bool + name string + args args + want Daemon + wantErr string }{ { name: "new success", args: args{ opts: []Option{}, }, - checkFunc: func(got Daemon) error { - p := got.(*policyd) - if p.expireMargin != time.Hour*3 { - return errors.New("invalid expireMargin") - } - return nil + want: &policyd{ + rolePolicies: gache.New(), + expireMargin: 3 * time.Hour, + policyExpiredDuration: 1 * time.Minute, + refreshDuration: 30 * time.Minute, + errRetryInterval: 1 * time.Minute, + client: http.DefaultClient, }, + wantErr: "", }, { name: "new success with options", args: args{ opts: []Option{WithExpireMargin("5s")}, }, - checkFunc: func(got Daemon) error { - p := got.(*policyd) - if p.expireMargin != time.Second*5 { - return errors.New("invalid expireMargin") - } - return nil + want: &policyd{ + rolePolicies: gache.New(), + expireMargin: 5 * time.Second, + policyExpiredDuration: 1 * time.Minute, + refreshDuration: 30 * time.Minute, + errRetryInterval: 1 * time.Minute, + client: http.DefaultClient, + }, + wantErr: "", + }, + { + name: "new success, domains with fetchers", + args: args{ + opts: []Option{WithAthenzDomains("dom1", "dom2")}, + }, + want: &policyd{ + rolePolicies: gache.New(), + expireMargin: 3 * time.Hour, + policyExpiredDuration: 1 * time.Minute, + refreshDuration: 30 * time.Minute, + errRetryInterval: 1 * time.Minute, + client: http.DefaultClient, + athenzDomains: []string{"dom1", "dom2"}, + fetchers: map[string]Fetcher{ + "dom1": &fetcher{domain: "dom1"}, + "dom2": &fetcher{domain: "dom2"}, + }, + }, + wantErr: "", + }, + { + name: "new fail, option error", + args: args{ + opts: []Option{ + func(*policyd) error { return errors.New("option error") }, + }, }, + want: nil, + wantErr: "error create policyd: option error", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := New(tt.args.opts...) - if (err != nil) != tt.wantErr { + if (err == nil && tt.wantErr != "") || (err != nil && err.Error() != tt.wantErr) { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } - - if tt.checkFunc != nil { - if err := tt.checkFunc(got); err != nil { - t.Errorf("New() = %v", err) - } + options := []cmp.Option{gacheCmp, fetcherCmp, cmp.AllowUnexported(policyd{}), cmpopts.EquateEmpty()} + if !cmp.Equal(got, tt.want, options...) { + t.Errorf("New() = %v, want %v", got, tt.want) } }) } @@ -100,11 +137,9 @@ func Test_policyd_Start(t *testing.T) { refreshDuration time.Duration errRetryInterval time.Duration pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration athenzURL string athenzDomains []string - client *http.Client + fetchers map[string]Fetcher } type args struct { ctx context.Context @@ -118,12 +153,44 @@ func Test_policyd_Start(t *testing.T) { } tests := []test{ func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("ETag", "dummyEtag") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"signedPolicyData":{"policyData":{"domain":"dummyDom","policies":[{"name":"dummyDom:policy.dummyPol","modified":"2099-02-14T05:42:07.219Z","assertions":[{"role":"dummyDom:role.dummyRole","resource":"dummyDom:dummyRes","action":"dummyAct","effect":"ALLOW"}]}]},"zmsSignature":"dummySig","zmsKeyId":"dummyKeyID","modified":"2099-03-04T04:33:27.318Z","expires":"2099-03-21T08:11:18.729Z"},"signature":"dummySig","keyId":"dummyKeyID"}`)) - })) - srv := httptest.NewTLSServer(handler) + domain := "dummyDom" + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "dummyKeyID", + Signature: "dummySig", + SignedPolicyData: &util.SignedPolicyData{ + ZmsKeyId: "dummyKeyID", + ZmsSignature: "dummySig", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Expires: &rdl.Timestamp{Time: fastime.Now().Add(time.Hour)}, + PolicyData: &util.PolicyData{ + Domain: "dummyDom", + Policies: []*util.Policy{ + { + Name: "dummyDom:policy.dummyPol", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Assertions: []*util.Assertion{ + { + Role: "dummyDom:role.dummyRole", + Action: "dummyAct", + Resource: "dummyDom:dummyRes", + Effect: "ALLOW", + }, + }, + }, + }, + }, + }, + }, + } + fetchers := map[string]Fetcher{ + "dummyDom": &fetcherMock{ + domainMock: func() string { return domain }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + return sp, nil + }, + }, + } ctx, cancel := context.WithCancel(context.Background()) return test{ @@ -131,20 +198,10 @@ func Test_policyd_Start(t *testing.T) { fields: fields{ rolePolicies: gache.New(), policyExpiredDuration: time.Minute * 30, - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagFlushDur: time.Second, - refreshDuration: time.Second, + refreshDuration: time.Millisecond * 30, expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil - }, - } - }, - athenzDomains: []string{"dummyDom"}, + athenzDomains: []string{domain}, + fetchers: fetchers, }, args: args{ ctx: ctx, @@ -159,10 +216,6 @@ func Test_policyd_Start(t *testing.T) { if len(asss.([]*Assertion)) != 1 { return errors.Errorf("invalid length assertions. want: 1, result: %d", len(asss.([]*Assertion))) } - _, ok = p.etagCache.Get("dummyDom") - if !ok { - return errors.New("etagCache is empty") - } return nil }, afterFunc: func() { @@ -171,16 +224,49 @@ func Test_policyd_Start(t *testing.T) { } }(), func() test { + domain := "dummyDom" + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "dummyKeyID", + Signature: "dummySig", + SignedPolicyData: &util.SignedPolicyData{ + ZmsKeyId: "dummyKeyID", + ZmsSignature: "dummySig", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Expires: &rdl.Timestamp{Time: fastime.Now().Add(time.Hour)}, + PolicyData: &util.PolicyData{ + Domain: "dummyDom", + Policies: []*util.Policy{ + { + Name: "dummyDom:policy.dummyPol", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Assertions: []*util.Assertion{ + { + Role: "dummyDom:role.dummyRole", + Effect: "ALLOW", + Action: "", + Resource: "", + }, + }, + }, + }, + }, + }, + }, + } c := 0 - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c++ - w.Header().Add("ETag", fmt.Sprintf("%v%d", "dummyEtag", c)) - res := fmt.Sprintf("dummyRes%d", c) - act := fmt.Sprintf("dummyAct%d", c) - w.WriteHeader(http.StatusOK) - w.Write([]byte(fmt.Sprintf(`{"signedPolicyData":{"policyData":{"domain":"dummyDom","policies":[{"name":"dummyDom:policy.dummyPol","modified":"2099-02-14T05:42:07.219Z","assertions":[{"role":"dummyDom:role.dummyRole","resource":"dummyDom:%s","action":"%s","effect":"ALLOW"}]}]},"zmsSignature":"dummySig","zmsKeyId":"dummyKeyID","modified":"2099-03-04T04:33:27.318Z","expires":"2099-03-21T08:11:18.729Z"},"signature":"dummySig","keyId":"dummyKeyID"}`, res, act))) - })) - srv := httptest.NewTLSServer(handler) + fetchers := map[string]Fetcher{ + "dummyDom": &fetcherMock{ + domainMock: func() string { return domain }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + c++ + a := sp.SignedPolicyData.PolicyData.Policies[0].Assertions[0] + a.Action = fmt.Sprintf("dummyAct%d", c) + a.Resource = fmt.Sprintf("dummyDom:dummyRes%d", c) + return sp, nil + }, + }, + } ctx, cancel := context.WithCancel(context.Background()) return test{ @@ -188,20 +274,10 @@ func Test_policyd_Start(t *testing.T) { fields: fields{ rolePolicies: gache.New(), policyExpiredDuration: time.Minute * 30, - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagFlushDur: time.Second, refreshDuration: time.Millisecond * 30, expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil - }, - } - }, - athenzDomains: []string{"dummyDom"}, + athenzDomains: []string{domain}, + fetchers: fetchers, }, args: args{ ctx: ctx, @@ -223,15 +299,6 @@ func Test_policyd_Start(t *testing.T) { return errors.Errorf("invalid assertion, got: %v, want: ^dummyact%d-dummyres%d$", ass.Reg.String(), c, c) } - ec, ok := p.etagCache.Get("dummyDom") - if !ok { - return errors.New("etagCache is empty") - } - ecwant := fmt.Sprintf("dummyEtag%d", c) - if ec.(*etagCache).etag != ecwant { - return errors.Errorf("invalid etag, got: %v, want: %s", ec, ecwant) - } - return nil }, afterFunc: func() { @@ -240,20 +307,52 @@ func Test_policyd_Start(t *testing.T) { } }(), func() test { + domain := "dummyDom" + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "dummyKeyID", + Signature: "dummySig", + SignedPolicyData: &util.SignedPolicyData{ + ZmsKeyId: "dummyKeyID", + ZmsSignature: "dummySig", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Expires: &rdl.Timestamp{Time: fastime.Now().Add(time.Hour)}, + PolicyData: &util.PolicyData{ + Domain: "dummyDom", + Policies: []*util.Policy{ + { + Name: "dummyDom:policy.dummyPol", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Assertions: []*util.Assertion{ + { + Role: "dummyDom:role.dummyRole", + Effect: "ALLOW", + Action: "", + Resource: "", + }, + }, + }, + }, + }, + }, + }, + } c := 0 - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if c < 3 { - c++ - w.WriteHeader(http.StatusInternalServerError) - return - } - w.Header().Add("ETag", fmt.Sprintf("%v%d", "dummyEtag", c)) - res := fmt.Sprintf("dummyRes%d", c) - act := fmt.Sprintf("dummyAct%d", c) - w.WriteHeader(http.StatusOK) - w.Write([]byte(fmt.Sprintf(`{"signedPolicyData":{"policyData":{"domain":"dummyDom","policies":[{"name":"dummyDom:policy.dummyPol","modified":"2099-02-14T05:42:07.219Z","assertions":[{"role":"dummyDom:role.dummyRole","resource":"dummyDom:%s","action":"%s","effect":"ALLOW"}]}]},"zmsSignature":"dummySig","zmsKeyId":"dummyKeyID","modified":"2099-03-04T04:33:27.318Z","expires":"2099-03-21T08:11:18.729Z"},"signature":"dummySig","keyId":"dummyKeyID"}`, res, act))) - })) - srv := httptest.NewTLSServer(handler) + fetchers := map[string]Fetcher{ + "dummyDom": &fetcherMock{ + domainMock: func() string { return domain }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + a := sp.SignedPolicyData.PolicyData.Policies[0].Assertions[0] + c++ + if c < 3 { + return nil, errors.New("fetchWithRetryMock error") + } + a.Action = fmt.Sprintf("dummyAct%d", c) + a.Resource = fmt.Sprintf("dummyDom:dummyRes%d", c) + return sp, nil + }, + }, + } ctx, cancel := context.WithCancel(context.Background()) return test{ @@ -261,29 +360,19 @@ func Test_policyd_Start(t *testing.T) { fields: fields{ rolePolicies: gache.New(), policyExpiredDuration: time.Minute * 30, - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - etagFlushDur: time.Second, - refreshDuration: time.Minute, + refreshDuration: time.Millisecond * 30, errRetryInterval: time.Millisecond * 5, expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil - }, - } - }, - athenzDomains: []string{"dummyDom"}, + athenzDomains: []string{domain}, + fetchers: fetchers, }, args: args{ ctx: ctx, }, checkFunc: func(p *policyd, ch <-chan error) error { - time.Sleep(time.Millisecond * 100) + time.Sleep(time.Millisecond * 120) cancel() - time.Sleep(time.Millisecond * 50) + time.Sleep(time.Millisecond * 30) asss, ok := p.rolePolicies.Get("dummyDom:role.dummyRole") if !ok { return errors.New("rolePolicies is empty") @@ -297,15 +386,6 @@ func Test_policyd_Start(t *testing.T) { return errors.Errorf("invalid assertion, got: %v, want: ^dummyact%d-dummyres%d$", ass.Reg.String(), c, c) } - ec, ok := p.etagCache.Get("dummyDom") - if !ok { - return errors.New("etagCache is empty") - } - ecwant := fmt.Sprintf("dummyEtag%d", c) - if ec.(*etagCache).etag != ecwant { - return errors.Errorf("invalid etag, got: %v, want: %s", ec, ecwant) - } - return nil }, afterFunc: func() { @@ -327,11 +407,9 @@ func Test_policyd_Start(t *testing.T) { refreshDuration: tt.fields.refreshDuration, errRetryInterval: tt.fields.errRetryInterval, pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, athenzURL: tt.fields.athenzURL, athenzDomains: tt.fields.athenzDomains, - client: tt.fields.client, + fetchers: tt.fields.fetchers, } ch := p.Start(tt.args.ctx) if tt.checkFunc != nil { @@ -351,173 +429,265 @@ func Test_policyd_Update(t *testing.T) { refreshDuration time.Duration errRetryInterval time.Duration pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration athenzURL string athenzDomains []string client *http.Client + fetchers map[string]Fetcher } type args struct { ctx context.Context } type test struct { - name string - fields fields - args args - beforeFunc func() - checkFunc func(pol *policyd) error - wantErr string - afterFunc func() + name string + fields fields + args args + wantErr string + wantRps map[string]interface{} } tests := []test{ - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("ETag", "dummyEtag") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"signedPolicyData":{"policyData":{"domain":"dummyDom","policies":[{"name":"dummyDom:policy.dummyPol","modified":"2099-02-14T05:42:07.219Z","assertions":[{"role":"dummyDom:role.dummyRole","resource":"dummyDom:dummyRes","action":"dummyAct","effect":"ALLOW"}]}]},"zmsSignature":"dummySig","zmsKeyId":"dummyKeyID","modified":"2099-03-04T04:33:27.318Z","expires":"2099-03-21T08:11:18.729Z"},"signature":"dummySig","keyId":"dummyKeyID"}`)) - })) - srv := httptest.NewTLSServer(handler) + func() (t test) { + t.name = "cancelled context, no actions" - return test{ - name: "Update policy success", - fields: fields{ - rolePolicies: gache.New(), - policyExpiredDuration: time.Minute * 30, - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil + // dummy values + domain := "dummyDom" + fetchers := make(map[string]Fetcher) + fetchers[domain] = &fetcherMock{ + domainMock: func() string { return domain }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + return nil, errors.New("fetchWithRetryMock executed") + }, + } + ctx, cancel := context.WithCancel(context.Background()) + + // prepare test + cancel() + t.fields = fields{ + rolePolicies: gache.New(), + athenzDomains: []string{domain}, + fetchers: fetchers, + } + t.args = args{ + ctx: ctx, + } + + // want + t.wantErr = context.Canceled.Error() + t.wantRps = make(map[string]interface{}) + return t + }(), + func() (t test) { + t.name = "Update policy success" + + // dummy values + domain := "dummyDom" + fetchers := make(map[string]Fetcher) + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "dummyKeyID", + Signature: "dummySig", + SignedPolicyData: &util.SignedPolicyData{ + ZmsKeyId: "dummyKeyID", + ZmsSignature: "dummySig", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Expires: &rdl.Timestamp{Time: fastime.Now().Add(time.Hour)}, + PolicyData: &util.PolicyData{ + Domain: "dummyDom", + Policies: []*util.Policy{ + { + Name: "dummyDom:policy.dummyPol", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Assertions: []*util.Assertion{ + { + Role: "dummyDom:role.dummyRole", + Effect: "ALLOW", + Action: "dummyAct", + Resource: "dummyDom:dummyRes", + }, + }, + }, }, - } + }, }, - athenzDomains: []string{"dummyDom"}, }, - args: args{ - ctx: context.Background(), + } + fetchers[domain] = &fetcherMock{ + domainMock: func() string { return domain }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + return sp, nil }, - wantErr: "", - checkFunc: func(pol *policyd) error { - pols, ok := pol.rolePolicies.Get("dummyDom:role.dummyRole") - if !ok { - return errors.New("role policies not found") - } - if len(pols.([]*Assertion)) != 1 { - return errors.New("role policies not correct") - } + } + ctx := context.Background() - return nil - }, + // prepare test + t.fields = fields{ + rolePolicies: gache.New(), + policyExpiredDuration: time.Hour, + athenzDomains: []string{domain}, + fetchers: fetchers, + } + t.args = args{ + ctx: ctx, } - }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - domain := strings.Split(r.URL.Path, "/")[2] - w.Header().Add("ETag", domain+"Etag") - w.WriteHeader(http.StatusOK) - spd := fmt.Sprintf(`{"signedPolicyData":{"policyData":{"domain":"%s","policies":[{"name":"%s:policy.dummyPol","modified":"2099-02-14T05:42:07.219Z","assertions":[{"role":"%s:role.dummyRole","resource":"%s:dummyRes","action":"dummyAct","effect":"ALLOW"}]}]},"zmsSignature":"dummySig","zmsKeyId":"dummyKeyID","modified":"2099-03-04T04:33:27.318Z","expires":"2099-03-21T08:11:18.729Z"},"signature":"dummySig","keyId":"dummyKeyID"}`, domain, domain, domain, domain) - w.Write([]byte(spd)) - })) - srv := httptest.NewTLSServer(handler) + // want + wantAssertion, _ := NewAssertion("dummyAct", "dummyDom:dummyRes", "ALLOW") + t.wantErr = "" + t.wantRps = make(map[string]interface{}) + t.wantRps["dummyDom:role.dummyRole"] = []*Assertion{wantAssertion} + return t + }(), + func() (t test) { + t.name = "Update policy success with multiple athenz domains" + // dummy values + createSp := func(domain string) *SignedPolicy { + return &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "dummyKeyID", + Signature: "dummySig", + SignedPolicyData: &util.SignedPolicyData{ + ZmsKeyId: "dummyKeyID", + ZmsSignature: "dummySig", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Expires: &rdl.Timestamp{Time: fastime.Now().Add(time.Hour)}, + PolicyData: &util.PolicyData{ + Domain: domain, + Policies: []*util.Policy{ + { + Name: fmt.Sprintf("%s:policy.dummyPol", domain), + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Assertions: []*util.Assertion{ + { + Role: fmt.Sprintf("%s:role.dummyRole", domain), + Effect: "ALLOW", + Action: "dummyAct", + Resource: fmt.Sprintf("%s:dummyRes", domain), + }, + }, + }, + }, + }, + }, + }, + } + } domains := make([]string, 1000) + fetchers := make(map[string]Fetcher, 1000) for i := 0; i < 1000; i++ { - domains[i] = fmt.Sprintf("dummyDom%d", i) - } - - return test{ - name: "Update policy success with multiple athenz domains", - fields: fields{ - rolePolicies: gache.New(), - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - policyExpiredDuration: time.Second * 120, - expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil - }, - } + d := fmt.Sprintf("dummyDom%d", i) + domains[i] = d + fetchers[d] = &fetcherMock{ + domainMock: func() string { return d }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + return createSp(d), nil }, - athenzDomains: domains, - }, - args: args{ - ctx: context.Background(), - }, - checkFunc: func(pol *policyd) error { - gotLen := len(pol.rolePolicies.ToRawMap(context.Background())) - wantLen := len(domains) - if gotLen != wantLen { - return fmt.Errorf("role policies length is not correct, got: %v, want: %v", gotLen, wantLen) - } + } + } + ctx := context.Background() - for _, dom := range domains { - domRole := fmt.Sprintf("%s:role.dummyRole", dom) - pols, ok := pol.rolePolicies.Get(domRole) - if !ok { - return errors.Errorf("role policies %s not found", domRole) - } - if len(pols.([]*Assertion)) != 1 { - return errors.Errorf("role policies of %s not correct", domRole) - } - } + // prepare test + t.fields = fields{ + rolePolicies: gache.New(), + policyExpiredDuration: time.Hour, + athenzDomains: domains, + fetchers: fetchers, + } + t.args = args{ + ctx: ctx, + } - return nil - }, + // want + t.wantErr = "" + t.wantRps = make(map[string]interface{}, 1000) + for i := 0; i < 1000; i++ { + d := fmt.Sprintf("dummyDom%d", i) + key := fmt.Sprintf("%s:role.dummyRole", d) + wantAssertion, _ := NewAssertion("dummyAct", fmt.Sprintf("%s:dummyRes", d), "ALLOW") + t.wantRps[key] = []*Assertion{wantAssertion} } + return t }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("ETag", "dummyEtag") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"signedPolicyData":{"policyData":{"domain":"dummyDom","policies":[{"name":"dummyDom:policy.dummyPol","modified":"2099-02-14T05:42:07.219Z","assertions":[{"role":"dummyDom:role.dummyRole","resource":"dummyDom:dummyRes","action":"dummyAct","effect":"ALLOW"}]}]},"zmsSignature":"dummySig","zmsKeyId":"dummyKeyID","modified":"2099-03-04T04:33:27.318Z","expires":"2099-03-21T08:11:18.729Z"},"signature":"dummySig","keyId":"dummyKeyID"}`)) - })) - srv := httptest.NewTLSServer(handler) - - ctx, cancel := context.WithDeadline(context.Background(), fastime.Now().Add(time.Millisecond*10)) - return test{ - name: "Update error, context timeout", - fields: fields{ - rolePolicies: gache.New(), - policyExpiredDuration: time.Minute * 30, - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil + func() (t test) { + t.name = "Update error, context timeout, no partial update" + + // dummy values + createSp := func(domain string) *SignedPolicy { + return &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "dummyKeyID", + Signature: "dummySig", + SignedPolicyData: &util.SignedPolicyData{ + ZmsKeyId: "dummyKeyID", + ZmsSignature: "dummySig", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Expires: &rdl.Timestamp{Time: fastime.Now().Add(time.Hour)}, + PolicyData: &util.PolicyData{ + Domain: domain, + Policies: []*util.Policy{ + { + Name: fmt.Sprintf("%s:policy.dummyPol", domain), + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Assertions: []*util.Assertion{ + { + Role: fmt.Sprintf("%s:role.dummyRole", domain), + Effect: "ALLOW", + Action: "dummyAct", + Resource: fmt.Sprintf("%s:dummyRes", domain), + }, + }, + }, + }, }, - } + }, }, - athenzDomains: []string{"dummyDom"}, - }, - args: args{ - ctx: ctx, - }, - wantErr: "context deadline exceeded", - beforeFunc: func() { - time.Sleep(time.Second) - }, - afterFunc: func() { + } + } + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500) // timeout should be long enough to enter Fetch() + go func() { + select { + case <-ctx.Done(): + return + case <-time.After(time.Hour): cancel() - }, + } + }() + domains := make([]string, 100) + fetchers := make(map[string]Fetcher, 100) + for i := 0; i < 100; i++ { + d := fmt.Sprintf("discardDom%d", i) + domains[i] = d + fetchers[d] = &fetcherMock{ + domainMock: func() string { return d }, + fetchWithRetryMock: func(ctx context.Context) (*SignedPolicy, error) { + if d == "discardDom"+"0" { + // blocking + <-ctx.Done() + return nil, ctx.Err() + } + return createSp(d), nil + }, + } } + + // prepare test + t.fields = fields{ + rolePolicies: gache.New(), + policyExpiredDuration: time.Hour, + athenzDomains: domains, + fetchers: fetchers, + } + t.args = args{ + ctx: ctx, + } + + // want + t.wantErr = "fetch policy fail: " + context.DeadlineExceeded.Error() + t.wantRps = make(map[string]interface{}) + return t }(), } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.afterFunc != nil { - defer tt.afterFunc() - } - p := &policyd{ expireMargin: tt.fields.expireMargin, rolePolicies: tt.fields.rolePolicies, @@ -525,22 +695,20 @@ func Test_policyd_Update(t *testing.T) { refreshDuration: tt.fields.refreshDuration, errRetryInterval: tt.fields.errRetryInterval, pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, athenzURL: tt.fields.athenzURL, athenzDomains: tt.fields.athenzDomains, client: tt.fields.client, + fetchers: tt.fields.fetchers, } - if tt.beforeFunc != nil { - tt.beforeFunc() - } - if err := p.Update(tt.args.ctx); (err != nil) && tt.wantErr != "" && err.Error() != tt.wantErr { - t.Errorf("policy.Update() error = %v, wantErr %v", err, tt.wantErr) + err := p.Update(tt.args.ctx) + if (err == nil && tt.wantErr != "") || (err != nil && err.Error() != tt.wantErr) { + t.Errorf("policyd.Update() error = %v, wantErr %v", err, tt.wantErr) + return } - if tt.checkFunc != nil { - if err := tt.checkFunc(p); err != nil { - t.Errorf("policy.Update() error = %v", err) - } + gotRps := p.GetPolicyCache(context.Background()) + if !cmp.Equal(gotRps, tt.wantRps, cmpopts.IgnoreFields(Assertion{}, "Reg")) { + t.Errorf("policyd.Update() rolePolicies = %v, want %v", gotRps, tt.wantRps) + t.Errorf("policyd.Update() rolePolicies diff = %s", cmp.Diff(gotRps, tt.wantRps, cmpopts.IgnoreFields(Assertion{}, "Reg"))) } }) } @@ -553,8 +721,6 @@ func Test_policyd_CheckPolicy(t *testing.T) { refreshDuration time.Duration errRetryInterval time.Duration pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration athenzURL string athenzDomains []string client *http.Client @@ -695,6 +861,24 @@ func Test_policyd_CheckPolicy(t *testing.T) { }, want: errors.New("no match: Access denied due to no match to any of the assertions defined in domain policy file"), }, + { + name: "check policy, canceled context", + fields: fields{ + rolePolicies: gache.New(), + }, + args: args{ + ctx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }(), + domain: "dummyDom", + roles: []string{"dummyRole"}, + action: "dummyAct", + resource: "dummyRes", + }, + want: context.Canceled, + }, { name: "check policy deny with multiple roles with allow and deny", fields: fields{ @@ -759,8 +943,6 @@ func Test_policyd_CheckPolicy(t *testing.T) { refreshDuration: tt.fields.refreshDuration, errRetryInterval: tt.fields.errRetryInterval, pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, athenzURL: tt.fields.athenzURL, athenzDomains: tt.fields.athenzDomains, client: tt.fields.client, @@ -788,8 +970,6 @@ func Test_policyd_CheckPolicy_goroutine(t *testing.T) { refreshDuration time.Duration errRetryInterval time.Duration pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration athenzURL string athenzDomains []string client *http.Client @@ -881,8 +1061,6 @@ func Test_policyd_CheckPolicy_goroutine(t *testing.T) { refreshDuration: tt.fields.refreshDuration, errRetryInterval: tt.fields.errRetryInterval, pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, athenzURL: tt.fields.athenzURL, athenzDomains: tt.fields.athenzDomains, client: tt.fields.client, @@ -890,7 +1068,8 @@ func Test_policyd_CheckPolicy_goroutine(t *testing.T) { b := make([]byte, 10240) lenStart := runtime.Stack(b, true) - // t.Log(string(b[:len])) + oldStack := string(b[:lenStart]) + // t.Log(oldStack) err := p.CheckPolicy(tt.args.ctx, tt.args.domain, tt.args.roles, tt.args.action, tt.args.resource) if err == nil { if tt.want != nil { @@ -905,1111 +1084,180 @@ func Test_policyd_CheckPolicy_goroutine(t *testing.T) { } // check runtime stack for go routine leak - time.Sleep(time.Second) // wait for some background process to cleanup + time.Sleep(time.Millisecond * 500) // wait for some background process to cleanup lenEnd := runtime.Stack(b, true) - // t.Log(string(b[:len])) - if math.Abs(float64(lenStart-lenEnd)) > 5 { - t.Errorf("go routine leak:\n%v", string(b[:lenEnd])) + // t.Log(string(b[:lenEnd])) + if math.Abs(float64(lenStart-lenEnd)) > 10 { // to tolerate fastime package goroutine status change, leaking will cause much larger stack length difference + t.Errorf("go routine leak:\n%v", cmp.Diff(oldStack, string(b[:lenEnd]))) } }) } } -func Test_policyd_fetchAndCachePolicy(t *testing.T) { - type fields struct { - expireMargin time.Duration - rolePolicies gache.Gache - policyExpiredDuration time.Duration - refreshDuration time.Duration - errRetryInterval time.Duration - pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration - athenzURL string - athenzDomains []string - client *http.Client - } +func Test_fetchAndCachePolicy(t *testing.T) { type args struct { ctx context.Context g gache.Gache - dom string + f Fetcher } type test struct { - name string - fields fields - args args - checkFunc func(pol *policyd) error - wantErr bool + name string + args args + wantErr string + wantRps map[string]interface{} } - tests := []test{ - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("ETag", "dummyEtag") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"signedPolicyData":{"policyData":{"domain":"dummyDom","policies":[{"name":"dummyDom:policy.dummyPol","modified":"2099-02-14T05:42:07.219Z","assertions":[{"role":"dummyDom:role.dummyRole","resource":"dummyDom:dummyRes","action":"dummyAct","effect":"ALLOW"}]}]},"zmsSignature":"dummySig","zmsKeyId":"dummyKeyID","modified":"2099-03-04T04:33:27.318Z","expires":"2099-03-21T08:11:18.729Z"},"signature":"dummySig","keyId":"dummyKeyID"}`)) - })) - srv := httptest.NewTLSServer(handler) - g := gache.New() - - return test{ - name: "fetch policy success with updated policy", - fields: fields{ - rolePolicies: gache.New(), - policyExpiredDuration: time.Minute * 30, - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil + createDummySp := func() *SignedPolicy { + return &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "dummyKeyID", + Signature: "dummySig", + SignedPolicyData: &util.SignedPolicyData{ + ZmsKeyId: "dummyKeyID", + ZmsSignature: "dummySig", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Expires: &rdl.Timestamp{Time: fastime.Now().Add(time.Hour)}, + PolicyData: &util.PolicyData{ + Domain: "dummyDom", + Policies: []*util.Policy{ + { + Name: "dummyDom:policy.dummyPol", + Modified: &rdl.Timestamp{Time: fastime.Now()}, + Assertions: []*util.Assertion{ + { + Role: "dummyDom:role.dummyRole", + Effect: "ALLOW", + Action: "dummyAct", + Resource: "dummyDom:dummyRes", + }, + }, }, - } + }, }, }, - args: args{ - ctx: context.Background(), - g: g, - dom: "dummyDom", - }, - checkFunc: func(pol *policyd) error { - pols, ok := g.Get("dummyDom:role.dummyRole") - if !ok { - return errors.New("role policies not found") - } - if len(pols.([]*Assertion)) != 1 { - return errors.New("role policies not correct") - } + }, + } + } + tests := []test{ + func() (t test) { + t.name = "fetch success, update cache" - return nil + // dummy values + domain := "dummyDom" + sp := createDummySp() + fetcher := &fetcherMock{ + domainMock: func() string { return domain }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + return sp, nil }, } - }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - srv := httptest.NewTLSServer(handler) - g := gache.New() - - return test{ - name: "fetch policy failed", - fields: fields{ - rolePolicies: gache.New(), - policyExpiredDuration: time.Minute * 30, - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil - }, - } - }, - }, - args: args{ - ctx: context.Background(), - g: g, - dom: "dummyDomain", - }, - wantErr: true, - } - }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("ETag", "dummyEtag") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"signedPolicyData":{"policyData":{"domain":"dummyDom","policies":[{"name":"dummyDom:policy.dummyPol","modified":"2099-02-14T05:42:07.219Z","assertions":[{"role":"dummyDom:role.dummyRole","resource":"","action":"dummyAct","effect":"ALLOW"}]}]},"zmsSignature":"dummySig","zmsKeyId":"dummyKeyID","modified":"2099-03-04T04:33:27.318Z","expires":"2099-03-21T08:11:18.729Z"},"signature":"dummySig","keyId":"dummyKeyID"}`)) - })) - srv := httptest.NewTLSServer(handler) - - return test{ - name: "simplifyAndCache failed", - fields: fields{ - rolePolicies: gache.New(), - policyExpiredDuration: time.Minute * 30, - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - etagCache: gache.New(), - expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil - }, - } - }, - }, - args: args{ - ctx: context.Background(), - g: gache.New(), - dom: "dummyDom", - }, - wantErr: true, - } - }(), - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := &policyd{ - expireMargin: tt.fields.expireMargin, - rolePolicies: tt.fields.rolePolicies, - policyExpiredDuration: tt.fields.policyExpiredDuration, - refreshDuration: tt.fields.refreshDuration, - errRetryInterval: tt.fields.errRetryInterval, - pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, - athenzURL: tt.fields.athenzURL, - athenzDomains: tt.fields.athenzDomains, - client: tt.fields.client, - } - if err := p.fetchAndCachePolicy(tt.args.ctx, tt.args.g, tt.args.dom); (err != nil) != tt.wantErr { - t.Errorf("policy.fetchAndCachePolicy() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.checkFunc != nil { - if err := tt.checkFunc(p); err != nil { - t.Errorf("policy.fetchAndCachePolicy() error = %v", err) - } - } - }) - } -} - -func Test_policyd_fetchPolicy(t *testing.T) { - type fields struct { - expireMargin time.Duration - rolePolicies gache.Gache - policyExpiredDuration time.Duration - refreshDuration time.Duration - errRetryInterval time.Duration - pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration - athenzURL string - athenzDomains []string - client *http.Client - } - type args struct { - ctx context.Context - domain string - } - type test struct { - name string - fields fields - args args - checkFunc func(p *policyd, sp *SignedPolicy, upd bool, err error) error - } - mockPkp := func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return nil - }, - } - } - tests := []test{ - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tByte, err := rdl.Timestamp{ - Time: fastime.Now().Add(1 * time.Hour).UTC(), - }.MarshalJSON() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - } else { - w.Header().Add("ETag", "dummyEtag") - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(http.StatusOK) - json := fmt.Sprintf(`{"signedPolicyData":{ - "zmsKeyId": "1", - "expires": %v - }}`, string(tByte)) - w.Write([]byte(json)) - } - })) - srv := httptest.NewTLSServer(handler) - - return test{ - name: "test fetch success", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: gache.New(), - expireMargin: time.Minute, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: "dummyDomain", - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - if err != nil { - return err - } - - etag, ok := p.etagCache.Get("dummyDomain") - if !ok { - return errors.New("etag not set") - } - etagCac := etag.(*etagCache) - if etagCac.etag != "dummyEtag" { - return errors.New("etag header not correct") - } - - want := &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - ZmsKeyId: "1", - }, - }, - } - - if !cmp.Equal(etagCac.sp, sp) { - return errors.Errorf("etag value not match, got: %v, want: %v", etag, want) - } - - if upd == false { - return errors.New("Invalid upd flag") - } - - return err - }, - } - }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - })) - srv := httptest.NewTLSServer(handler) - - return test{ - name: "test fetch error url", - fields: fields{ - athenzURL: " ", - policyExpiredDuration: time.Minute * 30, - etagCache: gache.New(), - expireMargin: time.Second, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: "dummy", - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - if sp != nil { - return errors.New("invalid return") - } - if upd != false { - return errors.New("invalid return ") - } - wantErr := `error creating fetch policy request: parse https:// /domain/dummy/signed_policy_data: invalid character " " in host name` - if err.Error() != wantErr { - return errors.Errorf("got error: %v, want: %v", err, wantErr) - } - - return nil - }, - } - }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("If-None-Match") != "dummyEtag" { - w.Header().Add("ETag", "dummyEtag") - w.WriteHeader(http.StatusOK) - } else { - w.WriteHeader(http.StatusNotModified) - } - })) - srv := httptest.NewTLSServer(handler) - - etagCac := gache.New() - etagCac.Set("dummyDomain", &etagCache{ - etag: "dummyEtag", - sp: &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - Expires: func() *rdl.Timestamp { - t := rdl.NewTimestamp(fastime.Now().Add(time.Hour)) - return &t - }(), - }, - }, - }, - }) - - return test{ - name: "test etag exists and response 304", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: etagCac, - expireMargin: time.Second, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: "dummyDomain", - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - if err != nil { - return err - } - - etag, ok := p.etagCache.Get("dummyDomain") - if !ok { - return errors.New("etag not set") - } - etagCac := etag.(*etagCache) - if etagCac.etag != "dummyEtag" { - return errors.New("etag header not correct") - } - - if etagCac.sp != sp { - return errors.Errorf("etag value not match, got: %v, want: %v", etag, sp) - } - - if upd != false { - return errors.New("Invalid upd flag") - } - - return err - }, - } - }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("If-None-Match") == "dummyOldEtag" { - w.Header().Add("ETag", "dummyNewEtag") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"signedPolicyData":{ - "zmsKeyId":"dummyNewId", - "expires":"2099-03-21T08:11:18.729Z" - }}`)) - } else { - w.WriteHeader(http.StatusNotModified) - } - })) - srv := httptest.NewTLSServer(handler) - - etagCac := gache.New() - etagCac.Set("dummyDomain", &etagCache{ - etag: "dummyOldEtag", - sp: &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - Expires: &rdl.Timestamp{ - Time: fastime.Now().Add(time.Hour).UTC(), - }, - }, - }, - }, - }) - - return test{ - name: "test etag exists but response 200", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: etagCac, - expireMargin: time.Second, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: "dummyDomain", - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - if err != nil { - return err - } + ctx := context.Background() - etag, ok := p.etagCache.Get("dummyDomain") - if !ok { - return errors.New("etag not set") - } - etagCac := etag.(*etagCache) - if etagCac.etag != "dummyNewEtag" { - return errors.New("etag header not correct") - } - - if !cmp.Equal(etagCac.sp, sp) { - return errors.Errorf("etag value not match, got: %v, want: %v", etagCac, sp) - } - - if upd != true { - return errors.New("Invalid upd flag") - } - - return err - }, + // prepare test + t.args = args{ + ctx: ctx, + g: gache.New(), + f: fetcher, } - }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - })) - srv := httptest.NewTLSServer(handler) - - return test{ - name: "test fetch error make https request", - fields: fields{ - athenzURL: "dummyURL", - policyExpiredDuration: time.Minute * 30, - etagCache: gache.New(), - expireMargin: time.Hour, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: "dummyDomain", - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - if sp != nil { - return errors.Errorf("sp should be nil") - } - if upd != false { - return errors.New("Invalid upd flag") - } - wantErr := "error making request: Get https://dummyURL/domain/dummyDomain/signed_policy_data: dial tcp: lookup dummyURL" - if !strings.HasPrefix(err.Error(), wantErr) { - return errors.Errorf("invalid error, got: %v, want: %v", err, wantErr) - } - return nil - }, - } + // want + wantAssertion, _ := NewAssertion("dummyAct", "dummyDom:dummyRes", "ALLOW") + t.wantErr = "" + t.wantRps = make(map[string]interface{}) + t.wantRps["dummyDom:role.dummyRole"] = []*Assertion{wantAssertion} + return t }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - srv := httptest.NewTLSServer(handler) + func() (t test) { + t.name = "fetch failed, no cache, error" - return test{ - name: "test fetch error return not ok", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: gache.New(), - expireMargin: time.Hour, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: "dummyDomain", - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - if sp != nil { - return errors.Errorf("sp should be nil") - } - if upd != false { - return errors.New("Invalid upd flag") - } - wantErr := "error fetching policy data: Error fetching athenz policy" - if err.Error() != wantErr { - return errors.Errorf("invalid error, got: %v, want: %v", err, wantErr) - } - - return nil + // dummy values + domain := "dummyDom" + fetcher := &fetcherMock{ + domainMock: func() string { return domain }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + return nil, errors.New("no cache") }, } - }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("")) - })) - srv := httptest.NewTLSServer(handler) - - return test{ - name: "test fetch error decode policy", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: gache.New(), - expireMargin: time.Hour, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: "dummyDomain", - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - if sp != nil { - return errors.Errorf("sp should be nil") - } - if upd != false { - return errors.New("Invalid upd flag") - } - wantErr := "error decode response: EOF" - if err.Error() != wantErr { - return errors.Errorf("invalid error, got: %v, want: %v", err, wantErr) - } + ctx := context.Background() - return nil - }, + // prepare test + t.args = args{ + ctx: ctx, + g: gache.New(), + f: fetcher, } - }(), - func() test { - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"signedPolicyData":{ - "zmsKeyId":"1", - "expires":"2099-03-21T08:11:18.729Z" - }}`)) - })) - srv := httptest.NewTLSServer(handler) - - return test{ - name: "test fetch verify error", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: gache.New(), - expireMargin: time.Hour, - client: srv.Client(), - pkp: func(e pubkey.AthenzEnv, id string) authcore.Verifier { - return VerifierMock{ - VerifyFunc: func(d, s string) error { - return errors.New("error") - }, - } - }, - }, - args: args{ - ctx: context.Background(), - domain: "dummyDomain", - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - if sp != nil { - return errors.Errorf("sp should be nil") - } - if upd != false { - return errors.New("Invalid upd flag") - } - wantErr := "error verify policy data: error verify signature: error" - if err.Error() != wantErr { - return errors.Errorf("invalid error, got: %v, want: %v", err, wantErr) - } - return nil - }, - } + // want + t.wantErr = "fetch policy fail: no cache" + t.wantRps = make(map[string]interface{}) + return t }(), - func() test { - domain := "dummyDomain" - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotModified) - })) - srv := httptest.NewTLSServer(handler) - cachedSp := &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - ZmsKeyId: "cached-policy", - Expires: &rdl.Timestamp{ - Time: fastime.Now().Add(-1 * time.Hour).UTC(), - }, - }, - }, - } - wantSp := cachedSp - - // old etag cache - etagCac := gache.New() - etagCac.Set(domain, &etagCache{ - etag: "\"dummyOldEtag\"", - sp: cachedSp, - }) - - return test{ - name: "test policy already expired (304), no expiry checking, return expired policy", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: etagCac, - expireMargin: time.Hour, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: domain, - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - // check return values - if err != nil { - return err - } - if !reflect.DeepEqual(sp, wantSp) { - return fmt.Errorf("SignedPolicy got: %+v, want: %+v", *sp.DomainSignedPolicyData.SignedPolicyData, *wantSp.DomainSignedPolicyData.SignedPolicyData) - } - if upd != false { - return errors.New("upd should be false") - } - if fastime.Now().Before(sp.SignedPolicyData.Expires.Time) { - // strange behavior - return errors.New("returned policy should be expired") - } - - // check etag cache values - etagCac, ok := p.etagCache.Get(domain) - if !ok { - return errors.New("etag cache should be found") - } - // check policy same - gotCachedSp := etagCac.(*etagCache).sp - if gotCachedSp != wantSp { - return fmt.Errorf("etag cache SignedPolicy got: %+v, want: %+v", *gotCachedSp.DomainSignedPolicyData.SignedPolicyData, *wantSp.DomainSignedPolicyData.SignedPolicyData) - } + func() (t test) { + t.name = "fetch failed, but with cached policy, update cache" - return nil - }, - } - }(), - func() test { - domain := "dummyDomain" - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("ETag", "\"dummyNewEtag\"") - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"signedPolicyData":{ - "zmsKeyId": "zmsKeyId-137" - }}`)) - })) - srv := httptest.NewTLSServer(handler) - cachedSp := &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - ZmsKeyId: "zmsKeyId-144", - }, + // dummy values + domain := "dummyDom" + sp := createDummySp() + fetcher := &fetcherMock{ + domainMock: func() string { return domain }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + return sp, errors.New("something wrong, return cache") }, } - wantSp := cachedSp - - // old etag cache, to confirm update - etagCac := gache.New() - etagCac.Set(domain, &etagCache{ - etag: "\"dummyOldEtag\"", - sp: cachedSp, - }) - - return test{ - name: "test policy without expiry, keep etagCache, return error", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: etagCac, - expireMargin: time.Hour, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: domain, - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - // check return values - wantError := "error verify policy data: policy without expiry" - if !strings.HasPrefix(err.Error(), wantError) { - return fmt.Errorf("err got: %v, want: %v", err.Error(), wantError) - } - if sp != nil { - return errors.New("sp should be nil") - } - if upd != false { - return errors.New("upd should be false") - } - - // check etag cache values - etagCac, _, ok := p.etagCache.GetWithExpire(domain) - if !ok { - return errors.New("etag cache should be found") - } - // check etag - wantEtag := "\"dummyOldEtag\"" - gotEtag := etagCac.(*etagCache).etag - if gotEtag != wantEtag { - return fmt.Errorf("etag got: %v, want: %v", gotEtag, wantEtag) - } - // check policy equal - gotCachedSp := etagCac.(*etagCache).sp - if !reflect.DeepEqual(gotCachedSp, wantSp) { - return fmt.Errorf("etag cache SignedPolicy got: %+v, want: %+v", *gotCachedSp.DomainSignedPolicyData.SignedPolicyData, *wantSp.DomainSignedPolicyData.SignedPolicyData) - } + ctx := context.Background() - return nil - }, - } - }(), - func() test { - domain := "dummyDomain" - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("ETag", "\"dummyNewEtag\"") - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(http.StatusOK) - // only customized format works - w.Write([]byte(`{"signedPolicyData":{ - "zmsKeyId": "zmsKeyId-235", - "expires":"2099-12-31" - }}`)) - })) - srv := httptest.NewTLSServer(handler) - cachedSp := &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - ZmsKeyId: "zmsKeyId-243", - }, - }, + // prepare test + t.args = args{ + ctx: ctx, + g: gache.New(), + f: fetcher, } - wantSp := cachedSp - - // old etag cache, to confirm delete - etagCac := gache.New() - etagCac.Set(domain, &etagCache{ - etag: "\"dummyOldEtag\"", - sp: cachedSp, - }) - - return test{ - name: "test policy with invalid expiry, keep etagCache, return error", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: etagCac, - expireMargin: time.Minute, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: domain, - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - // check return values - wantError := "error verify policy data: policy already expired" - if !strings.HasPrefix(err.Error(), wantError) { - return fmt.Errorf("err got: %v, want: %v", err.Error(), wantError) - } - if sp != nil { - return errors.New("sp should be nil") - } - if upd != false { - return errors.New("upd should be false") - } - - // check etag cache values - etagCac, _, ok := p.etagCache.GetWithExpire(domain) - if !ok { - return errors.New("etag cache should be found") - } - // check etag - wantEtag := "\"dummyOldEtag\"" - gotEtag := etagCac.(*etagCache).etag - if gotEtag != wantEtag { - return fmt.Errorf("etag got: %v, want: %v", gotEtag, wantEtag) - } - // check policy equal - gotCachedSp := etagCac.(*etagCache).sp - if !reflect.DeepEqual(gotCachedSp, wantSp) { - return fmt.Errorf("etag cache SignedPolicy got: %+v, want: %+v", *gotCachedSp.DomainSignedPolicyData.SignedPolicyData, *wantSp.DomainSignedPolicyData.SignedPolicyData) - } - return nil - }, - } + // want + wantAssertion, _ := NewAssertion("dummyAct", "dummyDom:dummyRes", "ALLOW") + t.wantErr = "" + t.wantRps = make(map[string]interface{}) + t.wantRps["dummyDom:role.dummyRole"] = []*Assertion{wantAssertion} + return t }(), - func() test { - domain := "dummyDomain" - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tByte, err := rdl.Timestamp{ - Time: fastime.Now().Add(-1 * time.Hour).UTC(), - }.MarshalJSON() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - } else { - w.Header().Set("ETag", "\"dummyNewEtag\"") - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(http.StatusOK) - json := fmt.Sprintf(`{"signedPolicyData":{ - "zmsKeyId": "zmsKeyId-322", - "expires": %v - }}`, string(tByte)) - w.Write([]byte(json)) - } - })) - srv := httptest.NewTLSServer(handler) - cachedSp := &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - ZmsKeyId: "zmsKeyId-332", - }, - }, - } - wantSp := cachedSp - - // old etag cache, to confirm delete - etagCac := gache.New() - etagCac.Set(domain, &etagCache{ - etag: "\"dummyOldEtag\"", - sp: cachedSp, - }) - - return test{ - name: "test policy already expired, expiry check fail, keep etagCache, return error", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: etagCac, - expireMargin: time.Minute, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: domain, - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - // check return values - wantError := "error verify policy data: policy already expired" - if !strings.HasPrefix(err.Error(), wantError) { - return fmt.Errorf("err got: %v, want: %v", err.Error(), wantError) - } - if sp != nil { - return errors.New("sp should be nil") - } - if upd != false { - return errors.New("upd should be false") - } - - // check etag cache values - etagCac, _, ok := p.etagCache.GetWithExpire(domain) - if !ok { - return errors.New("etag cache should be found") - } - // check etag - wantEtag := "\"dummyOldEtag\"" - gotEtag := etagCac.(*etagCache).etag - if gotEtag != wantEtag { - return fmt.Errorf("etag got: %v, want: %v", gotEtag, wantEtag) - } - // check policy equal - gotCachedSp := etagCac.(*etagCache).sp - if !reflect.DeepEqual(gotCachedSp, wantSp) { - return fmt.Errorf("etag cache SignedPolicy got: %+v, want: %+v", *gotCachedSp.DomainSignedPolicyData.SignedPolicyData, *wantSp.DomainSignedPolicyData.SignedPolicyData) - } + func() (t test) { + t.name = "simplifyAndCache failed, error" - return nil - }, - } - }(), - func() test { - domain := "dummyDomain" - policyExpires := rdl.Timestamp{ - Time: fastime.Now().Add(1 * time.Hour).Truncate(time.Millisecond).UTC(), - } - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tByte, err := policyExpires.MarshalJSON() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - } else { - w.Header().Set("ETag", "\"dummyNewEtag\"") - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(http.StatusOK) - json := fmt.Sprintf(`{"signedPolicyData":{ - "zmsKeyId": "zmsKeyId-403", - "expires": %v - }}`, string(tByte)) - w.Write([]byte(json)) - } - })) - srv := httptest.NewTLSServer(handler) - cachedSp := &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - ZmsKeyId: "zmsKeyId-414", - }, + // dummy values + domain := "dummyDom" + sp := createDummySp() + fetcher := &fetcherMock{ + domainMock: func() string { return domain }, + fetchWithRetryMock: func(context.Context) (*SignedPolicy, error) { + return sp, nil }, } - wantSp := &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - ZmsKeyId: "zmsKeyId-403", - Expires: &policyExpires, - }, - }, - } - - // old etag cache, to confirm delete - etagCac := gache.New() - etagCac.Set(domain, &etagCache{ - etag: "\"dummyOldEtag\"", - sp: cachedSp, - }) - - return test{ - name: "test policy expire within the expireMargin, remove etagCache, return policy", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: etagCac, - expireMargin: 2 * time.Hour, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: domain, - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - // check return values - if err != nil { - return err - } - if !reflect.DeepEqual(sp.DomainSignedPolicyData.SignedPolicyData, wantSp.DomainSignedPolicyData.SignedPolicyData) { - return fmt.Errorf("SignedPolicy got: %+v, want: %+v", *sp.DomainSignedPolicyData.SignedPolicyData, *wantSp.DomainSignedPolicyData.SignedPolicyData) - } - if upd != true { - return errors.New("upd should be true") - } - - // check etag cache empty - _, ok := p.etagCache.Get(domain) - if ok { - return errors.New("etag cache should be not found") - } + ctx := context.Background() - return nil - }, - } - }(), - func() test { - domain := "dummyDomain" - policyExpires := rdl.Timestamp{ - Time: fastime.Now().Add(1 * time.Hour).Truncate(time.Millisecond).UTC(), + // prepare test + sp.SignedPolicyData.PolicyData.Policies[0].Assertions[0].Resource = "invalid-resource" + t.args = args{ + ctx: ctx, + g: gache.New(), + f: fetcher, } - expireMargin := 30 * time.Minute - handler := http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tByte, err := policyExpires.MarshalJSON() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - } else { - w.Header().Set("ETag", "\"dummyNewEtag\"") - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(http.StatusOK) - json := fmt.Sprintf(`{"signedPolicyData":{ - "zmsKeyId": "zmsKeyId-482", - "expires": %v - }}`, string(tByte)) - w.Write([]byte(json)) - } - })) - srv := httptest.NewTLSServer(handler) - cachedSp := &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - ZmsKeyId: "zmsKeyId-492", - }, - }, - } - wantSp := &SignedPolicy{ - util.DomainSignedPolicyData{ - SignedPolicyData: &util.SignedPolicyData{ - ZmsKeyId: "zmsKeyId-482", - Expires: &policyExpires, - }, - }, - } - - // old etag cache, to confirm update - etagCac := gache.New() - etagCac.Set(domain, &etagCache{ - etag: "\"dummyOldEtag\"", - sp: cachedSp, - }) - - return test{ - name: "test valid policy (200), set etagCache with (policyExpires - expireMargin)", - fields: fields{ - athenzURL: strings.Replace(srv.URL, "https://", "", 1), - policyExpiredDuration: time.Minute * 30, - etagCache: etagCac, - expireMargin: expireMargin, - client: srv.Client(), - pkp: mockPkp, - }, - args: args{ - ctx: context.Background(), - domain: domain, - }, - checkFunc: func(p *policyd, sp *SignedPolicy, upd bool, err error) error { - // check return values - if err != nil { - return err - } - if !reflect.DeepEqual(sp.DomainSignedPolicyData.SignedPolicyData, wantSp.DomainSignedPolicyData.SignedPolicyData) { - return fmt.Errorf("SignedPolicy got: %+v, want: %+v", *sp.DomainSignedPolicyData.SignedPolicyData, *wantSp.DomainSignedPolicyData.SignedPolicyData) - } - if upd != true { - return errors.New("upd should be true") - } - - // check etag cache values - etagCac, gotExpiry, ok := p.etagCache.GetWithExpire(domain) - if !ok { - return errors.New("etag cache should be found") - } - // check etag - wantEtag := "\"dummyNewEtag\"" - gotEtag := etagCac.(*etagCache).etag - if gotEtag != wantEtag { - return fmt.Errorf("etag got: %v, want: %v", gotEtag, wantEtag) - } - // check policy equal - gotCachedSp := etagCac.(*etagCache).sp - if !reflect.DeepEqual(gotCachedSp, wantSp) { - return fmt.Errorf("etag cache SignedPolicy got: %+v, want: %+v", *gotCachedSp.DomainSignedPolicyData.SignedPolicyData, *wantSp.DomainSignedPolicyData.SignedPolicyData) - } - // // check cache expire time - wantExpiry := wantSp.DomainSignedPolicyData.SignedPolicyData.Expires.UnixNano() - expireMargin.Nanoseconds() - if gotExpiry-wantExpiry > (time.Second * 1).Nanoseconds() { - return fmt.Errorf("etag cache expiry got: %v, want: %v", gotExpiry, wantExpiry) - } - return nil - }, - } + // want + t.wantErr = "simplify and cache policy fail: assertion format not correct: Access denied due to invalid/empty policy resources" + t.wantRps = make(map[string]interface{}) + return t }(), } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p := &policyd{ - expireMargin: tt.fields.expireMargin, - rolePolicies: tt.fields.rolePolicies, - policyExpiredDuration: tt.fields.policyExpiredDuration, - refreshDuration: tt.fields.refreshDuration, - errRetryInterval: tt.fields.errRetryInterval, - pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, - athenzURL: tt.fields.athenzURL, - athenzDomains: tt.fields.athenzDomains, - client: tt.fields.client, + err := fetchAndCachePolicy(tt.args.ctx, tt.args.g, tt.args.f) + if (err == nil && tt.wantErr != "") || (err != nil && err.Error() != tt.wantErr) { + t.Errorf("fetchAndCachePolicy() error = %v, wantErr %v", err, tt.wantErr) + return } - got, got1, err := p.fetchPolicy(tt.args.ctx, tt.args.domain) - - if err := tt.checkFunc(p, got, got1, err); err != nil { - t.Errorf("policy.fetchPolicy() error = %v", err) + gotRps := tt.args.g.ToRawMap(context.Background()) + if !cmp.Equal(gotRps, tt.wantRps, cmpopts.IgnoreFields(Assertion{}, "Reg")) { + t.Errorf("fetchAndCachePolicy() g = %v, want %v", gotRps, tt.wantRps) + t.Errorf("fetchAndCachePolicy() g diff = %s", cmp.Diff(gotRps, tt.wantRps, cmpopts.IgnoreFields(Assertion{}, "Reg"))) } }) } @@ -2699,8 +1947,6 @@ func Test_policyd_GetPolicyCache(t *testing.T) { refreshDuration time.Duration errRetryInterval time.Duration pkp pubkey.Provider - etagCache gache.Gache - etagFlushDur time.Duration athenzURL string athenzDomains []string client *http.Client @@ -2769,8 +2015,6 @@ func Test_policyd_GetPolicyCache(t *testing.T) { refreshDuration: tt.fields.refreshDuration, errRetryInterval: tt.fields.errRetryInterval, pkp: tt.fields.pkp, - etagCache: tt.fields.etagCache, - etagFlushDur: tt.fields.etagFlushDur, athenzURL: tt.fields.athenzURL, athenzDomains: tt.fields.athenzDomains, client: tt.fields.client, diff --git a/policy/fetcher.go b/policy/fetcher.go new file mode 100644 index 00000000..b2fedcf1 --- /dev/null +++ b/policy/fetcher.go @@ -0,0 +1,198 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package policy + +import ( + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "sync/atomic" + "time" + "unsafe" + + "github.com/kpango/fastime" + "github.com/kpango/glg" + "github.com/pkg/errors" +) + +// SignedPolicyVerifier type defines the function signature to verify a signed policy. +type SignedPolicyVerifier func(*SignedPolicy) error + +// Fetcher represents fetcher object for fetching signed policy +type Fetcher interface { + Domain() string + Fetch(context.Context) (*SignedPolicy, error) + FetchWithRetry(context.Context) (*SignedPolicy, error) +} + +type fetcher struct { + + // etag related + expireMargin time.Duration + + // retry related + retryInterval time.Duration + retryMaxCount int + + // athenz related + domain string + athenzURL string + spVerifier SignedPolicyVerifier + + client *http.Client + policyCache unsafe.Pointer +} + +type taggedPolicy struct { + etag string + etagExpiry time.Time + sp *SignedPolicy + ctime time.Time +} + +// Domain returns the fetcher domain +func (f *fetcher) Domain() string { + return f.domain +} + +// Fetch fetches the policy. When calling concurrently, it is not guarantee that the cache will always have the latest version. +func (f *fetcher) Fetch(ctx context.Context) (*SignedPolicy, error) { + glg.Infof("will fetch policy for domain: %s", f.domain) + // https://{www.athenz.com/zts/v1}/domain/{athenz domain}/signed_policy_data + url := fmt.Sprintf("https://%s/domain/%s/signed_policy_data", f.athenzURL, f.domain) + + glg.Debugf("will fetch policy from url: %s", url) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + errMsg := "create fetch policy request fail" + glg.Errorf("%s, domain: %s, error: %v", errMsg, f.domain, err) + return nil, errors.Wrap(err, errMsg) + } + + // etag header + var tp *taggedPolicy + if f.policyCache != nil { + tp = (*taggedPolicy)(atomic.LoadPointer(&f.policyCache)) + if tp.etag != "" && tp.etagExpiry.After(fastime.Now()) { + glg.Debugf("request on domain: %s, with etag: %s", f.domain, tp.etag) + req.Header.Set("If-None-Match", tp.etag) + } + } + + res, err := f.client.Do(req.WithContext(ctx)) + if err != nil { + errMsg := "fetch policy HTTP request fail" + glg.Errorf("%s, domain: %s, error: %v", errMsg, f.domain, err) + return nil, errors.Wrap(err, errMsg) + } + defer func() { + if err := flushAndClose(res.Body); err != nil { + glg.Warn(errors.Wrap(err, "close Response.Body fail")) + } + }() + + // if server responses NotModified, return policy from cache + if res.StatusCode == http.StatusNotModified { + glg.Debugf("policy = 304 not modified, use cache for domain: %s, etag: %v", f.domain, tp.etag) + return tp.sp, nil + } + + if res.StatusCode != http.StatusOK { + errMsg := "fetch policy HTTP response != 200 OK" + glg.Errorf("%s, domain: %s, status: %d", errMsg, f.domain, res.StatusCode) + return nil, errors.Wrap(ErrFetchPolicy, errMsg) + } + + // read and decode + sp := new(SignedPolicy) + if err = json.NewDecoder(res.Body).Decode(&sp); err != nil { + errMsg := "policy decode fail" + glg.Errorf("%s, domain: %s, error: %v", errMsg, f.domain, err) + return nil, errors.Wrap(err, errMsg) + } + + // verify policy data + if err = f.spVerifier(sp); err != nil { + errMsg := "invalid policy" + glg.Errorf("%s, domain: %s, error: %v", errMsg, f.domain, err) + return nil, errors.Wrap(err, errMsg) + } + + // set policy cache + etag := res.Header.Get("ETag") + etagExpiry := sp.SignedPolicyData.Expires.Time.Add(-f.expireMargin) + newTp := &taggedPolicy{ + etag: etag, + etagExpiry: etagExpiry, + sp: sp, + ctime: fastime.Now(), + } + glg.Debugf("set policy cache for domain: %s, policy: %s", f.domain, newTp) + atomic.StorePointer(&f.policyCache, unsafe.Pointer(newTp)) + + return sp, nil +} + +// FetchWithRetry fetches policy with retry. Returns cached policy if all retries failed too. +func (f *fetcher) FetchWithRetry(ctx context.Context) (*SignedPolicy, error) { + var lastErr error + for i := -1; i < f.retryMaxCount; i++ { + sp, err := f.Fetch(ctx) + if err == nil { + return sp, nil + } + + lastErr = err + time.Sleep(f.retryInterval) + } + + errMsg := "max. retry count excess" + glg.Info("Will use policy cache, since: %s, domain: %s, error: %v", errMsg, f.domain, lastErr) + if lastErr == nil { + lastErr = fmt.Errorf("retryMaxCount %v", f.retryMaxCount) + } + if f.policyCache == nil { + return nil, errors.Wrap(errors.Wrap(lastErr, errMsg), "no policy cache") + } + return (*taggedPolicy)(atomic.LoadPointer(&f.policyCache)).sp, errors.Wrap(lastErr, errMsg) +} + +func (t *taggedPolicy) String() string { + var policyDomain string + if t.sp != nil && t.sp.SignedPolicyData != nil && t.sp.SignedPolicyData.PolicyData != nil { + policyDomain = t.sp.SignedPolicyData.PolicyData.Domain + } + return fmt.Sprintf("{ ctime: %s, etag: %s, etagExpiry: %s, sp.domain: %s }", t.ctime.UTC().String(), t.etag, t.etagExpiry.UTC().String(), policyDomain) +} + +// flushAndClose helps to flush and close a ReadCloser. Used for request body internal. +// Returns if there is any errors. +func flushAndClose(rc io.ReadCloser) error { + if rc != nil { + // flush + _, err := io.Copy(ioutil.Discard, rc) + if err != nil { + return err + } + // close + return rc.Close() + } + return nil +} diff --git a/policy/fetcher_mock_test.go b/policy/fetcher_mock_test.go new file mode 100644 index 00000000..32ade598 --- /dev/null +++ b/policy/fetcher_mock_test.go @@ -0,0 +1,57 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package policy + +import "context" + +// fetcherMock is the adapter implementation of Fetcher interface for mocking. +type fetcherMock struct { + domainMock func() string + fetchMock func(context.Context) (*SignedPolicy, error) + fetchWithRetryMock func(context.Context) (*SignedPolicy, error) +} + +// Domain is just an adapter. +func (r *fetcherMock) Domain() string { + return r.domainMock() +} + +// Fetch is just an adapter. +func (r *fetcherMock) Fetch(ctx context.Context) (*SignedPolicy, error) { + return r.fetchMock(ctx) +} + +// FetchWithRetry is just an adapter. +func (r *fetcherMock) FetchWithRetry(ctx context.Context) (*SignedPolicy, error) { + return r.fetchWithRetryMock(ctx) +} + +// readCloserMock is the adapter implementation of io.ReadCloser interface for mocking. +type readCloserMock struct { + readMock func(p []byte) (n int, err error) + closeMock func() error +} + +// Read is just an adapter. +func (r *readCloserMock) Read(p []byte) (n int, err error) { + return r.readMock(p) +} + +// Close is just an adapter. +func (r *readCloserMock) Close() error { + return r.closeMock() +} diff --git a/policy/fetcher_test.go b/policy/fetcher_test.go new file mode 100644 index 00000000..5a48691b --- /dev/null +++ b/policy/fetcher_test.go @@ -0,0 +1,1337 @@ +/* +Copyright (C) 2018 Yahoo Japan Corporation Athenz team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package policy + +import ( + "context" + "fmt" + "io" + "math" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "sync/atomic" + "testing" + "time" + "unsafe" + + "github.com/ardielle/ardielle-go/rdl" + "github.com/kpango/fastime" + "github.com/pkg/errors" + authcore "github.com/yahoo/athenz/libs/go/zmssvctoken" + "github.com/yahoo/athenz/utils/zpe-updater/util" + "github.com/yahoojapan/athenz-authorizer/v2/pubkey" +) + +func Test_flushAndClose(t *testing.T) { + type args struct { + readCloser io.ReadCloser + } + type testcase struct { + name string + args args + wantError error + } + tests := []testcase{ + { + name: "Check flushAndClose, readCloser is nil", + args: args{ + readCloser: nil, + }, + wantError: nil, + }, + { + name: "Check flushAndClose, flush & close success", + args: args{ + readCloser: &readCloserMock{ + readMock: func(p []byte) (n int, err error) { + return 0, io.EOF + }, + closeMock: func() error { + return nil + }, + }, + }, + wantError: nil, + }, + { + name: "Check flushAndClose, flush fail", + args: args{ + readCloser: &readCloserMock{ + readMock: func(p []byte) (n int, err error) { + return 0, fmt.Errorf("read-error-1332") + }, + closeMock: func() error { + return nil + }, + }, + }, + wantError: fmt.Errorf("read-error-1332"), + }, + { + name: "Check flushAndClose, close fail", + args: args{ + readCloser: &readCloserMock{ + readMock: func(p []byte) (n int, err error) { + return 0, io.EOF + }, + closeMock: func() error { + return fmt.Errorf("close-error-1349") + }, + }, + }, + wantError: fmt.Errorf("close-error-1349"), + }, + { + name: "Check flushAndClose, flush & close fail", + args: args{ + readCloser: &readCloserMock{ + readMock: func(p []byte) (n int, err error) { + return 0, fmt.Errorf("read-error-1360") + }, + closeMock: func() error { + return fmt.Errorf("close-error-1363") + }, + }, + }, + wantError: fmt.Errorf("read-error-1360"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotError := flushAndClose(tt.args.readCloser) + if !reflect.DeepEqual(gotError, tt.wantError) { + t.Errorf("flushAndClose() error = %v, want %v", gotError, tt.wantError) + } + }) + } +} + +func Test_fetcher_Domain(t *testing.T) { + type fields struct { + expireMargin time.Duration + retryInterval time.Duration + retryMaxCount int + domain string + athenzURL string + spVerifier SignedPolicyVerifier + client *http.Client + policyCache unsafe.Pointer + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "get domain success", + fields: fields{ + domain: "domain-217", + }, + want: "domain-217", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &fetcher{ + expireMargin: tt.fields.expireMargin, + retryInterval: tt.fields.retryInterval, + retryMaxCount: tt.fields.retryMaxCount, + domain: tt.fields.domain, + athenzURL: tt.fields.athenzURL, + spVerifier: tt.fields.spVerifier, + client: tt.fields.client, + policyCache: tt.fields.policyCache, + } + if got := f.Domain(); got != tt.want { + t.Errorf("fetcher.Domain() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_fetcher_Fetch(t *testing.T) { + type fields struct { + expireMargin time.Duration + retryInterval time.Duration + retryMaxCount int + domain string + athenzURL string + spVerifier SignedPolicyVerifier + client *http.Client + policyCache unsafe.Pointer + } + type args struct { + ctx context.Context + } + type test struct { + name string + fields fields + args args + want *SignedPolicy + wantPolicyCache *taggedPolicy + wantErrStr string + } + dummySignedPolicyVerifier := func(sp *SignedPolicy) error { + return sp.Verify(func(e pubkey.AthenzEnv, id string) authcore.Verifier { + return VerifierMock{ + VerifyFunc: func(d, s string) error { + return nil + }, + } + }) + } + createTestServer := func(hf http.HandlerFunc) (*httptest.Server, *http.Client, string) { + srv := httptest.NewTLSServer(hf) + return srv, srv.Client(), strings.Replace(srv.URL, "https://", "", 1) + } + createExpires := func(d time.Duration) (time.Time, string, error) { + t := fastime.Now().Add(d).UTC().Round(time.Millisecond) + tByte, err := rdl.Timestamp{ + Time: t, + }.MarshalJSON() + return t, string(tByte), err + } + compareTaggedPolicy := func(a, b *taggedPolicy) error { + if a == b { + return nil + } + if a.etag != b.etag { + return errors.New("etag") + } + if a.etagExpiry != b.etagExpiry { + return errors.New("etagExpiry") + } + if !reflect.DeepEqual(a.sp, b.sp) { + return errors.New("sp") + } + if time.Duration(math.Abs(float64(a.ctime.Sub(b.ctime)))) > time.Second { + return errors.New("ctime") + } + return nil + } + tests := []test{ + func() (t test) { + t.name = "success, no cache" + + // http response + domain := "dummyDomain" + expireMargin := time.Hour + etag := `"dummyEtag"` + zmsKeyID := "dummyZmsKeyId" + expires, expiresStr, err := createExpires(2 * expireMargin) + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/domain/dummyDomain/signed_policy_data" { + w.WriteHeader(http.StatusInternalServerError) + return + } + if r.Header.Get("If-None-Match") != "" { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } + + w.Header().Add("ETag", etag) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"signedPolicyData":{ + "zmsKeyId": "%s", + "expires": %s + }}`, zmsKeyID, expiresStr))) + }) + + // want objects + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{Time: expires}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "dummyZmsKeyId", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: `"dummyEtag"`, + etagExpiry: expires.Add(-expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "" + + // test input + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: time.Second, + retryMaxCount: 3, + domain: domain, + athenzURL: url, + spVerifier: dummySignedPolicyVerifier, + client: client, + // policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "success, no etag" + + // http response + domain := "dummyDomain" + expireMargin := time.Hour + etag := `"dummyEtag"` + zmsKeyID := "dummyZmsKeyId" + expires, expiresStr, err := createExpires(2 * expireMargin) + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/domain/dummyDomain/signed_policy_data" { + w.WriteHeader(http.StatusInternalServerError) + return + } + if r.Header.Get("If-None-Match") != "" { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } + + w.Header().Add("ETag", etag) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"signedPolicyData":{ + "zmsKeyId": "%s", + "expires": %s + }}`, zmsKeyID, expiresStr))) + }) + + // want objects + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{Time: expires}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "dummyZmsKeyId", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: `"dummyEtag"`, + etagExpiry: expires.Add(-expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "" + + // test input + policyCache := unsafe.Pointer(&taggedPolicy{}) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: time.Second, + retryMaxCount: 3, + domain: domain, + athenzURL: url, + spVerifier: dummySignedPolicyVerifier, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "success, etag with 200" + + // http response + domain := "dummyDomain" + expireMargin := time.Hour + etag := `"dummyEtag"` + zmsKeyID := "dummyZmsKeyId" + expires, expiresStr, err := createExpires(2 * expireMargin) + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/domain/dummyDomain/signed_policy_data" { + w.WriteHeader(http.StatusInternalServerError) + return + } + if r.Header.Get("If-None-Match") != `"dummyEtag"` { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } + + w.Header().Add("ETag", "dummyNewEtag") + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"signedPolicyData":{ + "zmsKeyId": "%s", + "expires": %s + }}`, zmsKeyID, expiresStr))) + }) + + // want objects + wantEtag := "dummyNewEtag" + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{Time: expires}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "dummyZmsKeyId", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: wantEtag, + etagExpiry: expires.Add(-expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "" + + // test input + policyCache := unsafe.Pointer(&taggedPolicy{ + etag: etag, + etagExpiry: expires.Add(-expireMargin), + }) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: time.Second, + retryMaxCount: 3, + domain: domain, + athenzURL: url, + spVerifier: dummySignedPolicyVerifier, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "success, etag with 304" + + // http response + domain := "dummyDomain" + expireMargin := time.Hour + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/domain/dummyDomain/signed_policy_data" { + w.WriteHeader(http.StatusInternalServerError) + return + } + if r.Header.Get("If-None-Match") == `"dummyEtag"` { + w.WriteHeader(http.StatusNotModified) + return + } + + w.WriteHeader(http.StatusInternalServerError) + }) + + // want objects + expires := fastime.Now().Add(2 * expireMargin) + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{Time: expires}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "dummyZmsKeyId", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: `"dummyEtag"`, + etagExpiry: expires.Add(-expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "" + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: time.Second, + retryMaxCount: 3, + domain: domain, + athenzURL: url, + spVerifier: dummySignedPolicyVerifier, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "success, etag expiry passed, request without etag" + + // http response + domain := "dummyDomain" + expireMargin := time.Hour + etag := `"dummyEtag"` + zmsKeyID := "dummyZmsKeyId" + expires, expiresStr, err := createExpires(2 * expireMargin) + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/domain/dummyDomain/signed_policy_data" { + w.WriteHeader(http.StatusInternalServerError) + return + } + if r.Header.Get("If-None-Match") != "" { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } + + w.Header().Add("ETag", etag) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"signedPolicyData":{ + "zmsKeyId": "%s", + "expires": %s + }}`, zmsKeyID, expiresStr))) + }) + + // want objects + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{Time: expires}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "dummyZmsKeyId", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: `"dummyEtag"`, + etagExpiry: expires.Add(-expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "" + + // test input + policyCache := unsafe.Pointer(&taggedPolicy{ + etag: "dummyOldEtag", + etagExpiry: fastime.Now().Add(-expireMargin), + sp: nil, + }) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: time.Second, + retryMaxCount: 3, + domain: domain, + athenzURL: url, + spVerifier: dummySignedPolicyVerifier, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "success, on 304, return cached policy even if expired" + + // http response + domain := "dummyDomain" + expireMargin := time.Hour + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/domain/dummyDomain/signed_policy_data" { + w.WriteHeader(http.StatusInternalServerError) + return + } + if r.Header.Get("If-None-Match") == `"dummyEtag"` { + w.WriteHeader(http.StatusNotModified) + return + } + + w.WriteHeader(http.StatusInternalServerError) + }) + + // want objects + expires := fastime.Now().Add(-expireMargin) + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{Time: expires}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "dummyZmsKeyId", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: `"dummyEtag"`, + etagExpiry: fastime.Now().Add(expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "" + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: time.Second, + retryMaxCount: 3, + domain: domain, + athenzURL: url, + spVerifier: dummySignedPolicyVerifier, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "fail, url error" + + // http response + domain := "dummyDomain" + + // want objects + t.want = nil + t.wantPolicyCache = &taggedPolicy{ctime: fastime.Now()} + t.wantErrStr = `create fetch policy request fail: parse https:// /domain/dummyDomain/signed_policy_data: invalid character " " in host name` + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + domain: domain, + athenzURL: " ", + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "fail, request error" + + // http response + domain := "dummyDomain" + _, client, _ := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // want objects + t.want = nil + t.wantPolicyCache = &taggedPolicy{ctime: fastime.Now()} + t.wantErrStr = `fetch policy HTTP request fail: Get https://127.0.0.1/api/domain/dummyDomain/signed_policy_data: dial tcp 127.0.0.1:443: connect: connection refused` + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + domain: domain, + athenzURL: "127.0.0.1/api", + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "fail, server error" + + // http response + domain := "dummyDomain" + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + + // want objects + t.want = nil + t.wantPolicyCache = &taggedPolicy{ctime: fastime.Now()} + t.wantErrStr = `fetch policy HTTP response != 200 OK: Error fetching athenz policy` + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + domain: domain, + athenzURL: url, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "fail, policy decode error" + + // http response + domain := "dummyDomain" + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("")) + }) + + // want objects + t.want = nil + t.wantPolicyCache = &taggedPolicy{ctime: fastime.Now()} + t.wantErrStr = `policy decode fail: EOF` + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + domain: domain, + athenzURL: url, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "fail, policy verify error" + + // http response + domain := "dummyDomain" + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("{}")) + }) + + // want objects + t.want = nil + t.wantPolicyCache = &taggedPolicy{ctime: fastime.Now()} + t.wantErrStr = `invalid policy: dummy policy verify error` + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + domain: domain, + athenzURL: url, + spVerifier: func(sp *SignedPolicy) error { + return errors.New("dummy policy verify error") + }, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "fail, policy verify error, null expires" + + // http response + domain := "dummyDomain" + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"signedPolicyData":{}}`)) + }) + + // want objects + t.want = nil + t.wantPolicyCache = &taggedPolicy{ctime: fastime.Now()} + t.wantErrStr = `invalid policy: policy without expiry` + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + domain: domain, + athenzURL: url, + spVerifier: dummySignedPolicyVerifier, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "fail, policy verify error, invalid expires" + + // http response + domain := "dummyDomain" + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"signedPolicyData":{"expires":"2099-12-31"}}`)) + }) + + // want objects + t.want = nil + t.wantPolicyCache = &taggedPolicy{ctime: fastime.Now()} + t.wantErrStr = `invalid policy: policy already expired at 0001-01-01 00:00:00 +0000 UTC` + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + domain: domain, + athenzURL: url, + spVerifier: dummySignedPolicyVerifier, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "fail, policy verify error, expired policy" + + // http response + domain := "dummyDomain" + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"signedPolicyData":{"expires":"2006-01-02T15:04:05.999Z"}}`)) + }) + + // want objects + t.want = nil + t.wantPolicyCache = &taggedPolicy{ctime: fastime.Now()} + t.wantErrStr = `invalid policy: policy already expired at 2006-01-02 15:04:05.999 +0000 UTC` + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + domain: domain, + athenzURL: url, + spVerifier: dummySignedPolicyVerifier, + client: client, + policyCache: policyCache, + } + + return t + }(), + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &fetcher{ + expireMargin: tt.fields.expireMargin, + retryInterval: tt.fields.retryInterval, + retryMaxCount: tt.fields.retryMaxCount, + domain: tt.fields.domain, + athenzURL: tt.fields.athenzURL, + spVerifier: tt.fields.spVerifier, + client: tt.fields.client, + policyCache: tt.fields.policyCache, + } + got, err := f.Fetch(tt.args.ctx) + if (err == nil && tt.wantErrStr != "") || (err != nil && err.Error() != tt.wantErrStr) { + t.Errorf("fetcher.Fetch() error = %v, wantErr %v", err, tt.wantErrStr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("fetcher.Fetch() = %v, want %v", got.SignedPolicyData, tt.want.SignedPolicyData) + return + } + gotPolicyCache := (*taggedPolicy)(f.policyCache) + if err = compareTaggedPolicy(gotPolicyCache, tt.wantPolicyCache); err != nil { + t.Errorf("fetcher.Fetch() policyCache = %v, want %v, error %v", gotPolicyCache, tt.wantPolicyCache, err) + return + } + }) + } +} + +func Test_fetcher_FetchWithRetry(t *testing.T) { + type fields struct { + expireMargin time.Duration + retryInterval time.Duration + retryMaxCount int + domain string + athenzURL string + spVerifier SignedPolicyVerifier + client *http.Client + policyCache unsafe.Pointer + } + type args struct { + ctx context.Context + } + type test struct { + name string + fields fields + args args + want *SignedPolicy + wantPolicyCache *taggedPolicy + wantErrStr string + cmpTP func(a, b *taggedPolicy) error + } + createTestServer := func(hf http.HandlerFunc) (*httptest.Server, *http.Client, string) { + srv := httptest.NewTLSServer(hf) + return srv, srv.Client(), strings.Replace(srv.URL, "https://", "", 1) + } + compareTaggedPolicy := func(a, b *taggedPolicy) error { + if a == b { + return nil + } + if a.etag != b.etag { + return errors.New("etag") + } + if a.etagExpiry != b.etagExpiry { + return errors.New("etagExpiry") + } + if !reflect.DeepEqual(a.sp, b.sp) { + return errors.New("sp") + } + if time.Duration(math.Abs(float64(a.ctime.Sub(b.ctime)))) > time.Second { + return errors.New("ctime") + } + return nil + } + tests := []test{ + func() (t test) { + t.name = "success, no retry" + var requestCount uint32 + + // HTTP response + expireMargin := time.Hour + retryInterval := time.Minute + retryMaxCount := 0 + keyID := "keyId" + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("ETag", fmt.Sprintf(`"dummyEtag%d"`, atomic.AddUint32(&requestCount, 1))) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"keyId":"%v","signedPolicyData":{"expires":""}}`, keyID))) + }) + + // want objects + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "keyId", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: `"dummyEtag1"`, + etagExpiry: time.Time{}.Add(-expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "" + t.cmpTP = compareTaggedPolicy + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: retryInterval, + retryMaxCount: retryMaxCount, + domain: "dummyDomain", + athenzURL: url, + spVerifier: func(sp *SignedPolicy) error { return nil }, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "success, after retry" + var requestCount uint32 + + // HTTP response + expireMargin := time.Hour + retryInterval := 100 * time.Millisecond + retryMaxCount := 2 + keyID := "keyId" + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + rc := atomic.AddUint32(&requestCount, 1) + if rc < 3 { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("ETag", fmt.Sprintf(`"dummyEtag%d"`, rc)) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"keyId":"%v","signedPolicyData":{"expires":""}}`, keyID))) + }) + + // want objects + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "keyId", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: `"dummyEtag3"`, + etagExpiry: time.Time{}.Add(-expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "" + t.cmpTP = func(a, b *taggedPolicy) error { + err := compareTaggedPolicy(a, b) + if err != nil { + return err + } + + // check retry interval + diff := a.ctime.Sub(b.ctime) + if diff < retryInterval*time.Duration(retryMaxCount) || diff > retryInterval*time.Duration(retryMaxCount+1) { + return errors.New("retry interval not working") + } + return nil + } + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: retryInterval, + retryMaxCount: retryMaxCount, + domain: "dummyDomain", + athenzURL: url, + spVerifier: func(sp *SignedPolicy) error { return nil }, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "all fail, no policy cache" + + // HTTP response + expireMargin := time.Hour + retryInterval := time.Millisecond + retryMaxCount := 2 + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + + // want objects + t.want = nil + t.wantPolicyCache = nil + t.wantErrStr = "no policy cache: max. retry count excess: fetch policy HTTP response != 200 OK: Error fetching athenz policy" + t.cmpTP = compareTaggedPolicy + + // test input + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: retryInterval, + retryMaxCount: retryMaxCount, + domain: "dummyDomain", + athenzURL: url, + spVerifier: func(sp *SignedPolicy) error { return nil }, + client: client, + // policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "all fail, return cached policy" + + // HTTP response + expireMargin := time.Hour + retryInterval := time.Millisecond + retryMaxCount := 2 + _, client, url := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + + // want objects + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "keyId", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: `"dummyEtag"`, + etagExpiry: time.Time{}.Add(-expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "max. retry count excess: fetch policy HTTP response != 200 OK: Error fetching athenz policy" + t.cmpTP = compareTaggedPolicy + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: retryInterval, + retryMaxCount: retryMaxCount, + domain: "dummyDomain", + athenzURL: url, + spVerifier: func(sp *SignedPolicy) error { return nil }, + client: client, + policyCache: policyCache, + } + + return t + }(), + func() (t test) { + t.name = "retryMaxCount < 0" + + // HTTP response + expireMargin := time.Hour + retryInterval := time.Millisecond + retryMaxCount := -1 + + // want objects + sp := &SignedPolicy{ + util.DomainSignedPolicyData{ + KeyId: "keyId", + Signature: "", + SignedPolicyData: &util.SignedPolicyData{ + Expires: &rdl.Timestamp{}, + Modified: nil, + PolicyData: nil, + ZmsKeyId: "", + ZmsSignature: "", + }, + }, + } + t.want = sp + t.wantPolicyCache = &taggedPolicy{ + etag: `"dummyEtag"`, + etagExpiry: time.Time{}.Add(-expireMargin), + sp: sp, + ctime: fastime.Now(), + } + t.wantErrStr = "max. retry count excess: retryMaxCount -1" + t.cmpTP = compareTaggedPolicy + + // test input + policyCache := unsafe.Pointer(t.wantPolicyCache) + t.args = args{ + ctx: context.Background(), + } + t.fields = fields{ + expireMargin: expireMargin, + retryInterval: retryInterval, + retryMaxCount: retryMaxCount, + domain: "dummyDomain", + spVerifier: func(sp *SignedPolicy) error { return nil }, + policyCache: policyCache, + } + + return t + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &fetcher{ + expireMargin: tt.fields.expireMargin, + retryInterval: tt.fields.retryInterval, + retryMaxCount: tt.fields.retryMaxCount, + domain: tt.fields.domain, + athenzURL: tt.fields.athenzURL, + spVerifier: tt.fields.spVerifier, + client: tt.fields.client, + policyCache: tt.fields.policyCache, + } + got, err := f.FetchWithRetry(tt.args.ctx) + if (err == nil && tt.wantErrStr != "") || (err != nil && err.Error() != tt.wantErrStr) { + t.Errorf("fetcher.FetchWithRetry() error = %v, wantErr %v", err, tt.wantErrStr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("fetcher.FetchWithRetry() = %v, want %v", got, tt.want) + } + gotPolicyCache := (*taggedPolicy)(f.policyCache) + if err = tt.cmpTP(gotPolicyCache, tt.wantPolicyCache); err != nil { + t.Errorf("fetcher.FetchWithRetry() policyCache = %v, want %v, error %v", gotPolicyCache, tt.wantPolicyCache, err) + return + } + }) + } + +} + +func Test_taggedPolicy_String(t *testing.T) { + type fields struct { + etag string + etagExpiry time.Time + sp *SignedPolicy + ctime time.Time + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "default value", + fields: fields{}, + want: `{ ctime: 0001-01-01 00:00:00 +0000 UTC, etag: , etagExpiry: 0001-01-01 00:00:00 +0000 UTC, sp.domain: }`, + }, + { + name: "custom value", + fields: fields{ + etag: `"etag"`, + etagExpiry: time.Unix(1567454350, 167000000), + ctime: time.Unix(1566454350, 167000000), + // sp: &SignedPolicy{}, + }, + want: `{ ctime: 2019-08-22 06:12:30.167 +0000 UTC, etag: "etag", etagExpiry: 2019-09-02 19:59:10.167 +0000 UTC, sp.domain: }`, + }, + { + name: "policy without data", + fields: fields{ + etag: `"etag"`, + sp: &SignedPolicy{ + DomainSignedPolicyData: util.DomainSignedPolicyData{ + SignedPolicyData: &util.SignedPolicyData{ + PolicyData: nil, + }, + }, + }, + }, + want: `{ ctime: 0001-01-01 00:00:00 +0000 UTC, etag: "etag", etagExpiry: 0001-01-01 00:00:00 +0000 UTC, sp.domain: }`, + }, + { + name: "policy with data", + fields: fields{ + etag: `"etag"`, + sp: &SignedPolicy{ + DomainSignedPolicyData: util.DomainSignedPolicyData{ + SignedPolicyData: &util.SignedPolicyData{ + PolicyData: &util.PolicyData{Domain: "domain"}, + }, + }, + }, + }, + want: `{ ctime: 0001-01-01 00:00:00 +0000 UTC, etag: "etag", etagExpiry: 0001-01-01 00:00:00 +0000 UTC, sp.domain: domain }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tp := &taggedPolicy{ + etag: tt.fields.etag, + etagExpiry: tt.fields.etagExpiry, + sp: tt.fields.sp, + ctime: tt.fields.ctime, + } + if got := tp.String(); got != tt.want { + t.Errorf("taggedPolicy.String() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/policy/option.go b/policy/option.go index 555b7546..b6f0ffff 100644 --- a/policy/option.go +++ b/policy/option.go @@ -28,7 +28,6 @@ import ( var ( defaultOptions = []Option{ WithExpireMargin("3h"), - WithEtagFlushDuration("12h"), WithPolicyExpiredDuration("1m"), WithRefreshDuration("30m"), WithErrRetryInterval("1m"), @@ -39,21 +38,6 @@ var ( // Option represents a functional option type Option func(*policyd) error -// WithEtagFlushDuration returns an ETagFlushDur functional option -func WithEtagFlushDuration(t string) Option { - return func(pol *policyd) error { - if t == "" { - return nil - } - etagFlushDur, err := time.ParseDuration(t) - if err != nil { - return errors.Wrap(err, "invalid flush duration") - } - pol.etagFlushDur = etagFlushDur - return nil - } -} - // WithExpireMargin returns an ExpiryMargin functional option func WithExpireMargin(t string) Option { return func(pol *policyd) error { @@ -100,7 +84,7 @@ func WithPolicyExpiredDuration(t string) Option { } rd, err := time.ParseDuration(t) if err != nil { - return errors.Wrap(err, "invalid refresh duration") + return errors.Wrap(err, "invalid flush duration") } pol.policyExpiredDuration = rd return nil diff --git a/policy/option_test.go b/policy/option_test.go index 283a874e..079354e3 100644 --- a/policy/option_test.go +++ b/policy/option_test.go @@ -27,72 +27,6 @@ import ( "github.com/yahoojapan/athenz-authorizer/v2/pubkey" ) -func TestWithEtagFlushDur(t *testing.T) { - type args struct { - t string - } - tests := []struct { - name string - args args - checkFunc func(Option) error - }{ - { - name: "set success", - args: args{ - "1h", - }, - checkFunc: func(opt Option) error { - pol := &policyd{} - if err := opt(pol); err != nil { - return err - } - if pol.etagFlushDur != time.Hour { - return fmt.Errorf("Error") - } - - return nil - }, - }, { - name: "invalid format", - args: args{ - "dummy", - }, - checkFunc: func(opt Option) error { - pol := &policyd{} - if err := opt(pol); err == nil { - return fmt.Errorf("expected error, but not return") - } - - return nil - }, - }, - { - name: "empty value", - args: args{ - "", - }, - checkFunc: func(opt Option) error { - pol := &policyd{} - if err := opt(pol); err != nil { - return err - } - if !reflect.DeepEqual(pol, &policyd{}) { - return fmt.Errorf("expected no changes, but got %v", pol) - } - return nil - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := WithEtagFlushDuration(tt.args.t) - if err := tt.checkFunc(got); err != nil { - t.Errorf("WithEtagFlushDur() error = %v", err) - } - }) - } -} - func TestWithExpireMargin(t *testing.T) { type args struct { t string @@ -278,6 +212,72 @@ func TestWithAthenzDomains(t *testing.T) { } } +func TestWithPolicyExpiredDuration(t *testing.T) { + type args struct { + t string + } + tests := []struct { + name string + args args + checkFunc func(Option) error + }{ + { + name: "set success", + args: args{ + "1h", + }, + checkFunc: func(opt Option) error { + pol := &policyd{} + if err := opt(pol); err != nil { + return err + } + if pol.policyExpiredDuration != time.Hour { + return fmt.Errorf("Error") + } + + return nil + }, + }, { + name: "invalid format", + args: args{ + "dummy", + }, + checkFunc: func(opt Option) error { + pol := &policyd{} + if err := opt(pol); err == nil { + return fmt.Errorf("expected error, but not return") + } + + return nil + }, + }, + { + name: "empty value", + args: args{ + "", + }, + checkFunc: func(opt Option) error { + pol := &policyd{} + if err := opt(pol); err != nil { + return err + } + if !reflect.DeepEqual(pol, &policyd{}) { + return fmt.Errorf("expected no changes, but got %v", pol) + } + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithPolicyExpiredDuration(tt.args.t) + if err := tt.checkFunc(got); err != nil { + t.Errorf("WithPolicyExpiredDuration() error = %v", err) + } + }) + } +} + func TestWithRefreshDuration(t *testing.T) { type args struct { t string diff --git a/pubkey/daemon.go b/pubkey/daemon.go index b45a37fe..5f5e5978 100644 --- a/pubkey/daemon.go +++ b/pubkey/daemon.go @@ -106,10 +106,6 @@ func (p *pubkeyd) Start(ctx context.Context) <-chan error { ech := make(chan error, 100) fch := make(chan struct{}, 1) - if err := p.Update(ctx); err != nil { - ech <- errors.Wrap(err, "error update pubkey") - fch <- struct{}{} - } go func() { defer close(fch) diff --git a/pubkey/daemon_test.go b/pubkey/daemon_test.go index 065255a6..bec30f7e 100644 --- a/pubkey/daemon_test.go +++ b/pubkey/daemon_test.go @@ -868,7 +868,7 @@ func Test_pubkeyd_Update(t *testing.T) { } } -func Test_pubkeyd_StartpubkeyUpdater(t *testing.T) { +func Test_pubkeyd_Start(t *testing.T) { type fields struct { refreshDuration time.Duration errRetryInterval time.Duration @@ -928,8 +928,22 @@ func Test_pubkeyd_StartpubkeyUpdater(t *testing.T) { }, checkFunc: func(c *pubkeyd, ch <-chan error) error { cancel() + err := <-ch + + // check error + wantErr := context.Canceled + if err != wantErr { + return fmt.Errorf("got: %v, want: %v", err, wantErr) + } + for err := range ch { + if err != nil { + return err + } + } + err = nil + + // check pubkey cache ind := 0 - var err error checker := func(key interface{}, value interface{}) bool { ind++ valType := fmt.Sprint(reflect.TypeOf(value)) @@ -949,19 +963,23 @@ func Test_pubkeyd_StartpubkeyUpdater(t *testing.T) { } return nil } - err = check(c.confCache.ZMSPubKeys, 1) + err = check(c.confCache.ZMSPubKeys, 0) if err != nil { return err } err = nil ind = 0 - err = check(c.confCache.ZTSPubKeys, 1) + err = check(c.confCache.ZTSPubKeys, 0) if err != nil { return err } - err = nil - ind = 0 + // check etag cache + ecLen := len(c.etagCache.ToRawMap(context.Background())) + wantEcLen := 0 + if ecLen != wantEcLen { + return errors.Errorf("invalid length ZMSPubKeys. got: %d, want: %d", ecLen, wantEcLen) + } c.etagCache.Foreach(context.Background(), func(key string, val interface{}, _ int64) bool { if key != "zms" && key != "zts" { err = errors.Errorf("unexpected key %s", key) @@ -974,6 +992,10 @@ func Test_pubkeyd_StartpubkeyUpdater(t *testing.T) { } return true }) + if err != nil { + return err + } + return nil }, } @@ -999,7 +1021,7 @@ func Test_pubkeyd_StartpubkeyUpdater(t *testing.T) { fields: fields{ athenzURL: strings.Replace(srv.URL, "https://", "", 1), sysAuthDomain: "dummyDom", - refreshDuration: time.Minute, + refreshDuration: 10 * time.Millisecond, errRetryInterval: time.Minute, etagCache: gache.New(), etagExpTime: time.Minute, diff --git a/role/claim.go b/role/claim.go index 409b4a28..e9d462fb 100644 --- a/role/claim.go +++ b/role/claim.go @@ -42,18 +42,18 @@ func (c *Claim) Valid() error { vErr := new(jwt.ValidationError) now := jwt.TimeFunc().Unix() - if c.VerifyExpiresAt(now, true) == false { + if !c.VerifyExpiresAt(now, true) { delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0)) vErr.Inner = fmt.Errorf("token is expired by %v", delta) vErr.Errors |= jwt.ValidationErrorExpired } - if c.VerifyIssuedAt(now, false) == false { + if !c.VerifyIssuedAt(now, false) { vErr.Inner = fmt.Errorf("Token used before issued") vErr.Errors |= jwt.ValidationErrorIssuedAt } - if c.VerifyNotBefore(now, false) == false { + if !c.VerifyNotBefore(now, false) { vErr.Inner = fmt.Errorf("token is not valid yet") vErr.Errors |= jwt.ValidationErrorNotValidYet } diff --git a/role/processor_test.go b/role/processor_test.go index 3bdc4621..88be4cc6 100644 --- a/role/processor_test.go +++ b/role/processor_test.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package role import ( @@ -30,6 +31,8 @@ import ( "github.com/yahoojapan/athenz-authorizer/v2/pubkey" ) +// test data is generated by `role/asserts/private.pem` + func TestNew(t *testing.T) { type args struct { opts []Option