Skip to content

Commit

Permalink
Remove redundant azidentity error content (#23407)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Sep 6, 2024
1 parent 7129337 commit 09f842c
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 48 deletions.
1 change: 1 addition & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
### Bugs Fixed

### Other Changes
* Removed redundant content from error messages

## 1.8.0-beta.2 (2024-08-06)

Expand Down
71 changes: 71 additions & 0 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,77 @@ func (t *tokenRequestCountingPolicy) Do(req *policy.Request) (*http.Response, er
return req.Next()
}

func TestResponseErrors(t *testing.T) {
// compact removes whitespace from errors to simplify validation
compact := func(s string) string {
return strings.Map(func(r rune) rune {
if r == ' ' || r == '\n' || r == '\t' {
return -1
}
return r
}, s)
}
content := "no tokens here"
statusCode := http.StatusTeapot
validate := func(t *testing.T, err error) {
require.Error(t, err)
flatErr := compact(err.Error())
actual := strings.Count(flatErr, compact(http.StatusText(statusCode)))
require.Equal(t, 1, actual, "error message should include response exactly once:\n%s", err.Error())
actual = strings.Count(flatErr, compact(content))
require.Equal(t, 1, actual, "error message should include body exactly once:\n%s", err.Error())
}

for _, client := range []struct {
name string
ctor func(co policy.ClientOptions) (azcore.TokenCredential, error)
}{
{
name: "confidential",
ctor: func(co policy.ClientOptions) (azcore.TokenCredential, error) {
return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &ClientSecretCredentialOptions{ClientOptions: co})
},
},
{
name: "managed identity",
ctor: func(co policy.ClientOptions) (azcore.TokenCredential, error) {
return NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ClientOptions: co})
},
},
{
name: "public",
ctor: func(co policy.ClientOptions) (azcore.TokenCredential, error) {
return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, "username", "password", &UsernamePasswordCredentialOptions{ClientOptions: co})
},
},
} {
t.Run(client.name, func(t *testing.T) {
cred, err := client.ctor(policy.ClientOptions{
Retry: policy.RetryOptions{MaxRetries: -1},
Transport: &mockSTS{
tokenRequestCallback: func(*http.Request) *http.Response {
return &http.Response{
Body: io.NopCloser(bytes.NewBufferString(content)),
Status: http.StatusText(statusCode),
StatusCode: statusCode,
}
},
},
})
require.NoError(t, err)
_, err = cred.GetToken(ctx, testTRO)
validate(t, err)

t.Run("ChainedTokenCredential", func(t *testing.T) {
chain, err := NewChainedTokenCredential([]azcore.TokenCredential{cred}, nil)
require.NoError(t, err)
_, err = chain.GetToken(ctx, testTRO)
validate(t, err)
})
})
}
}

func TestTenantID(t *testing.T) {
type tc struct {
name string
Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/azure_cli_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func azTokenOutput(expiresOn string, expires_on int64) []byte {
}

func mockAzTokenProviderFailure(context.Context, []string, string, string) ([]byte, error) {
return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil)
return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil)
}

