Skip to content

Commit

Permalink
Propagate context properly
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx committed Aug 30, 2021
1 parent ca32dc7 commit 35df71e
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 45 deletions.
2 changes: 1 addition & 1 deletion access/resource_secret_scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (a SecretScopesAPI) Create(s SecretScope) error {
BackendType: "DATABRICKS",
}
if s.KeyvaultMetadata != nil {
if err := a.client.Authenticate(); err != nil {
if err := a.client.Authenticate(a.context); err != nil {
return err
}
if !a.client.IsAzure() {
Expand Down
10 changes: 6 additions & 4 deletions common/azure_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (aa *DatabricksClient) IsAzureClientSecretSet() bool {
return aa.AzureClientID != "" && aa.AzureClientSecret != "" && aa.AzureTenantID != ""
}

func (aa *DatabricksClient) configureWithClientSecret() (func(r *http.Request) error, error) {
func (aa *DatabricksClient) configureWithClientSecret(ctx context.Context) (func(*http.Request) error, error) {
if !aa.IsAzure() {
return nil, nil
}
Expand All @@ -110,11 +110,13 @@ func (aa *DatabricksClient) configureWithClientSecret() (func(r *http.Request) e
}

log.Printf("[INFO] Generating AAD token for Azure Service Principal")
return aa.simpleAADRequestVisitor(aa.InitContext, aa.getClientSecretAuthorizer, aa.addSpManagementTokenVisitor)
return aa.simpleAADRequestVisitor(ctx, aa.getClientSecretAuthorizer, aa.addSpManagementTokenVisitor)
}

func (aa *DatabricksClient) configureWithManagedIdentity() (func(r *http.Request) error, error) {
ctx := context.TODO()
func (aa *DatabricksClient) configureWithManagedIdentity(ctx context.Context) (func(*http.Request) error, error) {
if !aa.IsAzure() {
return nil, nil
}
if !adal.MSIAvailable(ctx, aa.httpClient.HTTPClient) {
return nil, nil
}
Expand Down
10 changes: 6 additions & 4 deletions common/azure_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ func TestDatabricksClient_ensureWorkspaceURL(t *testing.T) {

func TestDatabricksClient_configureWithClientSecretPAT(t *testing.T) {
client := DatabricksClient{InsecureSkipVerify: true}
auth, err := client.configureWithClientSecret()
ctx := context.Background()
auth, err := client.configureWithClientSecret(ctx)
assert.Nil(t, auth)
assert.NoError(t, err)

client.AzureDatabricksResourceID = "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c"
auth, err = client.configureWithClientSecret()
auth, err = client.configureWithClientSecret(ctx)
assert.Nil(t, auth)
assert.NoError(t, err)

Expand Down Expand Up @@ -245,7 +246,7 @@ func TestDatabricksClient_configureWithClientSecretPAT(t *testing.T) {
client.AzureEnvironment = &azure.Environment{
ResourceManagerEndpoint: fmt.Sprintf("%s/", server.URL),
}
auth, err = client.configureWithClientSecret()
auth, err = client.configureWithClientSecret(ctx)
assert.NotNil(t, auth)
assert.NoError(t, err)

Expand Down Expand Up @@ -307,7 +308,8 @@ func TestDatabricksClient_configureWithClientSecretAAD(t *testing.T) {
ResourceManagerEndpoint: fmt.Sprintf("%s/", server.URL),
}
client.configureHTTPCLient()
auth, err := client.configureWithClientSecret()
ctx := context.Background()
auth, err := client.configureWithClientSecret(ctx)
assert.NoError(t, err)

client.authVisitor = auth
Expand Down
4 changes: 2 additions & 2 deletions common/azure_cli_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (aa *DatabricksClient) cliAuthorizer(resource string) (autorest.Authorizer,
return autorest.NewBearerAuthorizer(&rct), nil
}

func (aa *DatabricksClient) configureWithAzureCLI() (func(r *http.Request) error, error) {
func (aa *DatabricksClient) configureWithAzureCLI(ctx context.Context) (func(*http.Request) error, error) {
if !aa.IsAzure() {
return nil, nil
}
Expand Down Expand Up @@ -126,5 +126,5 @@ func (aa *DatabricksClient) configureWithAzureCLI() (func(r *http.Request) error
}, nil
}
log.Printf("[INFO] Using Azure CLI authentication with AAD tokens")
return aa.simpleAADRequestVisitor(context.TODO(), aa.cliAuthorizer)
return aa.simpleAADRequestVisitor(ctx, aa.cliAuthorizer)
}
9 changes: 5 additions & 4 deletions common/azure_cli_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ func TestConfigureWithAzureCLI_SP(t *testing.T) {
AzureTenantID: "c",
AzureDatabricksResourceID: "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c",
}
auth, err := aa.configureWithAzureCLI()
ctx := context.Background()
auth, err := aa.configureWithAzureCLI(ctx)
assert.NoError(t, err)
assert.Nil(t, auth)
}
Expand All @@ -193,7 +194,7 @@ func TestConfigureWithAzureCLI(t *testing.T) {
client.AzureDatabricksResourceID = "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c"
client.AzureUsePATForCLI = true

auth, err := client.configureWithAzureCLI()
auth, err := client.configureWithAzureCLI(context.Background())
assert.NoError(t, err)

err = auth(httptest.NewRequest("GET", "/clusters/list", http.NoBody))
Expand All @@ -213,7 +214,7 @@ func TestConfigureWithAzureCLI_Error(t *testing.T) {
client.AzureDatabricksResourceID = "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c"
client.AzureUsePATForCLI = true

auth, err := client.configureWithAzureCLI()
auth, err := client.configureWithAzureCLI(context.Background())
assert.NoError(t, err)

err = auth(httptest.NewRequest("GET", "/clusters/list", http.NoBody))
Expand All @@ -235,7 +236,7 @@ func TestConfigureWithAzureCLI_NotInstalled(t *testing.T) {
client.AzureDatabricksResourceID = "/subscriptions/a/resourceGroups/b/providers/Microsoft.Databricks/workspaces/c"
client.AzureUsePATForCLI = true

_, err := client.configureWithAzureCLI()
_, err := client.configureWithAzureCLI(context.Background())
require.Error(t, err)
assert.True(t, strings.HasPrefix(err.Error(), "most likely Azure CLI is not installed"),
"Actual message: %s", err.Error())
Expand Down
18 changes: 5 additions & 13 deletions common/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ type DatabricksClient struct {
// options used to enable unit testing mode for OIDC
googleAuthOptions []option.ClientOption

// Context used during provider initialisation,
// mostly for OAuth-based validation.
InitContext context.Context

// Mutex used by Authenticate method to guard `authVisitor`, which
// has to be lazily created on the first request to Databricks API.
// It is done because databricks host and token may often be available
Expand Down Expand Up @@ -208,8 +204,7 @@ func (c *DatabricksClient) Configure(attrsUsed ...string) error {
}

// Authenticate lazily authenticates across authorizers or returns error
func (c *DatabricksClient) Authenticate() error {
// TODO: add context
func (c *DatabricksClient) Authenticate(ctx context.Context) error {
if c.authVisitor != nil {
return nil
}
Expand All @@ -218,7 +213,7 @@ func (c *DatabricksClient) Authenticate() error {
if c.authVisitor != nil {
return nil
}
authorizers := []func() (func(r *http.Request) error, error){
authorizers := []func(context.Context) (func(*http.Request) error, error){
c.configureAuthWithDirectParams,
c.configureWithClientSecret,
c.configureWithManagedIdentity,
Expand All @@ -228,7 +223,7 @@ func (c *DatabricksClient) Authenticate() error {
c.configureFromDatabricksCfg,
}
for _, authProvider := range authorizers {
authorizer, err := authProvider()
authorizer, err := authProvider(ctx)
if err != nil {
return c.niceError(fmt.Sprintf("cannot configure auth: %s", err))
}
Expand Down Expand Up @@ -263,7 +258,7 @@ func (c *DatabricksClient) fixHost() {
}
}

func (c *DatabricksClient) configureAuthWithDirectParams() (func(r *http.Request) error, error) {
func (c *DatabricksClient) configureAuthWithDirectParams(ctx context.Context) (func(*http.Request) error, error) {
authType := "Bearer"
var needsHostBecause string
if c.Username != "" && c.Password != "" {
Expand All @@ -285,7 +280,7 @@ func (c *DatabricksClient) configureAuthWithDirectParams() (func(r *http.Request
return c.authorizer(authType, c.Token), nil
}

func (c *DatabricksClient) configureFromDatabricksCfg() (func(r *http.Request) error, error) {
func (c *DatabricksClient) configureFromDatabricksCfg(ctx context.Context) (func(r *http.Request) error, error) {
configFile := c.ConfigFile
if configFile == "" {
configFile = "~/.databrickscfg"
Expand Down Expand Up @@ -354,9 +349,6 @@ func (c *DatabricksClient) configureHTTPCLient() {
if c.RateLimitPerSecond == 0 {
c.RateLimitPerSecond = DefaultRateLimitPerSecond
}
if c.InitContext == nil {
c.InitContext = context.Background()
}
c.rateLimiter = rate.NewLimiter(rate.Limit(c.RateLimitPerSecond), 1)
// Set up a retryable HTTP Client to handle cases where the service returns
// a transient error on initial creation
Expand Down
3 changes: 2 additions & 1 deletion common/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"os"
"strings"
"testing"
Expand All @@ -17,7 +18,7 @@ func configureAndAuthenticate(dc *DatabricksClient) (*DatabricksClient, error) {
if err != nil {
return dc, err
}
return dc, dc.Authenticate()
return dc, dc.Authenticate(context.Background())
}

func TestDatabricksClientConfigure_Nothing(t *testing.T) {
Expand Down
15 changes: 8 additions & 7 deletions common/gcp.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package common

import (
"context"
"fmt"
"net/http"

"golang.org/x/oauth2"
"google.golang.org/api/impersonate"
)

func (c *DatabricksClient) getGoogleOIDCSource() (oauth2.TokenSource, error) {
func (c *DatabricksClient) getGoogleOIDCSource(ctx context.Context) (oauth2.TokenSource, error) {
// source for generateIdToken
ts, err := impersonate.IDTokenSource(c.InitContext, impersonate.IDTokenConfig{
ts, err := impersonate.IDTokenSource(ctx, impersonate.IDTokenConfig{
Audience: c.Host,
TargetPrincipal: c.GoogleServiceAccount,
IncludeEmail: true,
Expand All @@ -24,16 +25,16 @@ func (c *DatabricksClient) getGoogleOIDCSource() (oauth2.TokenSource, error) {
return ts, nil
}

func (c *DatabricksClient) configureWithGoogleForAccountsAPI() (func(r *http.Request) error, error) {
func (c *DatabricksClient) configureWithGoogleForAccountsAPI(ctx context.Context) (func(*http.Request) error, error) {
if c.GoogleServiceAccount == "" || !c.IsGcp() || !c.isAccountsClient() {
return nil, nil
}
oidcSource, err := c.getGoogleOIDCSource()
oidcSource, err := c.getGoogleOIDCSource(ctx)
if err != nil {
return nil, err
}
// source for generateAccessToken
platformSource, err := impersonate.CredentialsTokenSource(c.InitContext, impersonate.CredentialsConfig{
platformSource, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
TargetPrincipal: c.GoogleServiceAccount,
Scopes: []string{
"https://www.googleapis.com/auth/cloud-platform",
Expand Down Expand Up @@ -63,11 +64,11 @@ func newOidcAuthorizerForAccountsAPI(oidcSource oauth2.TokenSource,
}
}

func (c *DatabricksClient) configureWithGoogleForWorkspace() (func(r *http.Request) error, error) {
func (c *DatabricksClient) configureWithGoogleForWorkspace(ctx context.Context) (func(r *http.Request) error, error) {
if c.GoogleServiceAccount == "" || !c.IsGcp() || c.isAccountsClient() {
return nil, nil
}
oidcSource, err := c.getGoogleOIDCSource()
oidcSource, err := c.getGoogleOIDCSource(ctx)
if err != nil {
return nil, err
}
Expand Down
11 changes: 6 additions & 5 deletions common/gcp_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"net/http/httptest"
"testing"

Expand All @@ -21,7 +22,7 @@ func TestGoogleOIDC(t *testing.T) {
}
client.configureHTTPCLient()

_, err := client.getGoogleOIDCSource()
_, err := client.getGoogleOIDCSource(context.Background())
require.NoError(t, err)
}

Expand All @@ -33,11 +34,11 @@ func TestConfigureWithGoogleForAccountsAPI(t *testing.T) {
}
client.configureHTTPCLient()

_, err := client.configureWithGoogleForAccountsAPI()
_, err := client.configureWithGoogleForAccountsAPI(context.Background())
assert.Error(t, err)

client.googleAuthOptions = []option.ClientOption{option.WithoutAuthentication()}
a, err := client.configureWithGoogleForAccountsAPI()
a, err := client.configureWithGoogleForAccountsAPI(context.Background())
require.NoError(t, err)
assert.NotNil(t, a)
}
Expand All @@ -50,11 +51,11 @@ func TestConfigureWithGoogleForWorkspace(t *testing.T) {
}
client.configureHTTPCLient()

_, err := client.configureWithGoogleForWorkspace()
_, err := client.configureWithGoogleForWorkspace(context.Background())
assert.Error(t, err)

client.googleAuthOptions = []option.ClientOption{option.WithoutAuthentication()}
a, err := client.configureWithGoogleForWorkspace()
a, err := client.configureWithGoogleForWorkspace(context.Background())
require.NoError(t, err)
assert.NotNil(t, a)
}
Expand Down
2 changes: 1 addition & 1 deletion common/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ func (c *DatabricksClient) OldAPI(ctx context.Context, method, path string, requ

func (c *DatabricksClient) authenticatedQuery(ctx context.Context, method, requestURL string,
data interface{}, visitors ...func(*http.Request) error) (body []byte, err error) {
err = c.Authenticate()
err = c.Authenticate(ctx)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion mws/resource_workspace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestMwsAccWorkspace(t *testing.T) {
func TestGcpAccWorkspace(t *testing.T) {
acctID := qa.GetEnvOrSkipTest(t, "DATABRICKS_ACCOUNT_ID")
client := common.CommonEnvironmentClient()
workspacesAPI := NewWorkspacesAPI(client.InitContext, client)
workspacesAPI := NewWorkspacesAPI(context.Background(), client)

workspaceList, err := workspacesAPI.List(acctID)
require.NoError(t, err, err)
Expand Down
5 changes: 3 additions & 2 deletions provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ func configureProviderAndReturnClient(t *testing.T, tt providerFixture) (*common
os.Setenv(k, v)
}
p := DatabricksProvider()
diags := p.Configure(context.Background(), terraform.NewResourceConfigRaw(tt.rawConfig()))
ctx := context.Background()
diags := p.Configure(ctx, terraform.NewResourceConfigRaw(tt.rawConfig()))
if len(diags) > 0 {
issues := []string{}
for _, d := range diags {
Expand All @@ -427,7 +428,7 @@ func configureProviderAndReturnClient(t *testing.T, tt providerFixture) (*common
}
client := p.Meta().(*common.DatabricksClient)
client.AzureUsePATForSPN = tt.usePATForSPN
err := client.Authenticate()
err := client.Authenticate(ctx)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 35df71e

Please sign in to comment.