Skip to content

Commit

Permalink
fix(azidentity): do not strip away request headers in doForClient
Browse files Browse the repository at this point in the history
Some authorities might require certain headers to be passed.

For example, in our dSTS auth flow, the request form contains
client_info, which needs to be accompanied by X-Client-SKU=MSAL.Go
header, else the API call produces

AADSTS501791: Client_info is only supported for MSAL/ADAL,
please ensure that MSAL/ADAL custom headers are being sent.

The `doForClient` function creates new `runtime.Request` from the incoming
request, but it fails to propagate the respective headers.

This commits is addressing that.

Signed-off-by: HandsomeJack <dusek.honza@gmail.com>
  • Loading branch information
handsomejack-42 committed Dec 14, 2023
1 parent dd65c25 commit 4962931
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 0 deletions.
1 change: 1 addition & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
### Breaking Changes

### Bugs Fixed
* `azidentity.doForClient` method no longer removes headers from the incoming request

### Other Changes

Expand Down
11 changes: 11 additions & 0 deletions sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ func doForClient(client *azcore.Client, r *http.Request) (*http.Response, error)
return nil, err
}
}

// copy headers to the new request, ignoring any for which the new request has a value
h := req.Raw().Header
for key, vals := range r.Header {
if _, has := h[key]; !has {
for _, val := range vals {
h.Add(key, val)
}
}
}

resp, err := client.Pipeline().Do(req)
if err != nil {
return nil, err
Expand Down
4 changes: 4 additions & 0 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,10 @@ func TestDoForClient(t *testing.T) {
assert.Empty(t, rb)
}

for k, v := range tt.headers {
assert.Equal(t, v, req.Header[k])
}

assert.Equal(t, policyHeaderValue, req.Header.Get(policyHeaderName))

rw.Header().Set("content-type", "application/json")
Expand Down

0 comments on commit 4962931

Please sign in to comment.