Skip to content

Commit

Permalink
Adds http timeout to aws sessions (#1370)
Browse files Browse the repository at this point in the history
  • Loading branch information
couralex6 authored Feb 4, 2021
1 parent 211340f commit 21b2e61
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 17 deletions.
81 changes: 81 additions & 0 deletions pkg/awsutils/awssession/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package awssession

import (
"fmt"
"net/http"
"os"

"strconv"
"time"

"github.com/aws/amazon-vpc-cni-k8s/pkg/utils/logger"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
)

// Http client timeout env for sessions
const (
httpTimeoutEnv = "HTTP_TIMEOUT"
maxRetries = 15
)

var (
version string
log = logger.Get()
// HTTP timeout default value in seconds (10 seconds)
httpTimeoutValue = 10 * time.Second
)

func getHTTPTimeout() time.Duration {
httpTimeoutEnvInput := os.Getenv(httpTimeoutEnv)
// if httpTimeout is not empty, we convert value to int and overwrite default httpTimeoutValue
if httpTimeoutEnvInput != "" {
input, err := strconv.Atoi(httpTimeoutEnvInput)
if err == nil && input >= 10 {
log.Debugf("Using HTTP_TIMEOUT %v", input)
httpTimeoutValue = time.Duration(input) * time.Second
return httpTimeoutValue
}
}
log.Info("HTTP_TIMEOUT env is not set or set to less than 10 seconds, defaulting to httpTimeout to 10sec")
return httpTimeoutValue
}

// New will return an session for service clients
func New() *session.Session {
awsCfg := aws.Config{
MaxRetries: aws.Int(maxRetries),
HTTPClient: &http.Client{
Timeout: getHTTPTimeout(),
},
}
sess := session.Must(session.NewSession(&awsCfg))
//injecting session handler info
injectUserAgent(&sess.Handlers)

return sess
}

// injectUserAgent will inject app specific user-agent into awsSDK
func injectUserAgent(handlers *request.Handlers) {
handlers.Build.PushFrontNamed(request.NamedHandler{
Name: fmt.Sprintf("%s/user-agent", "amazon-vpc-cni-k8s"),
Fn: request.MakeAddToUserAgentHandler(
"amazon-vpc-cni-k8s",
fmt.Sprintf("version/%s",version)),
})
}
23 changes: 23 additions & 0 deletions pkg/awsutils/awssession/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package awssession

import (
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestHttpTimeoutReturnDefault(t *testing.T) {
os.Setenv(httpTimeoutEnv, "2")
defer os.Unsetenv(httpTimeoutEnv)
expectedHTTPTimeOut := time.Duration(10) * time.Second
assert.Equal(t, expectedHTTPTimeOut, getHTTPTimeout())
}

func TestHttpTimeoutWithValueAbove10(t *testing.T) {
os.Setenv(httpTimeoutEnv, "12")
defer os.Unsetenv(httpTimeoutEnv)
expectedHTTPTimeOut := time.Duration(12) * time.Second
assert.Equal(t, expectedHTTPTimeOut, getHTTPTimeout())
}
11 changes: 5 additions & 6 deletions pkg/awsutils/awsutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ import (
"github.com/aws/amazon-vpc-cni-k8s/pkg/utils/logger"
"github.com/prometheus/client_golang/prometheus"

"github.com/aws/amazon-vpc-cni-k8s/pkg/awsutils/awssession"
"github.com/aws/amazon-vpc-cni-k8s/pkg/ec2metadata"
"github.com/aws/amazon-vpc-cni-k8s/pkg/ec2wrapper"
"github.com/aws/amazon-vpc-cni-k8s/pkg/utils/retry"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
Expand Down Expand Up @@ -313,11 +313,10 @@ func New(useCustomNetworking bool) (*EC2InstanceMetadataCache, error) {
cache.useCustomNetworking = useCustomNetworking
log.Infof("Custom networking %v", cache.useCustomNetworking)

sess, err := session.NewSession(&aws.Config{Region: aws.String(cache.region), MaxRetries: aws.Int(15)})
if err != nil {
log.Errorf("Failed to initialize AWS SDK session %v", err)
return nil, errors.Wrap(err, "instance metadata: failed to initialize AWS SDK session")
}
sess := awssession.New()

awsCfg := aws.NewConfig().WithRegion(cache.region)
sess = sess.Copy(awsCfg)

ec2SVC := ec2wrapper.New(sess)
cache.ec2SVC = ec2SVC
Expand Down
7 changes: 2 additions & 5 deletions pkg/ec2metadatawrapper/ec2metadatawrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
package ec2metadatawrapper

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/amazon-vpc-cni-k8s/pkg/awsutils/awssession"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
)

const (
Expand Down Expand Up @@ -32,9 +31,7 @@ type ec2MetadataClientImpl struct {
// New creates an ec2metadata client to retrieve metadata
func New(client HTTPClient) EC2MetadataClient {
if client == nil {
awsSession := session.Must(session.NewSession(aws.NewConfig().
WithMaxRetries(metadataRetries),
))
awsSession := awssession.New()
return &ec2MetadataClientImpl{client: ec2metadata.New(awsSession)}
}
return &ec2MetadataClientImpl{client: client}
Expand Down
8 changes: 5 additions & 3 deletions pkg/ec2wrapper/ec2wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
package ec2wrapper

import (
"github.com/aws/amazon-vpc-cni-k8s/pkg/awsutils/awssession"
"github.com/aws/amazon-vpc-cni-k8s/pkg/ec2metadatawrapper"
"github.com/aws/amazon-vpc-cni-k8s/pkg/utils/logger"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/pkg/errors"
Expand All @@ -30,15 +30,17 @@ type EC2Wrapper struct {

//NewMetricsClient returns an instance of the EC2 wrapper
func NewMetricsClient() (*EC2Wrapper, error) {
metricsSession := session.Must(session.NewSession())
sess := awssession.New()
ec2MetadataClient := ec2metadatawrapper.New(nil)

instanceIdentityDocument, err := ec2MetadataClient.GetInstanceIdentityDocument()
if err != nil {
return &EC2Wrapper{}, err
}

ec2ServiceClient := ec2.New(metricsSession, aws.NewConfig().WithMaxRetries(maxRetries).WithRegion(instanceIdentityDocument.Region))
awsCfg := aws.NewConfig().WithRegion(instanceIdentityDocument.Region)
sess = sess.Copy(awsCfg)
ec2ServiceClient := ec2.New(sess)

return &EC2Wrapper{
ec2ServiceClient: ec2ServiceClient,
Expand Down
9 changes: 6 additions & 3 deletions pkg/publisher/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import (
"sync"
"time"

"github.com/aws/amazon-vpc-cni-k8s/pkg/awsutils/awssession"
"github.com/aws/amazon-vpc-cni-k8s/pkg/ec2metadatawrapper"
"github.com/aws/amazon-vpc-cni-k8s/pkg/ec2wrapper"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudwatch"
"github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface"

Expand Down Expand Up @@ -91,7 +91,7 @@ type cloudWatchPublisher struct {
// New returns a new instance of `Publisher`
func New(ctx context.Context) (Publisher, error) {
// Get AWS session
awsSession := session.Must(session.NewSession())
sess := awssession.New()

// Get cluster-ID
ec2Client, err := ec2wrapper.NewMetricsClient()
Expand All @@ -107,7 +107,10 @@ func New(ctx context.Context) (Publisher, error) {
if err != nil {
return nil, errors.Wrap(err, "publisher: Unable to obtain region")
}
cloudwatchClient := cloudwatch.New(awsSession, aws.NewConfig().WithMaxRetries(cloudwatchClientMaxRetries).WithRegion(region))

awsCfg := aws.NewConfig().WithRegion(region)
sess = sess.Copy(awsCfg)
cloudwatchClient := cloudwatch.New(sess)

// Build derived context
derivedContext, cancel := context.WithCancel(ctx)
Expand Down
1 change: 1 addition & 0 deletions test/integration/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ replace k8s.io/sample-controller => k8s.io/sample-controller v0.0.0-201908191433
require (
github.com/onsi/ginkgo v1.8.0
github.com/onsi/gomega v1.5.0
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 // indirect
k8s.io/api v0.0.0
k8s.io/apimachinery v0.0.0
k8s.io/client-go v0.0.0
Expand Down

0 comments on commit 21b2e61

Please sign in to comment.