func mockAzTokenProviderSuccess(context.Context, []string, string, string) ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/azure_developer_cli_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var (
`), nil
}
mockAzdTokenProviderFailure = func(context.Context, []string, string) ([]byte, error) {
return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil)
return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil)
}
)

Expand Down
12 changes: 6 additions & 6 deletions sdk/azidentity/azure_pipelines_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,33 +114,33 @@ func (a *AzurePipelinesCredential) getAssertion(ctx context.Context) (string, er
url := a.oidcURI + "?api-version=" + oidcAPIVersion + "&serviceConnectionId=" + a.connectionID
url, err := runtime.EncodeQueryParams(url)
if err != nil {
return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't encode OIDC URL: "+err.Error(), nil, nil)
return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't encode OIDC URL: "+err.Error(), nil)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't create OIDC token request: "+err.Error(), nil, nil)
return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't create OIDC token request: "+err.Error(), nil)
}
req.Header.Set("Authorization", "Bearer "+a.systemAccessToken)
res, err := doForClient(a.cred.client.azClient, req)
if err != nil {
return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't send OIDC token request: "+err.Error(), nil, nil)
return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't send OIDC token request: "+err.Error(), nil)
}
if res.StatusCode != http.StatusOK {
msg := res.Status + " response from the OIDC endpoint. Check service connection ID and Pipeline configuration"
// include the response because its body, if any, probably contains an error message.
// OK responses aren't included with errors because they probably contain secrets
return "", newAuthenticationFailedError(credNameAzurePipelines, msg, res, nil)
return "", newAuthenticationFailedError(credNameAzurePipelines, msg, res)
}
b, err := runtime.Payload(res)
if err != nil {
return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't read OIDC response content: "+err.Error(), nil, nil)
return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't read OIDC response content: "+err.Error(), nil)
}
var r struct {
OIDCToken string `json:"oidcToken"`
}
err = json.Unmarshal(b, &r)
if err != nil {
return "", newAuthenticationFailedError(credNameAzurePipelines, "unexpected response from OIDC endpoint", nil, nil)
return "", newAuthenticationFailedError(credNameAzurePipelines, "unexpected response from OIDC endpoint", nil)
}
return r.OIDCToken, nil
}
16 changes: 12 additions & 4 deletions sdk/azidentity/chained_token_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,19 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token
if err != nil {
// return credentialUnavailableError iff all sources did so; return AuthenticationFailedError otherwise
msg := createChainedErrorMessage(errs)
if errors.As(err, &unavailableErr) {
var authFailedErr *AuthenticationFailedError
switch {
case errors.As(err, &authFailedErr):
err = newAuthenticationFailedError(c.name, msg, authFailedErr.RawResponse)
if af, ok := err.(*AuthenticationFailedError); ok {
// stop Error() printing the response again; it's already in msg
af.omitResponse = true
}
case errors.As(err, &unavailableErr):
err = newCredentialUnavailableError(c.name, msg)
} else {
default:
res := getResponseFromError(err)
err = newAuthenticationFailedError(c.name, msg, res, err)
err = newAuthenticationFailedError(c.name, msg, res)
}
}
return token, err
Expand All @@ -126,7 +134,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token
func createChainedErrorMessage(errs []error) string {
msg := "failed to acquire a token.\nAttempted credentials:"
for _, err := range errs {
msg += fmt.Sprintf("\n\t%s", err.Error())
msg += fmt.Sprintf("\n\t%s", strings.ReplaceAll(err.Error(), "\n", "\n\t\t"))
}
return msg
}
Expand Down
6 changes: 3 additions & 3 deletions sdk/azidentity/chained_token_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func TestChainedTokenCredential_GetTokenSuccess(t *testing.T) {

func TestChainedTokenCredential_GetTokenFail(t *testing.T) {
c := NewFakeCredential()
c.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("test", "something went wrong", nil, nil))
c.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("test", "something went wrong", nil))
cred, err := NewChainedTokenCredential([]azcore.TokenCredential{c}, nil)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -158,7 +158,7 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenAuthenticationFailed(
c2 := NewFakeCredential()
c2.SetResponse(azcore.AccessToken{}, newCredentialUnavailableError("unavailableCredential2", "Unavailable expected error"))
c3 := NewFakeCredential()
c3.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("authenticationFailedCredential3", "Authentication failed expected error", nil, nil))
c3.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("authenticationFailedCredential3", "Authentication failed expected error", nil))
cred, err := NewChainedTokenCredential([]azcore.TokenCredential{c1, c2, c3}, nil)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -259,7 +259,7 @@ func TestChainedTokenCredential_Race(t *testing.T) {
successFake := NewFakeCredential()
successFake.SetResponse(azcore.AccessToken{Token: "*", ExpiresOn: time.Now().Add(time.Hour)}, nil)
authFailFake := NewFakeCredential()
authFailFake.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("", "", nil, nil))
authFailFake.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("", "", nil))
unavailableFake := NewFakeCredential()
unavailableFake.SetResponse(azcore.AccessToken{}, newCredentialUnavailableError("", ""))

Expand Down
12 changes: 6 additions & 6 deletions sdk/azidentity/confidential_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenReque
}
}
if err != nil {
// We could get a credentialUnavailableError from managed identity authentication because in that case the error comes from our code.
// We return it directly because it affects the behavior of credential chains. Otherwise, we return AuthenticationFailedError.
var unavailableErr credentialUnavailable
if !errors.As(err, &unavailableErr) {
res := getResponseFromError(err)
err = newAuthenticationFailedError(c.name, err.Error(), res, err)
var (
authFailedErr *AuthenticationFailedError
unavailableErr credentialUnavailable
)
if !(errors.As(err, &unavailableErr) || errors.As(err, &authFailedErr)) {
err = newAuthenticationFailedErrorFromMSAL(c.name, err)
}
} else {
msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", c.name, strings.Join(ar.GrantedScopes, ", "))
Expand Down
26 changes: 19 additions & 7 deletions sdk/azidentity/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,30 @@ type AuthenticationFailedError struct {
// RawResponse is the HTTP response motivating the error, if available.
RawResponse *http.Response

credType string
message string
err error
credType, message string
omitResponse bool
}

func newAuthenticationFailedError(credType string, message string, resp *http.Response, err error) error {
return &AuthenticationFailedError{credType: credType, message: message, RawResponse: resp, err: err}
func newAuthenticationFailedError(credType, message string, resp *http.Response) error {
return &AuthenticationFailedError{credType: credType, message: message, RawResponse: resp}
}

// newAuthenticationFailedErrorFromMSAL creates an AuthenticationFailedError from an MSAL error.
// If the error is an MSAL CallErr, the new error includes an HTTP response and not the MSAL error
// message, because that message is redundant given the response. If the original error isn't a
// CallErr, the returned error incorporates its message.
func newAuthenticationFailedErrorFromMSAL(credType string, err error) error {
msg := ""
res := getResponseFromError(err)
if res == nil {
msg = err.Error()
}
return newAuthenticationFailedError(credType, msg, res)
}

// Error implements the error interface. Note that the message contents are not contractual and can change over time.
func (e *AuthenticationFailedError) Error() string {
if e.RawResponse == nil {
if e.RawResponse == nil || e.omitResponse {
return e.credType + ": " + e.message
}
msg := &bytes.Buffer{}
Expand All @@ -62,7 +74,7 @@ func (e *AuthenticationFailedError) Error() string {
fmt.Fprintln(msg, "Request information not available")
}
fmt.Fprintln(msg, "--------------------------------------------------------------------------------")
fmt.Fprintf(msg, "RESPONSE %s\n", e.RawResponse.Status)
fmt.Fprintf(msg, "RESPONSE %d: %s\n", e.RawResponse.StatusCode, e.RawResponse.Status)
fmt.Fprintln(msg, "--------------------------------------------------------------------------------")
body, err := runtime.Payload(e.RawResponse)
switch {
Expand Down
6 changes: 3 additions & 3 deletions sdk/azidentity/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestAuthenticationFailedErrorInterface(t *testing.T) {
Body: io.NopCloser(bytes.NewBufferString(resBodyString)),
Request: req,
}
err = newAuthenticationFailedError(credNameAzureCLI, "error message", res, nil)
err = newAuthenticationFailedError(credNameAzureCLI, "error message", res)
if e, ok := err.(*AuthenticationFailedError); ok {
if e.RawResponse == nil {
t.Fatal("expected a non-nil RawResponse")
Expand Down Expand Up @@ -61,7 +61,7 @@ func TestAuthenticationFailedErrorInterface(t *testing.T) {
}

func TestAuthenticationFailedErrorWithoutResponse(t *testing.T) {
err := newAuthenticationFailedError(credNameAzureCLI, "error message", nil, nil)
err := newAuthenticationFailedError(credNameAzureCLI, "error message", nil)
if _, ok := err.(*AuthenticationFailedError); !ok {
t.Fatalf("expected AuthenticationFailedError, received %T", err)
}
Expand All @@ -79,7 +79,7 @@ func TestAuthenticationFailedErrorWithoutRequest(t *testing.T) {
Body: io.NopCloser(bytes.NewBufferString(resBodyString)),
Request: nil,
}
err := newAuthenticationFailedError(credNameAzureCLI, "error message", res, nil)
err := newAuthenticationFailedError(credNameAzureCLI, "error message", res)
if e, ok := err.(*AuthenticationFailedError); ok {
if e.RawResponse == nil {
t.Fatal("expected a non-nil RawResponse")
Expand Down
30 changes: 15 additions & 15 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi

resp, err := c.azClient.Pipeline().Do(msg)
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, err.Error(), nil, err)
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, err.Error(), nil)
}

if azruntime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) {
Expand All @@ -261,7 +261,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
switch resp.StatusCode {
case http.StatusBadRequest:
if id != nil {
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp, nil)
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp)
}
msg := "failed to authenticate a system assigned identity"
if body, err := azruntime.Payload(resp); err == nil && len(body) > 0 {
Expand All @@ -278,7 +278,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
}
}

return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "authentication failed", resp, nil)
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "", resp)
}

func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.AccessToken, error) {
Expand Down Expand Up @@ -306,10 +306,10 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.Ac
if expiresOn, err := strconv.Atoi(v); err == nil {
return azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(expiresOn), 0).UTC()}, nil
}
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "unexpected expires_on value: "+v, res, nil)
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "unexpected expires_on value: "+v, res)
default:
msg := fmt.Sprintf("unsupported type received in expires_on: %T, %v", v, v)
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, msg, res, nil)
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, msg, res)
}
}

Expand All @@ -324,7 +324,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
key, err := c.getAzureArcSecretKey(ctx, scopes)
if err != nil {
msg := fmt.Sprintf("failed to retreive secret key from the identity endpoint: %v", err)
return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err)
return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil)
}
return c.createAzureArcAuthRequest(ctx, scopes, key)
case msiTypeAzureML:
Expand Down Expand Up @@ -399,9 +399,9 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id
case miClientID:
q.Set("clientid", id.String())
case miObjectID:
return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by object ID", nil, nil)
return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by object ID", nil)
case miResourceID:
return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by resource ID", nil, nil)
return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by resource ID", nil)
}
}
request.Raw().URL.RawQuery = q.Encode()
Expand Down Expand Up @@ -442,34 +442,34 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour
// of the secret key file. Any other status code indicates an error in the request.
if response.StatusCode != 401 {
msg := fmt.Sprintf("expected a 401 response, received %d", response.StatusCode)
return "", newAuthenticationFailedError(credNameManagedIdentity, msg, response, nil)
return "", newAuthenticationFailedError(credNameManagedIdentity, msg, response)
}
header := response.Header.Get("WWW-Authenticate")
if len(header) == 0 {
return "", newAuthenticationFailedError(credNameManagedIdentity, "HIMDS response has no WWW-Authenticate header", nil, nil)
return "", newAuthenticationFailedError(credNameManagedIdentity, "HIMDS response has no WWW-Authenticate header", nil)
}
// the WWW-Authenticate header is expected in the following format: Basic realm=/some/file/path.key
_, p, found := strings.Cut(header, "=")
if !found {
return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected WWW-Authenticate header from HIMDS: "+header, nil, nil)
return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected WWW-Authenticate header from HIMDS: "+header, nil)
}
expected, err := arcKeyDirectory()
if err != nil {
return "", err
}
if filepath.Dir(p) != expected || !strings.HasSuffix(p, ".key") {
return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected file path from HIMDS service: "+p, nil, nil)
return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected file path from HIMDS service: "+p, nil)
}
f, err := os.Stat(p)
if err != nil {
return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not stat %q: %v", p, err), nil, nil)
return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not stat %q: %v", p, err), nil)
}
if s := f.Size(); s > 4096 {
return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("key is too large (%d bytes)", s), nil, nil)
return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("key is too large (%d bytes)", s), nil)
}
key, err := os.ReadFile(p)
if err != nil {
return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not read %q: %v", p, err), nil, nil)
return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not read %q: %v", p, err), nil)
}
return string(key), nil
}
Expand Down
Loading

0 comments on commit 09f842c

Please sign in to comment.