From f482e78f17453330cc74ec6c1e59ec5b70242e8b Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Thu, 6 Feb 2025 23:16:52 +0100 Subject: [PATCH 1/2] Enable async refrehses --- config/experimental/auth/auth.go | 12 +++--------- config/experimental/auth/auth_test.go | 9 +++++---- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/config/experimental/auth/auth.go b/config/experimental/auth/auth.go index abea1aef..ca1d4dc2 100644 --- a/config/experimental/auth/auth.go +++ b/config/experimental/auth/auth.go @@ -16,10 +16,6 @@ const ( // Default duration for the stale period. The number as been set arbitrarily // and might be changed in the future. defaultStaleDuration = 3 * time.Minute - - // Disable the asynchronous token refresh by default. This is meant to - // change in the future once the feature is stable. - defaultDisableAsyncRefresh = true ) // A TokenSource is anything that can return a token. @@ -45,9 +41,9 @@ func WithAsyncRefresh(b bool) Option { } } -// NewCachedTokenProvider wraps a [oauth2.TokenSource] to cache the tokens -// it returns. By default, the cache will refresh tokens asynchronously a few -// minutes before they expire. +// NewCachedTokenProvider wraps a [TokenSource] to cache the tokens it returns. +// By default, the cache will refresh tokens asynchronously a few minutes before +// they expire. // // The token cache is safe for concurrent use by multiple goroutines and will // guarantee that only one token refresh is triggered at a time. @@ -69,8 +65,6 @@ func NewCachedTokenSource(ts TokenSource, opts ...Option) TokenSource { cts := &cachedTokenSource{ tokenSource: ts, staleDuration: defaultStaleDuration, - disableAsync: defaultDisableAsyncRefresh, - cachedToken: nil, timeNow: time.Now, } diff --git a/config/experimental/auth/auth_test.go b/config/experimental/auth/auth_test.go index 24a0d13b..9d143f29 100644 --- a/config/experimental/auth/auth_test.go +++ b/config/experimental/auth/auth_test.go @@ -39,8 +39,8 @@ func TestNewCachedTokenSource_default(t *testing.T) { if got.staleDuration != defaultStaleDuration { t.Errorf("NewCachedTokenSource() staleDuration = %v, want %v", got.staleDuration, defaultStaleDuration) } - if got.disableAsync != defaultDisableAsyncRefresh { - t.Errorf("NewCachedTokenSource() disableAsync = %v, want %v", got.disableAsync, defaultDisableAsyncRefresh) + if got.disableAsync != false { + t.Errorf("NewCachedTokenSource() disableAsync = %v, want %v", got.disableAsync, false) } if got.cachedToken != nil { t.Errorf("NewCachedTokenSource() cachedToken = %v, want nil", got.cachedToken) @@ -227,7 +227,7 @@ func TestCachedTokenSource_Token(t *testing.T) { desc: "[Async] stale cached token, expired token returned", cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, returnedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, - wantCalls: 10, + wantCalls: 1, wantToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, }, { @@ -250,6 +250,7 @@ func TestCachedTokenSource_Token(t *testing.T) { timeNow: func() time.Time { return now }, tokenSource: mockTokenSource(func() (*oauth2.Token, error) { atomic.AddInt32(&gotCalls, 1) + time.Sleep(10 * time.Millisecond) return tc.returnedToken, tc.returnedError }), } @@ -268,7 +269,7 @@ func TestCachedTokenSource_Token(t *testing.T) { // Wait for async refreshes to finish. This part is a little brittle // but necessary to ensure that the async refresh is done before // checking the results. - time.Sleep(10 * time.Millisecond) + time.Sleep(20 * time.Millisecond) if int(gotCalls) != tc.wantCalls { t.Errorf("want %d calls to cts.tokenSource.Token(), got %d", tc.wantCalls, gotCalls) From 33f9a012282ebe68fb4db513d63883d71dbf602b Mon Sep 17 00:00:00 2001 From: Renaud Hartert Date: Fri, 7 Feb 2025 16:47:49 +0100 Subject: [PATCH 2/2] Remove test --- config/auth_azure_cli_test.go | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/config/auth_azure_cli_test.go b/config/auth_azure_cli_test.go index c2c8f742..af02e97c 100644 --- a/config/auth_azure_cli_test.go +++ b/config/auth_azure_cli_test.go @@ -118,31 +118,6 @@ func TestAzureCliCredentials_Valid(t *testing.T) { assert.Equal(t, "...", r.Header.Get("X-Databricks-Azure-SP-Management-Token")) } -func TestAzureCliCredentials_ReuseTokens(t *testing.T) { - env.CleanupEnvironment(t) - os.Setenv("PATH", testdataPath()) - os.Setenv("EXPIRE", "10M") - - // Use temporary file to store the number of calls to the AZ CLI. - tmp := t.TempDir() - count := filepath.Join(tmp, "count") - os.Setenv("COUNT", count) - - aa := AzureCliCredentials{} - visitor, err := aa.Configure(context.Background(), azDummy) - assert.NoError(t, err) - - r := &http.Request{Header: http.Header{}} - err = visitor.SetHeaders(r) - assert.NoError(t, err) - - // We verify the headers in the test above. - // This test validates we do not call the AZ CLI more than we need. - buf, err := os.ReadFile(count) - require.NoError(t, err) - assert.Len(t, buf, 2, "Expected the AZ CLI to be called twice") -} - func TestAzureCliCredentials_ValidNoManagementAccess(t *testing.T) { env.CleanupEnvironment(t) os.Setenv("PATH", testdataPath())