From 7d7147d7318dae611f1669d8b59ff25ab7f8114e Mon Sep 17 00:00:00 2001 From: Jason Del Ponte Date: Thu, 14 Jan 2016 12:58:54 -0800 Subject: [PATCH] Fix EC2Metadata Client overriding http.DefaultClient incorrectly The EC2Metadata Client was incorrectly overriding the http.DefaultClient when users had modified the default client's parameters. The metadata client will now only set its alternate dial timeout if no http client was provided, or if the provided client was never modified from the original http.Client{}. Also updates the logic so DefaultTransport is reused instead of hardcoded. Fix #504 --- aws/ec2metadata/service.go | 35 ++++++++++++++++------------- aws/ec2metadata/service_test.go | 40 +++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 15 deletions(-) create mode 100644 aws/ec2metadata/service_test.go diff --git a/aws/ec2metadata/service.go b/aws/ec2metadata/service.go index f0dc331e012..5f7ecbba80a 100644 --- a/aws/ec2metadata/service.go +++ b/aws/ec2metadata/service.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "net" "net/http" + "reflect" "time" "github.com/aws/aws-sdk-go/aws" @@ -26,6 +27,10 @@ type EC2Metadata struct { // New creates a new instance of the EC2Metadata client with a session. // This client is safe to use across multiple goroutines. // +// If an unmodified HTTP client is provided from the stdlib default, or no client +// it is safe to override the dial's timeout and keep alive for shorter connections. +// If any client is provided which is not equal to the original default. +// // Example: // // Create a EC2Metadata client from just a session. // svc := ec2metadata.New(mySession) @@ -41,22 +46,22 @@ func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata { // a client when not using a session. Generally using just New with a session // is preferred. func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string, opts ...func(*client.Client)) *EC2Metadata { - // If the default http client is provided, replace it with a custom - // client using default timeouts. - if cfg.HTTPClient == http.DefaultClient { - cfg.HTTPClient = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - // use a shorter timeout than default because the metadata - // service is local if it is running, and to fail faster - // if not running on an ec2 instance. - Timeout: 5 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, - }, + if cfg.HTTPClient == nil || reflect.DeepEqual(*cfg.HTTPClient, http.Client{}) { + // If a unmodified default http client is provided it is safe to add + // custom timeouts. + httpClient := *http.DefaultClient + if t, ok := http.DefaultTransport.(*http.Transport); ok { + transport := *t + transport.Dial = (&net.Dialer{ + // use a shorter timeout than default because the metadata + // service is local if it is running, and to fail faster + // if not running on an ec2 instance. + Timeout: 5 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial + httpClient.Transport = &transport } + cfg.HTTPClient = &httpClient } svc := &EC2Metadata{ diff --git a/aws/ec2metadata/service_test.go b/aws/ec2metadata/service_test.go new file mode 100644 index 00000000000..8cd04bf2e1a --- /dev/null +++ b/aws/ec2metadata/service_test.go @@ -0,0 +1,40 @@ +package ec2metadata_test + +import ( + "net/http" + "testing" + + "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/stretchr/testify/assert" +) + +func TestClientOverrideDefaultHTTPClientDialTimeout(t *testing.T) { + svc := ec2metadata.New(session.New()) + + assert.NotEqual(t, http.DefaultClient, svc.Config.HTTPClient) + + tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport) + assert.True(t, ok) + assert.NotNil(t, tr) + + assert.NotNil(t, tr.Dial) +} + +func TestClientNotOverrideDefaultHTTPClientDialTimeout(t *testing.T) { + origClient := *http.DefaultClient + http.DefaultClient.Transport = &http.Transport{} + defer func() { + http.DefaultClient = &origClient + }() + + svc := ec2metadata.New(session.New()) + + assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient) + + tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport) + assert.True(t, ok) + assert.NotNil(t, tr) + + assert.Nil(t, tr.Dial) +}