From ade72cba394677e407098a26ed0997cbeaecd427 Mon Sep 17 00:00:00 2001 From: grac3gao Date: Tue, 15 Dec 2020 17:03:49 -0800 Subject: [PATCH] address comments --- cmd/broker/fanout/main.go | 3 +- cmd/broker/retry/main.go | 3 +- .../v1/pullsubscription_lifecycle_test.go | 2 +- .../v1alpha1/brokercell_lifecycle_test.go | 6 +-- .../pullsubscription_lifecycle_test.go | 2 +- pkg/broker/handler/fanout_test.go | 6 +-- pkg/broker/handler/pool.go | 14 +++--- pkg/broker/handler/pool_test.go | 25 +++++++++-- pkg/broker/handler/retry_test.go | 7 ++- pkg/gclient/authcheck/client.go | 38 ---------------- pkg/gclient/authcheck/interfaces.go | 29 ------------- pkg/gclient/authcheck/testing/fake.go | 43 ------------------- pkg/utils/authcheck/authcheck.go | 30 +++++++++---- pkg/utils/authcheck/authcheck_test.go | 28 ++++++------ pkg/utils/authcheck/probechecker.go | 20 ++++----- pkg/utils/authcheck/probechecker_test.go | 36 ++++++++++++---- 16 files changed, 112 insertions(+), 180 deletions(-) delete mode 100644 pkg/gclient/authcheck/client.go delete mode 100644 pkg/gclient/authcheck/interfaces.go delete mode 100644 pkg/gclient/authcheck/testing/fake.go diff --git a/cmd/broker/fanout/main.go b/cmd/broker/fanout/main.go index 689ff88b27..aa166059fb 100644 --- a/cmd/broker/fanout/main.go +++ b/cmd/broker/fanout/main.go @@ -24,7 +24,6 @@ import ( "github.com/google/knative-gcp/pkg/broker/config/volume" "github.com/google/knative-gcp/pkg/broker/handler" - authcheckclient "github.com/google/knative-gcp/pkg/gclient/authcheck" "github.com/google/knative-gcp/pkg/metrics" "github.com/google/knative-gcp/pkg/utils" "github.com/google/knative-gcp/pkg/utils/appcredentials" @@ -104,7 +103,7 @@ func main() { if err != nil { logger.Fatal("Failed to create fanout sync pool", zap.Error(err)) } - if _, err := handler.StartSyncPool(ctx, syncPool, syncSignal, env.MaxStaleDuration, handler.DefaultProbeCheckPort, env.AuthType, authcheckclient.NewAuthCheckClient()); err != nil { + if _, err := handler.StartSyncPool(ctx, syncPool, syncSignal, env.MaxStaleDuration, handler.DefaultProbeCheckPort, authcheck.NewDefaultAuthenticationCheck(env.AuthType)); err != nil { logger.Fatalw("Failed to start fanout sync pool", zap.Error(err)) } diff --git a/cmd/broker/retry/main.go b/cmd/broker/retry/main.go index 73310fac71..945a979bdb 100644 --- a/cmd/broker/retry/main.go +++ b/cmd/broker/retry/main.go @@ -19,7 +19,6 @@ package main import ( "context" "flag" - "net/http" "time" "cloud.google.com/go/pubsub" @@ -102,7 +101,7 @@ func main() { if err != nil { logger.Fatal("Failed to get retry sync pool", zap.Error(err)) } - if _, err := handler.StartSyncPool(ctx, syncPool, syncSignal, env.MaxStaleDuration, handler.DefaultProbeCheckPort, env.AuthType, http.DefaultClient); err != nil { + if _, err := handler.StartSyncPool(ctx, syncPool, syncSignal, env.MaxStaleDuration, handler.DefaultProbeCheckPort, authcheck.NewDefaultAuthenticationCheck(env.AuthType)); err != nil { logger.Fatal("Failed to start retry sync pool", zap.Error(err)) } diff --git a/pkg/apis/intevents/v1/pullsubscription_lifecycle_test.go b/pkg/apis/intevents/v1/pullsubscription_lifecycle_test.go index 9667480406..d761049228 100644 --- a/pkg/apis/intevents/v1/pullsubscription_lifecycle_test.go +++ b/pkg/apis/intevents/v1/pullsubscription_lifecycle_test.go @@ -527,7 +527,7 @@ func TestPropagateDeploymentAvailability(t *testing.T) { s := &PullSubscriptionStatus{} got := s.PropagateDeploymentAvailability(replicaUnavailableDeployment) want := false - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } } diff --git a/pkg/apis/intevents/v1alpha1/brokercell_lifecycle_test.go b/pkg/apis/intevents/v1alpha1/brokercell_lifecycle_test.go index 81985be3ee..1f54cf0ca5 100644 --- a/pkg/apis/intevents/v1alpha1/brokercell_lifecycle_test.go +++ b/pkg/apis/intevents/v1alpha1/brokercell_lifecycle_test.go @@ -439,7 +439,7 @@ func TestPropagateDeploymentAvailability(t *testing.T) { s := &BrokerCellStatus{} got := s.PropagateIngressAvailability(&corev1.Endpoints{}, replicaUnavailableDeployment) want := false - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } }) @@ -448,7 +448,7 @@ func TestPropagateDeploymentAvailability(t *testing.T) { s := &BrokerCellStatus{} got := s.PropagateFanoutAvailability(replicaUnavailableDeployment) want := false - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } }) @@ -457,7 +457,7 @@ func TestPropagateDeploymentAvailability(t *testing.T) { s := &BrokerCellStatus{} got := s.PropagateRetryAvailability(replicaUnavailableDeployment) want := false - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } }) diff --git a/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle_test.go b/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle_test.go index 55710e781f..2af33082a2 100644 --- a/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle_test.go +++ b/pkg/apis/intevents/v1beta1/pullsubscription_lifecycle_test.go @@ -527,7 +527,7 @@ func TestPropagateDeploymentAvailability(t *testing.T) { s := &PullSubscriptionStatus{} got := s.PropagateDeploymentAvailability(replicaUnavailableDeployment) want := false - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Error("unexpected condition (-want, +got) =", diff) } } diff --git a/pkg/broker/handler/fanout_test.go b/pkg/broker/handler/fanout_test.go index 8ec18ddc57..2f528e6cfe 100644 --- a/pkg/broker/handler/fanout_test.go +++ b/pkg/broker/handler/fanout_test.go @@ -29,8 +29,8 @@ import ( "github.com/google/knative-gcp/pkg/broker/config" "github.com/google/knative-gcp/pkg/broker/eventutil" handlertesting "github.com/google/knative-gcp/pkg/broker/handler/testing" - authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" reportertest "github.com/google/knative-gcp/pkg/metrics/testing" + "github.com/google/knative-gcp/pkg/utils/authcheck" _ "knative.dev/pkg/metrics/testing" ) @@ -63,7 +63,7 @@ func TestFanoutWatchAndSync(t *testing.T) { } t.Run("start sync pool creates no handler", func(t *testing.T) { - _, err = StartSyncPool(ctx, syncPool, signal, time.Minute, p, "", authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)) + _, err = StartSyncPool(ctx, syncPool, signal, time.Minute, p, NewFakeAuthenticationCheck(authcheck.WorkloadIdentity, true)) if err != nil { t.Errorf("unexpected error from starting sync pool: %v", err) } @@ -156,7 +156,7 @@ func TestFanoutSyncPoolE2E(t *testing.T) { t.Fatalf("failed to get random free port: %v", err) } - if _, err := StartSyncPool(ctx, syncPool, signal, time.Minute, p, "", authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)); err != nil { + if _, err := StartSyncPool(ctx, syncPool, signal, time.Minute, p, NewFakeAuthenticationCheck(authcheck.WorkloadIdentity, true)); err != nil { t.Errorf("unexpected error from starting sync pool: %v", err) } diff --git a/pkg/broker/handler/pool.go b/pkg/broker/handler/pool.go index e2dbdbf7a5..c3d4d2a1e4 100644 --- a/pkg/broker/handler/pool.go +++ b/pkg/broker/handler/pool.go @@ -23,7 +23,6 @@ import ( "sync" "time" - authcheckclient "github.com/google/knative-gcp/pkg/gclient/authcheck" "github.com/google/knative-gcp/pkg/logging" "github.com/google/knative-gcp/pkg/utils/authcheck" @@ -45,8 +44,7 @@ type probeChecker struct { lastReportTime time.Time maxStaleDuration time.Duration port int - authCheckClient authcheckclient.Client - authType authcheck.AuthType + authCheck authcheck.AuthenticationCheck } func (c *probeChecker) reportHealth() { @@ -87,9 +85,9 @@ func (c *probeChecker) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } // Perform Authentication check. - if err := authcheck.AuthenticationCheck(req.Context(), c.authType, c.authCheckClient); err != nil { + if err := c.authCheck.Check(req.Context()); err != nil { c.logger.Error("authentication check failed", zap.Error(err)) - w.WriteHeader(http.StatusUnauthorized) + w.WriteHeader(http.StatusInternalServerError) return } @@ -111,8 +109,7 @@ func StartSyncPool( syncSignal <-chan struct{}, maxStaleDuration time.Duration, probeCheckPort int, - authType authcheck.AuthType, - authCheckClient authcheckclient.Client, + authCheck authcheck.AuthenticationCheck, ) (SyncPool, error) { if err := syncPool.SyncOnce(ctx); err != nil { @@ -122,8 +119,7 @@ func StartSyncPool( logger: logging.FromContext(ctx), maxStaleDuration: maxStaleDuration, port: probeCheckPort, - authType: authType, - authCheckClient: authCheckClient, + authCheck: authCheck, } go c.start(ctx) if syncSignal != nil { diff --git a/pkg/broker/handler/pool_test.go b/pkg/broker/handler/pool_test.go index 4172072d70..3ed03f8d64 100644 --- a/pkg/broker/handler/pool_test.go +++ b/pkg/broker/handler/pool_test.go @@ -18,6 +18,7 @@ package handler import ( "context" + "errors" "fmt" "net" "net/http" @@ -26,7 +27,6 @@ import ( "github.com/google/go-cmp/cmp" - authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" "github.com/google/knative-gcp/pkg/utils/authcheck" ) @@ -45,7 +45,7 @@ func TestSyncPool(t *testing.T) { t.Fatalf("failed to get random free port: %v", err) } - _, gotErr := StartSyncPool(ctx, syncPool, make(chan struct{}), 30*time.Second, p, authcheck.WorkloadIdentityGSA, authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)) + _, gotErr := StartSyncPool(ctx, syncPool, make(chan struct{}), 30*time.Second, p, NewFakeAuthenticationCheck(authcheck.WorkloadIdentity, true)) if gotErr == nil { t.Error("StartSyncPool got unexpected result") } @@ -68,7 +68,7 @@ func TestSyncPool(t *testing.T) { } ch := make(chan struct{}) - if _, err := StartSyncPool(ctx, syncPool, ch, time.Second, p, authcheck.WorkloadIdentityGSA, authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)); err != nil { + if _, err := StartSyncPool(ctx, syncPool, ch, time.Second, p, NewFakeAuthenticationCheck(authcheck.WorkloadIdentity, true)); err != nil { t.Errorf("StartSyncPool got unexpected error: %v", err) } syncPool.verifySyncOnceCalled(t) @@ -143,3 +143,22 @@ func GetFreePort() (int, error) { defer l.Close() return l.Addr().(*net.TCPAddr).Port, nil } + +type FakeAuthenticationCheck struct { + authType authcheck.AuthType + noError bool +} + +func NewFakeAuthenticationCheck(authType authcheck.AuthType, noError bool) authcheck.AuthenticationCheck { + return &FakeAuthenticationCheck{ + authType: authType, + noError: noError, + } +} + +func (ac *FakeAuthenticationCheck) Check(ctx context.Context) error { + if ac.noError { + return nil + } + return errors.New("induced error") +} diff --git a/pkg/broker/handler/retry_test.go b/pkg/broker/handler/retry_test.go index 56e7237e6d..acd370549c 100644 --- a/pkg/broker/handler/retry_test.go +++ b/pkg/broker/handler/retry_test.go @@ -18,7 +18,6 @@ package handler import ( "context" - "net/http" "testing" "time" @@ -29,8 +28,8 @@ import ( "github.com/google/knative-gcp/pkg/broker/config" "github.com/google/knative-gcp/pkg/broker/eventutil" handlertesting "github.com/google/knative-gcp/pkg/broker/handler/testing" - authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" reportertest "github.com/google/knative-gcp/pkg/metrics/testing" + "github.com/google/knative-gcp/pkg/utils/authcheck" _ "knative.dev/pkg/metrics/testing" ) @@ -63,7 +62,7 @@ func TestRetryWatchAndSync(t *testing.T) { } t.Run("start sync pool creates no handler", func(t *testing.T) { - _, err = StartSyncPool(ctx, syncPool, signal, time.Minute, p, "", authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)) + _, err = StartSyncPool(ctx, syncPool, signal, time.Minute, p, NewFakeAuthenticationCheck(authcheck.WorkloadIdentity, true)) if err != nil { t.Errorf("unexpected error from starting sync pool: %v", err) } @@ -159,7 +158,7 @@ func TestRetrySyncPoolE2E(t *testing.T) { t.Fatalf("failed to get random free port: %v", err) } - if _, err := StartSyncPool(ctx, syncPool, signal, time.Minute, p, "", authchecktesting.NewFakeAuthCheckClient(http.StatusAccepted)); err != nil { + if _, err := StartSyncPool(ctx, syncPool, signal, time.Minute, p, NewFakeAuthenticationCheck(authcheck.WorkloadIdentity, true)); err != nil { t.Errorf("unexpected error from starting sync pool: %v", err) } diff --git a/pkg/gclient/authcheck/client.go b/pkg/gclient/authcheck/client.go deleted file mode 100644 index 12f3f2484f..0000000000 --- a/pkg/gclient/authcheck/client.go +++ /dev/null @@ -1,38 +0,0 @@ -/* -Copyright 2020 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package authcheck provides interfaces and wrappers around the http client. -package authcheck - -import ( - "net/http" -) - -type authCheckClient struct { - httpClient *http.Client -} - -var _ Client = &authCheckClient{} - -func NewAuthCheckClient() Client { - return &authCheckClient{ - httpClient: http.DefaultClient, - } -} - -func (client *authCheckClient) Do(req *http.Request) (*http.Response, error) { - return client.httpClient.Do(req) -} diff --git a/pkg/gclient/authcheck/interfaces.go b/pkg/gclient/authcheck/interfaces.go deleted file mode 100644 index 643a146e5d..0000000000 --- a/pkg/gclient/authcheck/interfaces.go +++ /dev/null @@ -1,29 +0,0 @@ -/* -Copyright 2020 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package authcheck provides interfaces and wrappers around the http client. -package authcheck - -import ( - "net/http" -) - -type Client interface { - // Do sends an HTTP request and returns an HTTP response, - // following policy (such as redirects, cookies, auth) as configured on the client. - // See https://golang.org/pkg/net/http/#Client.Do - Do(req *http.Request) (*http.Response, error) -} diff --git a/pkg/gclient/authcheck/testing/fake.go b/pkg/gclient/authcheck/testing/fake.go deleted file mode 100644 index 641fa3bced..0000000000 --- a/pkg/gclient/authcheck/testing/fake.go +++ /dev/null @@ -1,43 +0,0 @@ -/* -Copyright 2020 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package authcheck provides interfaces and wrappers around the http client. -package testing - -import ( - "bytes" - "io/ioutil" - "net/http" - - "github.com/google/knative-gcp/pkg/gclient/authcheck" -) - -type FakeAuthCheckClient struct { - statusCode int -} - -func NewFakeAuthCheckClient(statusCode int) authcheck.Client { - return &FakeAuthCheckClient{ - statusCode: statusCode, - } -} - -func (f *FakeAuthCheckClient) Do(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: f.statusCode, - Body: ioutil.NopCloser(bytes.NewBufferString("fake auth check client")), - }, nil -} diff --git a/pkg/utils/authcheck/authcheck.go b/pkg/utils/authcheck/authcheck.go index 4193e389dd..d656d70767 100644 --- a/pkg/utils/authcheck/authcheck.go +++ b/pkg/utils/authcheck/authcheck.go @@ -27,8 +27,6 @@ import ( "net/http" "golang.org/x/oauth2/google" - - authcheckclient "github.com/google/knative-gcp/pkg/gclient/authcheck" ) const ( @@ -42,23 +40,39 @@ const ( authMessage = "checking authentication" ) +type AuthenticationCheck interface { + Check(ctx context.Context) error +} + +type DefaultAuthenticationCheck struct { + authType AuthType + client *http.Client +} + +func NewDefaultAuthenticationCheck(authType AuthType) AuthenticationCheck { + return &DefaultAuthenticationCheck{ + authType: authType, + client: http.DefaultClient, + } +} + // AuthenticationCheck performs the authentication check running in the Pod. -func AuthenticationCheck(ctx context.Context, authType AuthType, client authcheckclient.Client) error { +func (ac *DefaultAuthenticationCheck) Check(ctx context.Context) error { var err error - switch authType { + switch ac.authType { case Secret: err = AuthenticationCheckForSecret(ctx) case WorkloadIdentityGSA: - err = AuthenticationCheckForWorkloadIdentityGSA(resource, client) + err = AuthenticationCheckForWorkloadIdentityGSA(resource, ac.client) case WorkloadIdentity: // Skip authentication check running in Pods which use new generation of Workload Identity. return nil default: - return fmt.Errorf("unknown auth type: %s", authType) + return fmt.Errorf("unknown auth type: %s", ac.authType) } if err != nil { - return writeTerminationLog(err, authType) + return writeTerminationLog(err, ac.authType) } return nil } @@ -80,7 +94,7 @@ func AuthenticationCheckForSecret(ctx context.Context) error { } // AuthenticationCheckForWorkloadIdentityGSA performs the authentication check for Pod in workload-identity-gsa mode. -func AuthenticationCheckForWorkloadIdentityGSA(resource string, client authcheckclient.Client) error { +func AuthenticationCheckForWorkloadIdentityGSA(resource string, client *http.Client) error { req, err := http.NewRequest(http.MethodGet, resource, nil) if err != nil { return fmt.Errorf("error setting up the http request: %w", err) diff --git a/pkg/utils/authcheck/authcheck_test.go b/pkg/utils/authcheck/authcheck_test.go index 93ce49ab53..33efb11158 100644 --- a/pkg/utils/authcheck/authcheck_test.go +++ b/pkg/utils/authcheck/authcheck_test.go @@ -19,12 +19,11 @@ package authcheck import ( "context" "net/http" + "net/http/httptest" "os" "testing" "github.com/google/go-cmp/cmp" - - authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" ) func TestAuthenticationCheck(t *testing.T) { @@ -38,31 +37,24 @@ func TestAuthenticationCheck(t *testing.T) { { name: "authentication check failed, using empty authType", authType: "empty", - statusCode: http.StatusUnauthorized, + statusCode: http.StatusInternalServerError, wantError: true, setEnv: false, }, { name: "authentication check failed, using secret authType", authType: Secret, - statusCode: http.StatusUnauthorized, + statusCode: http.StatusInternalServerError, wantError: true, setEnv: true, }, { name: "authentication check failed, using workload-identity-gsa authType", authType: WorkloadIdentityGSA, - statusCode: http.StatusUnauthorized, + statusCode: http.StatusInternalServerError, wantError: true, setEnv: true, }, - { - name: "authentication check succeeded", - authType: WorkloadIdentityGSA, - statusCode: http.StatusOK, - wantError: false, - setEnv: true, - }, } for _, tc := range testCases { tc := tc @@ -75,7 +67,17 @@ func TestAuthenticationCheck(t *testing.T) { os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "empty") } - err := AuthenticationCheck(ctx, tc.authType, authchecktesting.NewFakeAuthCheckClient(tc.statusCode)) + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + defer server.Close() + authCheck := &DefaultAuthenticationCheck{ + authType: tc.authType, + client: server.Client(), + } + + err := authCheck.Check(ctx) if diff := cmp.Diff(tc.wantError, err != nil); diff != "" { t.Error("unexpected error (-want, +got) = ", err) diff --git a/pkg/utils/authcheck/probechecker.go b/pkg/utils/authcheck/probechecker.go index 99c1fcebf7..2b1bdccbbd 100644 --- a/pkg/utils/authcheck/probechecker.go +++ b/pkg/utils/authcheck/probechecker.go @@ -24,27 +24,23 @@ import ( "strconv" "go.uber.org/zap" - - authcheckclient "github.com/google/knative-gcp/pkg/gclient/authcheck" ) // DefaultProbeCheckPort is the default port for checking sync pool health. const DefaultProbeCheckPort = 8080 type ProbeChecker struct { - logger *zap.Logger - port int - authType AuthType - authCheckClient authcheckclient.Client + logger *zap.Logger + port int + authCheck AuthenticationCheck } // NewProbeChecker returns ProbeChecker with default probe checker port. func NewProbeChecker(logger *zap.Logger, authType AuthType) ProbeChecker { return ProbeChecker{ - logger: logger, - port: DefaultProbeCheckPort, - authType: authType, - authCheckClient: authcheckclient.NewAuthCheckClient(), + logger: logger, + port: DefaultProbeCheckPort, + authCheck: NewDefaultAuthenticationCheck(authType), } } @@ -72,9 +68,9 @@ func (pc *ProbeChecker) Start(ctx context.Context) { func (pc *ProbeChecker) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.URL.Path == "/healthz" { // Perform Authentication check. - if err := AuthenticationCheck(req.Context(), pc.authType, pc.authCheckClient); err != nil { + if err := pc.authCheck.Check(req.Context()); err != nil { pc.logger.Error("authentication check failed", zap.Error(err)) - w.WriteHeader(http.StatusUnauthorized) + w.WriteHeader(http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) diff --git a/pkg/utils/authcheck/probechecker_test.go b/pkg/utils/authcheck/probechecker_test.go index 3f3ba3ccc1..57eaef77e5 100644 --- a/pkg/utils/authcheck/probechecker_test.go +++ b/pkg/utils/authcheck/probechecker_test.go @@ -18,6 +18,7 @@ package authcheck import ( "context" + "errors" "fmt" "net" "net/http" @@ -26,24 +27,23 @@ import ( "github.com/google/go-cmp/cmp" - authchecktesting "github.com/google/knative-gcp/pkg/gclient/authcheck/testing" "github.com/google/knative-gcp/pkg/logging" ) func TestProbeCheckResult(t *testing.T) { testCases := []struct { name string - inputStatus int + noError bool wantStatusCode int }{ { name: "probe check got a failure result", - inputStatus: http.StatusBadRequest, - wantStatusCode: http.StatusUnauthorized, + noError: false, + wantStatusCode: http.StatusInternalServerError, }, { name: "probe check got a success result", - inputStatus: http.StatusAccepted, + noError: true, wantStatusCode: http.StatusOK, }, } @@ -66,10 +66,9 @@ func TestProbeCheckResult(t *testing.T) { logger := logging.FromContext(ctx) probeChecker := ProbeChecker{ - logger: logger, - port: port, - authType: WorkloadIdentityGSA, - authCheckClient: authchecktesting.NewFakeAuthCheckClient(tc.inputStatus), + logger: logger, + port: port, + authCheck: NewFakeAuthenticationCheck("", tc.noError), } go probeChecker.Start(ctx) @@ -94,3 +93,22 @@ func TestProbeCheckResult(t *testing.T) { }) } } + +type FakeAuthenticationCheck struct { + authType AuthType + noError bool +} + +func NewFakeAuthenticationCheck(authType AuthType, noError bool) AuthenticationCheck { + return &FakeAuthenticationCheck{ + authType: authType, + noError: noError, + } +} + +func (ac *FakeAuthenticationCheck) Check(ctx context.Context) error { + if ac.noError { + return nil + } + return errors.New("induced error") +}