Skip to content

Commit

Permalink
[azcore] Adding in a function create a policy.Request using an existi…
Browse files Browse the repository at this point in the history
…ng *http.Request (#23186)

Adding in a function create a policy.Request using an already created http.Request. This is useful if you want to use our pipelines in contexts where we're not in our generated clients.
  • Loading branch information
richardpark-msft committed Jul 16, 2024
1 parent 80dbc7d commit b4b4721
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- Added runtime.NewRequestFromRequest(), allowing for a policy.Request to be created from an existing *http.Request.

### Breaking Changes

### Bugs Fixed
Expand Down
72 changes: 72 additions & 0 deletions sdk/azcore/internal/exported/exported_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
package exported

import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"
Expand Down Expand Up @@ -74,3 +77,72 @@ func TestNewSASCredential(t *testing.T) {
cred.Update(val2)
require.EqualValues(t, val2, SASCredentialGet(cred))
}

func TestNewRequestFromRequest(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

expectedData := bytes.NewReader([]byte{1, 2, 3, 4, 5})

httpRequest, err := http.NewRequestWithContext(ctx, "POST", "https://example.com", expectedData)
require.NoError(t, err)

req, err := NewRequestFromRequest(httpRequest)
require.NoError(t, err)

// our stream has been drained - the func has to make a copy of the body so it can be seekable.
// so our stream should be at end.
currentPos, err := expectedData.Seek(0, io.SeekCurrent)
require.NoError(t, err)
require.Equal(t, int64(5), currentPos)

actualData, err := io.ReadAll(req.Body())
require.NoError(t, err)
require.Equal(t, []byte{1, 2, 3, 4, 5}, actualData)

// now we change stuff in the policy.Request...
replacementBuff := bytes.NewReader([]byte{6})
err = req.SetBody(NopCloser(replacementBuff), "application/coolstuff")
require.NoError(t, err)

// and it's automatically reflected in the http.Request, which helps us with interop
// with other HTTP pipelines.
require.Equal(t, "application/coolstuff", httpRequest.Header.Get("Content-Type"))
newBytes, err := io.ReadAll(httpRequest.Body)
require.NoError(t, err)
require.Equal(t, []byte{6}, newBytes)
}

func TestNewRequestFromRequest_AvoidExtraCopyIfReadSeekCloser(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

expectedData := NopCloser(bytes.NewReader([]byte{1, 2, 3, 4, 5}))

httpRequest, err := http.NewRequestWithContext(ctx, "POST", "https://example.com", expectedData)
require.NoError(t, err)

req, err := NewRequestFromRequest(httpRequest)
require.NoError(t, err)

// our stream should _NOT_ get drained since it was already an io.ReadSeekCloser
currentPos, err := expectedData.Seek(0, io.SeekCurrent)
require.NoError(t, err)
require.Equal(t, int64(0), currentPos)

actualData, err := io.ReadAll(req.Body())
require.NoError(t, err)
require.Equal(t, []byte{1, 2, 3, 4, 5}, actualData)

// now we change stuff in the policy.Request...
replacementBuff := bytes.NewReader([]byte{6})
err = req.SetBody(NopCloser(replacementBuff), "application/coolstuff")
require.NoError(t, err)

// and it's automatically reflected in the http.Request, which helps us with interop
// with other HTTP pipelines.
require.Equal(t, "application/coolstuff", httpRequest.Header.Get("Content-Type"))
newBytes, err := io.ReadAll(httpRequest.Body)
require.NoError(t, err)
require.Equal(t, []byte{6}, newBytes)
}
37 changes: 37 additions & 0 deletions sdk/azcore/internal/exported/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package exported

import (
"bytes"
"context"
"encoding/base64"
"errors"
Expand Down Expand Up @@ -67,6 +68,42 @@ func (ov opValues) get(value any) bool {
return ok
}

// NewRequestFromRequest creates a new policy.Request with an existing *http.Request
// Exported as runtime.NewRequestFromRequest().
func NewRequestFromRequest(req *http.Request) (*Request, error) {
policyReq := &Request{req: req}

if req.Body != nil {
// we can avoid a body copy here if the underlying stream is already a
// ReadSeekCloser.
readSeekCloser, isReadSeekCloser := req.Body.(io.ReadSeekCloser)

if !isReadSeekCloser {
// since this is an already populated http.Request we want to copy
// over its body, if it has one.
bodyBytes, err := io.ReadAll(req.Body)

if err != nil {
return nil, err
}

if err := req.Body.Close(); err != nil {
return nil, err
}

readSeekCloser = NopCloser(bytes.NewReader(bodyBytes))
}

// SetBody also takes care of updating the http.Request's body
// as well, so they should stay in-sync from this point.
if err := policyReq.SetBody(readSeekCloser, req.Header.Get("Content-Type")); err != nil {
return nil, err
}
}

return policyReq, nil
}

// NewRequest creates a new Request with the specified input.
// Exported as runtime.NewRequest().
func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) {
Expand Down
6 changes: 6 additions & 0 deletions sdk/azcore/runtime/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"path"
Expand Down Expand Up @@ -45,6 +46,11 @@ func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*polic
return exported.NewRequest(ctx, httpMethod, endpoint)
}

// NewRequestFromRequest creates a new policy.Request with an existing *http.Request
func NewRequestFromRequest(req *http.Request) (*policy.Request, error) {
return exported.NewRequestFromRequest(req)
}

// EncodeQueryParams will parse and encode any query parameters in the specified URL.
// Any semicolons will automatically be escaped.
func EncodeQueryParams(u string) (string, error) {
Expand Down

0 comments on commit b4b4721

Please sign in to comment.