Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add provider token_bucket_rate_limiter_capacity parameter #35926

Merged
merged 21 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b0b6d5d
Add 'token_bucket_rate_limiter_capacity' configuration parameter.
ewbankkit Feb 21, 2024
81fdd0b
Add CHANGELOG entry.
ewbankkit Feb 21, 2024
6a21558
Run 'go get github.com/hashicorp/aws-sdk-go-base/v2@7f2a917ddfa4169f8…
ewbankkit Feb 21, 2024
c908828
Set 'awsbase.Config.TokenBucketRateLimiterCapacity'.
ewbankkit Feb 21, 2024
f7a0c6e
'AWSClient.CredentialsProvider' pass Context.
ewbankkit Feb 21, 2024
ac0aa79
'AWSClient.AwsConfig' pass Context.
ewbankkit Feb 21, 2024
6057afa
'AWSClient.PartitionHostname' pass Context.
ewbankkit Feb 21, 2024
15bedce
'AWSClient.RegionalHostname' pass Context.
ewbankkit Feb 21, 2024
2680ff5
'AWSClient.CloudFrontDistributionHostedZoneID' pass Context.
ewbankkit Feb 21, 2024
36a0b22
'AWSClient.DefaultKMSKeyPolicy' pass Context.
ewbankkit Feb 21, 2024
2c2efd9
'AWSClient.GlobalAcceleratorHostedZoneID' pass Context.
ewbankkit Feb 21, 2024
bdde2dd
'AWSClient.S3UsePathStyle' pass Context.
ewbankkit Feb 21, 2024
03fc479
'AWSClient.SetHTTPClient' pass Context.
ewbankkit Feb 21, 2024
ca4f55c
'AWSClient.HTTPClient' pass Context.
ewbankkit Feb 21, 2024
e73cadf
'AWSClient.ReverseDNSPrefix' convert to getter.
ewbankkit Feb 21, 2024
9019c8b
'AWSClient.DNSSuffix' convert to getter.
ewbankkit Feb 21, 2024
14ea6b1
Acceptance test output:
ewbankkit Feb 21, 2024
b5fb26c
Use 'names'.
ewbankkit Feb 21, 2024
9fab229
Merge branch 'main' into f-token_bucket_rate_limiter_capacity
ewbankkit Feb 21, 2024
613ed4c
Correct CHANGELOG entry file name.
ewbankkit Feb 21, 2024
d5dc0b8
Suppress semgrep 'ci.semgrep.migrate.context-todo'.
ewbankkit Feb 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/35926.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
provider: Add `token_bucket_rate_limiter_capacity` parameter
```
5 changes: 3 additions & 2 deletions internal/acctest/vcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func vcrProviderConfigureContextFunc(provider *schema.Provider, configureContext
} else {
meta = new(conns.AWSClient)
}
meta.SetHTTPClient(httpClient)
meta.SetHTTPClient(ctx, httpClient)
provider.SetMeta(meta)

if v, ds := configureContextFunc(ctx, d); ds.HasError() {
Expand Down Expand Up @@ -391,14 +391,15 @@ func closeVCRRecorder(t *testing.T) {
panic(p)
}

ctx := context.TODO() // nosemgrep:ci.semgrep.migrate.context-todo
testName := t.Name()
providerMetas.Lock()
meta, ok := providerMetas[testName]
defer providerMetas.Unlock()

if ok {
if !t.Failed() {
if v, ok := meta.HTTPClient().Transport.(*recorder.Recorder); ok {
if v, ok := meta.HTTPClient(ctx).Transport.(*recorder.Recorder); ok {
t.Log("stopping VCR recorder")
if err := v.Stop(); err != nil {
t.Error(err)
Expand Down
52 changes: 30 additions & 22 deletions internal/conns/awsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
aws_sdkv2 "github.com/aws/aws-sdk-go-v2/aws"
config_sdkv2 "github.com/aws/aws-sdk-go-v2/config"
s3_sdkv2 "github.com/aws/aws-sdk-go-v2/service/s3"
endpoints_sdkv1 "github.com/aws/aws-sdk-go/aws/endpoints"
session_sdkv1 "github.com/aws/aws-sdk-go/aws/session"
apigatewayv2_sdkv1 "github.com/aws/aws-sdk-go/service/apigatewayv2"
baselogging "github.com/hashicorp/aws-sdk-go-base/v2/logging"
Expand All @@ -26,18 +25,17 @@ import (
type AWSClient struct {
AccountID string
DefaultTagsConfig *tftags.DefaultConfig
DNSSuffix string
IgnoreTagsConfig *tftags.IgnoreConfig
Partition string
Region string
ReverseDNSPrefix string
ServicePackages map[string]ServicePackage
Session *session_sdkv1.Session
TerraformVersion string

awsConfig *aws_sdkv2.Config
clients map[string]any
conns map[string]any
dnsSuffix string
endpoints map[string]string // From provider configuration.
httpClient *http.Client
lock sync.Mutex
Expand All @@ -49,29 +47,29 @@ type AWSClient struct {
}

// CredentialsProvider returns the AWS SDK for Go v2 credentials provider.
func (c *AWSClient) CredentialsProvider() aws_sdkv2.CredentialsProvider {
func (c *AWSClient) CredentialsProvider(context.Context) aws_sdkv2.CredentialsProvider {
if c.awsConfig == nil {
return nil
}
return c.awsConfig.Credentials
}

func (c *AWSClient) AwsConfig() aws_sdkv2.Config { // nosemgrep:ci.aws-in-func-name
func (c *AWSClient) AwsConfig(context.Context) aws_sdkv2.Config { // nosemgrep:ci.aws-in-func-name
return c.awsConfig.Copy()
}

// PartitionHostname returns a hostname with the provider domain suffix for the partition
// e.g. PREFIX.amazonaws.com
// The prefix should not contain a trailing period.
func (c *AWSClient) PartitionHostname(prefix string) string {
return fmt.Sprintf("%s.%s", prefix, c.DNSSuffix)
func (c *AWSClient) PartitionHostname(ctx context.Context, prefix string) string {
return fmt.Sprintf("%s.%s", prefix, c.DNSSuffix(ctx))
}

// RegionalHostname returns a hostname with the provider domain suffix for the region and partition
// e.g. PREFIX.us-west-2.amazonaws.com
// The prefix should not contain a trailing period.
func (c *AWSClient) RegionalHostname(prefix string) string {
return fmt.Sprintf("%s.%s.%s", prefix, c.Region, c.DNSSuffix)
func (c *AWSClient) RegionalHostname(ctx context.Context, prefix string) string {
return fmt.Sprintf("%s.%s.%s", prefix, c.Region, c.DNSSuffix(ctx))
}

// S3ExpressClient returns an S3 API client suitable for use with S3 Express (directory buckets).
Expand All @@ -97,20 +95,20 @@ func (c *AWSClient) S3ExpressClient(ctx context.Context) *s3_sdkv2.Client {
}

// S3UsePathStyle returns the s3_force_path_style provider configuration value.
func (c *AWSClient) S3UsePathStyle() bool {
func (c *AWSClient) S3UsePathStyle(context.Context) bool {
return c.s3UsePathStyle
}

// SetHTTPClient sets the http.Client used for AWS API calls.
// To have effect it must be called before the AWS SDK v1 Session is created.
func (c *AWSClient) SetHTTPClient(httpClient *http.Client) {
func (c *AWSClient) SetHTTPClient(_ context.Context, httpClient *http.Client) {
if c.Session == nil {
c.httpClient = httpClient
}
}

// HTTPClient returns the http.Client used for AWS API calls.
func (c *AWSClient) HTTPClient() *http.Client {
func (c *AWSClient) HTTPClient(context.Context) *http.Client {
return c.httpClient
}

Expand All @@ -121,36 +119,36 @@ func (c *AWSClient) RegisterLogger(ctx context.Context) context.Context {

// APIGatewayInvokeURL returns the Amazon API Gateway (REST APIs) invoke URL for the configured AWS Region.
// See https://docs.aws.amazon.com/apigateway/latest/developerguide/how-to-call-api.html.
func (c *AWSClient) APIGatewayInvokeURL(restAPIID, stageName string) string {
return fmt.Sprintf("https://%s/%s", c.RegionalHostname(fmt.Sprintf("%s.execute-api", restAPIID)), stageName)
func (c *AWSClient) APIGatewayInvokeURL(ctx context.Context, restAPIID, stageName string) string {
return fmt.Sprintf("https://%s/%s", c.RegionalHostname(ctx, fmt.Sprintf("%s.execute-api", restAPIID)), stageName)
}

// APIGatewayV2InvokeURL returns the Amazon API Gateway v2 (WebSocket & HTTP APIs) invoke URL for the configured AWS Region.
// See https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-publish.html and
// https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-set-up-websocket-deployment.html.
func (c *AWSClient) APIGatewayV2InvokeURL(protocolType, apiID, stageName string) string {
func (c *AWSClient) APIGatewayV2InvokeURL(ctx context.Context, protocolType, apiID, stageName string) string {
if protocolType == apigatewayv2_sdkv1.ProtocolTypeWebsocket {
return fmt.Sprintf("wss://%s/%s", c.RegionalHostname(fmt.Sprintf("%s.execute-api", apiID)), stageName)
return fmt.Sprintf("wss://%s/%s", c.RegionalHostname(ctx, fmt.Sprintf("%s.execute-api", apiID)), stageName)
}

if stageName == "$default" {
return fmt.Sprintf("https://%s/", c.RegionalHostname(fmt.Sprintf("%s.execute-api", apiID)))
return fmt.Sprintf("https://%s/", c.RegionalHostname(ctx, fmt.Sprintf("%s.execute-api", apiID)))
}

return fmt.Sprintf("https://%s/%s", c.RegionalHostname(fmt.Sprintf("%s.execute-api", apiID)), stageName)
return fmt.Sprintf("https://%s/%s", c.RegionalHostname(ctx, fmt.Sprintf("%s.execute-api", apiID)), stageName)
}

// CloudFrontDistributionHostedZoneID returns the Route 53 hosted zone ID
// for Amazon CloudFront distributions in the configured AWS partition.
func (c *AWSClient) CloudFrontDistributionHostedZoneID() string {
if c.Partition == endpoints_sdkv1.AwsCnPartitionID {
func (c *AWSClient) CloudFrontDistributionHostedZoneID(context.Context) string {
if c.Partition == names.ChinaPartitionID {
return "Z3RFFRIM2A3IF5" // See https://docs.amazonaws.cn/en_us/aws/latest/userguide/route53.html
}
return "Z2FDTNDATAQYW2" // See https://docs.aws.amazon.com/Route53/latest/APIReference/API_AliasTarget.html#Route53-Type-AliasTarget-HostedZoneId
}

// DefaultKMSKeyPolicy returns the default policy for KMS keys in the configured AWS partition.
func (c *AWSClient) DefaultKMSKeyPolicy() string {
func (c *AWSClient) DefaultKMSKeyPolicy(context.Context) string {
return fmt.Sprintf(`
{
"Id": "default",
Expand All @@ -172,10 +170,20 @@ func (c *AWSClient) DefaultKMSKeyPolicy() string {

// GlobalAcceleratorHostedZoneID returns the Route 53 hosted zone ID
// for AWS Global Accelerator accelerators in the configured AWS partition.
func (c *AWSClient) GlobalAcceleratorHostedZoneID() string {
func (c *AWSClient) GlobalAcceleratorHostedZoneID(context.Context) string {
return "Z2BJ6XQ5FK7U4H" // See https://docs.aws.amazon.com/general/latest/gr/global_accelerator.html#global_accelerator_region
}

// DNSSuffix returns the domain suffix for the configured AWS partition.
func (c *AWSClient) DNSSuffix(context.Context) string {
return c.dnsSuffix
}

// ReverseDNSPrefix returns the reverse DNS prefix for the configured AWS partition.
func (c *AWSClient) ReverseDNSPrefix(ctx context.Context) string {
return names.ReverseDNS(c.DNSSuffix(ctx))
}

// apiClientConfig returns the AWS API client configuration parameters for the specified service.
func (c *AWSClient) apiClientConfig(ctx context.Context, servicePackageName string) map[string]any {
m := map[string]any{
Expand Down
15 changes: 9 additions & 6 deletions internal/conns/awsclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
package conns

import (
"context"
"testing"
)

func TestAWSClientPartitionHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-name
t.Parallel()

ctx := context.TODO()
testCases := []struct {
Name string
AWSClient *AWSClient
Expand All @@ -19,15 +21,15 @@ func TestAWSClientPartitionHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-
{
Name: "AWS Commercial",
AWSClient: &AWSClient{
DNSSuffix: "amazonaws.com",
dnsSuffix: "amazonaws.com",
},
Prefix: "test",
Expected: "test.amazonaws.com",
},
{
Name: "AWS China",
AWSClient: &AWSClient{
DNSSuffix: "amazonaws.com.cn",
dnsSuffix: "amazonaws.com.cn",
},
Prefix: "test",
Expected: "test.amazonaws.com.cn",
Expand All @@ -39,7 +41,7 @@ func TestAWSClientPartitionHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-
t.Run(testCase.Name, func(t *testing.T) {
t.Parallel()

got := testCase.AWSClient.PartitionHostname(testCase.Prefix)
got := testCase.AWSClient.PartitionHostname(ctx, testCase.Prefix)

if got != testCase.Expected {
t.Errorf("got %s, expected %s", got, testCase.Expected)
Expand All @@ -51,6 +53,7 @@ func TestAWSClientPartitionHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-
func TestAWSClientRegionalHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-name
t.Parallel()

ctx := context.TODO()
testCases := []struct {
Name string
AWSClient *AWSClient
Expand All @@ -60,7 +63,7 @@ func TestAWSClientRegionalHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-n
{
Name: "AWS Commercial",
AWSClient: &AWSClient{
DNSSuffix: "amazonaws.com",
dnsSuffix: "amazonaws.com",
Region: "us-west-2", //lintignore:AWSAT003
},
Prefix: "test",
Expand All @@ -69,7 +72,7 @@ func TestAWSClientRegionalHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-n
{
Name: "AWS China",
AWSClient: &AWSClient{
DNSSuffix: "amazonaws.com.cn",
dnsSuffix: "amazonaws.com.cn",
Region: "cn-northwest-1", //lintignore:AWSAT003
},
Prefix: "test",
Expand All @@ -82,7 +85,7 @@ func TestAWSClientRegionalHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-n
t.Run(testCase.Name, func(t *testing.T) {
t.Parallel()

got := testCase.AWSClient.RegionalHostname(testCase.Prefix)
got := testCase.AWSClient.RegionalHostname(ctx, testCase.Prefix)

if got != testCase.Expected {
t.Errorf("got %s, expected %s", got, testCase.Expected)
Expand Down
69 changes: 35 additions & 34 deletions internal/conns/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type Config struct {
SuppressDebugLog bool
TerraformVersion string
Token string
TokenBucketRateLimiterCapacity int
UseDualStackEndpoint bool
UseFIPSEndpoint bool
}
Expand All @@ -68,35 +69,36 @@ func (c *Config) ConfigureProvider(ctx context.Context, client *AWSClient) (*AWS
ctx, logger := logging.NewTfLogger(ctx)

awsbaseConfig := awsbase.Config{
AccessKey: c.AccessKey,
AllowedAccountIds: c.AllowedAccountIds,
APNInfo: StdUserAgentProducts(c.TerraformVersion),
AssumeRoleWithWebIdentity: c.AssumeRoleWithWebIdentity,
CallerDocumentationURL: "https://registry.terraform.io/providers/hashicorp/aws",
CallerName: "Terraform AWS Provider",
EC2MetadataServiceEnableState: c.EC2MetadataServiceEnableState,
ForbiddenAccountIds: c.ForbiddenAccountIds,
IamEndpoint: c.Endpoints[names.IAM],
Insecure: c.Insecure,
HTTPClient: client.HTTPClient(),
HTTPProxy: c.HTTPProxy,
HTTPSProxy: c.HTTPSProxy,
HTTPProxyMode: awsbase.HTTPProxyModeLegacy,
Logger: logger,
MaxRetries: c.MaxRetries,
NoProxy: c.NoProxy,
Profile: c.Profile,
Region: c.Region,
RetryMode: c.RetryMode,
SecretKey: c.SecretKey,
SkipCredsValidation: c.SkipCredsValidation,
SkipRequestingAccountId: c.SkipRequestingAccountId,
SsoEndpoint: c.Endpoints[names.SSO],
StsEndpoint: c.Endpoints[names.STS],
SuppressDebugLog: c.SuppressDebugLog,
Token: c.Token,
UseDualStackEndpoint: c.UseDualStackEndpoint,
UseFIPSEndpoint: c.UseFIPSEndpoint,
AccessKey: c.AccessKey,
AllowedAccountIds: c.AllowedAccountIds,
APNInfo: StdUserAgentProducts(c.TerraformVersion),
AssumeRoleWithWebIdentity: c.AssumeRoleWithWebIdentity,
CallerDocumentationURL: "https://registry.terraform.io/providers/hashicorp/aws",
CallerName: "Terraform AWS Provider",
EC2MetadataServiceEnableState: c.EC2MetadataServiceEnableState,
ForbiddenAccountIds: c.ForbiddenAccountIds,
IamEndpoint: c.Endpoints[names.IAM],
Insecure: c.Insecure,
HTTPClient: client.HTTPClient(ctx),
HTTPProxy: c.HTTPProxy,
HTTPSProxy: c.HTTPSProxy,
HTTPProxyMode: awsbase.HTTPProxyModeLegacy,
Logger: logger,
MaxRetries: c.MaxRetries,
NoProxy: c.NoProxy,
Profile: c.Profile,
Region: c.Region,
RetryMode: c.RetryMode,
SecretKey: c.SecretKey,
SkipCredsValidation: c.SkipCredsValidation,
SkipRequestingAccountId: c.SkipRequestingAccountId,
SsoEndpoint: c.Endpoints[names.SSO],
StsEndpoint: c.Endpoints[names.STS],
SuppressDebugLog: c.SuppressDebugLog,
Token: c.Token,
TokenBucketRateLimiterCapacity: c.TokenBucketRateLimiterCapacity,
UseDualStackEndpoint: c.UseDualStackEndpoint,
UseFIPSEndpoint: c.UseFIPSEndpoint,
}

if c.AssumeRole != nil && c.AssumeRole.RoleARN != "" {
Expand Down Expand Up @@ -189,19 +191,18 @@ func (c *Config) ConfigureProvider(ctx context.Context, client *AWSClient) (*AWS
return nil, sdkdiag.AppendErrorf(diags, err.Error())
}

DNSSuffix := "amazonaws.com"
dnsSuffix := "amazonaws.com"
if p, ok := endpoints_sdkv1.PartitionForRegion(endpoints_sdkv1.DefaultPartitions(), c.Region); ok {
DNSSuffix = p.DNSSuffix()
dnsSuffix = p.DNSSuffix()
}

client.AccountID = accountID
client.DefaultTagsConfig = c.DefaultTagsConfig
client.DNSSuffix = DNSSuffix
client.dnsSuffix = dnsSuffix
client.IgnoreTagsConfig = c.IgnoreTagsConfig
client.Partition = partition
client.Region = c.Region
client.ReverseDNSPrefix = names.ReverseDNS(DNSSuffix)
client.SetHTTPClient(sess.Config.HTTPClient) // Must be called while client.Session is nil.
client.SetHTTPClient(ctx, sess.Config.HTTPClient) // Must be called while client.Session is nil.
client.Session = sess
client.TerraformVersion = c.TerraformVersion

Expand Down
2 changes: 1 addition & 1 deletion internal/conns/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ func TestProxyConfig(t *testing.T) {

meta := p.Meta().(*conns.AWSClient)

client := meta.AwsConfig().HTTPClient
client := meta.AwsConfig(ctx).HTTPClient
bClient, ok := client.(*awshttp.BuildableClient)
if !ok {
t.Fatalf("expected awshttp.BuildableClient, got %T", client)
Expand Down
4 changes: 4 additions & 0 deletions internal/provider/fwprovider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ func (p *fwprovider) Schema(ctx context.Context, req provider.SchemaRequest, res
Optional: true,
Description: "session token. A session token is only required if you are\nusing temporary security credentials.",
},
"token_bucket_rate_limiter_capacity": schema.Int64Attribute{
Optional: true,
Description: "The capacity of the AWS SDK's token bucket rate limiter.",
},
"use_dualstack_endpoint": schema.BoolAttribute{
Optional: true,
Description: "Resolve an endpoint with DualStack capability",
Expand Down
Loading
Loading