From 7e735658420a00816226ba014d6f0c8ec41a6c4f Mon Sep 17 00:00:00 2001 From: grac3gao Date: Wed, 9 Dec 2020 21:36:50 -0800 Subject: [PATCH] address comments --- cmd/broker/fanout/main.go | 3 +- cmd/broker/retry/main.go | 3 +- .../v1/pullsubscription_lifecycle.go | 8 +- .../v1/pullsubscription_lifecycle_test.go | 2 +- .../v1alpha1/brokercell_lifecycle.go | 22 ++--- .../v1alpha1/brokercell_lifecycle_test.go | 6 +- .../v1beta1/pullsubscription_lifecycle.go | 8 +- .../pullsubscription_lifecycle_test.go | 2 +- pkg/broker/handler/fanout_test.go | 5 +- pkg/broker/handler/pool.go | 10 +- pkg/broker/handler/pool_test.go | 14 ++- pkg/broker/handler/retry_test.go | 6 +- pkg/gclient/authcheck/client.go | 38 ++++++++ pkg/gclient/authcheck/interfaces.go | 29 ++++++ pkg/gclient/authcheck/testing/fake.go | 43 +++++++++ pkg/reconciler/brokercell/brokercell.go | 6 +- .../pullsubscription/keda/pullsubscription.go | 2 +- .../static/pullsubscription.go | 2 +- pkg/utils/authcheck/authcheck.go | 70 +++++++------- pkg/utils/authcheck/authcheck_test.go | 85 +++++++++++++++++ pkg/utils/authcheck/authchek_test.go | 81 ----------------- pkg/utils/authcheck/authtype.go | 18 ++-- pkg/utils/authcheck/authtype_test.go | 41 ++++++++- pkg/utils/authcheck/list.go | 7 +- pkg/utils/authcheck/probechecker.go | 25 +++-- pkg/utils/authcheck/probechecker_test.go | 91 ++++++++++++------- 26 files changed, 413 insertions(+), 214 deletions(-) create mode 100644 pkg/gclient/authcheck/client.go create mode 100644 pkg/gclient/authcheck/interfaces.go create mode 100644 pkg/gclient/authcheck/testing/fake.go create mode 100644 pkg/utils/authcheck/authcheck_test.go delete mode 100644 pkg/utils/authcheck/authchek_test.go diff --git a/cmd/broker/fanout/main.go b/cmd/broker/fanout/main.go index ca44a0469a..689ff88b27 100644 --- a/cmd/broker/fanout/main.go +++ b/cmd/broker/fanout/main.go @@ -24,6 +24,7 @@ import ( "github.com/google/knative-gcp/pkg/broker/config/volume" "github.com/google/knative-gcp/pkg/broker/handler" + authcheckclient "github.com/google/knative-gcp/pkg/gclient/authcheck" "github.com/google/knative-gcp/pkg/metrics" "github.com/google/knative-gcp/pkg/utils" "github.com/google/knative-gcp/pkg/utils/appcredentials" @@ -103,7 +104,7 @@ func main() { if err != nil { logger.Fatal("Failed to create fanout sync pool", zap.Error(err)) } - if _, err := handler.StartSyncPool(ctx, syncPool, syncSignal, env.MaxStaleDuration, handler.DefaultProbeCheckPort, env.AuthType); err != nil { + if _, err := handler.StartSyncPool(ctx, syncPool, syncSignal, env.MaxStaleDuration, handler.DefaultProbeCheckPort, env.AuthType, authcheckclient.NewAuthCheckClient()); err != nil { logger.Fatalw("Failed to start fanout sync pool", zap.Error(err)) } diff --git a/cmd/broker/retry/main.go b/cmd/broker/retry/main.go index 2d0facc22a..73310fac71 100644 --- a/cmd/broker/retry/main.go +++ b/cmd/broker/retry/main.go @@ -19,6 +19,7 @@ package main import ( "context" "flag" + "net/http" "time" "cloud.google.com/go/pubsub" @@ -101,7 +102,7 @@ func main() { if err != nil { logger.Fatal("Failed to get retry sync pool", zap.Error(err)) } - if _, err := handler.StartSyncPool(ctx, syncPool, syncSignal, env.MaxStaleDuration, handler.DefaultProbeCheckPort, env.AuthType); err != nil { + if _, err := handler.StartSyncPool(ctx, syncPool, syncSignal, env.MaxStaleDuration, handler.DefaultProbeCheckPort, env.AuthType, http.DefaultClient); err != nil { logger.Fatal("Failed to start retry sync pool", zap.Error(err)) } diff --git a/pkg/apis/intevents/v1/pullsubscription_lifecycle.go b/pkg/apis/intevents/v1/pullsubscription_lifecycle.go index 5a2f8c3ca6..22bc52fbaa 100644 --- a/pkg/apis/intevents/v1/pullsubscription_lifecycle.go +++ b/pkg/apis/intevents/v1/pullsubscription_lifecycle.go @@ -95,12 +95,12 @@ func (s *PullSubscriptionStatus) MarkDeployedUnknown(reason, messageFormat strin // PropagateDeploymentAvailability uses the availability of the provided Deployment to determine if // PullSubscriptionConditionDeployed should be marked as true or false. -// For authentication check purpose, this method will return true if a false condition +// For authentication check purpose, this method will return false if a false condition // is caused by deployment's replicaset unavailable. // ReplicaSet unavailable is a sign for potential authentication problems. func (s *PullSubscriptionStatus) PropagateDeploymentAvailability(d *appsv1.Deployment) bool { deploymentAvailableFound := false - replicaUnavailable := false + replicaAvailable := true for _, cond := range d.Status.Conditions { if cond.Type == appsv1.DeploymentAvailable { deploymentAvailableFound = true @@ -108,7 +108,7 @@ func (s *PullSubscriptionStatus) PropagateDeploymentAvailability(d *appsv1.Deplo pullSubscriptionCondSet.Manage(s).MarkTrue(PullSubscriptionConditionDeployed) } else if cond.Status == corev1.ConditionFalse { if cond.Reason == replicaUnavailableReason { - replicaUnavailable = true + replicaAvailable = false } pullSubscriptionCondSet.Manage(s).MarkFalse(PullSubscriptionConditionDeployed, cond.Reason, cond.Message) } else if cond.Status == corev1.ConditionUnknown { @@ -119,5 +119,5 @@ func (s *PullSubscriptionStatus) PropagateDeploymentAvailability(d *appsv1.Deplo if !deploymentAvailableFound { pullSubscriptionCondSet.Manage(s).MarkUnknown(PullSubscriptionConditionDeployed, "DeploymentUnavailable", "Deployment %q is unavailable.", d.Name) } - return replicaUnavailable + return replicaAvailable } diff --git a/pkg/apis/intevents/v1/pullsubscription_lifecycle_test.go b/pkg/apis/intevents/v1/pullsubscription_lifecycle_test.go index ccffa1be68..9667480406 100644 --- a/pkg/apis/intevents/v1/pullsubscription_lifecycle_test.go +++ b/pkg/apis/intevents/v1/pullsubscription_lifecycle_test.go @@ -526,7 +526,7 @@ func TestPullSubscriptionStatusGetCondition(t *testing.T) { func TestPropagateDeploymentAvailability(t *testing.T) { s := &PullSubscriptionStatus{} got := s.PropagateDeploymentAvailability(replicaUnavailableDeployment) - want := true + want := false if diff := cmp.Diff(got, want); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } diff --git a/pkg/apis/intevents/v1alpha1/brokercell_lifecycle.go b/pkg/apis/intevents/v1alpha1/brokercell_lifecycle.go index ab2d26464f..b7ba22e317 100644 --- a/pkg/apis/intevents/v1alpha1/brokercell_lifecycle.go +++ b/pkg/apis/intevents/v1alpha1/brokercell_lifecycle.go @@ -77,11 +77,11 @@ func (bs *BrokerCellStatus) InitializeConditions() { // PropagateIngressAvailability uses the availability of the provided Endpoints // to determine if BrokerCellConditionIngress should be marked as true or // false. -// For authentication check purpose, this method will return true if a false condition +// For authentication check purpose, this method will return false if a false condition // is caused by deployment's replicaset unavailable. // ReplicaSet unavailable is a sign for potential authentication problems. func (bs *BrokerCellStatus) PropagateIngressAvailability(ep *corev1.Endpoints, ind *appsv1.Deployment) bool { - replicaUnavailable := false + replicaAvailable := true if duck.EndpointsAreAvailable(ep) { brokerCellCondSet.Manage(bs).MarkTrue(BrokerCellConditionIngress) } else { @@ -89,14 +89,14 @@ func (bs *BrokerCellStatus) PropagateIngressAvailability(ep *corev1.Endpoints, i if cond.Type == appsv1.DeploymentAvailable { if cond.Status == corev1.ConditionFalse { if cond.Reason == replicaUnavailableReason { - replicaUnavailable = true + replicaAvailable = false } } } } brokerCellCondSet.Manage(bs).MarkFalse(BrokerCellConditionIngress, "EndpointsUnavailable", "Endpoints %q is unavailable.", ep.Name) } - return replicaUnavailable + return replicaAvailable } func (bs *BrokerCellStatus) MarkIngressFailed(reason, format string, args ...interface{}) { @@ -110,12 +110,12 @@ func (bs *BrokerCellStatus) MarkIngressUnknown(reason, format string, args ...in // PropagateFanoutAvailability uses the availability of the provided Deployment // to determine if BrokerCellConditionFanout should be marked as true or // false. -// For authentication check purpose, this method will return true if a false condition +// For authentication check purpose, this method will return false if a false condition // is caused by deployment's replicaset unavailable. // ReplicaSet unavailable is a sign for potential authentication problems. func (bs *BrokerCellStatus) PropagateFanoutAvailability(d *appsv1.Deployment) bool { deploymentAvailableFound := false - replicaUnavailable := false + replicaAvailable := true for _, cond := range d.Status.Conditions { if cond.Type == appsv1.DeploymentAvailable { deploymentAvailableFound = true @@ -123,7 +123,7 @@ func (bs *BrokerCellStatus) PropagateFanoutAvailability(d *appsv1.Deployment) bo brokerCellCondSet.Manage(bs).MarkTrue(BrokerCellConditionFanout) } else if cond.Status == corev1.ConditionFalse { if cond.Reason == replicaUnavailableReason { - replicaUnavailable = true + replicaAvailable = false } brokerCellCondSet.Manage(bs).MarkFalse(BrokerCellConditionFanout, cond.Reason, cond.Message) } else if cond.Status == corev1.ConditionUnknown { @@ -134,7 +134,7 @@ func (bs *BrokerCellStatus) PropagateFanoutAvailability(d *appsv1.Deployment) bo if !deploymentAvailableFound { brokerCellCondSet.Manage(bs).MarkUnknown(BrokerCellConditionFanout, "DeploymentUnavailable", "Deployment %q is unavailable.", d.Name) } - return replicaUnavailable + return replicaAvailable } func (bs *BrokerCellStatus) MarkFanoutFailed(reason, format string, args ...interface{}) { @@ -153,7 +153,7 @@ func (bs *BrokerCellStatus) MarkFanoutUnknown(reason, format string, args ...int // ReplicaSet unavailable is a sign for potential authentication problems. func (bs *BrokerCellStatus) PropagateRetryAvailability(d *appsv1.Deployment) bool { deploymentAvailableFound := false - replicaUnavailable := false + replicaAvailable := true for _, cond := range d.Status.Conditions { if cond.Type == appsv1.DeploymentAvailable { deploymentAvailableFound = true @@ -161,7 +161,7 @@ func (bs *BrokerCellStatus) PropagateRetryAvailability(d *appsv1.Deployment) boo brokerCellCondSet.Manage(bs).MarkTrue(BrokerCellConditionRetry) } else if cond.Status == corev1.ConditionFalse { if cond.Reason == replicaUnavailableReason { - replicaUnavailable = true + replicaAvailable = false } brokerCellCondSet.Manage(bs).MarkFalse(BrokerCellConditionRetry, cond.Reason, cond.Message) } else if cond.Status == corev1.ConditionUnknown { @@ -172,7 +172,7 @@ func (bs *BrokerCellStatus) PropagateRetryAvailability(d *appsv1.Deployment) boo if !deploymentAvailableFound { brokerCellCondSet.Manage(bs).MarkUnknown(BrokerCellConditionRetry, "DeploymentUnavailable", "Deployment %q is unavailable.", d.Name) } - return replicaUnavailable + return replicaAvailable } func (bs *BrokerCellStatus) MarkRetryFailed(reason, format string, args ...interface{}) { diff --git a/pkg/apis/intevents/v1alpha1/brokercell_lifecycle_test.go b/pkg/apis/intevents/v1alpha1/brokercell_lifecycle_test.go index 2ed64ea6f4..81985be3ee 100644 --- a/pkg/apis/intevents/v1alpha1/brokercell_lifecycle_test.go +++ b/pkg/apis/intevents/v1alpha1/brokercell_lifecycle_test.go @@ -438,7 +438,7 @@ func TestPropagateDeploymentAvailability(t *testing.T) { t.Run("propagate ingress availability", func(t *testing.T) { s := &BrokerCellStatus{} got := s.PropagateIngressAvailability(&corev1.Endpoints{}, replicaUnavailableDeployment) - want := true + want := false if diff := cmp.Diff(got, want); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } @@ -447,7 +447,7 @@ func TestPropagateDeploymentAvailability(t *testing.T) { t.Run("propagate fanout availability", func(t *testing.T) { s := &BrokerCellStatus{} got := s.PropagateFanoutAvailability(replicaUnavailableDeployment) - want := true + want := false if diff := cmp.Diff(got, want); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } @@ -456,7 +456,7 @@ func TestPropagateDeploymentAvailability(t *testing.T) { t.Run("propagate retry availability", func(t *testing.T) { s := &BrokerCellStatus{} got := s.PropagateRetryAvailability(replicaUnavailableDeployment) - want := true + want := false if diff := cmp.Diff(got, want); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } diff --git a/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle.go b/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle.go index 42ae9af4d9..eedc9e0bc5 100644 --- a/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle.go +++ b/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle.go @@ -87,12 +87,12 @@ func (s *PullSubscriptionStatus) MarkNoSubscription(reason, messageFormat string // PropagateDeploymentAvailability uses the availability of the provided Deployment to determine if // PullSubscriptionConditionDeployed should be marked as true or false. -// For authentication check purpose, this method will return true if a false condition +// For authentication check purpose, this method will return false if a false condition // is caused by deployment's replicaset unavailable. // ReplicaSet unavailable is a sign for potential authentication problems. func (s *PullSubscriptionStatus) PropagateDeploymentAvailability(d *appsv1.Deployment) bool { deploymentAvailableFound := false - replicaUnavailable := false + replicaAvailable := true for _, cond := range d.Status.Conditions { if cond.Type == appsv1.DeploymentAvailable { deploymentAvailableFound = true @@ -100,7 +100,7 @@ func (s *PullSubscriptionStatus) PropagateDeploymentAvailability(d *appsv1.Deplo pullSubscriptionCondSet.Manage(s).MarkTrue(PullSubscriptionConditionDeployed) } else if cond.Status == corev1.ConditionFalse { if cond.Reason == replicaUnavailableReason { - replicaUnavailable = true + replicaAvailable = false } pullSubscriptionCondSet.Manage(s).MarkFalse(PullSubscriptionConditionDeployed, cond.Reason, cond.Message) } else if cond.Status == corev1.ConditionUnknown { @@ -111,5 +111,5 @@ func (s *PullSubscriptionStatus) PropagateDeploymentAvailability(d *appsv1.Deplo if !deploymentAvailableFound { pullSubscriptionCondSet.Manage(s).MarkUnknown(PullSubscriptionConditionDeployed, "DeploymentUnavailable", "Deployment %q is unavailable.", d.Name) } - return replicaUnavailable + return replicaAvailable } diff --git a/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle_test.go b/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle_test.go index fe21c33a76..55710e781f 100644 --- a/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle_test.go +++ b/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle_test.go @@ -526,7 +526,7 @@ func TestPullSubscriptionStatusGetCondition(t *testing.T) { func TestPropagateDeploymentAvailability(t *testing.T) { s := &PullSubscriptionStatus{} got := s.PropagateDeploymentAvailability(replicaUnavailableDeployment) - want := true + want := false if diff := cmp.Diff(got, want); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } diff --git a/pkg/broker/handler/fanout_test.go b/pkg/broker/handler/fanout_test.go index de4aa6f007..8ec18ddc57 100644 --- a/pkg/broker/handler/fanout_test.go +++ b/pkg/broker/handler/fanout_test.go @@ -29,6 +29,7 @@ import ( "github.com/google/knative-gcp/pkg/broker/config" "github.com/google/knative-gcp/pkg/broker/eventutil" handlertesting "github.com/google/knative-gcp/pkg/broker/handler/testing" + authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" reportertest "github.com/google/knative-gcp/pkg/metrics/testing" _ "knative.dev/pkg/metrics/testing" @@ -62,7 +63,7 @@ func TestFanoutWatchAndSync(t *testing.T) { } t.Run("start sync pool creates no handler", func(t *testing.T) { - _, err = StartSyncPool(ctx, syncPool, signal, time.Minute, p, "") + _, err = StartSyncPool(ctx, syncPool, signal, time.Minute, p, "", authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)) if err != nil { t.Errorf("unexpected error from starting sync pool: %v", err) } @@ -155,7 +156,7 @@ func TestFanoutSyncPoolE2E(t *testing.T) { t.Fatalf("failed to get random free port: %v", err) } - if _, err := StartSyncPool(ctx, syncPool, signal, time.Minute, p, ""); err != nil { + if _, err := StartSyncPool(ctx, syncPool, signal, time.Minute, p, "", authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)); err != nil { t.Errorf("unexpected error from starting sync pool: %v", err) } diff --git a/pkg/broker/handler/pool.go b/pkg/broker/handler/pool.go index c2ac0cab22..e2dbdbf7a5 100644 --- a/pkg/broker/handler/pool.go +++ b/pkg/broker/handler/pool.go @@ -23,6 +23,7 @@ import ( "sync" "time" + authcheckclient "github.com/google/knative-gcp/pkg/gclient/authcheck" "github.com/google/knative-gcp/pkg/logging" "github.com/google/knative-gcp/pkg/utils/authcheck" @@ -44,6 +45,7 @@ type probeChecker struct { lastReportTime time.Time maxStaleDuration time.Duration port int + authCheckClient authcheckclient.Client authType authcheck.AuthType } @@ -85,7 +87,11 @@ func (c *probeChecker) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } // Perform Authentication check. - authcheck.AuthenticationCheck(req.Context(), c.logger, c.authType, w) + if err := authcheck.AuthenticationCheck(req.Context(), c.authType, c.authCheckClient); err != nil { + c.logger.Error("authentication check failed", zap.Error(err)) + w.WriteHeader(http.StatusUnauthorized) + return + } // Zero maxStaleDuration means infinite. if c.maxStaleDuration == 0 { @@ -106,6 +112,7 @@ func StartSyncPool( maxStaleDuration time.Duration, probeCheckPort int, authType authcheck.AuthType, + authCheckClient authcheckclient.Client, ) (SyncPool, error) { if err := syncPool.SyncOnce(ctx); err != nil { @@ -116,6 +123,7 @@ func StartSyncPool( maxStaleDuration: maxStaleDuration, port: probeCheckPort, authType: authType, + authCheckClient: authCheckClient, } go c.start(ctx) if syncSignal != nil { diff --git a/pkg/broker/handler/pool_test.go b/pkg/broker/handler/pool_test.go index 0cd366ec9d..4172072d70 100644 --- a/pkg/broker/handler/pool_test.go +++ b/pkg/broker/handler/pool_test.go @@ -25,6 +25,9 @@ import ( "time" "github.com/google/go-cmp/cmp" + + authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" + "github.com/google/knative-gcp/pkg/utils/authcheck" ) func TestSyncPool(t *testing.T) { @@ -42,7 +45,7 @@ func TestSyncPool(t *testing.T) { t.Fatalf("failed to get random free port: %v", err) } - _, gotErr := StartSyncPool(ctx, syncPool, make(chan struct{}), 30*time.Second, p, "") + _, gotErr := StartSyncPool(ctx, syncPool, make(chan struct{}), 30*time.Second, p, authcheck.WorkloadIdentityGSA, authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)) if gotErr == nil { t.Error("StartSyncPool got unexpected result") } @@ -65,7 +68,7 @@ func TestSyncPool(t *testing.T) { } ch := make(chan struct{}) - if _, err := StartSyncPool(ctx, syncPool, ch, time.Second, p, ""); err != nil { + if _, err := StartSyncPool(ctx, syncPool, ch, time.Second, p, authcheck.WorkloadIdentityGSA, authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)); err != nil { t.Errorf("StartSyncPool got unexpected error: %v", err) } syncPool.verifySyncOnceCalled(t) @@ -74,10 +77,13 @@ func TestSyncPool(t *testing.T) { ch <- struct{}{} syncPool.verifySyncOnceCalled(t) - // False because authentication check will fail. - assertProbeCheckResult(t, p, false, "healthz") + assertProbeCheckResult(t, p, true, "healthz") // False because path is not healthz. assertProbeCheckResult(t, p, false, "empty") + + time.Sleep(time.Second) + // False because it exceeds StaleDuration. + assertProbeCheckResult(t, p, false, "healthz") }) } diff --git a/pkg/broker/handler/retry_test.go b/pkg/broker/handler/retry_test.go index 3db8e2c48e..56e7237e6d 100644 --- a/pkg/broker/handler/retry_test.go +++ b/pkg/broker/handler/retry_test.go @@ -18,6 +18,7 @@ package handler import ( "context" + "net/http" "testing" "time" @@ -28,6 +29,7 @@ import ( "github.com/google/knative-gcp/pkg/broker/config" "github.com/google/knative-gcp/pkg/broker/eventutil" handlertesting "github.com/google/knative-gcp/pkg/broker/handler/testing" + authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" reportertest "github.com/google/knative-gcp/pkg/metrics/testing" _ "knative.dev/pkg/metrics/testing" @@ -61,7 +63,7 @@ func TestRetryWatchAndSync(t *testing.T) { } t.Run("start sync pool creates no handler", func(t *testing.T) { - _, err = StartSyncPool(ctx, syncPool, signal, time.Minute, p, "") + _, err = StartSyncPool(ctx, syncPool, signal, time.Minute, p, "", authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)) if err != nil { t.Errorf("unexpected error from starting sync pool: %v", err) } @@ -157,7 +159,7 @@ func TestRetrySyncPoolE2E(t *testing.T) { t.Fatalf("failed to get random free port: %v", err) } - if _, err := StartSyncPool(ctx, syncPool, signal, time.Minute, p, ""); err != nil { + if _, err := StartSyncPool(ctx, syncPool, signal, time.Minute, p, "", authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)); err != nil { t.Errorf("unexpected error from starting sync pool: %v", err) } diff --git a/pkg/gclient/authcheck/client.go b/pkg/gclient/authcheck/client.go new file mode 100644 index 0000000000..12f3f2484f --- /dev/null +++ b/pkg/gclient/authcheck/client.go @@ -0,0 +1,38 @@ +/* +Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package authcheck provides interfaces and wrappers around the http client. +package authcheck + +import ( + "net/http" +) + +type authCheckClient struct { + httpClient *http.Client +} + +var _ Client = &authCheckClient{} + +func NewAuthCheckClient() Client { + return &authCheckClient{ + httpClient: http.DefaultClient, + } +} + +func (client *authCheckClient) Do(req *http.Request) (*http.Response, error) { + return client.httpClient.Do(req) +} diff --git a/pkg/gclient/authcheck/interfaces.go b/pkg/gclient/authcheck/interfaces.go new file mode 100644 index 0000000000..643a146e5d --- /dev/null +++ b/pkg/gclient/authcheck/interfaces.go @@ -0,0 +1,29 @@ +/* +Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package authcheck provides interfaces and wrappers around the http client. +package authcheck + +import ( + "net/http" +) + +type Client interface { + // Do sends an HTTP request and returns an HTTP response, + // following policy (such as redirects, cookies, auth) as configured on the client. + // See https://golang.org/pkg/net/http/#Client.Do + Do(req *http.Request) (*http.Response, error) +} diff --git a/pkg/gclient/authcheck/testing/fake.go b/pkg/gclient/authcheck/testing/fake.go new file mode 100644 index 0000000000..641fa3bced --- /dev/null +++ b/pkg/gclient/authcheck/testing/fake.go @@ -0,0 +1,43 @@ +/* +Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package authcheck provides interfaces and wrappers around the http client. +package testing + +import ( + "bytes" + "io/ioutil" + "net/http" + + "github.com/google/knative-gcp/pkg/gclient/authcheck" +) + +type FakeAuthCheckClient struct { + statusCode int +} + +func NewFakeAuthCheckClient(statusCode int) authcheck.Client { + return &FakeAuthCheckClient{ + statusCode: statusCode, + } +} + +func (f *FakeAuthCheckClient) Do(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: f.statusCode, + Body: ioutil.NopCloser(bytes.NewBufferString("fake auth check client")), + }, nil +} diff --git a/pkg/reconciler/brokercell/brokercell.go b/pkg/reconciler/brokercell/brokercell.go index 3681b16c6f..70455818e2 100644 --- a/pkg/reconciler/brokercell/brokercell.go +++ b/pkg/reconciler/brokercell/brokercell.go @@ -176,7 +176,7 @@ func (r *Reconciler) ReconcileKind(ctx context.Context, bc *intv1alpha1.BrokerCe return err } // If deployment has replicaUnavailable error, it potentially has authentication configuration issues. - if replicaUnavailable := bc.Status.PropagateIngressAvailability(endpoints, ind); replicaUnavailable { + if replicaAvailable := bc.Status.PropagateIngressAvailability(endpoints, ind); !replicaAvailable { podList, err := authcheck.GetPodList(ctx, resources.GetLabelSelector(bc.Name, resources.IngressName), r.KubeClientSet, bc.Namespace) if err != nil { logging.FromContext(ctx).Error("Failed to propagate authentication check message from ingress component", zap.Any("namespace", bc.Namespace), zap.Any("name", bc.Name), zap.Error(err)) @@ -204,7 +204,7 @@ func (r *Reconciler) ReconcileKind(ctx context.Context, bc *intv1alpha1.BrokerCe return err } // If deployment has replicaUnavailable error, it potentially has authentication configuration issues. - if replicaUnavailable := bc.Status.PropagateFanoutAvailability(fd); replicaUnavailable { + if replicaAvailable := bc.Status.PropagateFanoutAvailability(fd); !replicaAvailable { podList, err := authcheck.GetPodList(ctx, resources.GetLabelSelector(bc.Name, resources.FanoutName), r.KubeClientSet, bc.Namespace) if err != nil { logging.FromContext(ctx).Error("Failed to propagate authentication check message from fanout component", zap.Any("namespace", bc.Namespace), zap.Any("name", bc.Name), zap.Error(err)) @@ -229,7 +229,7 @@ func (r *Reconciler) ReconcileKind(ctx context.Context, bc *intv1alpha1.BrokerCe return err } // If deployment has replicaUnavailable error, it potentially has authentication configuration issues. - if replicaUnavailable := bc.Status.PropagateRetryAvailability(rd); replicaUnavailable { + if replicaAvailable := bc.Status.PropagateRetryAvailability(rd); !replicaAvailable { podList, err := authcheck.GetPodList(ctx, resources.GetLabelSelector(bc.Name, resources.RetryName), r.KubeClientSet, bc.Namespace) if err != nil { logging.FromContext(ctx).Error("Failed to propagate authentication check message from retry component", zap.Any("namespace", bc.Namespace), zap.Any("name", bc.Name), zap.Error(err)) diff --git a/pkg/reconciler/intevents/pullsubscription/keda/pullsubscription.go b/pkg/reconciler/intevents/pullsubscription/keda/pullsubscription.go index 0a5c5d6bee..39479399b3 100644 --- a/pkg/reconciler/intevents/pullsubscription/keda/pullsubscription.go +++ b/pkg/reconciler/intevents/pullsubscription/keda/pullsubscription.go @@ -91,7 +91,7 @@ func (r *Reconciler) ReconcileScaledObject(ctx context.Context, ra *appsv1.Deplo } // If deployment has replicaUnavailable error, it potentially has authentication configuration issues. - if replicaUnavailable := src.Status.PropagateDeploymentAvailability(existing); replicaUnavailable { + if replicaAvailable := src.Status.PropagateDeploymentAvailability(existing); !replicaAvailable { podList, err := authcheck.GetPodList(ctx, psresources.GetLabelSelector(r.ControllerAgentName, src.Name), r.KubeClientSet, src.Namespace) if err != nil { logging.FromContext(ctx).Error("Error propagating authentication check message", zap.Error(err)) diff --git a/pkg/reconciler/intevents/pullsubscription/static/pullsubscription.go b/pkg/reconciler/intevents/pullsubscription/static/pullsubscription.go index aad021d75a..7b894fce5c 100644 --- a/pkg/reconciler/intevents/pullsubscription/static/pullsubscription.go +++ b/pkg/reconciler/intevents/pullsubscription/static/pullsubscription.go @@ -62,7 +62,7 @@ func (r *Reconciler) ReconcileDeployment(ctx context.Context, ra *appsv1.Deploym } // If deployment has replicaUnavailable error, it potentially has authentication configuration issues. - if replicaUnavailable := src.Status.PropagateDeploymentAvailability(existing); replicaUnavailable { + if replicaAvailable := src.Status.PropagateDeploymentAvailability(existing); !replicaAvailable { podList, err := authcheck.GetPodList(ctx, psresources.GetLabelSelector(r.ControllerAgentName, src.Name), r.KubeClientSet, src.Namespace) if err != nil { logging.FromContext(ctx).Error("Error propagating authentication check message", zap.Error(err)) diff --git a/pkg/utils/authcheck/authcheck.go b/pkg/utils/authcheck/authcheck.go index 2da98af469..4193e389dd 100644 --- a/pkg/utils/authcheck/authcheck.go +++ b/pkg/utils/authcheck/authcheck.go @@ -25,56 +25,42 @@ import ( "fmt" "io/ioutil" "net/http" - nethttp "net/http" - "go.uber.org/zap" "golang.org/x/oauth2/google" + + authcheckclient "github.com/google/knative-gcp/pkg/gclient/authcheck" ) const ( - // Resource is used as the path to get the default token from metadata server. + // resource is used as the path to get the default token from metadata server. // In workload-identity-gsa mode, this path will return a token if // corresponding k8s service account and google service account establish a correct relationship. resource = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" - // Scope is used as the scope to get token from default credential. + // scope is used as the scope to get token from default credential. scope = "https://www.googleapis.com/auth/cloud-platform" + // authMessage is the key words to determine if a termination log is about authentication. + authMessage = "checking authentication" ) // AuthenticationCheck performs the authentication check running in the Pod. -func AuthenticationCheck(ctx context.Context, logger *zap.Logger, authType AuthType, response nethttp.ResponseWriter) { +func AuthenticationCheck(ctx context.Context, authType AuthType, client authcheckclient.Client) error { var err error - if authType == Secret { + switch authType { + case Secret: err = AuthenticationCheckForSecret(ctx) - } else if authType == WorkloadIdentityGSA { - err = AuthenticationCheckForWorkloadIdentityGSA(resource) - } else { - logger.Error(fmt.Sprint("unknown auth type: ", authType)) - response.WriteHeader(nethttp.StatusUnauthorized) - return + case WorkloadIdentityGSA: + err = AuthenticationCheckForWorkloadIdentityGSA(resource, client) + case WorkloadIdentity: + // Skip authentication check running in Pods which use new generation of Workload Identity. + return nil + default: + return fmt.Errorf("unknown auth type: %s", authType) } if err != nil { - // Transfer the error into a string message, otherwise, marshalling error may return nil unexpectedly. - message := fmt.Sprintf("using %s mode, when checking authentication, get error: %s", authType, err.Error()) - b, err := json.Marshal(map[string]interface{}{ - "error": message, - }) - if err != nil { - logger.Error(fmt.Sprint("error marshalling the message: ", message), zap.Error(err)) - response.WriteHeader(nethttp.StatusUnauthorized) - return - } - errs := ioutil.WriteFile("/dev/termination-log", b, 0644) - if errs != nil { - logger.Error(fmt.Sprintf("error writing the message: %s into termination log", message), zap.Error(err)) - response.WriteHeader(nethttp.StatusUnauthorized) - return - } - logger.Info(message) - response.WriteHeader(nethttp.StatusUnauthorized) - return + return writeTerminationLog(err, authType) } - response.WriteHeader(nethttp.StatusOK) + return nil } // AuthenticationCheckForSecret performs the authentication check for Pod in secret mode. @@ -94,13 +80,13 @@ func AuthenticationCheckForSecret(ctx context.Context) error { } // AuthenticationCheckForWorkloadIdentityGSA performs the authentication check for Pod in workload-identity-gsa mode. -func AuthenticationCheckForWorkloadIdentityGSA(resource string) error { +func AuthenticationCheckForWorkloadIdentityGSA(resource string, client authcheckclient.Client) error { req, err := http.NewRequest(http.MethodGet, resource, nil) if err != nil { return fmt.Errorf("error setting up the http request: %w", err) } req.Header.Set("Metadata-Flavor", "Google") - resp, err := http.DefaultClient.Do(req) + resp, err := client.Do(req) if err != nil { return fmt.Errorf("error getting the http response: %w", err) } @@ -112,3 +98,19 @@ func AuthenticationCheckForWorkloadIdentityGSA(resource string) error { } return nil } + +func writeTerminationLog(inputErr error, authType AuthType) error { + // Transfer the error into a string message, otherwise, marshalling error may return nil unexpectedly. + message := fmt.Sprintf("%s, pod uses %s mode, get error: %s", authMessage, authType, inputErr.Error()) + b, err := json.Marshal(map[string]interface{}{ + "error": message, + }) + if err != nil { + return fmt.Errorf("error marshalling the message: %s", message) + } + err = ioutil.WriteFile("/dev/termination-log", b, 0644) + if err != nil { + return fmt.Errorf("error writing the message into termination log, message: %s", message) + } + return inputErr +} diff --git a/pkg/utils/authcheck/authcheck_test.go b/pkg/utils/authcheck/authcheck_test.go new file mode 100644 index 0000000000..93ce49ab53 --- /dev/null +++ b/pkg/utils/authcheck/authcheck_test.go @@ -0,0 +1,85 @@ +/* +Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package authcheck + +import ( + "context" + "net/http" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + + authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" +) + +func TestAuthenticationCheck(t *testing.T) { + testCases := []struct { + name string + authType AuthType + statusCode int + wantError bool + setEnv bool + }{ + { + name: "authentication check failed, using empty authType", + authType: "empty", + statusCode: http.StatusUnauthorized, + wantError: true, + setEnv: false, + }, + { + name: "authentication check failed, using secret authType", + authType: Secret, + statusCode: http.StatusUnauthorized, + wantError: true, + setEnv: true, + }, + { + name: "authentication check failed, using workload-identity-gsa authType", + authType: WorkloadIdentityGSA, + statusCode: http.StatusUnauthorized, + wantError: true, + setEnv: true, + }, + { + name: "authentication check succeeded", + authType: WorkloadIdentityGSA, + statusCode: http.StatusOK, + wantError: false, + setEnv: true, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if tc.setEnv { + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "empty") + } + + err := AuthenticationCheck(ctx, tc.authType, authchecktesting.NewFakeAuthCheckClient(tc.statusCode)) + + if diff := cmp.Diff(tc.wantError, err != nil); diff != "" { + t.Error("unexpected error (-want, +got) = ", err) + } + }) + } +} diff --git a/pkg/utils/authcheck/authchek_test.go b/pkg/utils/authcheck/authchek_test.go deleted file mode 100644 index 208baac69c..0000000000 --- a/pkg/utils/authcheck/authchek_test.go +++ /dev/null @@ -1,81 +0,0 @@ -/* -Copyright 2020 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package authcheck - -import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "os" - "testing" - - "github.com/google/go-cmp/cmp" - - "github.com/google/knative-gcp/pkg/logging" -) - -func TestAuthenticationCheck(t *testing.T) { - testCases := []struct { - name string - authType AuthType - wantStatusCode int - setEnv bool - }{ - { - name: "authentication check uses empty authType", - authType: "", - wantStatusCode: http.StatusUnauthorized, - setEnv: false, - }, - { - name: "authentication check uses secret authType", - authType: Secret, - wantStatusCode: http.StatusUnauthorized, - setEnv: true, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := logging.FromContext(ctx) - response := httptest.NewRecorder() - - if tc.setEnv { - os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "empty") - } - - AuthenticationCheck(ctx, logger, tc.authType, response) - - if diff := cmp.Diff(tc.wantStatusCode, response.Result().StatusCode); diff != "" { - t.Error("unexpected status (-want, +got) = ", diff) - } - }) - } -} - -func TestAuthenticationCheckForWorkloadIdentityGSA(t *testing.T) { - resource := "http:/empty" - wantErr := errors.New("error getting the http response: Get \"http:///empty\": http: no Host in request URL") - gotErr := AuthenticationCheckForWorkloadIdentityGSA(resource) - if diff := cmp.Diff(gotErr.Error(), wantErr.Error()); diff != "" { - t.Error("unexpected status (-want, +got) = ", diff) - } -} diff --git a/pkg/utils/authcheck/authtype.go b/pkg/utils/authcheck/authtype.go index 791b5fb1ad..15176fb5bc 100644 --- a/pkg/utils/authcheck/authtype.go +++ b/pkg/utils/authcheck/authtype.go @@ -57,7 +57,12 @@ const ( var ( // Regex for a valid google service account email. - emailRegexp = regexp.MustCompile(`^[a-z][a-z0-9-]{5,29}@[a-z][a-z0-9-]{5,29}\.iam\.gserviceaccount.com$`) + // The format of google service account email is service-account-name@project-id.iam.gserviceaccount.com + // Service account name must be between 6 and 30 characters (inclusive), + // must begin with a lowercase letter, and consist of lowercase alphanumeric characters that can be separated by hyphens. + // Project IDs must start with a lowercase letter and can have lowercase ASCII letters, digits or hyphens, + // must be between 6 and 30 characters. Some older project may have dot as well, like project-example.example.com + emailRegexp = regexp.MustCompile(`^[a-z][a-z0-9-]{5,29}@[a-z][a-z0-9-]{5,29}.*\.iam\.gserviceaccount\.com$`) BrokerSecret = &corev1.SecretKeySelector{ LocalObjectReference: corev1.LocalObjectReference{Name: "google-broker-key"}, @@ -123,15 +128,8 @@ func getAuthTypeForWorkloadIdentity(ctx context.Context, serviceAccountLister co } else if email := kServiceAccount.Annotations[resources.WorkloadIdentityKey]; email != "" { // Check if email is a valid google service account email. if match := emailRegexp.FindStringSubmatch(email); len(match) == 0 { - // The format of google service account email is service-account-name@project-id.iam.gserviceaccount.com - - // Service account name must be between 6 and 30 characters (inclusive), - // must begin with a lowercase letter, and consist of lowercase alphanumeric characters that can be separated by hyphens. - - // Project IDs must start with a lowercase letter and can have lowercase ASCII letters, digits or hyphens, - // must be between 6 and 30 characters. - return "", fmt.Errorf("the annotation %s of Kubernetes Service Account %s does not contain a valid Google Service Account", - resources.WorkloadIdentityKey, args.ServiceAccountName) + return "", fmt.Errorf("%s is not a valid Google Service Account as the value of Kubernetes Service Account %s for annotation %s", + email, args.ServiceAccountName, resources.WorkloadIdentityKey) } return WorkloadIdentityGSA, nil } diff --git a/pkg/utils/authcheck/authtype_test.go b/pkg/utils/authcheck/authtype_test.go index 9c602f4e3b..91bcb69627 100644 --- a/pkg/utils/authcheck/authtype_test.go +++ b/pkg/utils/authcheck/authtype_test.go @@ -86,7 +86,18 @@ func TestGetAuthTypeForSources(t *testing.T) { wantError: nil, }, { - name: "error get authType, invalid service account", + name: "successfully get authType for workload identity, with old style project id", + objects: []runtime.Object{ + pkgtesting.NewServiceAccount(serviceAccountName, testNS, + pkgtesting.WithServiceAccountAnnotation("longname@project-example.example.com.iam.gserviceaccount.com"), + ), + }, + args: serviceAccountArgs, + wantAuthType: WorkloadIdentityGSA, + wantError: nil, + }, + { + name: "error get authType, invalid service account without a email-like format", objects: []runtime.Object{ pkgtesting.NewServiceAccount(serviceAccountName, testNS, pkgtesting.WithServiceAccountAnnotation("name"), @@ -94,8 +105,32 @@ func TestGetAuthTypeForSources(t *testing.T) { }, args: serviceAccountArgs, wantAuthType: "", - wantError: fmt.Errorf("using Workload Identity for authentication configuration: the annotation iam.gke.io/gcp-service-account "+ - "of Kubernetes Service Account %s does not contain a valid Google Service Account", serviceAccountName), + wantError: fmt.Errorf("using Workload Identity for authentication configuration: name is not a valid Google Service Account " + + "as the value of Kubernetes Service Account test-ksa for annotation iam.gke.io/gcp-service-account"), + }, + { + name: "error get authType, invalid service account with a email-like format but without project id", + objects: []runtime.Object{ + pkgtesting.NewServiceAccount(serviceAccountName, testNS, + pkgtesting.WithServiceAccountAnnotation("name@.iam.gserviceaccount.com"), + ), + }, + args: serviceAccountArgs, + wantAuthType: "", + wantError: fmt.Errorf("using Workload Identity for authentication configuration: name@.iam.gserviceaccount.com is not a valid Google Service Account " + + "as the value of Kubernetes Service Account test-ksa for annotation iam.gke.io/gcp-service-account"), + }, + { + name: "error get authType, invalid service account with a email-like format but name is too short", + objects: []runtime.Object{ + pkgtesting.NewServiceAccount(serviceAccountName, testNS, + pkgtesting.WithServiceAccountAnnotation("name@project-id.iam.gserviceaccount.com"), + ), + }, + args: serviceAccountArgs, + wantAuthType: "", + wantError: fmt.Errorf("using Workload Identity for authentication configuration: name@project-id.iam.gserviceaccount.com is not a valid Google Service Account " + + "as the value of Kubernetes Service Account test-ksa for annotation iam.gke.io/gcp-service-account"), }, { name: "error get authType, service account doesn't exist", diff --git a/pkg/utils/authcheck/list.go b/pkg/utils/authcheck/list.go index 7260787a9c..910da5c8cd 100644 --- a/pkg/utils/authcheck/list.go +++ b/pkg/utils/authcheck/list.go @@ -31,7 +31,7 @@ import ( "knative.dev/pkg/logging" ) -// GetPodList get a list of Pod in a certain namespace with certain label selector. +// GetPodList get a list of Pods in a certain namespace with certain label selector. func GetPodList(ctx context.Context, ls labels.Selector, kubeClientSet kubernetes.Interface, namespace string) (*corev1.PodList, error) { pl, err := kubeClientSet.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ LabelSelector: ls.String(), @@ -47,7 +47,8 @@ func GetPodList(ctx context.Context, ls labels.Selector, kubeClientSet kubernete return pl, nil } -// GetTerminationLogFromPodList get termination log from Pod. +// GetTerminationLogFromPodList gets the termination log from Pods that failed due to authentication check errors. +// It returns the first authentication termination log from any Pods in the list. func GetTerminationLogFromPodList(pl *corev1.PodList) string { for _, pod := range pl.Items { for _, cs := range pod.Status.ContainerStatuses { @@ -62,5 +63,5 @@ func GetTerminationLogFromPodList(pl *corev1.PodList) string { } func isAuthMessage(message string) bool { - return strings.Contains(message, "checking authentication") + return strings.Contains(message, authMessage) } diff --git a/pkg/utils/authcheck/probechecker.go b/pkg/utils/authcheck/probechecker.go index 048a79f87d..99c1fcebf7 100644 --- a/pkg/utils/authcheck/probechecker.go +++ b/pkg/utils/authcheck/probechecker.go @@ -24,27 +24,31 @@ import ( "strconv" "go.uber.org/zap" + + authcheckclient "github.com/google/knative-gcp/pkg/gclient/authcheck" ) // DefaultProbeCheckPort is the default port for checking sync pool health. const DefaultProbeCheckPort = 8080 type ProbeChecker struct { - logger *zap.Logger - port int - authType AuthType + logger *zap.Logger + port int + authType AuthType + authCheckClient authcheckclient.Client } // NewProbeChecker returns ProbeChecker with default probe checker port. func NewProbeChecker(logger *zap.Logger, authType AuthType) ProbeChecker { return ProbeChecker{ - logger: logger, - port: DefaultProbeCheckPort, - authType: authType, + logger: logger, + port: DefaultProbeCheckPort, + authType: authType, + authCheckClient: authcheckclient.NewAuthCheckClient(), } } -// Start will initialize s http serve and start to listen. +// Start will initialize an http server and start to listen. func (pc *ProbeChecker) Start(ctx context.Context) { srv := &http.Server{ Addr: ":" + strconv.Itoa(pc.port), @@ -68,6 +72,11 @@ func (pc *ProbeChecker) Start(ctx context.Context) { func (pc *ProbeChecker) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.URL.Path == "/healthz" { // Perform Authentication check. - AuthenticationCheck(req.Context(), pc.logger, pc.authType, w) + if err := AuthenticationCheck(req.Context(), pc.authType, pc.authCheckClient); err != nil { + pc.logger.Error("authentication check failed", zap.Error(err)) + w.WriteHeader(http.StatusUnauthorized) + return + } + w.WriteHeader(http.StatusOK) } } diff --git a/pkg/utils/authcheck/probechecker_test.go b/pkg/utils/authcheck/probechecker_test.go index 018f003612..3f3ba3ccc1 100644 --- a/pkg/utils/authcheck/probechecker_test.go +++ b/pkg/utils/authcheck/probechecker_test.go @@ -26,50 +26,71 @@ import ( "github.com/google/go-cmp/cmp" + authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" "github.com/google/knative-gcp/pkg/logging" ) func TestProbeCheckResult(t *testing.T) { - t.Helper() - - ctx, cancel := context.WithCancel(context.Background()) - - // Get a free port. - addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - t.Fatal("Failed to resolve TCP address:", err) - } - l, err := net.ListenTCP("tcp", addr) - if err != nil { - t.Fatal("Failed to listen TCP:", err) + testCases := []struct { + name string + inputStatus int + wantStatusCode int + }{ + { + name: "probe check got a failure result", + inputStatus: http.StatusBadRequest, + wantStatusCode: http.StatusUnauthorized, + }, + { + name: "probe check got a success result", + inputStatus: http.StatusAccepted, + wantStatusCode: http.StatusOK, + }, } - l.Close() - port := l.Addr().(*net.TCPAddr).Port + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + // Get a free port. + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + t.Fatal("Failed to resolve TCP address:", err) + } + l, err := net.ListenTCP("tcp", addr) + if err != nil { + t.Fatal("Failed to listen TCP:", err) + } + l.Close() + port := l.Addr().(*net.TCPAddr).Port - logger := logging.FromContext(ctx) - probeChecker := ProbeChecker{ - logger: logger, - port: port, - authType: "", - } - go probeChecker.Start(ctx) + logger := logging.FromContext(ctx) + probeChecker := ProbeChecker{ + logger: logger, + port: port, + authType: WorkloadIdentityGSA, + authCheckClient: authchecktesting.NewFakeAuthCheckClient(tc.inputStatus), + } + go probeChecker.Start(ctx) - time.Sleep(1 * time.Second) + time.Sleep(1 * time.Second) - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/healthz", port), nil) - if err != nil { - t.Fatal("Failed to create probe check request:", err) - } + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/healthz", port), nil) + if err != nil { + t.Fatal("Failed to create probe check request:", err) + } - client := http.DefaultClient + client := http.DefaultClient + resp, err := client.Do(req) - resp, err := client.Do(req) - if err != nil { - t.Fatal("Failed to execute probe check:", err) - return - } - if diff := cmp.Diff(resp.StatusCode, http.StatusUnauthorized); diff != "" { - t.Error("unexpected probe check result (-want, +got) = ", diff) + if err != nil { + t.Fatal("Failed to execute probe check:", err) + return + } + if diff := cmp.Diff(resp.StatusCode, tc.wantStatusCode); diff != "" { + t.Error("unexpected probe check result (-want, +got) = ", diff) + } + cancel() + }) } - cancel() }