diff --git a/aws/ec2metadata/service.go b/aws/ec2metadata/service.go index f0dc331e012..5f7ecbba80a 100644 --- a/aws/ec2metadata/service.go +++ b/aws/ec2metadata/service.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "net" "net/http" + "reflect" "time" "github.com/aws/aws-sdk-go/aws" @@ -26,6 +27,10 @@ type EC2Metadata struct { // New creates a new instance of the EC2Metadata client with a session. // This client is safe to use across multiple goroutines. // +// If an unmodified HTTP client is provided from the stdlib default, or no client +// it is safe to override the dial's timeout and keep alive for shorter connections. +// If any client is provided which is not equal to the original default. +// // Example: // // Create a EC2Metadata client from just a session. // svc := ec2metadata.New(mySession) @@ -41,22 +46,22 @@ func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata { // a client when not using a session. Generally using just New with a session // is preferred. func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string, opts ...func(*client.Client)) *EC2Metadata { - // If the default http client is provided, replace it with a custom - // client using default timeouts. - if cfg.HTTPClient == http.DefaultClient { - cfg.HTTPClient = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - // use a shorter timeout than default because the metadata - // service is local if it is running, and to fail faster - // if not running on an ec2 instance. - Timeout: 5 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, - }, + if cfg.HTTPClient == nil || reflect.DeepEqual(*cfg.HTTPClient, http.Client{}) { + // If a unmodified default http client is provided it is safe to add + // custom timeouts. + httpClient := *http.DefaultClient + if t, ok := http.DefaultTransport.(*http.Transport); ok { + transport := *t + transport.Dial = (&net.Dialer{ + // use a shorter timeout than default because the metadata + // service is local if it is running, and to fail faster + // if not running on an ec2 instance. + Timeout: 5 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial + httpClient.Transport = &transport } + cfg.HTTPClient = &httpClient } svc := &EC2Metadata{ diff --git a/aws/ec2metadata/service_test.go b/aws/ec2metadata/service_test.go new file mode 100644 index 00000000000..8cd04bf2e1a --- /dev/null +++ b/aws/ec2metadata/service_test.go @@ -0,0 +1,40 @@ +package ec2metadata_test + +import ( + "net/http" + "testing" + + "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/stretchr/testify/assert" +) + +func TestClientOverrideDefaultHTTPClientDialTimeout(t *testing.T) { + svc := ec2metadata.New(session.New()) + + assert.NotEqual(t, http.DefaultClient, svc.Config.HTTPClient) + + tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport) + assert.True(t, ok) + assert.NotNil(t, tr) + + assert.NotNil(t, tr.Dial) +} + +func TestClientNotOverrideDefaultHTTPClientDialTimeout(t *testing.T) { + origClient := *http.DefaultClient + http.DefaultClient.Transport = &http.Transport{} + defer func() { + http.DefaultClient = &origClient + }() + + svc := ec2metadata.New(session.New()) + + assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient) + + tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport) + assert.True(t, ok) + assert.NotNil(t, tr) + + assert.Nil(t, tr.Dial) +}