Skip to content

Commit

Permalink
Merge pull request #506 from aws/fixIssue504
Browse files Browse the repository at this point in the history
Fix EC2Metadata Client overriding http.DefaultClient incorrectly
  • Loading branch information
xibz committed Jan 14, 2016
2 parents 6811074 + 7d7147d commit 833c9f7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 15 deletions.
35 changes: 20 additions & 15 deletions aws/ec2metadata/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io/ioutil"
"net"
"net/http"
"reflect"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -26,6 +27,10 @@ type EC2Metadata struct {
// New creates a new instance of the EC2Metadata client with a session.
// This client is safe to use across multiple goroutines.
//
// If an unmodified HTTP client is provided from the stdlib default, or no client
// it is safe to override the dial's timeout and keep alive for shorter connections.
// If any client is provided which is not equal to the original default.
//
// Example:
// // Create a EC2Metadata client from just a session.
// svc := ec2metadata.New(mySession)
Expand All @@ -41,22 +46,22 @@ func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata {
// a client when not using a session. Generally using just New with a session
// is preferred.
func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string, opts ...func(*client.Client)) *EC2Metadata {
// If the default http client is provided, replace it with a custom
// client using default timeouts.
if cfg.HTTPClient == http.DefaultClient {
cfg.HTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
// use a shorter timeout than default because the metadata
// service is local if it is running, and to fail faster
// if not running on an ec2 instance.
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
},
if cfg.HTTPClient == nil || reflect.DeepEqual(*cfg.HTTPClient, http.Client{}) {
// If a unmodified default http client is provided it is safe to add
// custom timeouts.
httpClient := *http.DefaultClient
if t, ok := http.DefaultTransport.(*http.Transport); ok {
transport := *t
transport.Dial = (&net.Dialer{
// use a shorter timeout than default because the metadata
// service is local if it is running, and to fail faster
// if not running on an ec2 instance.
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial
httpClient.Transport = &transport
}
cfg.HTTPClient = &httpClient
}

svc := &EC2Metadata{
Expand Down
40 changes: 40 additions & 0 deletions aws/ec2metadata/service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package ec2metadata_test

import (
"net/http"
"testing"

"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/stretchr/testify/assert"
)

func TestClientOverrideDefaultHTTPClientDialTimeout(t *testing.T) {
svc := ec2metadata.New(session.New())

assert.NotEqual(t, http.DefaultClient, svc.Config.HTTPClient)

tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.NotNil(t, tr)

assert.NotNil(t, tr.Dial)
}

func TestClientNotOverrideDefaultHTTPClientDialTimeout(t *testing.T) {
origClient := *http.DefaultClient
http.DefaultClient.Transport = &http.Transport{}
defer func() {
http.DefaultClient = &origClient
}()

svc := ec2metadata.New(session.New())

assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)

tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.NotNil(t, tr)

assert.Nil(t, tr.Dial)
}

0 comments on commit 833c9f7

Please sign in to comment.