diff --git a/auth/auth.go b/auth/auth.go index e26802f1ea73..ea7c1b0ad8dc 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -284,18 +284,18 @@ type Error struct { uri string } -func (r *Error) Error() string { - if r.code != "" { - s := fmt.Sprintf("auth: %q", r.code) - if r.description != "" { - s += fmt.Sprintf(" %q", r.description) +func (e *Error) Error() string { + if e.code != "" { + s := fmt.Sprintf("auth: %q", e.code) + if e.description != "" { + s += fmt.Sprintf(" %q", e.description) } - if r.uri != "" { - s += fmt.Sprintf(" %q", r.uri) + if e.uri != "" { + s += fmt.Sprintf(" %q", e.uri) } return s } - return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", r.Response.StatusCode, r.Body) + return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", e.Response.StatusCode, e.Body) } // Temporary returns true if the error is considered temporary and may be able diff --git a/auth/credentials/detect.go b/auth/credentials/detect.go index 24d153bb8b93..3e45726631cc 100644 --- a/auth/credentials/detect.go +++ b/auth/credentials/detect.go @@ -93,6 +93,7 @@ func DetectDefault(opts *DetectOptions) (*auth.Credentials, error) { ProjectIDProvider: auth.CredentialsPropertyFunc(func(context.Context) (string, error) { return metadata.ProjectID() }), + UniverseDomainProvider: &internal.ComputeUniverseDomainProvider{}, }), nil } @@ -140,6 +141,9 @@ type DetectOptions struct { // Client configures the underlying client used to make network requests // when fetching tokens. Optional. Client *http.Client + // UniverseDomain is the default service domain for a given Cloud universe. + // The default value is "googleapis.com". Optional. + UniverseDomain string } func (o *DetectOptions) validate() error { diff --git a/auth/credentials/detect_test.go b/auth/credentials/detect_test.go index 20e77cc98f0d..b366b2c473e4 100644 --- a/auth/credentials/detect_test.go +++ b/auth/credentials/detect_test.go @@ -596,7 +596,7 @@ func TestDefaultCredentials_ExternalAccountKey(t *testing.T) { if want := "googleapis.com"; got != want { t.Fatalf("got %q, want %q", got, want) } - tok, err := creds.Token(context.Background()) + tok, err := creds.Token(ctx) if err != nil { t.Fatalf("creds.Token() = %v", err) } @@ -720,3 +720,111 @@ func TestDefaultCredentials_Validate(t *testing.T) { }) } } + +func TestDefaultCredentials_UniverseDomain(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + opts *DetectOptions + want string + }{ + { + name: "user json", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/user.json", + TokenURL: "example.com", + }, + want: "googleapis.com", + }, + { + name: "user json with file universe domain", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/user_universe_domain.json", + TokenURL: "example.com", + }, + want: "googleapis.com", + }, + { + name: "service account token URL json", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/sa.json", + }, + want: "googleapis.com", + }, + { + name: "external account json", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/exaccount_user.json", + UseSelfSignedJWT: true, + }, + want: "googleapis.com", + }, + { + name: "service account impersonation json", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/imp.json", + UseSelfSignedJWT: true, + }, + want: "googleapis.com", + }, + { + name: "service account json with file universe domain", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/sa_universe_domain.json", + UseSelfSignedJWT: true, + }, + want: "example.com", + }, + { + name: "service account json with options universe domain", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/sa.json", + UseSelfSignedJWT: true, + UniverseDomain: "foo.com", + }, + want: "foo.com", + }, + { + name: "service account json with file and options universe domain", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/sa_universe_domain.json", + UseSelfSignedJWT: true, + UniverseDomain: "bar.com", + }, + want: "bar.com", + }, + { + name: "external account json with options universe domain", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/exaccount_user.json", + UseSelfSignedJWT: true, + UniverseDomain: "foo.com", + }, + want: "foo.com", + }, + { + name: "impersonated service account json with options universe domain", + opts: &DetectOptions{ + CredentialsFile: "../internal/testdata/imp.json", + UseSelfSignedJWT: true, + UniverseDomain: "foo.com", + }, + want: "foo.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + creds, err := DetectDefault(tt.opts) + if err != nil { + t.Fatalf("%s: %v", tt.name, err) + } + ud, err := creds.UniverseDomain(ctx) + if err != nil { + t.Fatalf("%s: %v", tt.name, err) + } + if ud != tt.want { + t.Fatalf("%s: got %q, want %q", tt.name, ud, tt.want) + } + }) + } +} diff --git a/auth/credentials/filetypes.go b/auth/credentials/filetypes.go index 9bef2afe848e..3f4b12988772 100644 --- a/auth/credentials/filetypes.go +++ b/auth/credentials/filetypes.go @@ -101,6 +101,9 @@ func fileCredentials(b []byte, opts *DetectOptions) (*auth.Credentials, error) { default: return nil, fmt.Errorf("detect: unsupported filetype %q", fileType) } + if opts.UniverseDomain != "" { + universeDomain = opts.UniverseDomain + } return auth.NewCredentials(&auth.CredentialsOptions{ TokenProvider: auth.NewCachedTokenProvider(tp, &auth.CachedTokenProviderOptions{ ExpireEarly: opts.EarlyTokenRefresh, diff --git a/auth/credentials/internal/externalaccount/aws_provider.go b/auth/credentials/internal/externalaccount/aws_provider.go index 9fa86e6303da..09e7a252b1c6 100644 --- a/auth/credentials/internal/externalaccount/aws_provider.go +++ b/auth/credentials/internal/externalaccount/aws_provider.go @@ -173,17 +173,17 @@ func (sp *awsSubjectProvider) providerType() string { return awsProviderType } -func (cs *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) { - if cs.IMDSv2SessionTokenURL == "" { +func (sp *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) { + if sp.IMDSv2SessionTokenURL == "" { return "", nil } - req, err := http.NewRequestWithContext(ctx, "PUT", cs.IMDSv2SessionTokenURL, nil) + req, err := http.NewRequestWithContext(ctx, "PUT", sp.IMDSv2SessionTokenURL, nil) if err != nil { return "", err } req.Header.Set(awsIMDSv2SessionTTLHeader, awsIMDSv2SessionTTL) - resp, err := cs.Client.Do(req) + resp, err := sp.Client.Do(req) if err != nil { return "", err } @@ -199,7 +199,7 @@ func (cs *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, e return string(respBody), nil } -func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) { +func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) { if canRetrieveRegionFromEnvironment() { if envAwsRegion := getenv(awsRegionEnvVar); envAwsRegion != "" { return envAwsRegion, nil @@ -207,11 +207,11 @@ func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string] return getenv(awsDefaultRegionEnvVar), nil } - if cs.RegionURL == "" { + if sp.RegionURL == "" { return "", errors.New("detect: unable to determine AWS region") } - req, err := http.NewRequestWithContext(ctx, "GET", cs.RegionURL, nil) + req, err := http.NewRequestWithContext(ctx, "GET", sp.RegionURL, nil) if err != nil { return "", err } @@ -220,7 +220,7 @@ func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string] req.Header.Add(name, value) } - resp, err := cs.Client.Do(req) + resp, err := sp.Client.Do(req) if err != nil { return "", err } @@ -244,7 +244,7 @@ func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string] return string(respBody[:bodyLen-1]), nil } -func (cs *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result awsSecurityCredentials, err error) { +func (sp *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result awsSecurityCredentials, err error) { if canRetrieveSecurityCredentialFromEnvironment() { return awsSecurityCredentials{ AccessKeyID: getenv(awsAccessKeyIDEnvVar), @@ -253,11 +253,11 @@ func (cs *awsSubjectProvider) getSecurityCredentials(ctx context.Context, header }, nil } - roleName, err := cs.getMetadataRoleName(ctx, headers) + roleName, err := sp.getMetadataRoleName(ctx, headers) if err != nil { return } - credentials, err := cs.getMetadataSecurityCredentials(ctx, roleName, headers) + credentials, err := sp.getMetadataSecurityCredentials(ctx, roleName, headers) if err != nil { return } @@ -272,10 +272,10 @@ func (cs *awsSubjectProvider) getSecurityCredentials(ctx context.Context, header return credentials, nil } -func (cs *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (awsSecurityCredentials, error) { +func (sp *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (awsSecurityCredentials, error) { var result awsSecurityCredentials - req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil) + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", sp.CredVerificationURL, roleName), nil) if err != nil { return result, err } @@ -283,7 +283,7 @@ func (cs *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context req.Header.Add(name, value) } - resp, err := cs.Client.Do(req) + resp, err := sp.Client.Do(req) if err != nil { return result, err } @@ -300,11 +300,11 @@ func (cs *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context return result, err } -func (cs *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) { - if cs.CredVerificationURL == "" { +func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) { + if sp.CredVerificationURL == "" { return "", errors.New("detect: unable to determine the AWS metadata server security credentials endpoint") } - req, err := http.NewRequestWithContext(ctx, "GET", cs.CredVerificationURL, nil) + req, err := http.NewRequestWithContext(ctx, "GET", sp.CredVerificationURL, nil) if err != nil { return "", err } @@ -312,7 +312,7 @@ func (cs *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers m req.Header.Add(name, value) } - resp, err := cs.Client.Do(req) + resp, err := sp.Client.Do(req) if err != nil { return "", err } diff --git a/auth/credentials/internal/externalaccount/executable_provider.go b/auth/credentials/internal/externalaccount/executable_provider.go index 51044ba797bf..681207b7b7a1 100644 --- a/auth/credentials/internal/externalaccount/executable_provider.go +++ b/auth/credentials/internal/externalaccount/executable_provider.go @@ -122,7 +122,7 @@ type executableResponse struct { Message string `json:"message,omitempty"` } -func (cs *executableSubjectProvider) parseSubjectTokenFromSource(response []byte, source string, now int64) (string, error) { +func (sp *executableSubjectProvider) parseSubjectTokenFromSource(response []byte, source string, now int64) (string, error) { var result executableResponse if err := json.Unmarshal(response, &result); err != nil { return "", jsonParsingError(source, string(response)) @@ -143,7 +143,7 @@ func (cs *executableSubjectProvider) parseSubjectTokenFromSource(response []byte if result.Version > executableSupportedMaxVersion || result.Version < 0 { return "", unsupportedVersionError(source, result.Version) } - if result.ExpirationTime == 0 && cs.OutputFile != "" { + if result.ExpirationTime == 0 && sp.OutputFile != "" { return "", missingFieldError(source, "expiration_time") } if result.TokenType == "" { @@ -169,24 +169,24 @@ func (cs *executableSubjectProvider) parseSubjectTokenFromSource(response []byte } } -func (cs *executableSubjectProvider) subjectToken(ctx context.Context) (string, error) { - if token, err := cs.getTokenFromOutputFile(); token != "" || err != nil { +func (sp *executableSubjectProvider) subjectToken(ctx context.Context) (string, error) { + if token, err := sp.getTokenFromOutputFile(); token != "" || err != nil { return token, err } - return cs.getTokenFromExecutableCommand(ctx) + return sp.getTokenFromExecutableCommand(ctx) } -func (cs *executableSubjectProvider) providerType() string { +func (sp *executableSubjectProvider) providerType() string { return executableProviderType } -func (cs *executableSubjectProvider) getTokenFromOutputFile() (token string, err error) { - if cs.OutputFile == "" { +func (sp *executableSubjectProvider) getTokenFromOutputFile() (token string, err error) { + if sp.OutputFile == "" { // This ExecutableCredentialSource doesn't use an OutputFile. return "", nil } - file, err := os.Open(cs.OutputFile) + file, err := os.Open(sp.OutputFile) if err != nil { // No OutputFile found. Hasn't been created yet, so skip it. return "", nil @@ -199,7 +199,7 @@ func (cs *executableSubjectProvider) getTokenFromOutputFile() (token string, err return "", nil } - token, err = cs.parseSubjectTokenFromSource(data, outputFileSource, cs.env.now().Unix()) + token, err = sp.parseSubjectTokenFromSource(data, outputFileSource, sp.env.now().Unix()) if err != nil { if _, ok := err.(nonCacheableError); ok { // If the cached token is expired we need a new token, @@ -231,20 +231,20 @@ func (sp *executableSubjectProvider) executableEnvironment() []string { return result } -func (cs *executableSubjectProvider) getTokenFromExecutableCommand(ctx context.Context) (string, error) { +func (sp *executableSubjectProvider) getTokenFromExecutableCommand(ctx context.Context) (string, error) { // For security reasons, we need our consumers to set this environment variable to allow executables to be run. - if cs.env.getenv(allowExecutablesEnvVar) != "1" { + if sp.env.getenv(allowExecutablesEnvVar) != "1" { return "", errors.New("detect: executables need to be explicitly allowed (set GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES to '1') to run") } - ctx, cancel := context.WithDeadline(ctx, cs.env.now().Add(cs.Timeout)) + ctx, cancel := context.WithDeadline(ctx, sp.env.now().Add(sp.Timeout)) defer cancel() - output, err := cs.env.run(ctx, cs.Command, cs.executableEnvironment()) + output, err := sp.env.run(ctx, sp.Command, sp.executableEnvironment()) if err != nil { return "", err } - return cs.parseSubjectTokenFromSource(output, executableSource, cs.env.now().Unix()) + return sp.parseSubjectTokenFromSource(output, executableSource, sp.env.now().Unix()) } func missingFieldError(source, field string) error { diff --git a/auth/credentials/internal/impersonate/impersonate.go b/auth/credentials/internal/impersonate/impersonate.go index 469c04f467aa..de54c636b8a2 100644 --- a/auth/credentials/internal/impersonate/impersonate.go +++ b/auth/credentials/internal/impersonate/impersonate.go @@ -87,29 +87,29 @@ func (o *Options) validate() error { } // Token performs the exchange to get a temporary service account token to allow access to GCP. -func (tp *Options) Token(ctx context.Context) (*auth.Token, error) { +func (o *Options) Token(ctx context.Context) (*auth.Token, error) { lifetime := defaultTokenLifetime - if tp.TokenLifetimeSeconds != 0 { - lifetime = fmt.Sprintf("%ds", tp.TokenLifetimeSeconds) + if o.TokenLifetimeSeconds != 0 { + lifetime = fmt.Sprintf("%ds", o.TokenLifetimeSeconds) } reqBody := generateAccessTokenReq{ Lifetime: lifetime, - Scope: tp.Scopes, - Delegates: tp.Delegates, + Scope: o.Scopes, + Delegates: o.Delegates, } b, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("detect: unable to marshal request: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", tp.URL, bytes.NewReader(b)) + req, err := http.NewRequestWithContext(ctx, "POST", o.URL, bytes.NewReader(b)) if err != nil { return nil, fmt.Errorf("detect: unable to create impersonation request: %w", err) } req.Header.Set("Content-Type", "application/json") - if err := setAuthHeader(ctx, tp.Tp, req); err != nil { + if err := setAuthHeader(ctx, o.Tp, req); err != nil { return nil, err } - resp, err := tp.Client.Do(req) + resp, err := o.Client.Do(req) if err != nil { return nil, fmt.Errorf("detect: unable to generate access token: %w", err) } diff --git a/auth/grpctransport/grpctransport.go b/auth/grpctransport/grpctransport.go index 8a1e31bce4ff..17e83d9b34a0 100644 --- a/auth/grpctransport/grpctransport.go +++ b/auth/grpctransport/grpctransport.go @@ -22,6 +22,7 @@ import ( "cloud.google.com/go/auth" "cloud.google.com/go/auth/credentials" + "cloud.google.com/go/auth/internal" "cloud.google.com/go/auth/internal/transport" "go.opencensus.io/plugin/ocgrpc" "google.golang.org/grpc" @@ -71,6 +72,11 @@ type Options struct { // DetectOpts configures settings for detect Application Default // Credentials. DetectOpts *credentials.DetectOptions + // UniverseDomain is the default service domain for a given Cloud universe. + // The default value is "googleapis.com". This is the universe domain + // configured for the client, which will be compared to the universe domain + // that is separately configured for the credentials. + UniverseDomain string // InternalOptions are NOT meant to be set directly by consumers of this // package, they should only be set by generated client code. @@ -218,11 +224,11 @@ func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, er } metadata[quotaProjectHeaderKey] = qp } - grpcOpts = append(grpcOpts, grpc.WithPerRPCCredentials(&grpcCredentialsProvider{ - creds: creds, - metadata: metadata, + creds: creds, + metadata: metadata, + clientUniverseDomain: opts.UniverseDomain, }), ) @@ -246,10 +252,29 @@ type grpcCredentialsProvider struct { secure bool // Additional metadata attached as headers. - metadata map[string]string + metadata map[string]string + clientUniverseDomain string +} + +// getClientUniverseDomain returns the default service domain for a given Cloud universe. +// The default value is "googleapis.com". This is the universe domain +// configured for the client, which will be compared to the universe domain +// that is separately configured for the credentials. +func (c *grpcCredentialsProvider) getClientUniverseDomain() string { + if c.clientUniverseDomain == "" { + return internal.DefaultUniverseDomain + } + return c.clientUniverseDomain } func (c *grpcCredentialsProvider) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + credentialsUniverseDomain, err := c.creds.UniverseDomain(ctx) + if err != nil { + return nil, err + } + if err := transport.ValidateUniverseDomain(c.getClientUniverseDomain(), credentialsUniverseDomain); err != nil { + return nil, err + } token, err := c.creds.Token(ctx) if err != nil { return nil, err @@ -269,8 +294,8 @@ func (c *grpcCredentialsProvider) GetRequestMetadata(ctx context.Context, uri .. return metadata, nil } -func (tp *grpcCredentialsProvider) RequireTransportSecurity() bool { - return tp.secure +func (c *grpcCredentialsProvider) RequireTransportSecurity() bool { + return c.secure } func addOCStatsHandler(dialOpts []grpc.DialOption, opts *Options) []grpc.DialOption { diff --git a/auth/grpctransport/grpctransport_test.go b/auth/grpctransport/grpctransport_test.go index 004d49b69bfe..e27571dd0f46 100644 --- a/auth/grpctransport/grpctransport_test.go +++ b/auth/grpctransport/grpctransport_test.go @@ -23,6 +23,7 @@ import ( "cloud.google.com/go/auth" "cloud.google.com/go/auth/credentials" echo "cloud.google.com/go/auth/grpctransport/testdata" + "cloud.google.com/go/auth/internal" "github.com/google/go-cmp/cmp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -242,6 +243,33 @@ func TestOptions_ResolveDetectOptions(t *testing.T) { } } +func TestGrpcCredentialsProvider_GetClientUniverseDomain(t *testing.T) { + nonDefault := "example.com" + tests := []struct { + name string + universeDomain string + want string + }{ + { + name: "default", + universeDomain: "", + want: internal.DefaultUniverseDomain, + }, + { + name: "non-default", + universeDomain: nonDefault, + want: nonDefault, + }, + } + for _, tt := range tests { + at := &grpcCredentialsProvider{clientUniverseDomain: tt.universeDomain} + got := at.getClientUniverseDomain() + if got != tt.want { + t.Errorf("%s: got %q, want %q", tt.name, got, tt.want) + } + } +} + func TestNewClient_DetectedServiceAccount(t *testing.T) { testQuota := "testquota" wantHeader := "bar" @@ -284,10 +312,11 @@ func TestNewClient_DetectedServiceAccount(t *testing.T) { }, DetectOpts: &credentials.DetectOptions{ Audience: l.Addr().String(), - CredentialsFile: "../internal/testdata/sa.json", + CredentialsFile: "../internal/testdata/sa_universe_domain.json", UseSelfSignedJWT: true, }, - GRPCDialOpts: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, + GRPCDialOpts: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, + UniverseDomain: "example.com", // Also configured in sa_universe_domain.json }) if err != nil { t.Fatalf("NewClient() = %v", err) diff --git a/auth/httptransport/httptransport.go b/auth/httptransport/httptransport.go index 4794e0f87943..5fc3f93f5b89 100644 --- a/auth/httptransport/httptransport.go +++ b/auth/httptransport/httptransport.go @@ -59,6 +59,11 @@ type Options struct { // DetectOpts configures settings for detect Application Default // Credentials. DetectOpts *detect.DetectOptions + // UniverseDomain is the default service domain for a given Cloud universe. + // The default value is "googleapis.com". This is the universe domain + // configured for the client, which will be compared to the universe domain + // that is separately configured for the credentials. + UniverseDomain string // InternalOptions are NOT meant to be set directly by consumers of this // package, they should only be set by generated client code. @@ -140,8 +145,9 @@ func AddAuthorizationMiddleware(client *http.Client, creds *auth.Credentials) er base = http.DefaultTransport.(*http.Transport).Clone() } client.Transport = &authTransport{ - provider: creds, - base: base, + creds: creds, + base: base, + // TODO(quartzmo): Somehow set clientUniverseDomain from impersonate calls. } return nil } diff --git a/auth/httptransport/transport.go b/auth/httptransport/transport.go index 673a3e51f896..aef4663e6ce8 100644 --- a/auth/httptransport/transport.go +++ b/auth/httptransport/transport.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/auth" "cloud.google.com/go/auth/credentials" "cloud.google.com/go/auth/internal" + "cloud.google.com/go/auth/internal/transport" "cloud.google.com/go/auth/internal/transport/cert" "go.opencensus.io/plugin/ochttp" "golang.org/x/net/http2" @@ -75,9 +76,11 @@ func newTransport(base http.RoundTripper, opts *Options) (http.RoundTripper, err if opts.Credentials != nil { creds = opts.Credentials } + creds.TokenProvider = auth.NewCachedTokenProvider(creds.TokenProvider, nil) trans = &authTransport{ - base: trans, - provider: creds, + base: trans, + creds: creds, + clientUniverseDomain: opts.UniverseDomain, } } return trans, nil @@ -161,8 +164,18 @@ func addOCTransport(trans http.RoundTripper, opts *Options) http.RoundTripper { } type authTransport struct { - provider *auth.Credentials - base http.RoundTripper + creds *auth.Credentials + base http.RoundTripper + clientUniverseDomain string +} + +// getClientUniverseDomain returns the universe domain configured for the client. +// The default value is "googleapis.com". +func (t *authTransport) getClientUniverseDomain() string { + if t.clientUniverseDomain == "" { + return internal.DefaultUniverseDomain + } + return t.clientUniverseDomain } // RoundTrip authorizes and authenticates the request with an @@ -178,7 +191,14 @@ func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) { } }() } - token, err := t.provider.Token(req.Context()) + credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context()) + if err != nil { + return nil, err + } + if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil { + return nil, err + } + token, err := t.creds.Token(req.Context()) if err != nil { return nil, err } diff --git a/auth/httptransport/transport_test.go b/auth/httptransport/transport_test.go new file mode 100644 index 000000000000..da4d69befb3c --- /dev/null +++ b/auth/httptransport/transport_test.go @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httptransport + +import ( + "testing" + + "cloud.google.com/go/auth/internal" +) + +func TestAuthTransport_GetClientUniverseDomain(t *testing.T) { + nonDefault := "example.com" + tests := []struct { + name string + universeDomain string + want string + }{ + { + name: "default", + universeDomain: "", + want: internal.DefaultUniverseDomain, + }, + { + name: "non-default", + universeDomain: nonDefault, + want: nonDefault, + }, + } + for _, tt := range tests { + at := &authTransport{clientUniverseDomain: tt.universeDomain} + got := at.getClientUniverseDomain() + if got != tt.want { + t.Errorf("%s: got %q, want %q", tt.name, got, tt.want) + } + } +} diff --git a/auth/idtoken/compute.go b/auth/idtoken/compute.go index 4c14b7ce19c7..d6757b60f871 100644 --- a/auth/idtoken/compute.go +++ b/auth/idtoken/compute.go @@ -46,7 +46,7 @@ func computeCredentials(opts *Options) (*auth.Credentials, error) { ProjectIDProvider: auth.CredentialsPropertyFunc(func(context.Context) (string, error) { return metadata.ProjectID() }), - // TODO(quartzmo): add universe domain resolver here + UniverseDomainProvider: &internal.ComputeUniverseDomainProvider{}, }), nil } diff --git a/auth/idtoken/idtoken.go b/auth/idtoken/idtoken.go index 41dbf7b1a267..34d92592d438 100644 --- a/auth/idtoken/idtoken.go +++ b/auth/idtoken/idtoken.go @@ -103,13 +103,13 @@ func NewCredentials(opts *Options) (*auth.Credentials, error) { return nil, fmt.Errorf("idtoken: couldn't find any credentials") } -func (opts *Options) jsonBytes() []byte { - if opts.CredentialsJSON != nil { - return opts.CredentialsJSON +func (o *Options) jsonBytes() []byte { + if o.CredentialsJSON != nil { + return o.CredentialsJSON } var fnOverride string - if opts != nil { - fnOverride = opts.CredentialsFile + if o != nil { + fnOverride = o.CredentialsFile } filename := internaldetect.GetFileNameFromEnv(fnOverride) if filename != "" { diff --git a/auth/internal/internal.go b/auth/internal/internal.go index 66953bf9576e..21dd2f020bf3 100644 --- a/auth/internal/internal.go +++ b/auth/internal/internal.go @@ -25,7 +25,10 @@ import ( "io" "net/http" "os" + "sync" "time" + + "cloud.google.com/go/compute/metadata" ) const ( @@ -35,6 +38,10 @@ const ( quotaProjectEnvVar = "GOOGLE_CLOUD_QUOTA_PROJECT" projectEnvVar = "GOOGLE_CLOUD_PROJECT" maxBodySize = 1 << 20 + + // DefaultUniverseDomain is the default value for universe domain. + // Universe domain is the default service domain for a given Cloud universe. + DefaultUniverseDomain = "googleapis.com" ) // CloneDefaultClient returns a [http.Client] with some good defaults. @@ -134,3 +141,42 @@ type StaticProperty string func (p StaticProperty) GetProperty(context.Context) (string, error) { return string(p), nil } + +// ComputeUniverseDomainProvider fetches the credentials universe domain from +// the google cloud metadata service. +type ComputeUniverseDomainProvider struct { + universeDomainOnce sync.Once + universeDomain string + universeDomainErr error +} + +// GetProperty fetches the credentials universe domain from the google cloud +// metadata service. +func (c *ComputeUniverseDomainProvider) GetProperty(ctx context.Context) (string, error) { + c.universeDomainOnce.Do(func() { + c.universeDomain, c.universeDomainErr = getMetadataUniverseDomain(ctx) + }) + if c.universeDomainErr != nil { + return "", c.universeDomainErr + } + return c.universeDomain, nil +} + +// httpGetMetadataUniverseDomain is a package var for unit test substitution. +var httpGetMetadataUniverseDomain = func(ctx context.Context) (string, error) { + client := metadata.NewClient(&http.Client{Timeout: time.Second}) + // TODO(quartzmo): set ctx on request + return client.Get("universe/universe_domain") +} + +func getMetadataUniverseDomain(ctx context.Context) (string, error) { + universeDomain, err := httpGetMetadataUniverseDomain(ctx) + if err == nil { + return universeDomain, nil + } + if _, ok := err.(metadata.NotDefinedError); ok { + // http.StatusNotFound (404) + return DefaultUniverseDomain, nil + } + return "", err +} diff --git a/auth/internal/internal_test.go b/auth/internal/internal_test.go new file mode 100644 index 000000000000..b0eda2e5cea6 --- /dev/null +++ b/auth/internal/internal_test.go @@ -0,0 +1,77 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "errors" + "testing" + + "cloud.google.com/go/compute/metadata" +) + +func TestComputeUniverseDomainProvider(t *testing.T) { + fatalErr := errors.New("fatal error") + notDefinedError := metadata.NotDefinedError("universe/universe_domain") + testCases := []struct { + name string + getFunc func(ctx context.Context) (string, error) + want string + wantErr error + }{ + { + name: "test error", + getFunc: func(ctx context.Context) (string, error) { + return "", fatalErr + }, + want: "", + wantErr: fatalErr, + }, + { + name: "test error 404", + getFunc: func(ctx context.Context) (string, error) { + return "", notDefinedError + }, + want: DefaultUniverseDomain, + wantErr: nil, + }, + { + name: "test valid", + getFunc: func(ctx context.Context) (string, error) { + return "example.com", nil + }, + want: "example.com", + wantErr: nil, + }, + } + + oldGet := httpGetMetadataUniverseDomain + defer func() { + httpGetMetadataUniverseDomain = oldGet + }() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + httpGetMetadataUniverseDomain = tc.getFunc + c := ComputeUniverseDomainProvider{} + got, err := c.GetProperty(context.Background()) + if err != tc.wantErr { + t.Errorf("%s: got error %v; want error %v", tc.name, err, tc.wantErr) + } + if got != tc.want { + t.Errorf("%s: got %v; want %v", tc.name, got, tc.want) + } + }) + } +} diff --git a/auth/internal/transport/transport.go b/auth/internal/transport/transport.go index 1cf75af7c316..b76386d3c0df 100644 --- a/auth/internal/transport/transport.go +++ b/auth/internal/transport/transport.go @@ -16,7 +16,11 @@ // (grpctransport and httptransport). package transport -import "cloud.google.com/go/auth/credentials" +import ( + "fmt" + + "cloud.google.com/go/auth/credentials" +) // CloneDetectOptions clones a user set detect option into some new memory that // we can internally manipulate before sending onto the detect package. @@ -36,6 +40,7 @@ func CloneDetectOptions(oldDo *credentials.DetectOptions) *credentials.DetectOpt STSAudience: oldDo.STSAudience, CredentialsFile: oldDo.CredentialsFile, UseSelfSignedJWT: oldDo.UseSelfSignedJWT, + UniverseDomain: oldDo.UniverseDomain, // These fields are are pointer types that we just want to use exactly // as the user set, copy the ref @@ -55,3 +60,17 @@ func CloneDetectOptions(oldDo *credentials.DetectOptions) *credentials.DetectOpt return newDo } + +// ValidateUniverseDomain verifies that the universe domain configured for the +// client matches the universe domain configured for the credentials. +func ValidateUniverseDomain(clientUniverseDomain, credentialsUniverseDomain string) error { + if clientUniverseDomain != credentialsUniverseDomain { + return fmt.Errorf( + "the configured universe domain (%q) does not match the universe "+ + "domain found in the credentials (%q). If you haven't configured "+ + "the universe domain explicitly, \"googleapis.com\" is the default", + clientUniverseDomain, + credentialsUniverseDomain) + } + return nil +} diff --git a/auth/internal/transport/transport_test.go b/auth/internal/transport/transport_test.go index ecd444cc7aee..3518195a55ad 100644 --- a/auth/internal/transport/transport_test.go +++ b/auth/internal/transport/transport_test.go @@ -29,7 +29,7 @@ import ( // future. To make the test pass simply bump the int, but please also clone the // relevant fields. func TestCloneDetectOptions_FieldTest(t *testing.T) { - const WantNumberOfFields = 11 + const WantNumberOfFields = 12 o := credentials.DetectOptions{} got := reflect.TypeOf(o).NumField() if got != WantNumberOfFields { diff --git a/auth/threelegged.go b/auth/threelegged.go index 3433d6afcb00..1b8d83c4b4fe 100644 --- a/auth/threelegged.go +++ b/auth/threelegged.go @@ -123,44 +123,44 @@ func (e *tokenJSON) expiry() (t time.Time) { return } -func (c *Options3LO) client() *http.Client { - if c.Client != nil { - return c.Client +func (o *Options3LO) client() *http.Client { + if o.Client != nil { + return o.Client } return internal.CloneDefaultClient() } // authCodeURL returns a URL that points to a OAuth2 consent page. -func (c *Options3LO) authCodeURL(state string, values url.Values) string { +func (o *Options3LO) authCodeURL(state string, values url.Values) string { var buf bytes.Buffer - buf.WriteString(c.AuthURL) + buf.WriteString(o.AuthURL) v := url.Values{ "response_type": {"code"}, - "client_id": {c.ClientID}, + "client_id": {o.ClientID}, } - if c.RedirectURL != "" { - v.Set("redirect_uri", c.RedirectURL) + if o.RedirectURL != "" { + v.Set("redirect_uri", o.RedirectURL) } - if len(c.Scopes) > 0 { - v.Set("scope", strings.Join(c.Scopes, " ")) + if len(o.Scopes) > 0 { + v.Set("scope", strings.Join(o.Scopes, " ")) } if state != "" { v.Set("state", state) } - if c.AuthHandlerOpts != nil { - if c.AuthHandlerOpts.PKCEOpts != nil && - c.AuthHandlerOpts.PKCEOpts.Challenge != "" { - v.Set(codeChallengeKey, c.AuthHandlerOpts.PKCEOpts.Challenge) + if o.AuthHandlerOpts != nil { + if o.AuthHandlerOpts.PKCEOpts != nil && + o.AuthHandlerOpts.PKCEOpts.Challenge != "" { + v.Set(codeChallengeKey, o.AuthHandlerOpts.PKCEOpts.Challenge) } - if c.AuthHandlerOpts.PKCEOpts != nil && - c.AuthHandlerOpts.PKCEOpts.ChallengeMethod != "" { - v.Set(codeChallengeMethodKey, c.AuthHandlerOpts.PKCEOpts.ChallengeMethod) + if o.AuthHandlerOpts.PKCEOpts != nil && + o.AuthHandlerOpts.PKCEOpts.ChallengeMethod != "" { + v.Set(codeChallengeMethodKey, o.AuthHandlerOpts.PKCEOpts.ChallengeMethod) } } for k := range values { v.Set(k, v.Get(k)) } - if strings.Contains(c.AuthURL, "?") { + if strings.Contains(o.AuthURL, "?") { buf.WriteByte('&') } else { buf.WriteByte('?') @@ -205,24 +205,24 @@ func new3LOTokenProviderWithAuthHandler(opts *Options3LO) TokenProvider { // exchange handles the final exchange portion of the 3lo flow. Returns a Token, // refreshToken, and error. -func (c *Options3LO) exchange(ctx context.Context, code string) (*Token, string, error) { +func (o *Options3LO) exchange(ctx context.Context, code string) (*Token, string, error) { // Build request v := url.Values{ "grant_type": {"authorization_code"}, "code": {code}, } - if c.RedirectURL != "" { - v.Set("redirect_uri", c.RedirectURL) + if o.RedirectURL != "" { + v.Set("redirect_uri", o.RedirectURL) } - if c.AuthHandlerOpts != nil && - c.AuthHandlerOpts.PKCEOpts != nil && - c.AuthHandlerOpts.PKCEOpts.Verifier != "" { - v.Set(codeVerifierKey, c.AuthHandlerOpts.PKCEOpts.Verifier) + if o.AuthHandlerOpts != nil && + o.AuthHandlerOpts.PKCEOpts != nil && + o.AuthHandlerOpts.PKCEOpts.Verifier != "" { + v.Set(codeVerifierKey, o.AuthHandlerOpts.PKCEOpts.Verifier) } - for k := range c.URLParams { - v.Set(k, c.URLParams.Get(k)) + for k := range o.URLParams { + v.Set(k, o.URLParams.Get(k)) } - return fetchToken(ctx, c, v) + return fetchToken(ctx, o, v) } // This struct is not safe for concurrent access alone, but the way it is used @@ -274,27 +274,27 @@ func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) { } // fetchToken returns a Token, refresh token, and/or an error. -func fetchToken(ctx context.Context, c *Options3LO, v url.Values) (*Token, string, error) { +func fetchToken(ctx context.Context, o *Options3LO, v url.Values) (*Token, string, error) { var refreshToken string - if c.AuthStyle == StyleInParams { - if c.ClientID != "" { - v.Set("client_id", c.ClientID) + if o.AuthStyle == StyleInParams { + if o.ClientID != "" { + v.Set("client_id", o.ClientID) } - if c.ClientSecret != "" { - v.Set("client_secret", c.ClientSecret) + if o.ClientSecret != "" { + v.Set("client_secret", o.ClientSecret) } } - req, err := http.NewRequest("POST", c.TokenURL, strings.NewReader(v.Encode())) + req, err := http.NewRequest("POST", o.TokenURL, strings.NewReader(v.Encode())) if err != nil { return nil, refreshToken, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if c.AuthStyle == StyleInHeader { - req.SetBasicAuth(url.QueryEscape(c.ClientID), url.QueryEscape(c.ClientSecret)) + if o.AuthStyle == StyleInHeader { + req.SetBasicAuth(url.QueryEscape(o.ClientID), url.QueryEscape(o.ClientSecret)) } // Make request - r, err := c.client().Do(req.WithContext(ctx)) + r, err := o.client().Do(req.WithContext(ctx)) if err != nil { return nil, refreshToken, err }