Skip to content

Commit

Permalink
Redesign persistent token cache API (#23114)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jul 16, 2024
1 parent 5df73f9 commit c5213b1
Show file tree
Hide file tree
Showing 29 changed files with 690 additions and 380 deletions.
9 changes: 7 additions & 2 deletions sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ var (
errInvalidTenantID = errors.New("invalid tenantID. You can locate your tenantID by following the instructions listed here: https://learn.microsoft.com/partner-center/find-ids-and-domain-names")
)

// TokenCachePersistenceOptions contains options for persistent token caching
type TokenCachePersistenceOptions = internal.TokenCachePersistenceOptions
// Cache represents a persistent cache that makes authentication data available across processes.
// Construct one with [github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache.New]. This package's
// [persistent user authentication example] shows how to use a persistent cache to reuse logins
// across application runs.
//
// [persistent user authentication example]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity@v1.8.0-beta.1#example-package-PersistentUserAuthentication
type Cache = internal.Cache

// setAuthorityHost initializes the authority host for credentials. Precedence is:
// 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user
Expand Down
268 changes: 130 additions & 138 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -213,6 +212,17 @@ func TestTenantID(t *testing.T) {
}
}

type testCache []byte

func (c *testCache) Export(_ context.Context, m cache.Marshaler, _ cache.ExportHints) (err error) {
*c, err = m.Marshal()
return
}

func (c *testCache) Replace(_ context.Context, u cache.Unmarshaler, _ cache.ReplaceHints) error {
return u.Unmarshal(*c)
}

