From 35df71e869e2ffd268dde1f481d816a9387e0f90 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Mon, 30 Aug 2021 18:31:27 +0200 Subject: [PATCH] Propagate context properly --- access/resource_secret_scope.go | 2 +- common/azure_auth.go | 10 ++++++---- common/azure_auth_test.go | 10 ++++++---- common/azure_cli_auth.go | 4 ++-- common/azure_cli_auth_test.go | 9 +++++---- common/client.go | 18 +++++------------- common/client_test.go | 3 ++- common/gcp.go | 15 ++++++++------- common/gcp_test.go | 11 ++++++----- common/http.go | 2 +- mws/resource_workspace_test.go | 2 +- provider/provider_test.go | 5 +++-- 12 files changed, 46 insertions(+), 45 deletions(-) diff --git a/access/resource_secret_scope.go b/access/resource_secret_scope.go index 1fe3abfc9b..176d5fc745 100644 --- a/access/resource_secret_scope.go +++ b/access/resource_secret_scope.go @@ -59,7 +59,7 @@ func (a SecretScopesAPI) Create(s SecretScope) error { BackendType: "DATABRICKS", } if s.KeyvaultMetadata != nil { - if err := a.client.Authenticate(); err != nil { + if err := a.client.Authenticate(a.context); err != nil { return err } if !a.client.IsAzure() { diff --git a/common/azure_auth.go b/common/azure_auth.go index cd9133ee46..d68023eb03 100644 --- a/common/azure_auth.go +++ b/common/azure_auth.go @@ -84,7 +84,7 @@ func (aa *DatabricksClient) IsAzureClientSecretSet() bool { return aa.AzureClientID != "" && aa.AzureClientSecret != "" && aa.AzureTenantID != "" } -func (aa *DatabricksClient) configureWithClientSecret() (func(r *http.Request) error, error) { +func (aa *DatabricksClient) configureWithClientSecret(ctx context.Context) (func(*http.Request) error, error) { if !aa.IsAzure() { return nil, nil } @@ -110,11 +110,13 @@ func (aa *DatabricksClient) configureWithClientSecret() (func(r *http.Request) e } log.Printf("[INFO] Generating AAD token for Azure Service Principal") - return aa.simpleAADRequestVisitor(aa.InitContext, aa.getClientSecretAuthorizer, aa.addSpManagementTokenVisitor) + return aa.simpleAADRequestVisitor(ctx, aa.getClientSecretAuthorizer, aa.addSpManagementTokenVisitor) } -func (aa *DatabricksClient) configureWithManagedIdentity() (func(r *http.Request) error, error) { - ctx := context.TODO() +func (aa *DatabricksClient) configureWithManagedIdentity(ctx context.Context) (func(*http.Request) error, error) { + if !aa.IsAzure() { + return nil, nil + } if !adal.MSIAvailable(ctx, aa.httpClient.HTTPClient) { return nil, nil } diff --git a/common/azure_auth_test.go b/common/azure_auth_test.go index 54c743ed3e..f405fb020d 100644 --- a/common/azure_auth_test.go +++ b/common/azure_auth_test.go @@ -182,12 +182,13 @@ func TestDatabricksClient_ensureWorkspaceURL(t *testing.T) { func TestDatabricksClient_configureWithClientSecretPAT(t *testing.T) { client := DatabricksClient{InsecureSkipVerify: true} - auth, err := client.configureWithClientSecret() + ctx := context.Background() + auth, err := client.configureWithClientSecret(ctx) assert.Nil(t, auth) assert.NoError(t, err) client.AzureDatabricksResourceID = "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c" - auth, err = client.configureWithClientSecret() + auth, err = client.configureWithClientSecret(ctx) assert.Nil(t, auth) assert.NoError(t, err) @@ -245,7 +246,7 @@ func TestDatabricksClient_configureWithClientSecretPAT(t *testing.T) { client.AzureEnvironment = &azure.Environment{ ResourceManagerEndpoint: fmt.Sprintf("%s/", server.URL), } - auth, err = client.configureWithClientSecret() + auth, err = client.configureWithClientSecret(ctx) assert.NotNil(t, auth) assert.NoError(t, err) @@ -307,7 +308,8 @@ func TestDatabricksClient_configureWithClientSecretAAD(t *testing.T) { ResourceManagerEndpoint: fmt.Sprintf("%s/", server.URL), } client.configureHTTPCLient() - auth, err := client.configureWithClientSecret() + ctx := context.Background() + auth, err := client.configureWithClientSecret(ctx) assert.NoError(t, err) client.authVisitor = auth diff --git a/common/azure_cli_auth.go b/common/azure_cli_auth.go index fc1a7e16dc..87fa215e68 100644 --- a/common/azure_cli_auth.go +++ b/common/azure_cli_auth.go @@ -93,7 +93,7 @@ func (aa *DatabricksClient) cliAuthorizer(resource string) (autorest.Authorizer, return autorest.NewBearerAuthorizer(&rct), nil } -func (aa *DatabricksClient) configureWithAzureCLI() (func(r *http.Request) error, error) { +func (aa *DatabricksClient) configureWithAzureCLI(ctx context.Context) (func(*http.Request) error, error) { if !aa.IsAzure() { return nil, nil } @@ -126,5 +126,5 @@ func (aa *DatabricksClient) configureWithAzureCLI() (func(r *http.Request) error }, nil } log.Printf("[INFO] Using Azure CLI authentication with AAD tokens") - return aa.simpleAADRequestVisitor(context.TODO(), aa.cliAuthorizer) + return aa.simpleAADRequestVisitor(ctx, aa.cliAuthorizer) } diff --git a/common/azure_cli_auth_test.go b/common/azure_cli_auth_test.go index b4f974251a..b89822a14f 100644 --- a/common/azure_cli_auth_test.go +++ b/common/azure_cli_auth_test.go @@ -175,7 +175,8 @@ func TestConfigureWithAzureCLI_SP(t *testing.T) { AzureTenantID: "c", AzureDatabricksResourceID: "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c", } - auth, err := aa.configureWithAzureCLI() + ctx := context.Background() + auth, err := aa.configureWithAzureCLI(ctx) assert.NoError(t, err) assert.Nil(t, auth) } @@ -193,7 +194,7 @@ func TestConfigureWithAzureCLI(t *testing.T) { client.AzureDatabricksResourceID = "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c" client.AzureUsePATForCLI = true - auth, err := client.configureWithAzureCLI() + auth, err := client.configureWithAzureCLI(context.Background()) assert.NoError(t, err) err = auth(httptest.NewRequest("GET", "/clusters/list", http.NoBody)) @@ -213,7 +214,7 @@ func TestConfigureWithAzureCLI_Error(t *testing.T) { client.AzureDatabricksResourceID = "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c" client.AzureUsePATForCLI = true - auth, err := client.configureWithAzureCLI() + auth, err := client.configureWithAzureCLI(context.Background()) assert.NoError(t, err) err = auth(httptest.NewRequest("GET", "/clusters/list", http.NoBody)) @@ -235,7 +236,7 @@ func TestConfigureWithAzureCLI_NotInstalled(t *testing.T) { client.AzureDatabricksResourceID = "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c" client.AzureUsePATForCLI = true - _, err := client.configureWithAzureCLI() + _, err := client.configureWithAzureCLI(context.Background()) require.Error(t, err) assert.True(t, strings.HasPrefix(err.Error(), "most likely Azure CLI is not installed"), "Actual message: %s", err.Error()) diff --git a/common/client.go b/common/client.go index a2a69992b0..6ee84de773 100644 --- a/common/client.go +++ b/common/client.go @@ -100,10 +100,6 @@ type DatabricksClient struct { // options used to enable unit testing mode for OIDC googleAuthOptions []option.ClientOption - // Context used during provider initialisation, - // mostly for OAuth-based validation. - InitContext context.Context - // Mutex used by Authenticate method to guard `authVisitor`, which // has to be lazily created on the first request to Databricks API. // It is done because databricks host and token may often be available @@ -208,8 +204,7 @@ func (c *DatabricksClient) Configure(attrsUsed ...string) error { } // Authenticate lazily authenticates across authorizers or returns error -func (c *DatabricksClient) Authenticate() error { - // TODO: add context +func (c *DatabricksClient) Authenticate(ctx context.Context) error { if c.authVisitor != nil { return nil } @@ -218,7 +213,7 @@ func (c *DatabricksClient) Authenticate() error { if c.authVisitor != nil { return nil } - authorizers := []func() (func(r *http.Request) error, error){ + authorizers := []func(context.Context) (func(*http.Request) error, error){ c.configureAuthWithDirectParams, c.configureWithClientSecret, c.configureWithManagedIdentity, @@ -228,7 +223,7 @@ func (c *DatabricksClient) Authenticate() error { c.configureFromDatabricksCfg, } for _, authProvider := range authorizers { - authorizer, err := authProvider() + authorizer, err := authProvider(ctx) if err != nil { return c.niceError(fmt.Sprintf("cannot configure auth: %s", err)) } @@ -263,7 +258,7 @@ func (c *DatabricksClient) fixHost() { } } -func (c *DatabricksClient) configureAuthWithDirectParams() (func(r *http.Request) error, error) { +func (c *DatabricksClient) configureAuthWithDirectParams(ctx context.Context) (func(*http.Request) error, error) { authType := "Bearer" var needsHostBecause string if c.Username != "" && c.Password != "" { @@ -285,7 +280,7 @@ func (c *DatabricksClient) configureAuthWithDirectParams() (func(r *http.Request return c.authorizer(authType, c.Token), nil } -func (c *DatabricksClient) configureFromDatabricksCfg() (func(r *http.Request) error, error) { +func (c *DatabricksClient) configureFromDatabricksCfg(ctx context.Context) (func(r *http.Request) error, error) { configFile := c.ConfigFile if configFile == "" { configFile = "~/.databrickscfg" @@ -354,9 +349,6 @@ func (c *DatabricksClient) configureHTTPCLient() { if c.RateLimitPerSecond == 0 { c.RateLimitPerSecond = DefaultRateLimitPerSecond } - if c.InitContext == nil { - c.InitContext = context.Background() - } c.rateLimiter = rate.NewLimiter(rate.Limit(c.RateLimitPerSecond), 1) // Set up a retryable HTTP Client to handle cases where the service returns // a transient error on initial creation diff --git a/common/client_test.go b/common/client_test.go index e22f153ea7..c7c1d759ef 100644 --- a/common/client_test.go +++ b/common/client_test.go @@ -1,6 +1,7 @@ package common import ( + "context" "os" "strings" "testing" @@ -17,7 +18,7 @@ func configureAndAuthenticate(dc *DatabricksClient) (*DatabricksClient, error) { if err != nil { return dc, err } - return dc, dc.Authenticate() + return dc, dc.Authenticate(context.Background()) } func TestDatabricksClientConfigure_Nothing(t *testing.T) { diff --git a/common/gcp.go b/common/gcp.go index 98fa653799..e6294edcd9 100644 --- a/common/gcp.go +++ b/common/gcp.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "net/http" @@ -8,9 +9,9 @@ import ( "google.golang.org/api/impersonate" ) -func (c *DatabricksClient) getGoogleOIDCSource() (oauth2.TokenSource, error) { +func (c *DatabricksClient) getGoogleOIDCSource(ctx context.Context) (oauth2.TokenSource, error) { // source for generateIdToken - ts, err := impersonate.IDTokenSource(c.InitContext, impersonate.IDTokenConfig{ + ts, err := impersonate.IDTokenSource(ctx, impersonate.IDTokenConfig{ Audience: c.Host, TargetPrincipal: c.GoogleServiceAccount, IncludeEmail: true, @@ -24,16 +25,16 @@ func (c *DatabricksClient) getGoogleOIDCSource() (oauth2.TokenSource, error) { return ts, nil } -func (c *DatabricksClient) configureWithGoogleForAccountsAPI() (func(r *http.Request) error, error) { +func (c *DatabricksClient) configureWithGoogleForAccountsAPI(ctx context.Context) (func(*http.Request) error, error) { if c.GoogleServiceAccount == "" || !c.IsGcp() || !c.isAccountsClient() { return nil, nil } - oidcSource, err := c.getGoogleOIDCSource() + oidcSource, err := c.getGoogleOIDCSource(ctx) if err != nil { return nil, err } // source for generateAccessToken - platformSource, err := impersonate.CredentialsTokenSource(c.InitContext, impersonate.CredentialsConfig{ + platformSource, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{ TargetPrincipal: c.GoogleServiceAccount, Scopes: []string{ "https://www.googleapis.com/auth/cloud-platform", @@ -63,11 +64,11 @@ func newOidcAuthorizerForAccountsAPI(oidcSource oauth2.TokenSource, } } -func (c *DatabricksClient) configureWithGoogleForWorkspace() (func(r *http.Request) error, error) { +func (c *DatabricksClient) configureWithGoogleForWorkspace(ctx context.Context) (func(r *http.Request) error, error) { if c.GoogleServiceAccount == "" || !c.IsGcp() || c.isAccountsClient() { return nil, nil } - oidcSource, err := c.getGoogleOIDCSource() + oidcSource, err := c.getGoogleOIDCSource(ctx) if err != nil { return nil, err } diff --git a/common/gcp_test.go b/common/gcp_test.go index ac78404eb6..20df121e40 100644 --- a/common/gcp_test.go +++ b/common/gcp_test.go @@ -1,6 +1,7 @@ package common import ( + "context" "net/http/httptest" "testing" @@ -21,7 +22,7 @@ func TestGoogleOIDC(t *testing.T) { } client.configureHTTPCLient() - _, err := client.getGoogleOIDCSource() + _, err := client.getGoogleOIDCSource(context.Background()) require.NoError(t, err) } @@ -33,11 +34,11 @@ func TestConfigureWithGoogleForAccountsAPI(t *testing.T) { } client.configureHTTPCLient() - _, err := client.configureWithGoogleForAccountsAPI() + _, err := client.configureWithGoogleForAccountsAPI(context.Background()) assert.Error(t, err) client.googleAuthOptions = []option.ClientOption{option.WithoutAuthentication()} - a, err := client.configureWithGoogleForAccountsAPI() + a, err := client.configureWithGoogleForAccountsAPI(context.Background()) require.NoError(t, err) assert.NotNil(t, a) } @@ -50,11 +51,11 @@ func TestConfigureWithGoogleForWorkspace(t *testing.T) { } client.configureHTTPCLient() - _, err := client.configureWithGoogleForWorkspace() + _, err := client.configureWithGoogleForWorkspace(context.Background()) assert.Error(t, err) client.googleAuthOptions = []option.ClientOption{option.WithoutAuthentication()} - a, err := client.configureWithGoogleForWorkspace() + a, err := client.configureWithGoogleForWorkspace(context.Background()) require.NoError(t, err) assert.NotNil(t, a) } diff --git a/common/http.go b/common/http.go index 81e04a1c25..b68dbb9a20 100644 --- a/common/http.go +++ b/common/http.go @@ -355,7 +355,7 @@ func (c *DatabricksClient) OldAPI(ctx context.Context, method, path string, requ func (c *DatabricksClient) authenticatedQuery(ctx context.Context, method, requestURL string, data interface{}, visitors ...func(*http.Request) error) (body []byte, err error) { - err = c.Authenticate() + err = c.Authenticate(ctx) if err != nil { return } diff --git a/mws/resource_workspace_test.go b/mws/resource_workspace_test.go index 81f2bf1c39..f94c9f9172 100644 --- a/mws/resource_workspace_test.go +++ b/mws/resource_workspace_test.go @@ -30,7 +30,7 @@ func TestMwsAccWorkspace(t *testing.T) { func TestGcpAccWorkspace(t *testing.T) { acctID := qa.GetEnvOrSkipTest(t, "DATABRICKS_ACCOUNT_ID") client := common.CommonEnvironmentClient() - workspacesAPI := NewWorkspacesAPI(client.InitContext, client) + workspacesAPI := NewWorkspacesAPI(context.Background(), client) workspaceList, err := workspacesAPI.List(acctID) require.NoError(t, err, err) diff --git a/provider/provider_test.go b/provider/provider_test.go index 160527a217..c79bd94c06 100644 --- a/provider/provider_test.go +++ b/provider/provider_test.go @@ -417,7 +417,8 @@ func configureProviderAndReturnClient(t *testing.T, tt providerFixture) (*common os.Setenv(k, v) } p := DatabricksProvider() - diags := p.Configure(context.Background(), terraform.NewResourceConfigRaw(tt.rawConfig())) + ctx := context.Background() + diags := p.Configure(ctx, terraform.NewResourceConfigRaw(tt.rawConfig())) if len(diags) > 0 { issues := []string{} for _, d := range diags { @@ -427,7 +428,7 @@ func configureProviderAndReturnClient(t *testing.T, tt providerFixture) (*common } client := p.Meta().(*common.DatabricksClient) client.AzureUsePATForSPN = tt.usePATForSPN - err := client.Authenticate() + err := client.Authenticate(ctx) if err != nil { return nil, err }