Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Enable async refreshes for OAuth tokens #1143

Merged
merged 3 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions config/auth_azure_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,6 @@ func TestAzureCliCredentials_Valid(t *testing.T) {
assert.Equal(t, "...", r.Header.Get("X-Databricks-Azure-SP-Management-Token"))
}

func TestAzureCliCredentials_ReuseTokens(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", testdataPath())
os.Setenv("EXPIRE", "10M")

// Use temporary file to store the number of calls to the AZ CLI.
tmp := t.TempDir()
count := filepath.Join(tmp, "count")
os.Setenv("COUNT", count)

aa := AzureCliCredentials{}
visitor, err := aa.Configure(context.Background(), azDummy)
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor.SetHeaders(r)
assert.NoError(t, err)

// We verify the headers in the test above.
// This test validates we do not call the AZ CLI more than we need.
buf, err := os.ReadFile(count)
require.NoError(t, err)
assert.Len(t, buf, 2, "Expected the AZ CLI to be called twice")
}

func TestAzureCliCredentials_ValidNoManagementAccess(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", testdataPath())
Expand Down
12 changes: 3 additions & 9 deletions config/experimental/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ const (
// Default duration for the stale period. The number as been set arbitrarily
// and might be changed in the future.
defaultStaleDuration = 3 * time.Minute

// Disable the asynchronous token refresh by default. This is meant to
// change in the future once the feature is stable.
defaultDisableAsyncRefresh = true
)

// A TokenSource is anything that can return a token.
Expand Down Expand Up @@ -59,9 +55,9 @@ func WithAsyncRefresh(b bool) Option {
}
}

// NewCachedTokenProvider wraps a [oauth2.TokenSource] to cache the tokens
// it returns. By default, the cache will refresh tokens asynchronously a few
// minutes before they expire.
// NewCachedTokenProvider wraps a [TokenSource] to cache the tokens it returns.
// By default, the cache will refresh tokens asynchronously a few minutes before
// they expire.
//
// The token cache is safe for concurrent use by multiple goroutines and will
// guarantee that only one token refresh is triggered at a time.
Expand All @@ -83,8 +79,6 @@ func NewCachedTokenSource(ts TokenSource, opts ...Option) TokenSource {
cts := &cachedTokenSource{
tokenSource: ts,
staleDuration: defaultStaleDuration,
disableAsync: defaultDisableAsyncRefresh,
cachedToken: nil,
timeNow: time.Now,
}

Expand Down
9 changes: 5 additions & 4 deletions config/experimental/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func TestNewCachedTokenSource_default(t *testing.T) {
if got.staleDuration != defaultStaleDuration {
t.Errorf("NewCachedTokenSource() staleDuration = %v, want %v", got.staleDuration, defaultStaleDuration)
}
if got.disableAsync != defaultDisableAsyncRefresh {
t.Errorf("NewCachedTokenSource() disableAsync = %v, want %v", got.disableAsync, defaultDisableAsyncRefresh)
if got.disableAsync != false {
t.Errorf("NewCachedTokenSource() disableAsync = %v, want %v", got.disableAsync, false)
}
if got.cachedToken != nil {
t.Errorf("NewCachedTokenSource() cachedToken = %v, want nil", got.cachedToken)
Expand Down Expand Up @@ -221,7 +221,7 @@ func TestCachedTokenSource_Token(t *testing.T) {
desc: "[Async] stale cached token, expired token returned",
cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)},
returnedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)},
wantCalls: 10,
wantCalls: 1,
wantToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)},
},
{
Expand All @@ -244,6 +244,7 @@ func TestCachedTokenSource_Token(t *testing.T) {
timeNow: func() time.Time { return now },
tokenSource: TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) {
atomic.AddInt32(&gotCalls, 1)
time.Sleep(10 * time.Millisecond)
return tc.returnedToken, tc.returnedError
}),
}
Expand All @@ -262,7 +263,7 @@ func TestCachedTokenSource_Token(t *testing.T) {
// Wait for async refreshes to finish. This part is a little brittle
// but necessary to ensure that the async refresh is done before
// checking the results.
time.Sleep(10 * time.Millisecond)
time.Sleep(20 * time.Millisecond)

if int(gotCalls) != tc.wantCalls {
t.Errorf("want %d calls to cts.tokenSource.Token(), got %d", tc.wantCalls, gotCalls)
Expand Down
Loading