func TestUserAuthentication(t *testing.T) {
type authenticater interface {
azcore.TokenCredential
Expand All @@ -221,30 +231,30 @@ func TestUserAuthentication(t *testing.T) {
for _, credential := range []struct {
name string
interactive, recordable bool
new func(*TokenCachePersistenceOptions, azcore.ClientOptions, AuthenticationRecord, bool) (authenticater, error)
new func(Cache, azcore.ClientOptions, AuthenticationRecord, bool) (authenticater, error)
}{
{
name: credNameBrowser,
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
new: func(c Cache, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
return NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
Cache: c,
ClientOptions: co,
DisableAutomaticAuthentication: disableAutoAuth,
TokenCachePersistenceOptions: tcpo,
})
},
interactive: true,
},
{
name: credNameDeviceCode,
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
new: func(c Cache, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
o := DeviceCodeCredentialOptions{
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
Cache: c,
ClientOptions: co,
DisableAutomaticAuthentication: disableAutoAuth,
TokenCachePersistenceOptions: tcpo,
}
if recording.GetRecordMode() == recording.PlaybackMode {
o.UserPrompt = func(context.Context, DeviceCodeMessage) error { return nil }
Expand All @@ -256,12 +266,12 @@ func TestUserAuthentication(t *testing.T) {
},
{
name: credNameUserPassword,
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
new: func(c Cache, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
opts := UsernamePasswordCredentialOptions{
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
ClientOptions: co,
TokenCachePersistenceOptions: tcpo,
AdditionallyAllowedTenants: []string{"*"},
AuthenticationRecord: ar,
Cache: c,
ClientOptions: co,
}
return NewUsernamePasswordCredential(liveUser.tenantID, developerSignOnClientID, liveUser.username, liveUser.password, &opts)
},
Expand All @@ -286,13 +296,13 @@ func TestUserAuthentication(t *testing.T) {
}}

co := azcore.ClientOptions{Cloud: cc, Transport: &sts}
cred, err := credential.new(nil, co, AuthenticationRecord{}, false)
cred, err := credential.new(Cache{}, co, AuthenticationRecord{}, false)
require.NoError(t, err)
_, err = cred.Authenticate(context.Background(), nil)
require.NoError(t, err)

t.Setenv(azureAuthorityHost, cc.ActiveDirectoryAuthorityHost)
cred, err = credential.new(nil, azcore.ClientOptions{Transport: &sts}, AuthenticationRecord{}, false)
cred, err = credential.new(Cache{}, azcore.ClientOptions{Transport: &sts}, AuthenticationRecord{}, false)
require.NoError(t, err)
_, err = cred.Authenticate(context.Background(), nil)
if cc.ActiveDirectoryAuthorityHost == customCloud.ActiveDirectoryAuthorityHost {
Expand Down Expand Up @@ -320,14 +330,14 @@ func TestUserAuthentication(t *testing.T) {
counter := tokenRequestCountingPolicy{}
co.PerCallPolicies = append(co.PerCallPolicies, &counter)

cred, err := credential.new(nil, co, AuthenticationRecord{}, false)
cred, err := credential.new(Cache{}, co, AuthenticationRecord{}, false)
require.NoError(t, err)
ar, err := cred.Authenticate(context.Background(), &testTRO)
require.NoError(t, err)

// some fields of the returned AuthenticationRecord should have specific values
require.Equal(t, ar.ClientID, developerSignOnClientID)
require.Equal(t, ar.Version, supportedAuthRecordVersions[0])
require.Equal(t, developerSignOnClientID, ar.ClientID)
require.Equal(t, supportedAuthRecordVersions[0], ar.Version)
// all others should have nonempty values
v := reflect.Indirect(reflect.ValueOf(&ar))
for _, f := range reflect.VisibleFields(reflect.TypeOf(ar)) {
Expand All @@ -337,48 +347,47 @@ func TestUserAuthentication(t *testing.T) {
require.Equal(t, 1, counter.count)
})

t.Run("PersistentCache_Live/"+credential.name, func(t *testing.T) {
switch recording.GetRecordMode() {
case recording.LiveMode:
if credential.interactive && !runManualTests {
t.Skipf("set %s to run this test", azidentityRunManualTests)
}
case recording.PlaybackMode, recording.RecordingMode:
if !credential.recordable {
t.Skip("this test can't be recorded")
}
t.Run("PersistentCache/"+credential.name, func(t *testing.T) {
if credential.name == credNameBrowser && !runManualTests {
t.Skipf("set %s to run this test", azidentityRunManualTests)
}
if runtime.GOOS != "windows" {
t.Skip("this test runs only on Windows")
}
p, err := internal.CacheFilePath(t.Name())
require.NoError(t, err)
os.Remove(p)
co, stop := initRecording(t)
defer stop()
counter := tokenRequestCountingPolicy{}
co.PerCallPolicies = append(co.PerCallPolicies, &counter)
tcpo := TokenCachePersistenceOptions{Name: t.Name()}
tokenReqs := 0
c := internal.NewCache(func(bool) (cache.ExportReplace, error) {
return &testCache{}, nil
})
co := azcore.ClientOptions{Transport: &mockSTS{
tokenRequestCallback: func(*http.Request) *http.Response {
tokenReqs++
return nil
},
}}

cred, err := credential.new(&tcpo, co, AuthenticationRecord{}, true)
cred, err := credential.new(c, co, AuthenticationRecord{}, false)
require.NoError(t, err)
record, err := cred.Authenticate(context.Background(), &testTRO)
record, err := cred.Authenticate(ctx, &testTRO)
require.NoError(t, err)
defer os.Remove(p)
tk, err := cred.GetToken(context.Background(), testTRO)
_, err = cred.GetToken(ctx, testTRO)
require.NoError(t, err)
require.Equal(t, 1, counter.count)
require.Equal(t, 1, tokenReqs)

cred2, err := credential.new(&tcpo, co, record, true)
// cred2 should return the token cached by cred
cred2, err := credential.new(c, co, record, true)
require.NoError(t, err)
tk2, err := cred2.GetToken(context.Background(), testTRO)
_, err = cred2.GetToken(ctx, testTRO)
require.NoError(t, err)
require.Equal(t, tk.Token, tk2.Token)
require.Equal(t, 1, tokenReqs)

// cred should request a new token because the cached one isn't a CAE token
caeTRO := testTRO
caeTRO.EnableCAE = true
_, err = cred.GetToken(ctx, caeTRO)
require.NoError(t, err)
require.Equal(t, 2, tokenReqs)
})

if credential.interactive {
t.Run("DisableAutomaticAuthentication/"+credential.name, func(t *testing.T) {
cred, err := credential.new(nil, policy.ClientOptions{Transport: &mockSTS{}}, AuthenticationRecord{}, true)
cred, err := credential.new(Cache{}, policy.ClientOptions{Transport: &mockSTS{}}, AuthenticationRecord{}, true)
require.NoError(t, err)
expected := policy.TokenRequestOptions{
Claims: "claims",
Expand All @@ -402,7 +411,7 @@ func TestUserAuthentication(t *testing.T) {
}
})
t.Run("DisableAutomaticAuthentication/ChainedTokenCredential/"+credential.name, func(t *testing.T) {
cred, err := credential.new(nil, policy.ClientOptions{}, AuthenticationRecord{}, true)
cred, err := credential.new(Cache{}, policy.ClientOptions{}, AuthenticationRecord{}, true)
require.NoError(t, err)
expected := azcore.AccessToken{ExpiresOn: time.Now().UTC(), Token: tokenValue}
fake := NewFakeCredential()
Expand Down Expand Up @@ -1103,107 +1112,90 @@ func TestResolveTenant(t *testing.T) {
}
}

func TestTokenCachePersistenceOptions(t *testing.T) {
af := filepath.Join(t.TempDir(), t.Name()+credNameWorkloadIdentity)
if err := os.WriteFile(af, []byte("assertion"), os.ModePerm); err != nil {
t.Fatal(err)
}
before := internal.NewCache
t.Cleanup(func() { internal.NewCache = before })
for _, test := range []struct {
desc string
options *TokenCachePersistenceOptions
err error
func TestConfidentialClientPersistentCache(t *testing.T) {
// for WorkloadIdentityCredential
tfp := filepath.Join(t.TempDir(), "tokenfile")
require.NoError(t, os.WriteFile(tfp, []byte("token"), 0600))
for _, credential := range []struct {
name string
new func(azcore.ClientOptions, Cache) (azcore.TokenCredential, error)
}{
{
desc: "nil options",
name: credNameAssertion,
new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
o := ClientAssertionCredentialOptions{Cache: c, ClientOptions: co}
return NewClientAssertionCredential(fakeTenantID, fakeClientID, func(context.Context) (string, error) { return "...", nil }, &o)
},
},
// TODO: set SYSTEM_OIDC_REQUEST_URI, fake response
// {
// name: credNameAzurePipelines,
// new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
// o := AzurePipelinesCredentialOptions{Cache: c, ClientOptions: co}
// return NewAzurePipelinesCredential(fakeTenantID, fakeClientID, "service-connection", tokenValue, &o)
// },
// },
{
desc: "default options",
options: &TokenCachePersistenceOptions{},
name: credNameCert,
new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
o := ClientCertificateCredentialOptions{Cache: c, ClientOptions: co}
return NewClientCertificateCredential(fakeTenantID, fakeClientID, allCertTests[0].certs, allCertTests[0].key, &o)
},
},
{
desc: "all options set",
options: &TokenCachePersistenceOptions{AllowUnencryptedStorage: true, Name: "name"},
name: credNameSecret,
new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
o := ClientSecretCredentialOptions{Cache: c, ClientOptions: co}
return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &o)
},
},
} {
internal.NewCache = func(o *internal.TokenCachePersistenceOptions, _ bool) (cache.ExportReplace, error) {
if (test.options == nil) != (o == nil) {
t.Fatalf("expected %v, got %v", test.options, o)
}
if test.options != nil {
if test.options.AllowUnencryptedStorage != o.AllowUnencryptedStorage {
t.Fatalf("expected AllowUnencryptedStorage %v, got %v", test.options.AllowUnencryptedStorage, o.AllowUnencryptedStorage)
}
if test.options.Name != o.Name {
t.Fatalf("expected Name %q, got %q", test.options.Name, o.Name)
{
name: credNameWorkloadIdentity,
new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
o := WorkloadIdentityCredentialOptions{
Cache: c,
ClientID: fakeClientID,
ClientOptions: co,
TenantID: fakeTenantID,
TokenFilePath: tfp,
}
}
return nil, nil
}
for _, subtest := range []struct {
ctor func(azcore.ClientOptions, *TokenCachePersistenceOptions) (azcore.TokenCredential, error)
env map[string]string
name string
}{
{
name: credNameAssertion,
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
o := ClientAssertionCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco}
return NewClientAssertionCredential(fakeTenantID, fakeClientID, func(context.Context) (string, error) { return "...", nil }, &o)
},
},
{
name: credNameCert,
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
o := ClientCertificateCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco}
return NewClientCertificateCredential(fakeTenantID, fakeClientID, allCertTests[0].certs, allCertTests[0].key, &o)
},
},
{
name: credNameDeviceCode,
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
o := DeviceCodeCredentialOptions{
ClientOptions: co,
TokenCachePersistenceOptions: tco,
UserPrompt: func(context.Context, DeviceCodeMessage) error { return nil },
}
return NewDeviceCodeCredential(&o)
},
},
{
name: credNameSecret,
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
o := ClientSecretCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco}
return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &o)
},
},
{
name: credNameUserPassword,
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
o := UsernamePasswordCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco}
return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, fakeUsername, "password", &o)
},
return NewWorkloadIdentityCredential(&o)
},
} {
t.Run(fmt.Sprintf("%s/%s", subtest.name, test.desc), func(t *testing.T) {
for k, v := range subtest.env {
t.Setenv(k, v)
}
c, err := subtest.ctor(policy.ClientOptions{Transport: &mockSTS{}}, test.options)
if err != nil {
t.Fatal(err)
}
_, err = c.GetToken(context.Background(), testTRO)
if err != nil {
if !errors.Is(err, test.err) {
t.Fatalf("expected %v, got %v", test.err, err)
}
} else if test.err != nil {
t.Fatal("expected an error")
}
},
} {
t.Run(credential.name, func(t *testing.T) {
tokenReqs := 0
c := internal.NewCache(func(bool) (cache.ExportReplace, error) {
return &testCache{}, nil
})
}
sts := mockSTS{
tokenRequestCallback: func(*http.Request) *http.Response {
tokenReqs++
return nil
},
}
cred, err := credential.new(policy.ClientOptions{Transport: &sts}, c)
require.NoError(t, err)
_, err = cred.GetToken(context.Background(), testTRO)
require.NoError(t, err)
_, err = cred.GetToken(ctx, testTRO)
require.NoError(t, err)
require.Equal(t, 1, tokenReqs)

// cred2 should return the token cached by cred
cred2, err := credential.new(policy.ClientOptions{Transport: &sts}, c)
require.NoError(t, err)
_, err = cred2.GetToken(ctx, testTRO)
require.NoError(t, err)
require.Equal(t, 1, tokenReqs)

// cred should request a new token because the cached one isn't a CAE token
caeTRO := testTRO
caeTRO.EnableCAE = true
_, err = cred.GetToken(ctx, caeTRO)
require.NoError(t, err)
require.Equal(t, 2, tokenReqs)
})
}
}

Expand Down
Loading

0 comments on commit c5213b1

Please sign in to comment.