diff --git a/gov2/aurora/actions/clusters.go b/gov2/aurora/actions/clusters.go index 5af769c5ae3..e515a32db87 100644 --- a/gov2/aurora/actions/clusters.go +++ b/gov2/aurora/actions/clusters.go @@ -23,10 +23,10 @@ type DbClusters struct { // snippet-start:[gov2.aurora.DescribeDBClusterParameterGroups] // GetParameterGroup gets a DB cluster parameter group by name. -func (clusters *DbClusters) GetParameterGroup(parameterGroupName string) ( +func (clusters *DbClusters) GetParameterGroup(ctx context.Context, parameterGroupName string) ( *types.DBClusterParameterGroup, error) { output, err := clusters.AuroraClient.DescribeDBClusterParameterGroups( - context.TODO(), &rds.DescribeDBClusterParameterGroupsInput{ + ctx, &rds.DescribeDBClusterParameterGroupsInput{ DBClusterParameterGroupName: aws.String(parameterGroupName), }) if err != nil { @@ -49,10 +49,10 @@ func (clusters *DbClusters) GetParameterGroup(parameterGroupName string) ( // CreateParameterGroup creates a DB cluster parameter group that is based on the specified // parameter group family. func (clusters *DbClusters) CreateParameterGroup( - parameterGroupName string, parameterGroupFamily string, description string) ( + ctx context.Context, parameterGroupName string, parameterGroupFamily string, description string) ( *types.DBClusterParameterGroup, error) { - output, err := clusters.AuroraClient.CreateDBClusterParameterGroup(context.TODO(), + output, err := clusters.AuroraClient.CreateDBClusterParameterGroup(ctx, &rds.CreateDBClusterParameterGroupInput{ DBClusterParameterGroupName: aws.String(parameterGroupName), DBParameterGroupFamily: aws.String(parameterGroupFamily), @@ -71,8 +71,8 @@ func (clusters *DbClusters) CreateParameterGroup( // snippet-start:[gov2.aurora.DeleteDBClusterParameterGroup] // DeleteParameterGroup deletes the named DB cluster parameter group. -func (clusters *DbClusters) DeleteParameterGroup(parameterGroupName string) error { - _, err := clusters.AuroraClient.DeleteDBClusterParameterGroup(context.TODO(), +func (clusters *DbClusters) DeleteParameterGroup(ctx context.Context, parameterGroupName string) error { + _, err := clusters.AuroraClient.DeleteDBClusterParameterGroup(ctx, &rds.DeleteDBClusterParameterGroupInput{ DBClusterParameterGroupName: aws.String(parameterGroupName), }) @@ -89,7 +89,7 @@ func (clusters *DbClusters) DeleteParameterGroup(parameterGroupName string) erro // snippet-start:[gov2.aurora.DescribeDBClusterParameters] // GetParameters gets the parameters that are contained in a DB cluster parameter group. -func (clusters *DbClusters) GetParameters(parameterGroupName string, source string) ( +func (clusters *DbClusters) GetParameters(ctx context.Context, parameterGroupName string, source string) ( []types.Parameter, error) { var output *rds.DescribeDBClusterParametersOutput @@ -101,7 +101,7 @@ func (clusters *DbClusters) GetParameters(parameterGroupName string, source stri Source: aws.String(source), }) for parameterPaginator.HasMorePages() { - output, err = parameterPaginator.NextPage(context.TODO()) + output, err = parameterPaginator.NextPage(ctx) if err != nil { log.Printf("Couldn't get paramaeters for %v: %v\n", parameterGroupName, err) break @@ -117,8 +117,8 @@ func (clusters *DbClusters) GetParameters(parameterGroupName string, source stri // snippet-start:[gov2.aurora.ModifyDBClusterParameterGroup] // UpdateParameters updates parameters in a named DB cluster parameter group. -func (clusters *DbClusters) UpdateParameters(parameterGroupName string, params []types.Parameter) error { - _, err := clusters.AuroraClient.ModifyDBClusterParameterGroup(context.TODO(), +func (clusters *DbClusters) UpdateParameters(ctx context.Context, parameterGroupName string, params []types.Parameter) error { + _, err := clusters.AuroraClient.ModifyDBClusterParameterGroup(ctx, &rds.ModifyDBClusterParameterGroupInput{ DBClusterParameterGroupName: aws.String(parameterGroupName), Parameters: params, @@ -136,8 +136,8 @@ func (clusters *DbClusters) UpdateParameters(parameterGroupName string, params [ // snippet-start:[gov2.aurora.DescribeDBClusters] // GetDbCluster gets data about an Aurora DB cluster. -func (clusters *DbClusters) GetDbCluster(clusterName string) (*types.DBCluster, error) { - output, err := clusters.AuroraClient.DescribeDBClusters(context.TODO(), +func (clusters *DbClusters) GetDbCluster(ctx context.Context, clusterName string) (*types.DBCluster, error) { + output, err := clusters.AuroraClient.DescribeDBClusters(ctx, &rds.DescribeDBClustersInput{ DBClusterIdentifier: aws.String(clusterName), }) @@ -162,11 +162,11 @@ func (clusters *DbClusters) GetDbCluster(clusterName string) (*types.DBCluster, // CreateDbCluster creates a DB cluster that is configured to use the specified parameter group. // The newly created DB cluster contains a database that uses the specified engine and // engine version. -func (clusters *DbClusters) CreateDbCluster(clusterName string, parameterGroupName string, +func (clusters *DbClusters) CreateDbCluster(ctx context.Context, clusterName string, parameterGroupName string, dbName string, dbEngine string, dbEngineVersion string, adminName string, adminPassword string) ( *types.DBCluster, error) { - output, err := clusters.AuroraClient.CreateDBCluster(context.TODO(), &rds.CreateDBClusterInput{ + output, err := clusters.AuroraClient.CreateDBCluster(ctx, &rds.CreateDBClusterInput{ DBClusterIdentifier: aws.String(clusterName), Engine: aws.String(dbEngine), DBClusterParameterGroupName: aws.String(parameterGroupName), @@ -188,8 +188,8 @@ func (clusters *DbClusters) CreateDbCluster(clusterName string, parameterGroupNa // snippet-start:[gov2.aurora.DeleteDBCluster] // DeleteDbCluster deletes a DB cluster without keeping a final snapshot. -func (clusters *DbClusters) DeleteDbCluster(clusterName string) error { - _, err := clusters.AuroraClient.DeleteDBCluster(context.TODO(), &rds.DeleteDBClusterInput{ +func (clusters *DbClusters) DeleteDbCluster(ctx context.Context, clusterName string) error { + _, err := clusters.AuroraClient.DeleteDBCluster(ctx, &rds.DeleteDBClusterInput{ DBClusterIdentifier: aws.String(clusterName), SkipFinalSnapshot: aws.Bool(true), }) @@ -206,9 +206,9 @@ func (clusters *DbClusters) DeleteDbCluster(clusterName string) error { // snippet-start:[gov2.aurora.CreateDBClusterSnapshot] // CreateClusterSnapshot creates a snapshot of a DB cluster. -func (clusters *DbClusters) CreateClusterSnapshot(clusterName string, snapshotName string) ( +func (clusters *DbClusters) CreateClusterSnapshot(ctx context.Context, clusterName string, snapshotName string) ( *types.DBClusterSnapshot, error) { - output, err := clusters.AuroraClient.CreateDBClusterSnapshot(context.TODO(), &rds.CreateDBClusterSnapshotInput{ + output, err := clusters.AuroraClient.CreateDBClusterSnapshot(ctx, &rds.CreateDBClusterSnapshotInput{ DBClusterIdentifier: aws.String(clusterName), DBClusterSnapshotIdentifier: aws.String(snapshotName), }) @@ -225,8 +225,8 @@ func (clusters *DbClusters) CreateClusterSnapshot(clusterName string, snapshotNa // snippet-start:[gov2.aurora.DescribeDBClusterSnapshots] // GetClusterSnapshot gets a DB cluster snapshot. -func (clusters *DbClusters) GetClusterSnapshot(snapshotName string) (*types.DBClusterSnapshot, error) { - output, err := clusters.AuroraClient.DescribeDBClusterSnapshots(context.TODO(), +func (clusters *DbClusters) GetClusterSnapshot(ctx context.Context, snapshotName string) (*types.DBClusterSnapshot, error) { + output, err := clusters.AuroraClient.DescribeDBClusterSnapshots(ctx, &rds.DescribeDBClusterSnapshotsInput{ DBClusterSnapshotIdentifier: aws.String(snapshotName), }) @@ -244,9 +244,9 @@ func (clusters *DbClusters) GetClusterSnapshot(snapshotName string) (*types.DBCl // CreateInstanceInCluster creates a database instance in an existing DB cluster. The first database that is // created defaults to a read-write DB instance. -func (clusters *DbClusters) CreateInstanceInCluster(clusterName string, instanceName string, +func (clusters *DbClusters) CreateInstanceInCluster(ctx context.Context, clusterName string, instanceName string, dbEngine string, dbInstanceClass string) (*types.DBInstance, error) { - output, err := clusters.AuroraClient.CreateDBInstance(context.TODO(), &rds.CreateDBInstanceInput{ + output, err := clusters.AuroraClient.CreateDBInstance(ctx, &rds.CreateDBInstanceInput{ DBInstanceIdentifier: aws.String(instanceName), DBClusterIdentifier: aws.String(clusterName), Engine: aws.String(dbEngine), @@ -265,9 +265,9 @@ func (clusters *DbClusters) CreateInstanceInCluster(clusterName string, instance // snippet-start:[gov2.aurora.DescribeDBInstances] // GetInstance gets data about a DB instance. -func (clusters *DbClusters) GetInstance(instanceName string) ( +func (clusters *DbClusters) GetInstance(ctx context.Context, instanceName string) ( *types.DBInstance, error) { - output, err := clusters.AuroraClient.DescribeDBInstances(context.TODO(), + output, err := clusters.AuroraClient.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{ DBInstanceIdentifier: aws.String(instanceName), }) @@ -290,8 +290,8 @@ func (clusters *DbClusters) GetInstance(instanceName string) ( // snippet-start:[gov2.aurora.DeleteDBInstance] // DeleteInstance deletes a DB instance. -func (clusters *DbClusters) DeleteInstance(instanceName string) error { - _, err := clusters.AuroraClient.DeleteDBInstance(context.TODO(), &rds.DeleteDBInstanceInput{ +func (clusters *DbClusters) DeleteInstance(ctx context.Context, instanceName string) error { + _, err := clusters.AuroraClient.DeleteDBInstance(ctx, &rds.DeleteDBInstanceInput{ DBInstanceIdentifier: aws.String(instanceName), SkipFinalSnapshot: aws.Bool(true), DeleteAutomatedBackups: aws.Bool(true), @@ -310,9 +310,9 @@ func (clusters *DbClusters) DeleteInstance(instanceName string) error { // GetEngineVersions gets database engine versions that are available for the specified engine // and parameter group family. -func (clusters *DbClusters) GetEngineVersions(engine string, parameterGroupFamily string) ( +func (clusters *DbClusters) GetEngineVersions(ctx context.Context, engine string, parameterGroupFamily string) ( []types.DBEngineVersion, error) { - output, err := clusters.AuroraClient.DescribeDBEngineVersions(context.TODO(), + output, err := clusters.AuroraClient.DescribeDBEngineVersions(ctx, &rds.DescribeDBEngineVersionsInput{ Engine: aws.String(engine), DBParameterGroupFamily: aws.String(parameterGroupFamily), @@ -331,7 +331,7 @@ func (clusters *DbClusters) GetEngineVersions(engine string, parameterGroupFamil // GetOrderableInstances uses a paginator to get DB instance options that can be used to create DB instances that are // compatible with a set of specifications. -func (clusters *DbClusters) GetOrderableInstances(engine string, engineVersion string) ( +func (clusters *DbClusters) GetOrderableInstances(ctx context.Context, engine string, engineVersion string) ( []types.OrderableDBInstanceOption, error) { var output *rds.DescribeOrderableDBInstanceOptionsOutput @@ -343,7 +343,7 @@ func (clusters *DbClusters) GetOrderableInstances(engine string, engineVersion s EngineVersion: aws.String(engineVersion), }) for orderablePaginator.HasMorePages() { - output, err = orderablePaginator.NextPage(context.TODO()) + output, err = orderablePaginator.NextPage(ctx) if err != nil { log.Printf("Couldn't get orderable DB instances: %v\n", err) break diff --git a/gov2/aurora/cmd/main.go b/gov2/aurora/cmd/main.go index 73e2e672ea8..e14c68c1134 100644 --- a/gov2/aurora/cmd/main.go +++ b/gov2/aurora/cmd/main.go @@ -23,7 +23,7 @@ import ( // - `clusters` - Runs the interactive DB clusters scenario that shows you how to use // Amazon Aurora commands to work with DB clusters and databases. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config){ + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config){ "clusters": runClusterScenario, } choices := make([]string, len(scenarioMap)) @@ -41,18 +41,19 @@ func main() { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } log.SetFlags(0) - runScenario(sdkConfig) + runScenario(ctx, sdkConfig) } } -func runClusterScenario(sdkConfig aws.Config) { +func runClusterScenario(ctx context.Context, sdkConfig aws.Config) { scenario := scenarios.NewGetStartedClusters(sdkConfig, demotools.NewQuestioner(), scenarios.ScenarioHelper{}) - scenario.Run("aurora-mysql", "doc-example-cluster-parameter-group", "doc-example-aurora", + scenario.Run(ctx, "aurora-mysql", "doc-example-cluster-parameter-group", "doc-example-aurora", "docexampledb") } diff --git a/gov2/aurora/hello/hello.go b/gov2/aurora/hello/hello.go index b9158990da3..d22999af315 100644 --- a/gov2/aurora/hello/hello.go +++ b/gov2/aurora/hello/hello.go @@ -19,7 +19,8 @@ import ( // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -28,8 +29,8 @@ func main() { auroraClient := rds.NewFromConfig(sdkConfig) const maxClusters = 20 fmt.Printf("Let's list up to %v DB clusters.\n", maxClusters) - output, err := auroraClient.DescribeDBClusters(context.TODO(), - &rds.DescribeDBClustersInput{MaxRecords: aws.Int32(maxClusters)}) + output, err := auroraClient.DescribeDBClusters( + ctx, &rds.DescribeDBClustersInput{MaxRecords: aws.Int32(maxClusters)}) if err != nil { fmt.Printf("Couldn't list DB clusters: %v\n", err) return diff --git a/gov2/aurora/scenarios/get_started_clusters.go b/gov2/aurora/scenarios/get_started_clusters.go index f48ae18adb7..a5f0d7c55f8 100644 --- a/gov2/aurora/scenarios/get_started_clusters.go +++ b/gov2/aurora/scenarios/get_started_clusters.go @@ -5,6 +5,7 @@ package scenarios import ( "aurora/actions" + "context" "fmt" "log" "slices" @@ -71,7 +72,7 @@ func NewGetStartedClusters(sdkConfig aws.Config, questioner demotools.IQuestione } // Run runs the interactive scenario. -func (scenario GetStartedClusters) Run(dbEngine string, parameterGroupName string, +func (scenario GetStartedClusters) Run(ctx context.Context, dbEngine string, parameterGroupName string, clusterName string, dbName string) { defer func() { if r := recover(); r != nil { @@ -83,14 +84,14 @@ func (scenario GetStartedClusters) Run(dbEngine string, parameterGroupName strin log.Println("Welcome to the Amazon Aurora DB Cluster demo.") log.Println(strings.Repeat("-", 88)) - parameterGroup := scenario.CreateParameterGroup(dbEngine, parameterGroupName) - scenario.SetUserParameters(parameterGroupName) - cluster := scenario.CreateCluster(clusterName, dbEngine, dbName, parameterGroup) + parameterGroup := scenario.CreateParameterGroup(ctx, dbEngine, parameterGroupName) + scenario.SetUserParameters(ctx, parameterGroupName) + cluster := scenario.CreateCluster(ctx, clusterName, dbEngine, dbName, parameterGroup) scenario.helper.Pause(5) - dbInstance := scenario.CreateInstance(cluster) + dbInstance := scenario.CreateInstance(ctx, cluster) scenario.DisplayConnection(cluster) - scenario.CreateSnapshot(clusterName) - scenario.Cleanup(dbInstance, cluster, parameterGroup) + scenario.CreateSnapshot(ctx, clusterName) + scenario.Cleanup(ctx, dbInstance, cluster, parameterGroup) log.Println(strings.Repeat("-", 88)) log.Println("Thanks for watching!") @@ -100,18 +101,18 @@ func (scenario GetStartedClusters) Run(dbEngine string, parameterGroupName strin // CreateParameterGroup shows how to get available engine versions for a specified // database engine and create a DB cluster parameter group that is compatible with a // selected engine family. -func (scenario GetStartedClusters) CreateParameterGroup(dbEngine string, +func (scenario GetStartedClusters) CreateParameterGroup(ctx context.Context, dbEngine string, parameterGroupName string) *types.DBClusterParameterGroup { log.Printf("Checking for an existing DB cluster parameter group named %v.\n", parameterGroupName) - parameterGroup, err := scenario.dbClusters.GetParameterGroup(parameterGroupName) + parameterGroup, err := scenario.dbClusters.GetParameterGroup(ctx, parameterGroupName) if err != nil { panic(err) } if parameterGroup == nil { log.Printf("Getting available database engine versions for %v.\n", dbEngine) - engineVersions, err := scenario.dbClusters.GetEngineVersions(dbEngine, "") + engineVersions, err := scenario.dbClusters.GetEngineVersions(ctx, dbEngine, "") if err != nil { panic(err) } @@ -128,11 +129,11 @@ func (scenario GetStartedClusters) CreateParameterGroup(dbEngine string, familyIndex := scenario.questioner.AskChoice("Which family do you want to use?\n", families) log.Println("Creating a DB cluster parameter group.") _, err = scenario.dbClusters.CreateParameterGroup( - parameterGroupName, families[familyIndex], "Example parameter group.") + ctx, parameterGroupName, families[familyIndex], "Example parameter group.") if err != nil { panic(err) } - parameterGroup, err = scenario.dbClusters.GetParameterGroup(parameterGroupName) + parameterGroup, err = scenario.dbClusters.GetParameterGroup(ctx, parameterGroupName) if err != nil { panic(err) } @@ -149,9 +150,9 @@ func (scenario GetStartedClusters) CreateParameterGroup(dbEngine string, // SetUserParameters shows how to get the parameters contained in a custom parameter // group and update some of the parameter values in the group. -func (scenario GetStartedClusters) SetUserParameters(parameterGroupName string) { +func (scenario GetStartedClusters) SetUserParameters(ctx context.Context, parameterGroupName string) { log.Println("Let's set some parameter values in your parameter group.") - dbParameters, err := scenario.dbClusters.GetParameters(parameterGroupName, "") + dbParameters, err := scenario.dbClusters.GetParameters(ctx, parameterGroupName, "") if err != nil { panic(err) } @@ -171,12 +172,12 @@ func (scenario GetStartedClusters) SetUserParameters(parameterGroupName string) updateParams = append(updateParams, dbParam) } } - err = scenario.dbClusters.UpdateParameters(parameterGroupName, updateParams) + err = scenario.dbClusters.UpdateParameters(ctx, parameterGroupName, updateParams) if err != nil { panic(err) } log.Println("You can get a list of parameters you've set by specifying a source of 'user'.") - userParameters, err := scenario.dbClusters.GetParameters(parameterGroupName, "user") + userParameters, err := scenario.dbClusters.GetParameters(ctx, parameterGroupName, "user") if err != nil { panic(err) } @@ -190,11 +191,11 @@ func (scenario GetStartedClusters) SetUserParameters(parameterGroupName string) // CreateCluster shows how to create an Aurora DB cluster that contains a database // of a specified type. The database is also configured to use a custom DB cluster // parameter group. -func (scenario GetStartedClusters) CreateCluster(clusterName string, dbEngine string, +func (scenario GetStartedClusters) CreateCluster(ctx context.Context, clusterName string, dbEngine string, dbName string, parameterGroup *types.DBClusterParameterGroup) *types.DBCluster { log.Println("Checking for an existing DB cluster.") - cluster, err := scenario.dbClusters.GetDbCluster(clusterName) + cluster, err := scenario.dbClusters.GetDbCluster(ctx, clusterName) if err != nil { panic(err) } @@ -203,7 +204,7 @@ func (scenario GetStartedClusters) CreateCluster(clusterName string, dbEngine st "Enter an administrator user name for the database: ", demotools.NotEmpty{}) adminPassword := scenario.questioner.Ask( "Enter a password for the administrator (at least 8 characters): ", demotools.NotEmpty{}) - engineVersions, err := scenario.dbClusters.GetEngineVersions(dbEngine, *parameterGroup.DBParameterGroupFamily) + engineVersions, err := scenario.dbClusters.GetEngineVersions(ctx, dbEngine, *parameterGroup.DBParameterGroupFamily) if err != nil { panic(err) } @@ -219,14 +220,14 @@ func (scenario GetStartedClusters) CreateCluster(clusterName string, dbEngine st log.Printf("and selected engine %v.\n", engineChoices[engineIndex]) log.Println("This typically takes several minutes.") cluster, err = scenario.dbClusters.CreateDbCluster( - clusterName, *parameterGroup.DBClusterParameterGroupName, dbName, dbEngine, + ctx, clusterName, *parameterGroup.DBClusterParameterGroupName, dbName, dbEngine, engineChoices[engineIndex], adminUsername, adminPassword) if err != nil { panic(err) } for *cluster.Status != "available" { scenario.helper.Pause(30) - cluster, err = scenario.dbClusters.GetDbCluster(clusterName) + cluster, err = scenario.dbClusters.GetDbCluster(ctx, clusterName) if err != nil { panic(err) } @@ -248,9 +249,9 @@ func (scenario GetStartedClusters) CreateCluster(clusterName string, dbEngine st // CreateInstance shows how to create a DB instance in an existing Aurora DB cluster. // A new DB cluster contains no DB instances, so you must add one. The first DB instance // that is added to a DB cluster defaults to a read-write DB instance. -func (scenario GetStartedClusters) CreateInstance(cluster *types.DBCluster) *types.DBInstance { +func (scenario GetStartedClusters) CreateInstance(ctx context.Context, cluster *types.DBCluster) *types.DBInstance { log.Println("Checking for an existing database instance.") - dbInstance, err := scenario.dbClusters.GetInstance(*cluster.DBClusterIdentifier) + dbInstance, err := scenario.dbClusters.GetInstance(ctx, *cluster.DBClusterIdentifier) if err != nil { panic(err) } @@ -258,7 +259,7 @@ func (scenario GetStartedClusters) CreateInstance(cluster *types.DBCluster) *typ log.Println("Let's create a database instance in your DB cluster.") log.Println("First, choose a DB instance type:") instOpts, err := scenario.dbClusters.GetOrderableInstances( - *cluster.Engine, *cluster.EngineVersion) + ctx, *cluster.Engine, *cluster.EngineVersion) if err != nil { panic(err) } @@ -272,14 +273,14 @@ func (scenario GetStartedClusters) CreateInstance(cluster *types.DBCluster) *typ "Which DB instance class do you want to use?\n", instChoices) log.Println("Creating a database instance. This typically takes several minutes.") dbInstance, err = scenario.dbClusters.CreateInstanceInCluster( - *cluster.DBClusterIdentifier, *cluster.DBClusterIdentifier, *cluster.Engine, + ctx, *cluster.DBClusterIdentifier, *cluster.DBClusterIdentifier, *cluster.Engine, instChoices[instIndex]) if err != nil { panic(err) } for *dbInstance.DBInstanceStatus != "available" { scenario.helper.Pause(30) - dbInstance, err = scenario.dbClusters.GetInstance(*cluster.DBClusterIdentifier) + dbInstance, err = scenario.dbClusters.GetInstance(ctx, *cluster.DBClusterIdentifier) if err != nil { panic(err) } @@ -312,18 +313,18 @@ func (scenario GetStartedClusters) DisplayConnection(cluster *types.DBCluster) { } // CreateSnapshot shows how to create a DB cluster snapshot and wait until it's available. -func (scenario GetStartedClusters) CreateSnapshot(clusterName string) { +func (scenario GetStartedClusters) CreateSnapshot(ctx context.Context, clusterName string) { if scenario.questioner.AskBool( "Do you want to create a snapshot of your DB cluster (y/n)? ", "y") { snapshotId := fmt.Sprintf("%v-%v", clusterName, scenario.helper.UniqueId()) log.Printf("Creating a snapshot named %v. This typically takes a few minutes.\n", snapshotId) - snapshot, err := scenario.dbClusters.CreateClusterSnapshot(clusterName, snapshotId) + snapshot, err := scenario.dbClusters.CreateClusterSnapshot(ctx, clusterName, snapshotId) if err != nil { panic(err) } for *snapshot.Status != "available" { scenario.helper.Pause(30) - snapshot, err = scenario.dbClusters.GetClusterSnapshot(snapshotId) + snapshot, err = scenario.dbClusters.GetClusterSnapshot(ctx, snapshotId) if err != nil { panic(err) } @@ -343,18 +344,18 @@ func (scenario GetStartedClusters) CreateSnapshot(clusterName string) { // Cleanup shows how to clean up a DB instance, DB cluster, and DB cluster parameter group. // Before the DB cluster parameter group can be deleted, all associated DB instances and // DB clusters must first be deleted. -func (scenario GetStartedClusters) Cleanup(dbInstance *types.DBInstance, cluster *types.DBCluster, +func (scenario GetStartedClusters) Cleanup(ctx context.Context, dbInstance *types.DBInstance, cluster *types.DBCluster, parameterGroup *types.DBClusterParameterGroup) { if scenario.questioner.AskBool( "\nDo you want to delete the database instance, DB cluster, and parameter group (y/n)? ", "y") { log.Printf("Deleting database instance %v.\n", *dbInstance.DBInstanceIdentifier) - err := scenario.dbClusters.DeleteInstance(*dbInstance.DBInstanceIdentifier) + err := scenario.dbClusters.DeleteInstance(ctx, *dbInstance.DBInstanceIdentifier) if err != nil { panic(err) } log.Printf("Deleting database cluster %v.\n", *cluster.DBClusterIdentifier) - err = scenario.dbClusters.DeleteDbCluster(*cluster.DBClusterIdentifier) + err = scenario.dbClusters.DeleteDbCluster(ctx, *cluster.DBClusterIdentifier) if err != nil { panic(err) } @@ -363,20 +364,20 @@ func (scenario GetStartedClusters) Cleanup(dbInstance *types.DBInstance, cluster for dbInstance != nil || cluster != nil { scenario.helper.Pause(30) if dbInstance != nil { - dbInstance, err = scenario.dbClusters.GetInstance(*dbInstance.DBInstanceIdentifier) + dbInstance, err = scenario.dbClusters.GetInstance(ctx, *dbInstance.DBInstanceIdentifier) if err != nil { panic(err) } } if cluster != nil { - cluster, err = scenario.dbClusters.GetDbCluster(*cluster.DBClusterIdentifier) + cluster, err = scenario.dbClusters.GetDbCluster(ctx, *cluster.DBClusterIdentifier) if err != nil { panic(err) } } } log.Printf("Deleting parameter group %v.", *parameterGroup.DBClusterParameterGroupName) - err = scenario.dbClusters.DeleteParameterGroup(*parameterGroup.DBClusterParameterGroupName) + err = scenario.dbClusters.DeleteParameterGroup(ctx, *parameterGroup.DBClusterParameterGroupName) if err != nil { panic(err) } diff --git a/gov2/aurora/scenarios/get_started_clusters_integ_test.go b/gov2/aurora/scenarios/get_started_clusters_integ_test.go index 6a544003751..5fbbbcb76ac 100644 --- a/gov2/aurora/scenarios/get_started_clusters_integ_test.go +++ b/gov2/aurora/scenarios/get_started_clusters_integ_test.go @@ -53,7 +53,8 @@ func TestRunGetStartedClustersScenario_Integration(t *testing.T) { }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -65,6 +66,7 @@ func TestRunGetStartedClustersScenario_Integration(t *testing.T) { scenario := NewGetStartedClusters(sdkConfig, mockQuestioner, &helper) testId := time.Now().Unix() scenario.Run( + ctx, "aurora-mysql", fmt.Sprintf("doc-example-cluster-parameter-group-%v", testId), fmt.Sprintf("doc-example-aurora-%v", testId), diff --git a/gov2/aurora/scenarios/get_started_clusters_test.go b/gov2/aurora/scenarios/get_started_clusters_test.go index b49ecf67038..246d6d5a4e9 100644 --- a/gov2/aurora/scenarios/get_started_clusters_test.go +++ b/gov2/aurora/scenarios/get_started_clusters_test.go @@ -4,6 +4,7 @@ package scenarios import ( "aurora/stubs" + "context" "fmt" "strconv" "testing" @@ -139,7 +140,7 @@ func (scenTest *GetStartedClustersTest) RunSubTest(stubber *testtools.AwsmStubbe mockQuestioner := demotools.MockQuestioner{Answers: scenTest.Answers} scenario := NewGetStartedClusters(*stubber.SdkConfig, &mockQuestioner, &scenTest.helper) scenario.isTestRun = true - scenario.Run(scenTest.dbEngine, scenTest.parameterGroupName, scenTest.clusterName, scenTest.dbName) + scenario.Run(context.Background(), scenTest.dbEngine, scenTest.parameterGroupName, scenTest.clusterName, scenTest.dbName) } func (scenTest *GetStartedClustersTest) Cleanup() {} diff --git a/gov2/bedrock-runtime/actions/invoke_model.go b/gov2/bedrock-runtime/actions/invoke_model.go index 8d72be41355..bf44ebb8a06 100644 --- a/gov2/bedrock-runtime/actions/invoke_model.go +++ b/gov2/bedrock-runtime/actions/invoke_model.go @@ -43,7 +43,7 @@ type ClaudeResponse struct { // Invokes Anthropic Claude on Amazon Bedrock to run an inference using the input // provided in the request body. -func (wrapper InvokeModelWrapper) InvokeClaude(prompt string) (string, error) { +func (wrapper InvokeModelWrapper) InvokeClaude(ctx context.Context, prompt string) (string, error) { modelId := "anthropic.claude-v2" // Anthropic Claude requires enclosing the prompt as follows: @@ -60,7 +60,7 @@ func (wrapper InvokeModelWrapper) InvokeClaude(prompt string) (string, error) { log.Fatal("failed to marshal", err) } - output, err := wrapper.BedrockRuntimeClient.InvokeModel(context.TODO(), &bedrockruntime.InvokeModelInput{ + output, err := wrapper.BedrockRuntimeClient.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ ModelId: aws.String(modelId), ContentType: aws.String("application/json"), Body: body, @@ -104,7 +104,7 @@ type Data struct { // Invokes AI21 Labs Jurassic-2 on Amazon Bedrock to run an inference using the input // provided in the request body. -func (wrapper InvokeModelWrapper) InvokeJurassic2(prompt string) (string, error) { +func (wrapper InvokeModelWrapper) InvokeJurassic2(ctx context.Context, prompt string) (string, error) { modelId := "ai21.j2-mid-v1" body, err := json.Marshal(Jurassic2Request{ @@ -117,7 +117,7 @@ func (wrapper InvokeModelWrapper) InvokeJurassic2(prompt string) (string, error) log.Fatal("failed to marshal", err) } - output, err := wrapper.BedrockRuntimeClient.InvokeModel(context.TODO(), &bedrockruntime.InvokeModelInput{ + output, err := wrapper.BedrockRuntimeClient.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ ModelId: aws.String(modelId), ContentType: aws.String("application/json"), Body: body, @@ -155,7 +155,7 @@ type Llama2Response struct { // Invokes Meta Llama 2 Chat on Amazon Bedrock to run an inference using the input // provided in the request body. -func (wrapper InvokeModelWrapper) InvokeLlama2(prompt string) (string, error) { +func (wrapper InvokeModelWrapper) InvokeLlama2(ctx context.Context, prompt string) (string, error) { modelId := "meta.llama2-13b-chat-v1" body, err := json.Marshal(Llama2Request{ @@ -168,7 +168,7 @@ func (wrapper InvokeModelWrapper) InvokeLlama2(prompt string) (string, error) { log.Fatal("failed to marshal", err) } - output, err := wrapper.BedrockRuntimeClient.InvokeModel(context.TODO(), &bedrockruntime.InvokeModelInput{ + output, err := wrapper.BedrockRuntimeClient.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ ModelId: aws.String(modelId), ContentType: aws.String("application/json"), Body: body, @@ -213,7 +213,7 @@ type TitanImageResponse struct { // Invokes the Titan Image model to create an image using the input provided // in the request body. -func (wrapper InvokeModelWrapper) InvokeTitanImage(prompt string, seed int64) (string, error) { +func (wrapper InvokeModelWrapper) InvokeTitanImage(ctx context.Context, prompt string, seed int64) (string, error) { modelId := "amazon.titan-image-generator-v1" body, err := json.Marshal(TitanImageRequest{ @@ -235,7 +235,7 @@ func (wrapper InvokeModelWrapper) InvokeTitanImage(prompt string, seed int64) (s log.Fatal("failed to marshal", err) } - output, err := wrapper.BedrockRuntimeClient.InvokeModel(context.TODO(), &bedrockruntime.InvokeModelInput{ + output, err := wrapper.BedrockRuntimeClient.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ ModelId: aws.String(modelId), ContentType: aws.String("application/json"), Body: body, @@ -286,7 +286,7 @@ type Result struct { CompletionReason string `json:"completionReason"` } -func (wrapper InvokeModelWrapper) InvokeTitanText(prompt string) (string, error) { +func (wrapper InvokeModelWrapper) InvokeTitanText(ctx context.Context, prompt string) (string, error) { modelId := "amazon.titan-text-express-v1" body, err := json.Marshal(TitanTextRequest{ @@ -302,7 +302,7 @@ func (wrapper InvokeModelWrapper) InvokeTitanText(prompt string) (string, error) log.Fatal("failed to marshal", err) } - output, err := wrapper.BedrockRuntimeClient.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{ + output, err := wrapper.BedrockRuntimeClient.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ ModelId: aws.String(modelId), ContentType: aws.String("application/json"), Body: body, diff --git a/gov2/bedrock-runtime/actions/invoke_model_test.go b/gov2/bedrock-runtime/actions/invoke_model_test.go index 9500d9d061e..6824b74df50 100644 --- a/gov2/bedrock-runtime/actions/invoke_model_test.go +++ b/gov2/bedrock-runtime/actions/invoke_model_test.go @@ -6,6 +6,7 @@ package actions import ( + "context" "encoding/json" "log" "testing" @@ -33,33 +34,34 @@ func CallInvokeModelActions(sdkConfig aws.Config) { client := bedrockruntime.NewFromConfig(sdkConfig) wrapper := InvokeModelWrapper{client} + ctx := context.Background() - claudeCompletion, err := wrapper.InvokeClaude(prompt) + claudeCompletion, err := wrapper.InvokeClaude(ctx, prompt) if err != nil { panic(err) } log.Println(claudeCompletion) - jurassic2Completion, err := wrapper.InvokeJurassic2(prompt) + jurassic2Completion, err := wrapper.InvokeJurassic2(ctx, prompt) if err != nil { panic(err) } log.Println(jurassic2Completion) - llama2Completion, err := wrapper.InvokeLlama2(prompt) + llama2Completion, err := wrapper.InvokeLlama2(ctx, prompt) if err != nil { panic(err) } log.Println(llama2Completion) seed := int64(0) - titanImageCompletion, err := wrapper.InvokeTitanImage(prompt, seed) + titanImageCompletion, err := wrapper.InvokeTitanImage(ctx, prompt, seed) if err != nil { panic(err) } log.Println(titanImageCompletion) - titanTextCompletion, err := wrapper.InvokeTitanText(prompt) + titanTextCompletion, err := wrapper.InvokeTitanText(ctx, prompt) if err != nil { panic(err) } diff --git a/gov2/bedrock-runtime/actions/invoke_model_with_response_stream.go b/gov2/bedrock-runtime/actions/invoke_model_with_response_stream.go index 44c2f3a40e3..b608dc077a4 100644 --- a/gov2/bedrock-runtime/actions/invoke_model_with_response_stream.go +++ b/gov2/bedrock-runtime/actions/invoke_model_with_response_stream.go @@ -46,7 +46,7 @@ type Response struct { // Invokes Anthropic Claude on Amazon Bedrock to run an inference and asynchronously // process the response stream. -func (wrapper InvokeModelWithResponseStreamWrapper) InvokeModelWithResponseStream(prompt string) (string, error) { +func (wrapper InvokeModelWithResponseStreamWrapper) InvokeModelWithResponseStream(ctx context.Context, prompt string) (string, error) { modelId := "anthropic.claude-v2" @@ -67,7 +67,7 @@ func (wrapper InvokeModelWithResponseStreamWrapper) InvokeModelWithResponseStrea log.Panicln("Couldn't marshal the request: ", err) } - output, err := wrapper.BedrockRuntimeClient.InvokeModelWithResponseStream(context.Background(), &bedrockruntime.InvokeModelWithResponseStreamInput{ + output, err := wrapper.BedrockRuntimeClient.InvokeModelWithResponseStream(ctx, &bedrockruntime.InvokeModelWithResponseStreamInput{ Body: body, ModelId: aws.String(modelId), ContentType: aws.String("application/json"), @@ -84,7 +84,7 @@ func (wrapper InvokeModelWithResponseStreamWrapper) InvokeModelWithResponseStrea } } - resp, err := processStreamingOutput(output, func(ctx context.Context, part []byte) error { + resp, err := processStreamingOutput(ctx, output, func(ctx context.Context, part []byte) error { fmt.Print(string(part)) return nil }) @@ -99,7 +99,7 @@ func (wrapper InvokeModelWithResponseStreamWrapper) InvokeModelWithResponseStrea type StreamingOutputHandler func(ctx context.Context, part []byte) error -func processStreamingOutput(output *bedrockruntime.InvokeModelWithResponseStreamOutput, handler StreamingOutputHandler) (Response, error) { +func processStreamingOutput(ctx context.Context, output *bedrockruntime.InvokeModelWithResponseStreamOutput, handler StreamingOutputHandler) (Response, error) { var combinedResult string resp := Response{} @@ -116,7 +116,7 @@ func processStreamingOutput(output *bedrockruntime.InvokeModelWithResponseStream return resp, err } - err = handler(context.Background(), []byte(resp.Completion)) + err = handler(ctx, []byte(resp.Completion)) if err != nil { return resp, err } diff --git a/gov2/bedrock-runtime/cmd/main.go b/gov2/bedrock-runtime/cmd/main.go index 6aa67599b18..8703c768ce9 100644 --- a/gov2/bedrock-runtime/cmd/main.go +++ b/gov2/bedrock-runtime/cmd/main.go @@ -11,8 +11,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" - "github.com/awsdocs/aws-doc-sdk-examples/gov2/demotools" "github.com/awsdocs/aws-doc-sdk-examples/gov2/bedrock-runtime/scenarios" + "github.com/awsdocs/aws-doc-sdk-examples/gov2/demotools" ) // main loads default AWS credentials and configuration from the ~/.aws folder and runs @@ -20,12 +20,12 @@ import ( // // `-scenario` can be one of the following: // -// * `invokemodels` - Runs a scenario that shows how to invoke various image and text -// generation models on Amazon Bedrock. +// - `invokemodels` - Runs a scenario that shows how to invoke various image and text +// generation models on Amazon Bedrock. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config){ - "invokemodels": runInvokeModelsScenario, + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config){ + "invokemodels": runInvokeModelsScenario, } choices := make([]string, len(scenarioMap)) choiceIndex := 0 @@ -37,26 +37,27 @@ func main() { "scenario", "", fmt.Sprintf("The scenario to run. Must be one of %v.", choices)) - var region = flag.String("region", "us-east-1", "The AWS region") - flag.Parse() + var region = flag.String("region", "us-east-1", "The AWS region") + flag.Parse() - fmt.Printf("Using AWS region: %s\n", *region) + fmt.Printf("Using AWS region: %s\n", *region) if runScenario, ok := scenarioMap[*scenario]; !ok { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(*region)) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(*region)) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } log.SetFlags(0) - runScenario(sdkConfig) + runScenario(ctx, sdkConfig) } } -func runInvokeModelsScenario(sdkConfig aws.Config) { +func runInvokeModelsScenario(ctx context.Context, sdkConfig aws.Config) { scenario := scenarios.NewInvokeModelsScenario(sdkConfig, demotools.NewQuestioner()) - scenario.Run() + scenario.Run(ctx) } diff --git a/gov2/bedrock-runtime/hello/hello.go b/gov2/bedrock-runtime/hello/hello.go index 1ab0bbbdea5..27f5c3e2067 100644 --- a/gov2/bedrock-runtime/hello/hello.go +++ b/gov2/bedrock-runtime/hello/hello.go @@ -44,7 +44,8 @@ func main() { fmt.Printf("Using AWS region: %s\n", *region) - sdkConfig, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(*region)) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(*region)) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -72,7 +73,7 @@ func main() { log.Panicln("Couldn't marshal the request: ", err) } - result, err := client.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{ + result, err := client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ ModelId: aws.String(modelId), ContentType: aws.String("application/json"), Body: body, diff --git a/gov2/bedrock-runtime/scenarios/scenario_invoke_models.go b/gov2/bedrock-runtime/scenarios/scenario_invoke_models.go index 517190babad..74824fb50ff 100644 --- a/gov2/bedrock-runtime/scenarios/scenario_invoke_models.go +++ b/gov2/bedrock-runtime/scenarios/scenario_invoke_models.go @@ -4,6 +4,7 @@ package scenarios import ( + "context" "encoding/base64" "fmt" "log" @@ -50,7 +51,7 @@ func NewInvokeModelsScenario(sdkConfig aws.Config, questioner demotools.IQuestio } // Runs the interactive scenario. -func (scenario InvokeModelsScenario) Run() { +func (scenario InvokeModelsScenario) Run(ctx context.Context) { defer func() { if r := recover(); r != nil { log.Printf("Something went wrong with the demo: %v\n", r) @@ -67,22 +68,22 @@ func (scenario InvokeModelsScenario) Run() { log.Println(strings.Repeat("-", 77)) log.Printf("Invoking Claude with prompt: %v\n", text2textPrompt) - scenario.InvokeClaude(text2textPrompt) + scenario.InvokeClaude(ctx, text2textPrompt) log.Println(strings.Repeat("-", 77)) log.Printf("Invoking Jurassic-2 with prompt: %v\n", text2textPrompt) - scenario.InvokeJurassic2(text2textPrompt) + scenario.InvokeJurassic2(ctx, text2textPrompt) log.Println(strings.Repeat("-", 77)) log.Printf("Invoking Llama2 with prompt: %v\n", text2textPrompt) - scenario.InvokeLlama2(text2textPrompt) + scenario.InvokeLlama2(ctx, text2textPrompt) log.Println(strings.Repeat("=", 77)) log.Printf("Now, let's invoke Claude with the asynchronous client and process the response stream:\n\n") log.Println(strings.Repeat("-", 77)) log.Printf("Invoking Claude with prompt: %v\n", text2textPrompt) - scenario.InvokeWithResponseStream(text2textPrompt) + scenario.InvokeWithResponseStream(ctx, text2textPrompt) log.Println(strings.Repeat("=", 77)) log.Printf("Now, let's create an image with the Amazon Titan image generation model:\n\n") @@ -92,52 +93,52 @@ func (scenario InvokeModelsScenario) Run() { log.Println(strings.Repeat("-", 77)) log.Printf("Invoking Amazon Titan with prompt: %v\n", text2ImagePrompt) - scenario.InvokeTitanImage(text2ImagePrompt, seed) + scenario.InvokeTitanImage(ctx, text2ImagePrompt, seed) log.Println(strings.Repeat("-", 77)) log.Printf("Invoking Titan Text Express with prompt: %v\n", text2textPrompt) - scenario.InvokeTitanText(text2textPrompt) + scenario.InvokeTitanText(ctx, text2textPrompt) log.Println(strings.Repeat("=", 77)) log.Println("Thanks for watching!") log.Println(strings.Repeat("=", 77)) } -func (scenario InvokeModelsScenario) InvokeClaude(prompt string) { - completion, err := scenario.invokeModelWrapper.InvokeClaude(prompt) +func (scenario InvokeModelsScenario) InvokeClaude(ctx context.Context, prompt string) { + completion, err := scenario.invokeModelWrapper.InvokeClaude(ctx, prompt) if err != nil { panic(err) } log.Printf("\nClaude : %v\n", strings.TrimSpace(completion)) } -func (scenario InvokeModelsScenario) InvokeJurassic2(prompt string) { - completion, err := scenario.invokeModelWrapper.InvokeJurassic2(prompt) +func (scenario InvokeModelsScenario) InvokeJurassic2(ctx context.Context, prompt string) { + completion, err := scenario.invokeModelWrapper.InvokeJurassic2(ctx, prompt) if err != nil { panic(err) } log.Printf("\nJurassic-2 : %v\n", strings.TrimSpace(completion)) } -func (scenario InvokeModelsScenario) InvokeLlama2(prompt string) { - completion, err := scenario.invokeModelWrapper.InvokeLlama2(prompt) +func (scenario InvokeModelsScenario) InvokeLlama2(ctx context.Context, prompt string) { + completion, err := scenario.invokeModelWrapper.InvokeLlama2(ctx, prompt) if err != nil { panic(err) } log.Printf("\nLlama 2 : %v\n\n", strings.TrimSpace(completion)) } -func (scenario InvokeModelsScenario) InvokeWithResponseStream(prompt string) { +func (scenario InvokeModelsScenario) InvokeWithResponseStream(ctx context.Context, prompt string) { log.Println("\nClaude with response stream:") - _, err := scenario.responseStreamWrapper.InvokeModelWithResponseStream(prompt) + _, err := scenario.responseStreamWrapper.InvokeModelWithResponseStream(ctx, prompt) if err != nil { panic(err) } log.Println() } -func (scenario InvokeModelsScenario) InvokeTitanImage(prompt string, seed int64) { - base64ImageData, err := scenario.invokeModelWrapper.InvokeTitanImage(prompt, seed) +func (scenario InvokeModelsScenario) InvokeTitanImage(ctx context.Context, prompt string, seed int64) { + base64ImageData, err := scenario.invokeModelWrapper.InvokeTitanImage(ctx, prompt, seed) if err != nil { panic(err) } @@ -145,8 +146,8 @@ func (scenario InvokeModelsScenario) InvokeTitanImage(prompt string, seed int64) fmt.Printf("The generated image has been saved to %s\n", imagePath) } -func (scenario InvokeModelsScenario) InvokeTitanText(prompt string) { - completion, err := scenario.invokeModelWrapper.InvokeTitanText(prompt) +func (scenario InvokeModelsScenario) InvokeTitanText(ctx context.Context, prompt string) { + completion, err := scenario.invokeModelWrapper.InvokeTitanText(ctx, prompt) if err != nil { panic(err) } diff --git a/gov2/bedrock/actions/foundation_model.go b/gov2/bedrock/actions/foundation_model.go index 403730d06db..b570c06b70d 100644 --- a/gov2/bedrock/actions/foundation_model.go +++ b/gov2/bedrock/actions/foundation_model.go @@ -7,8 +7,8 @@ import ( "context" "log" - "github.com/aws/aws-sdk-go-v2/service/bedrock" - "github.com/aws/aws-sdk-go-v2/service/bedrock/types" + "github.com/aws/aws-sdk-go-v2/service/bedrock" + "github.com/aws/aws-sdk-go-v2/service/bedrock/types" ) // snippet-start:[gov2.bedrock.FoundationModelWrapper.complete] @@ -25,18 +25,18 @@ type FoundationModelWrapper struct { // snippet-start:[gov2.bedrock.ListFoundationModels] // ListPolicies lists Bedrock foundation models that you can use. -func (wrapper FoundationModelWrapper) ListFoundationModels() ([]types.FoundationModelSummary, error) { +func (wrapper FoundationModelWrapper) ListFoundationModels(ctx context.Context) ([]types.FoundationModelSummary, error) { - var models []types.FoundationModelSummary + var models []types.FoundationModelSummary - result, err := wrapper.BedrockClient.ListFoundationModels(context.TODO(), &bedrock.ListFoundationModelsInput{}) + result, err := wrapper.BedrockClient.ListFoundationModels(ctx, &bedrock.ListFoundationModelsInput{}) - if err != nil { - log.Printf("Couldn't list foundation models. Here's why: %v\n", err) - } else { - models = result.ModelSummaries - } - return models, err + if err != nil { + log.Printf("Couldn't list foundation models. Here's why: %v\n", err) + } else { + models = result.ModelSummaries + } + return models, err } // snippet-end:[gov2.bedrock.ListFoundationModels] diff --git a/gov2/bedrock/actions/foundation_model_test.go b/gov2/bedrock/actions/foundation_model_test.go index 53e17e2d379..9dfce282419 100644 --- a/gov2/bedrock/actions/foundation_model_test.go +++ b/gov2/bedrock/actions/foundation_model_test.go @@ -6,8 +6,9 @@ package actions import ( - "testing" + "context" "log" + "testing" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/bedrock" @@ -15,40 +16,43 @@ import ( "github.com/awsdocs/aws-doc-sdk-examples/gov2/testtools" ) -func CallFoundationModelActions(sdkConfig aws.Config, ) { +func CallFoundationModelActions(sdkConfig aws.Config) { defer func() { if r := recover(); r != nil { log.Println(r) } }() - bedrockClient := bedrock.NewFromConfig(sdkConfig) - foundationModelWrapper := FoundationModelWrapper{bedrockClient} + bedrockClient := bedrock.NewFromConfig(sdkConfig) + foundationModelWrapper := FoundationModelWrapper{bedrockClient} - models, err := foundationModelWrapper.ListFoundationModels() - if err != nil {panic(err)} - for _, model := range models { - log.Println(*model.ModelId) - } + ctx := context.Background() + models, err := foundationModelWrapper.ListFoundationModels(ctx) + if err != nil { + panic(err) + } + for _, model := range models { + log.Println(*model.ModelId) + } - log.Printf("Thanks for watching!") + log.Printf("Thanks for watching!") } func TestCallFoundationModelActions(t *testing.T) { - scenTest := FoundationModelActionsTest{} - testtools.RunScenarioTests(&scenTest, t) + scenTest := FoundationModelActionsTest{} + testtools.RunScenarioTests(&scenTest, t) } -type FoundationModelActionsTest struct {} +type FoundationModelActionsTest struct{} func (scenTest *FoundationModelActionsTest) SetupDataAndStubs() []testtools.Stub { - var stubList []testtools.Stub - stubList = append(stubList, stubs.StubListFoundationModels(nil)) - return stubList + var stubList []testtools.Stub + stubList = append(stubList, stubs.StubListFoundationModels(nil)) + return stubList } func (scenTest *FoundationModelActionsTest) RunSubTest(stubber *testtools.AwsmStubber) { CallFoundationModelActions(*stubber.SdkConfig) } -func (scenTest *FoundationModelActionsTest) Cleanup() {} \ No newline at end of file +func (scenTest *FoundationModelActionsTest) Cleanup() {} diff --git a/gov2/bedrock/hello/hello.go b/gov2/bedrock/hello/hello.go index 0c719728355..a198123c484 100644 --- a/gov2/bedrock/hello/hello.go +++ b/gov2/bedrock/hello/hello.go @@ -20,23 +20,25 @@ const region = "us-east-1" // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) - if err != nil { - fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") - fmt.Println(err) - return - } - bedrockClient := bedrock.NewFromConfig(sdkConfig) - result, err := bedrockClient.ListFoundationModels(context.TODO(), &bedrock.ListFoundationModelsInput{}) - if err != nil { + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") + fmt.Println(err) + return + } + bedrockClient := bedrock.NewFromConfig(sdkConfig) + result, err := bedrockClient.ListFoundationModels(ctx, &bedrock.ListFoundationModelsInput{}) + if err != nil { fmt.Printf("Couldn't list foundation models. Here's why: %v\n", err) return - } - if len(result.ModelSummaries) == 0 { - fmt.Println("There are no foundation models.")} - for _, modelSummary := range result.ModelSummaries { - fmt.Println(*modelSummary.ModelId) - } + } + if len(result.ModelSummaries) == 0 { + fmt.Println("There are no foundation models.") + } + for _, modelSummary := range result.ModelSummaries { + fmt.Println(*modelSummary.ModelId) + } } -// snippet-end:[gov2.bedrock.Hello] \ No newline at end of file +// snippet-end:[gov2.bedrock.Hello] diff --git a/gov2/cloudfront/CreateDistribution/CreateDistribution.go b/gov2/cloudfront/CreateDistribution/CreateDistribution.go index 2b2d8b0f3d7..86dd1c6ad76 100644 --- a/gov2/cloudfront/CreateDistribution/CreateDistribution.go +++ b/gov2/cloudfront/CreateDistribution/CreateDistribution.go @@ -22,8 +22,8 @@ import ( // CFDistributionAPI defines the interface for the CreateDistribution function. // We use this interface to test the function using a mocked service. type CFDistributionAPI interface { - CreateDistribution(bucketName, certificateSSLArn, domain string) (*cloudfront.CreateDistributionOutput, error) - createoriginAccessIdentity(domainName string) (string, error) + CreateDistribution(ctx context.Context, bucketName, certificateSSLArn, domain string) (*cloudfront.CreateDistributionOutput, error) + createoriginAccessIdentity(ctx context.Context, domainName string) (string, error) } type CFDistributionAPIImpl struct { @@ -38,8 +38,8 @@ func createCFDistribution(s3client *s3.Client, cloudfront *cloudfront.Client) CF } } -func (c *CFDistributionAPIImpl) CreateDistribution(bucketName, certificateSSLArn, domain string) (*cloudfront.CreateDistributionOutput, error) { - locationOutput, err := c.s3Client.GetBucketLocation(context.Background(), &s3.GetBucketLocationInput{Bucket: aws.String(bucketName)}) +func (c *CFDistributionAPIImpl) CreateDistribution(ctx context.Context, bucketName, certificateSSLArn, domain string) (*cloudfront.CreateDistributionOutput, error) { + locationOutput, err := c.s3Client.GetBucketLocation(ctx, &s3.GetBucketLocationInput{Bucket: aws.String(bucketName)}) if err != nil { return nil, err @@ -50,12 +50,12 @@ func (c *CFDistributionAPIImpl) CreateDistribution(bucketName, certificateSSLArn return nil, err } - originAccessIdentityID, err := c.createoriginAccessIdentity(domain) + originAccessIdentityID, err := c.createoriginAccessIdentity(ctx, domain) if err != nil { return nil, err } - cloudfrontResponse, err := c.cloudfrontClient.CreateDistribution(context.TODO(), &cloudfront.CreateDistributionInput{ + cloudfrontResponse, err := c.cloudfrontClient.CreateDistribution(ctx, &cloudfront.CreateDistributionInput{ DistributionConfig: &cloudfrontTypes.DistributionConfig{ Enabled: aws.Bool(true), CallerReference: &originDomain, @@ -118,8 +118,7 @@ func (c *CFDistributionAPIImpl) CreateDistribution(bucketName, certificateSSLArn return cloudfrontResponse, nil } -func (c *CFDistributionAPIImpl) createoriginAccessIdentity(domainName string) (string, error) { - ctx := context.Background() +func (c *CFDistributionAPIImpl) createoriginAccessIdentity(ctx context.Context, domainName string) (string, error) { oai, err := c.cloudfrontClient.CreateCloudFrontOriginAccessIdentity(ctx, &cloudfront.CreateCloudFrontOriginAccessIdentityInput{ CloudFrontOriginAccessIdentityConfig: &cloudfrontTypes.CloudFrontOriginAccessIdentityConfig{ CallerReference: aws.String(domainName), @@ -169,7 +168,8 @@ func main() { return } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") @@ -182,7 +182,7 @@ func main() { cfDistribution := createCFDistribution(s3Client, cloudfrontClient) - result, err := cfDistribution.CreateDistribution(bucketName, certificateSSLArn, domain) + result, err := cfDistribution.CreateDistribution(ctx, bucketName, certificateSSLArn, domain) if err != nil { fmt.Println("Couldn't create distribution. Please check error message and try again.") fmt.Println(err) diff --git a/gov2/cloudfront/CreateDistribution/CreateDistribution_test.go b/gov2/cloudfront/CreateDistribution/CreateDistribution_test.go index a755c8e010f..8eeb4000140 100644 --- a/gov2/cloudfront/CreateDistribution/CreateDistribution_test.go +++ b/gov2/cloudfront/CreateDistribution/CreateDistribution_test.go @@ -21,7 +21,7 @@ type MockCFDistributionAPI struct { cloudfrontClient *cloudfront.Client } -func (m *MockCFDistributionAPI) CreateDistribution(bucketName, certificateSSLArn, domain string) (*cloudfront.CreateDistributionOutput, error) { +func (m *MockCFDistributionAPI) CreateDistribution(ctx context.Context, bucketName, certificateSSLArn, domain string) (*cloudfront.CreateDistributionOutput, error) { if bucketName == "" || certificateSSLArn == "" || domain == "" { return nil, errors.New("bucket name, certificate SSL ARN, and domain are required") } @@ -38,7 +38,7 @@ func (m *MockCFDistributionAPI) CreateDistribution(bucketName, certificateSSLArn }, nil } -func (m *MockCFDistributionAPI) createoriginAccessIdentity(domainName string) (string, error) { +func (m *MockCFDistributionAPI) createoriginAccessIdentity(ctx context.Context, domainName string) (string, error) { return domainName, nil } @@ -54,7 +54,8 @@ func TestCreateDistribution(t *testing.T) { nowString := thisTime.Format("2006-01-02 15:04:05 Monday") t.Log("Starting integration test at " + nowString) - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { t.Log("Got an error ...:") @@ -71,7 +72,7 @@ func TestCreateDistribution(t *testing.T) { certificateSSLArn := "arn:aws:acm:ap-northeast-2:123456789000:certificate/000000000-0000-0000-0000-000000000000" domain := "example.com" - result, err := mockCFDistribution.CreateDistribution(bucketName, certificateSSLArn, domain) + result, err := mockCFDistribution.CreateDistribution(ctx, bucketName, certificateSSLArn, domain) if err != nil { t.Error(err) diff --git a/gov2/cognito/hello/hello.go b/gov2/cognito/hello/hello.go index a2acac94d16..f1d4b10fdce 100644 --- a/gov2/cognito/hello/hello.go +++ b/gov2/cognito/hello/hello.go @@ -21,7 +21,8 @@ import ( // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -33,7 +34,7 @@ func main() { paginator := cognitoidentityprovider.NewListUserPoolsPaginator( cognitoClient, &cognitoidentityprovider.ListUserPoolsInput{MaxResults: aws.Int32(10)}) for paginator.HasMorePages() { - output, err := paginator.NextPage(context.TODO()) + output, err := paginator.NextPage(ctx) if err != nil { log.Printf("Couldn't get user pools. Here's why: %v\n", err) } else { diff --git a/gov2/dynamodb/actions/partiql.go b/gov2/dynamodb/actions/partiql.go index 58e7b310ce6..399350608d4 100644 --- a/gov2/dynamodb/actions/partiql.go +++ b/gov2/dynamodb/actions/partiql.go @@ -30,12 +30,12 @@ type PartiQLRunner struct { // snippet-start:[gov2.dynamodb.ExecuteStatement.Insert] // AddMovie runs a PartiQL INSERT statement to add a movie to the DynamoDB table. -func (runner PartiQLRunner) AddMovie(movie Movie) error { +func (runner PartiQLRunner) AddMovie(ctx context.Context, movie Movie) error { params, err := attributevalue.MarshalList([]interface{}{movie.Title, movie.Year, movie.Info}) if err != nil { panic(err) } - _, err = runner.DynamoDbClient.ExecuteStatement(context.TODO(), &dynamodb.ExecuteStatementInput{ + _, err = runner.DynamoDbClient.ExecuteStatement(ctx, &dynamodb.ExecuteStatementInput{ Statement: aws.String( fmt.Sprintf("INSERT INTO \"%v\" VALUE {'title': ?, 'year': ?, 'info': ?}", runner.TableName)), @@ -53,13 +53,13 @@ func (runner PartiQLRunner) AddMovie(movie Movie) error { // GetMovie runs a PartiQL SELECT statement to get a movie from the DynamoDB table by // title and year. -func (runner PartiQLRunner) GetMovie(title string, year int) (Movie, error) { +func (runner PartiQLRunner) GetMovie(ctx context.Context, title string, year int) (Movie, error) { var movie Movie params, err := attributevalue.MarshalList([]interface{}{title, year}) if err != nil { panic(err) } - response, err := runner.DynamoDbClient.ExecuteStatement(context.TODO(), &dynamodb.ExecuteStatementInput{ + response, err := runner.DynamoDbClient.ExecuteStatement(ctx, &dynamodb.ExecuteStatementInput{ Statement: aws.String( fmt.Sprintf("SELECT * FROM \"%v\" WHERE title=? AND year=?", runner.TableName)), @@ -83,13 +83,13 @@ func (runner PartiQLRunner) GetMovie(title string, year int) (Movie, error) { // GetAllMovies runs a PartiQL SELECT statement to get all movies from the DynamoDB table. // pageSize is not typically required and is used to show how to paginate the results. // The results are projected to return only the title and rating of each movie. -func (runner PartiQLRunner) GetAllMovies(pageSize int32) ([]map[string]interface{}, error) { +func (runner PartiQLRunner) GetAllMovies(ctx context.Context, pageSize int32) ([]map[string]interface{}, error) { var output []map[string]interface{} var response *dynamodb.ExecuteStatementOutput var err error var nextToken *string for moreData := true; moreData; { - response, err = runner.DynamoDbClient.ExecuteStatement(context.TODO(), &dynamodb.ExecuteStatementInput{ + response, err = runner.DynamoDbClient.ExecuteStatement(ctx, &dynamodb.ExecuteStatementInput{ Statement: aws.String( fmt.Sprintf("SELECT title, info.rating FROM \"%v\"", runner.TableName)), Limit: aws.Int32(pageSize), @@ -120,12 +120,12 @@ func (runner PartiQLRunner) GetAllMovies(pageSize int32) ([]map[string]interface // UpdateMovie runs a PartiQL UPDATE statement to update the rating of a movie that // already exists in the DynamoDB table. -func (runner PartiQLRunner) UpdateMovie(movie Movie, rating float64) error { +func (runner PartiQLRunner) UpdateMovie(ctx context.Context, movie Movie, rating float64) error { params, err := attributevalue.MarshalList([]interface{}{rating, movie.Title, movie.Year}) if err != nil { panic(err) } - _, err = runner.DynamoDbClient.ExecuteStatement(context.TODO(), &dynamodb.ExecuteStatementInput{ + _, err = runner.DynamoDbClient.ExecuteStatement(ctx, &dynamodb.ExecuteStatementInput{ Statement: aws.String( fmt.Sprintf("UPDATE \"%v\" SET info.rating=? WHERE title=? AND year=?", runner.TableName)), @@ -142,12 +142,12 @@ func (runner PartiQLRunner) UpdateMovie(movie Movie, rating float64) error { // snippet-start:[gov2.dynamodb.ExecuteStatement.Delete] // DeleteMovie runs a PartiQL DELETE statement to remove a movie from the DynamoDB table. -func (runner PartiQLRunner) DeleteMovie(movie Movie) error { +func (runner PartiQLRunner) DeleteMovie(ctx context.Context, movie Movie) error { params, err := attributevalue.MarshalList([]interface{}{movie.Title, movie.Year}) if err != nil { panic(err) } - _, err = runner.DynamoDbClient.ExecuteStatement(context.TODO(), &dynamodb.ExecuteStatementInput{ + _, err = runner.DynamoDbClient.ExecuteStatement(ctx, &dynamodb.ExecuteStatementInput{ Statement: aws.String( fmt.Sprintf("DELETE FROM \"%v\" WHERE title=? AND year=?", runner.TableName)), @@ -165,7 +165,7 @@ func (runner PartiQLRunner) DeleteMovie(movie Movie) error { // AddMovieBatch runs a batch of PartiQL INSERT statements to add multiple movies to the // DynamoDB table. -func (runner PartiQLRunner) AddMovieBatch(movies []Movie) error { +func (runner PartiQLRunner) AddMovieBatch(ctx context.Context, movies []Movie) error { statementRequests := make([]types.BatchStatementRequest, len(movies)) for index, movie := range movies { params, err := attributevalue.MarshalList([]interface{}{movie.Title, movie.Year, movie.Info}) @@ -179,7 +179,7 @@ func (runner PartiQLRunner) AddMovieBatch(movies []Movie) error { } } - _, err := runner.DynamoDbClient.BatchExecuteStatement(context.TODO(), &dynamodb.BatchExecuteStatementInput{ + _, err := runner.DynamoDbClient.BatchExecuteStatement(ctx, &dynamodb.BatchExecuteStatementInput{ Statements: statementRequests, }) if err != nil { @@ -194,7 +194,7 @@ func (runner PartiQLRunner) AddMovieBatch(movies []Movie) error { // GetMovieBatch runs a batch of PartiQL SELECT statements to get multiple movies from // the DynamoDB table by title and year. -func (runner PartiQLRunner) GetMovieBatch(movies []Movie) ([]Movie, error) { +func (runner PartiQLRunner) GetMovieBatch(ctx context.Context, movies []Movie) ([]Movie, error) { statementRequests := make([]types.BatchStatementRequest, len(movies)) for index, movie := range movies { params, err := attributevalue.MarshalList([]interface{}{movie.Title, movie.Year}) @@ -208,7 +208,7 @@ func (runner PartiQLRunner) GetMovieBatch(movies []Movie) ([]Movie, error) { } } - output, err := runner.DynamoDbClient.BatchExecuteStatement(context.TODO(), &dynamodb.BatchExecuteStatementInput{ + output, err := runner.DynamoDbClient.BatchExecuteStatement(ctx, &dynamodb.BatchExecuteStatementInput{ Statements: statementRequests, }) var outMovies []Movie @@ -234,7 +234,7 @@ func (runner PartiQLRunner) GetMovieBatch(movies []Movie) ([]Movie, error) { // UpdateMovieBatch runs a batch of PartiQL UPDATE statements to update the rating of // multiple movies that already exist in the DynamoDB table. -func (runner PartiQLRunner) UpdateMovieBatch(movies []Movie, ratings []float64) error { +func (runner PartiQLRunner) UpdateMovieBatch(ctx context.Context, movies []Movie, ratings []float64) error { statementRequests := make([]types.BatchStatementRequest, len(movies)) for index, movie := range movies { params, err := attributevalue.MarshalList([]interface{}{ratings[index], movie.Title, movie.Year}) @@ -248,7 +248,7 @@ func (runner PartiQLRunner) UpdateMovieBatch(movies []Movie, ratings []float64) } } - _, err := runner.DynamoDbClient.BatchExecuteStatement(context.TODO(), &dynamodb.BatchExecuteStatementInput{ + _, err := runner.DynamoDbClient.BatchExecuteStatement(ctx, &dynamodb.BatchExecuteStatementInput{ Statements: statementRequests, }) if err != nil { @@ -263,7 +263,7 @@ func (runner PartiQLRunner) UpdateMovieBatch(movies []Movie, ratings []float64) // DeleteMovieBatch runs a batch of PartiQL DELETE statements to remove multiple movies // from the DynamoDB table. -func (runner PartiQLRunner) DeleteMovieBatch(movies []Movie) error { +func (runner PartiQLRunner) DeleteMovieBatch(ctx context.Context, movies []Movie) error { statementRequests := make([]types.BatchStatementRequest, len(movies)) for index, movie := range movies { params, err := attributevalue.MarshalList([]interface{}{movie.Title, movie.Year}) @@ -277,7 +277,7 @@ func (runner PartiQLRunner) DeleteMovieBatch(movies []Movie) error { } } - _, err := runner.DynamoDbClient.BatchExecuteStatement(context.TODO(), &dynamodb.BatchExecuteStatementInput{ + _, err := runner.DynamoDbClient.BatchExecuteStatement(ctx, &dynamodb.BatchExecuteStatementInput{ Statements: statementRequests, }) if err != nil { diff --git a/gov2/dynamodb/actions/partiql_test.go b/gov2/dynamodb/actions/partiql_test.go index c60d3974f53..f752e2bc9fc 100644 --- a/gov2/dynamodb/actions/partiql_test.go +++ b/gov2/dynamodb/actions/partiql_test.go @@ -6,6 +6,7 @@ package actions import ( + "context" "errors" "fmt" "testing" @@ -16,10 +17,11 @@ import ( "github.com/awsdocs/aws-doc-sdk-examples/gov2/testtools" ) -func enterPartiQLTest() (*testtools.AwsmStubber, *PartiQLRunner) { +func enterPartiQLTest() (context.Context, *testtools.AwsmStubber, *PartiQLRunner) { + ctx := context.Background() stubber := testtools.NewStubber() runner := &PartiQLRunner{TableName: "test-table", DynamoDbClient: dynamodb.NewFromConfig(*stubber.SdkConfig)} - return stubber, runner + return ctx, stubber, runner } func TestPartiQL_AddMovie(t *testing.T) { @@ -28,7 +30,7 @@ func TestPartiQL_AddMovie(t *testing.T) { } func AddMoviePartiQL(raiseErr *testtools.StubError, t *testing.T) { - stubber, runner := enterPartiQLTest() + ctx, stubber, runner := enterPartiQLTest() movie := Movie{Title: "Test movie", Year: 2001, Info: map[string]interface{}{ "rating": 3.5, "plot": "Not bad."}} @@ -37,7 +39,7 @@ func AddMoviePartiQL(raiseErr *testtools.StubError, t *testing.T) { fmt.Sprintf("INSERT INTO \"%v\" VALUE {'title': ?, 'year': ?, 'info': ?}", runner.TableName), []interface{}{movie.Title, movie.Year, movie.Info}, nil, nil, nil, nil, raiseErr)) - err := runner.AddMovie(movie) + err := runner.AddMovie(ctx, movie) testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) @@ -49,7 +51,7 @@ func TestPartiQL_GetMovie(t *testing.T) { } func GetMoviePartiQL(raiseErr *testtools.StubError, t *testing.T) { - stubber, runner := enterPartiQLTest() + ctx, stubber, runner := enterPartiQLTest() movie := Movie{Title: "Test movie", Year: 2001, Info: map[string]interface{}{ "rating": 3.5, "plot": "Not bad."}} @@ -58,7 +60,7 @@ func GetMoviePartiQL(raiseErr *testtools.StubError, t *testing.T) { fmt.Sprintf("SELECT * FROM \"%v\" WHERE title=? AND year=?", runner.TableName), []interface{}{movie.Title, movie.Year}, nil, nil, movie, nil, raiseErr)) - gotMovie, err := runner.GetMovie(movie.Title, movie.Year) + gotMovie, err := runner.GetMovie(ctx, movie.Title, movie.Year) testtools.VerifyError(err, raiseErr, t) if err == nil { @@ -76,7 +78,7 @@ func TestPartiQL_GetAllMovies(t *testing.T) { } func GetAllMoviesPartiQL(raiseErr *testtools.StubError, t *testing.T) { - stubber, runner := enterPartiQLTest() + ctx, stubber, runner := enterPartiQLTest() outProjection := map[string]interface{}{"title": "Test movie", "rating": 3.5} @@ -84,7 +86,7 @@ func GetAllMoviesPartiQL(raiseErr *testtools.StubError, t *testing.T) { fmt.Sprintf("SELECT title, info.rating FROM \"%v\"", runner.TableName), nil, aws.Int32(2), nil, outProjection, nil, raiseErr)) - gotProjections, err := runner.GetAllMovies(2) + gotProjections, err := runner.GetAllMovies(ctx, 2) testtools.VerifyError(err, raiseErr, t) if err == nil { @@ -104,7 +106,7 @@ func TestPartiQL_UpdateMovie(t *testing.T) { } func UpdateMoviePartiQL(raiseErr *testtools.StubError, t *testing.T) { - stubber, runner := enterPartiQLTest() + ctx, stubber, runner := enterPartiQLTest() movie := Movie{Title: "Test movie", Year: 2001, Info: map[string]interface{}{ "rating": 3.5, "plot": "Not bad."}} @@ -114,7 +116,7 @@ func UpdateMoviePartiQL(raiseErr *testtools.StubError, t *testing.T) { fmt.Sprintf("UPDATE \"%v\" SET info.rating=? WHERE title=? AND year=?", runner.TableName), []interface{}{newRating, movie.Title, movie.Year}, nil, nil, movie, nil, raiseErr)) - err := runner.UpdateMovie(movie, newRating) + err := runner.UpdateMovie(ctx, movie, newRating) testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) @@ -126,7 +128,7 @@ func TestPartiQL_DeleteMovie(t *testing.T) { } func DeleteMoviePartiQL(raiseErr *testtools.StubError, t *testing.T) { - stubber, runner := enterPartiQLTest() + ctx, stubber, runner := enterPartiQLTest() movie := Movie{Title: "Test movie", Year: 2001, Info: map[string]interface{}{ "rating": 3.5, "plot": "Not bad."}} @@ -135,7 +137,7 @@ func DeleteMoviePartiQL(raiseErr *testtools.StubError, t *testing.T) { fmt.Sprintf("DELETE FROM \"%v\" WHERE title=? AND year=?", runner.TableName), []interface{}{movie.Title, movie.Year}, nil, nil, movie, nil, raiseErr)) - err := runner.DeleteMovie(movie) + err := runner.DeleteMovie(ctx, movie) testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) @@ -147,7 +149,7 @@ func TestPartiQL_AddMovieBatch(t *testing.T) { } func AddMovieBatchPartiQL(raiseErr *testtools.StubError, t *testing.T) { - stubber, runner := enterPartiQLTest() + ctx, stubber, runner := enterPartiQLTest() movies := make([]Movie, 3) statements := make([]string, len(movies)) @@ -166,7 +168,7 @@ func AddMovieBatchPartiQL(raiseErr *testtools.StubError, t *testing.T) { stubber.Add(stubs.StubBatchExecuteStatement(statements, paramList, nil, raiseErr)) - err := runner.AddMovieBatch(movies) + err := runner.AddMovieBatch(ctx, movies) testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) @@ -178,7 +180,7 @@ func TestPartiQL_GetMovieBatch(t *testing.T) { } func GetMovieBatchPartiQL(raiseErr *testtools.StubError, t *testing.T) { - stubber, runner := enterPartiQLTest() + ctx, stubber, runner := enterPartiQLTest() movies := make([]Movie, 3) statements := make([]string, len(movies)) @@ -201,7 +203,7 @@ func GetMovieBatchPartiQL(raiseErr *testtools.StubError, t *testing.T) { } stubber.Add(stubs.StubBatchExecuteStatement(statements, paramList, intMovies, raiseErr)) - outMovies, err := runner.GetMovieBatch(movies) + outMovies, err := runner.GetMovieBatch(ctx, movies) testtools.VerifyError(err, raiseErr, t) if err == nil { @@ -223,7 +225,7 @@ func TestPartiQL_UpdateMovieBatch(t *testing.T) { } func UpdateMovieBatchPartiQL(raiseErr *testtools.StubError, t *testing.T) { - stubber, runner := enterPartiQLTest() + ctx, stubber, runner := enterPartiQLTest() movies := make([]Movie, 3) newRatings := make([]float64, len(movies)) @@ -242,7 +244,7 @@ func UpdateMovieBatchPartiQL(raiseErr *testtools.StubError, t *testing.T) { stubber.Add(stubs.StubBatchExecuteStatement(statements, paramList, nil, raiseErr)) - err := runner.UpdateMovieBatch(movies, newRatings) + err := runner.UpdateMovieBatch(ctx, movies, newRatings) testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) @@ -254,7 +256,7 @@ func TestPartiQL_DeleteMovieBatch(t *testing.T) { } func DeleteMovieBatchPartiQL(raiseErr *testtools.StubError, t *testing.T) { - stubber, runner := enterPartiQLTest() + ctx, stubber, runner := enterPartiQLTest() movies := make([]Movie, 3) statements := make([]string, len(movies)) @@ -271,7 +273,7 @@ func DeleteMovieBatchPartiQL(raiseErr *testtools.StubError, t *testing.T) { stubber.Add(stubs.StubBatchExecuteStatement(statements, paramList, nil, raiseErr)) - err := runner.DeleteMovieBatch(movies) + err := runner.DeleteMovieBatch(ctx, movies) testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) diff --git a/gov2/dynamodb/actions/table_basics.go b/gov2/dynamodb/actions/table_basics.go index d0d74fff648..b7f07d38a2f 100644 --- a/gov2/dynamodb/actions/table_basics.go +++ b/gov2/dynamodb/actions/table_basics.go @@ -31,10 +31,10 @@ type TableBasics struct { // snippet-start:[gov2.dynamodb.DescribeTable] // TableExists determines whether a DynamoDB table exists. -func (basics TableBasics) TableExists() (bool, error) { +func (basics TableBasics) TableExists(ctx context.Context) (bool, error) { exists := true _, err := basics.DynamoDbClient.DescribeTable( - context.TODO(), &dynamodb.DescribeTableInput{TableName: aws.String(basics.TableName)}, + ctx, &dynamodb.DescribeTableInput{TableName: aws.String(basics.TableName)}, ) if err != nil { var notFoundEx *types.ResourceNotFoundException @@ -57,9 +57,9 @@ func (basics TableBasics) TableExists() (bool, error) { // a string sort key named `title`, and a numeric partition key named `year`. // This function uses NewTableExistsWaiter to wait for the table to be created by // DynamoDB before it returns. -func (basics TableBasics) CreateMovieTable() (*types.TableDescription, error) { +func (basics TableBasics) CreateMovieTable(ctx context.Context) (*types.TableDescription, error) { var tableDesc *types.TableDescription - table, err := basics.DynamoDbClient.CreateTable(context.TODO(), &dynamodb.CreateTableInput{ + table, err := basics.DynamoDbClient.CreateTable(ctx, &dynamodb.CreateTableInput{ AttributeDefinitions: []types.AttributeDefinition{{ AttributeName: aws.String("year"), AttributeType: types.ScalarAttributeTypeN, @@ -84,7 +84,7 @@ func (basics TableBasics) CreateMovieTable() (*types.TableDescription, error) { log.Printf("Couldn't create table %v. Here's why: %v\n", basics.TableName, err) } else { waiter := dynamodb.NewTableExistsWaiter(basics.DynamoDbClient) - err = waiter.Wait(context.TODO(), &dynamodb.DescribeTableInput{ + err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{ TableName: aws.String(basics.TableName)}, 5*time.Minute) if err != nil { log.Printf("Wait for table exists failed. Here's why: %v\n", err) @@ -99,13 +99,13 @@ func (basics TableBasics) CreateMovieTable() (*types.TableDescription, error) { // snippet-start:[gov2.dynamodb.ListTables] // ListTables lists the DynamoDB table names for the current account. -func (basics TableBasics) ListTables() ([]string, error) { +func (basics TableBasics) ListTables(ctx context.Context) ([]string, error) { var tableNames []string var output *dynamodb.ListTablesOutput var err error tablePaginator := dynamodb.NewListTablesPaginator(basics.DynamoDbClient, &dynamodb.ListTablesInput{}) for tablePaginator.HasMorePages() { - output, err = tablePaginator.NextPage(context.TODO()) + output, err = tablePaginator.NextPage(ctx) if err != nil { log.Printf("Couldn't list tables. Here's why: %v\n", err) break @@ -121,12 +121,12 @@ func (basics TableBasics) ListTables() ([]string, error) { // snippet-start:[gov2.dynamodb.PutItem] // AddMovie adds a movie the DynamoDB table. -func (basics TableBasics) AddMovie(movie Movie) error { +func (basics TableBasics) AddMovie(ctx context.Context, movie Movie) error { item, err := attributevalue.MarshalMap(movie) if err != nil { panic(err) } - _, err = basics.DynamoDbClient.PutItem(context.TODO(), &dynamodb.PutItemInput{ + _, err = basics.DynamoDbClient.PutItem(ctx, &dynamodb.PutItemInput{ TableName: aws.String(basics.TableName), Item: item, }) if err != nil { @@ -142,7 +142,7 @@ func (basics TableBasics) AddMovie(movie Movie) error { // UpdateMovie updates the rating and plot of a movie that already exists in the // DynamoDB table. This function uses the `expression` package to build the update // expression. -func (basics TableBasics) UpdateMovie(movie Movie) (map[string]map[string]interface{}, error) { +func (basics TableBasics) UpdateMovie(ctx context.Context, movie Movie) (map[string]map[string]interface{}, error) { var err error var response *dynamodb.UpdateItemOutput var attributeMap map[string]map[string]interface{} @@ -152,7 +152,7 @@ func (basics TableBasics) UpdateMovie(movie Movie) (map[string]map[string]interf if err != nil { log.Printf("Couldn't build expression for update. Here's why: %v\n", err) } else { - response, err = basics.DynamoDbClient.UpdateItem(context.TODO(), &dynamodb.UpdateItemInput{ + response, err = basics.DynamoDbClient.UpdateItem(ctx, &dynamodb.UpdateItemInput{ TableName: aws.String(basics.TableName), Key: movie.GetKey(), ExpressionAttributeNames: expr.Names(), @@ -179,7 +179,7 @@ func (basics TableBasics) UpdateMovie(movie Movie) (map[string]map[string]interf // AddMovieBatch adds a slice of movies to the DynamoDB table. The function sends // batches of 25 movies to DynamoDB until all movies are added or it reaches the // specified maximum. -func (basics TableBasics) AddMovieBatch(movies []Movie, maxMovies int) (int, error) { +func (basics TableBasics) AddMovieBatch(ctx context.Context, movies []Movie, maxMovies int) (int, error) { var err error var item map[string]types.AttributeValue written := 0 @@ -202,7 +202,7 @@ func (basics TableBasics) AddMovieBatch(movies []Movie, maxMovies int) (int, err ) } } - _, err = basics.DynamoDbClient.BatchWriteItem(context.TODO(), &dynamodb.BatchWriteItemInput{ + _, err = basics.DynamoDbClient.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{ RequestItems: map[string][]types.WriteRequest{basics.TableName: writeReqs}}) if err != nil { log.Printf("Couldn't add a batch of movies to %v. Here's why: %v\n", basics.TableName, err) @@ -222,9 +222,9 @@ func (basics TableBasics) AddMovieBatch(movies []Movie, maxMovies int) (int, err // GetMovie gets movie data from the DynamoDB table by using the primary composite key // made of title and year. -func (basics TableBasics) GetMovie(title string, year int) (Movie, error) { +func (basics TableBasics) GetMovie(ctx context.Context, title string, year int) (Movie, error) { movie := Movie{Title: title, Year: year} - response, err := basics.DynamoDbClient.GetItem(context.TODO(), &dynamodb.GetItemInput{ + response, err := basics.DynamoDbClient.GetItem(ctx, &dynamodb.GetItemInput{ Key: movie.GetKey(), TableName: aws.String(basics.TableName), }) if err != nil { @@ -245,7 +245,7 @@ func (basics TableBasics) GetMovie(title string, year int) (Movie, error) { // Query gets all movies in the DynamoDB table that were released in the specified year. // The function uses the `expression` package to build the key condition expression // that is used in the query. -func (basics TableBasics) Query(releaseYear int) ([]Movie, error) { +func (basics TableBasics) Query(ctx context.Context, releaseYear int) ([]Movie, error) { var err error var response *dynamodb.QueryOutput var movies []Movie @@ -261,7 +261,7 @@ func (basics TableBasics) Query(releaseYear int) ([]Movie, error) { KeyConditionExpression: expr.KeyCondition(), }) for queryPaginator.HasMorePages() { - response, err = queryPaginator.NextPage(context.TODO()) + response, err = queryPaginator.NextPage(ctx) if err != nil { log.Printf("Couldn't query for movies released in %v. Here's why: %v\n", releaseYear, err) break @@ -288,7 +288,7 @@ func (basics TableBasics) Query(releaseYear int) ([]Movie, error) { // and projects them to return a reduced set of fields. // The function uses the `expression` package to build the filter and projection // expressions. -func (basics TableBasics) Scan(startYear int, endYear int) ([]Movie, error) { +func (basics TableBasics) Scan(ctx context.Context, startYear int, endYear int) ([]Movie, error) { var movies []Movie var err error var response *dynamodb.ScanOutput @@ -307,7 +307,7 @@ func (basics TableBasics) Scan(startYear int, endYear int) ([]Movie, error) { ProjectionExpression: expr.Projection(), }) for scanPaginator.HasMorePages() { - response, err = scanPaginator.NextPage(context.TODO()) + response, err = scanPaginator.NextPage(ctx) if err != nil { log.Printf("Couldn't scan for movies released between %v and %v. Here's why: %v\n", startYear, endYear, err) @@ -332,8 +332,8 @@ func (basics TableBasics) Scan(startYear int, endYear int) ([]Movie, error) { // snippet-start:[gov2.dynamodb.DeleteItem] // DeleteMovie removes a movie from the DynamoDB table. -func (basics TableBasics) DeleteMovie(movie Movie) error { - _, err := basics.DynamoDbClient.DeleteItem(context.TODO(), &dynamodb.DeleteItemInput{ +func (basics TableBasics) DeleteMovie(ctx context.Context, movie Movie) error { + _, err := basics.DynamoDbClient.DeleteItem(ctx, &dynamodb.DeleteItemInput{ TableName: aws.String(basics.TableName), Key: movie.GetKey(), }) if err != nil { @@ -347,8 +347,8 @@ func (basics TableBasics) DeleteMovie(movie Movie) error { // snippet-start:[gov2.dynamodb.DeleteTable] // DeleteTable deletes the DynamoDB table and all of its data. -func (basics TableBasics) DeleteTable() error { - _, err := basics.DynamoDbClient.DeleteTable(context.TODO(), &dynamodb.DeleteTableInput{ +func (basics TableBasics) DeleteTable(ctx context.Context) error { + _, err := basics.DynamoDbClient.DeleteTable(ctx, &dynamodb.DeleteTableInput{ TableName: aws.String(basics.TableName)}) if err != nil { log.Printf("Couldn't delete table %v. Here's why: %v\n", basics.TableName, err) diff --git a/gov2/dynamodb/actions/table_basics_test.go b/gov2/dynamodb/actions/table_basics_test.go index cfccee17b41..076536d7b37 100644 --- a/gov2/dynamodb/actions/table_basics_test.go +++ b/gov2/dynamodb/actions/table_basics_test.go @@ -6,6 +6,7 @@ package actions import ( + "context" "errors" "fmt" "reflect" @@ -20,10 +21,11 @@ import ( "github.com/awsdocs/aws-doc-sdk-examples/gov2/testtools" ) -func enterTest() (*testtools.AwsmStubber, *TableBasics) { +func enterTest() (context.Context, *testtools.AwsmStubber, *TableBasics) { + ctx := context.Background() stubber := testtools.NewStubber() basics := &TableBasics{TableName: "test-table", DynamoDbClient: dynamodb.NewFromConfig(*stubber.SdkConfig)} - return stubber, basics + return ctx, stubber, basics } func TestTableBasics_TableExists(t *testing.T) { @@ -37,10 +39,10 @@ func TestTableBasics_TableExists(t *testing.T) { } func TableExists(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() stubber.Add(stubs.StubDescribeTable(basics.TableName, raiseErr)) - exists, err := basics.TableExists() + exists, err := basics.TableExists(ctx) testtools.VerifyError(err, raiseErr, t, &types.ResourceNotFoundException{}) var nfEx *types.ResourceNotFoundException @@ -59,11 +61,11 @@ func TestTableBasics_CreateMovieTable(t *testing.T) { } func CreateMovieTable(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() stubber.Add(stubs.StubCreateTable(basics.TableName, raiseErr)) stubber.Add(stubs.StubDescribeTable(basics.TableName, raiseErr)) - tableDesc, err := basics.CreateMovieTable() + tableDesc, err := basics.CreateMovieTable(ctx) testtools.VerifyError(err, raiseErr, t) if raiseErr == nil { @@ -81,7 +83,7 @@ func TestTableBasics_AddMovie(t *testing.T) { } func AddMovie(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() movie := Movie{Title: "Test movie", Year: 2001} item, marshErr := attributevalue.MarshalMap(movie) @@ -91,7 +93,7 @@ func AddMovie(raiseErr *testtools.StubError, t *testing.T) { stubber.Add(stubs.StubAddMovie(basics.TableName, item, raiseErr)) - err := basics.AddMovie(movie) + err := basics.AddMovie(ctx, movie) testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) @@ -103,7 +105,7 @@ func TestTableBasics_UpdateMovie(t *testing.T) { } func UpdateMovie(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() ratingS := "3.5" plot := "Test plot." @@ -115,7 +117,7 @@ func UpdateMovie(raiseErr *testtools.StubError, t *testing.T) { stubber.Add(stubs.StubUpdateMovie(basics.TableName, movie.GetKey(), ratingS, plot, raiseErr)) - attribs, err := basics.UpdateMovie(movie) + attribs, err := basics.UpdateMovie(ctx, movie) testtools.VerifyError(err, raiseErr, t) if raiseErr == nil { @@ -134,7 +136,7 @@ func TestTableBasics_AddMovieBatch(t *testing.T) { } func AddMovieBatch(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() var testData []Movie var inputRequests []types.WriteRequest @@ -158,7 +160,7 @@ func AddMovieBatch(raiseErr *testtools.StubError, t *testing.T) { stubber.Add(stubs.StubAddMovieBatch(basics.TableName, inputRequests[0:25], raiseErr)) stubber.Add(stubs.StubAddMovieBatch(basics.TableName, inputRequests[25:30], raiseErr)) - count, err := basics.AddMovieBatch(testData, 200) + count, err := basics.AddMovieBatch(ctx, testData, 200) testtools.VerifyError(err, raiseErr, t) if raiseErr == nil { @@ -175,7 +177,7 @@ func TestTableBasics_GetMovie(t *testing.T) { } func GetMovie(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() rating := 3.5 ratingS := "3.5" @@ -185,7 +187,7 @@ func GetMovie(raiseErr *testtools.StubError, t *testing.T) { stubber.Add(stubs.StubGetMovie(basics.TableName, movie.GetKey(), movie.Title, strconv.Itoa(movie.Year), ratingS, plot, raiseErr)) - gotMovie, err := basics.GetMovie(movie.Title, movie.Year) + gotMovie, err := basics.GetMovie(ctx, movie.Title, movie.Year) testtools.VerifyError(err, raiseErr, t) if err == nil { @@ -203,7 +205,7 @@ func TestTableBasics_Query(t *testing.T) { } func Query(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() title := "Test movie" year := 2001 @@ -211,7 +213,7 @@ func Query(raiseErr *testtools.StubError, t *testing.T) { stubber.Add(stubs.StubQuery(basics.TableName, title, yearS, raiseErr)) - movies, err := basics.Query(year) + movies, err := basics.Query(ctx, year) testtools.VerifyError(err, raiseErr, t) if err == nil { @@ -231,7 +233,7 @@ func TestTableBasics_Scan(t *testing.T) { } func Scan(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() title := "Test movie" startYear := 2001 @@ -241,7 +243,7 @@ func Scan(raiseErr *testtools.StubError, t *testing.T) { stubber.Add(stubs.StubScan(basics.TableName, title, startYearS, endYearS, raiseErr)) - movies, err := basics.Scan(startYear, endYear) + movies, err := basics.Scan(ctx, startYear, endYear) testtools.VerifyError(err, raiseErr, t) if err == nil { @@ -261,13 +263,13 @@ func TestTableBasics_DeleteMovie(t *testing.T) { } func DeleteMovie(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() movie := Movie{Title: "Test title", Year: 2001} stubber.Add(stubs.StubDeleteItem(basics.TableName, movie.GetKey(), raiseErr)) - err := basics.DeleteMovie(movie) + err := basics.DeleteMovie(ctx, movie) testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) @@ -279,11 +281,11 @@ func TestTableBasics_DeleteTable(t *testing.T) { } func DeleteTable(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() stubber.Add(stubs.StubDeleteTable(basics.TableName, raiseErr)) - err := basics.DeleteTable() + err := basics.DeleteTable(ctx) testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) @@ -295,13 +297,13 @@ func TestTableBasics_ListTables(t *testing.T) { } func ListTables(raiseErr *testtools.StubError, t *testing.T) { - stubber, basics := enterTest() + ctx, stubber, basics := enterTest() tableNames := []string{"Table 1", "Table 2", "Table 3"} stubber.Add(stubs.StubListTables(tableNames, raiseErr)) - tables, err := basics.ListTables() + tables, err := basics.ListTables(ctx) testtools.VerifyError(err, raiseErr, t) if err == nil { diff --git a/gov2/dynamodb/cmd/main.go b/gov2/dynamodb/cmd/main.go index c7620ed6147..ba58743fc38 100644 --- a/gov2/dynamodb/cmd/main.go +++ b/gov2/dynamodb/cmd/main.go @@ -21,14 +21,14 @@ import ( // // `-scenario` can be one of the following: // -// * `movieTable` - Runs the interactive movie table scenario that shows you how to use -// Amazon DynamoDB API commands to work with DynamoDB tables and items. -// * `partiQLSingle` - Runs a scenario that shows you how to use PartiQL statements -// to work with DynamoDB tables and items. -// * `partiQLBatch` - Runs a scenario that shows you how to use batches of PartiQL -// statements to work with DynamoDB tables and items. +// - `movieTable` - Runs the interactive movie table scenario that shows you how to use +// Amazon DynamoDB API commands to work with DynamoDB tables and items. +// - `partiQLSingle` - Runs a scenario that shows you how to use PartiQL statements +// to work with DynamoDB tables and items. +// - `partiQLBatch` - Runs a scenario that shows you how to use batches of PartiQL +// statements to work with DynamoDB tables and items. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config){ + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config){ "movieTable": runMovieScenario, "partiQLSingle": runPartiQLSingleScenario, "partiQLBatch": runPartiQLBatchScenario, @@ -48,18 +48,20 @@ func main() { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } log.SetFlags(0) - runScenario(sdkConfig) + runScenario(ctx, sdkConfig) } } -func runMovieScenario(sdkConfig aws.Config) { +func runMovieScenario(ctx context.Context, sdkConfig aws.Config) { scenarios.RunMovieScenario( + ctx, sdkConfig, demotools.NewQuestioner(), "doc-example-movie-table", @@ -67,10 +69,10 @@ func runMovieScenario(sdkConfig aws.Config) { ) } -func runPartiQLSingleScenario(sdkConfig aws.Config) { - scenarios.RunPartiQLSingleScenario(sdkConfig, "doc-example-partiql-single-table") +func runPartiQLSingleScenario(ctx context.Context, sdkConfig aws.Config) { + scenarios.RunPartiQLSingleScenario(ctx, sdkConfig, "doc-example-partiql-single-table") } -func runPartiQLBatchScenario(sdkConfig aws.Config) { - scenarios.RunPartiQLBatchScenario(sdkConfig, "doc-example-partiql-batch-table") +func runPartiQLBatchScenario(ctx context.Context, sdkConfig aws.Config) { + scenarios.RunPartiQLBatchScenario(ctx, sdkConfig, "doc-example-partiql-batch-table") } diff --git a/gov2/dynamodb/scenarios/scenario_movie_table.go b/gov2/dynamodb/scenarios/scenario_movie_table.go index ccb7085ee2a..7a96ad2f052 100644 --- a/gov2/dynamodb/scenarios/scenario_movie_table.go +++ b/gov2/dynamodb/scenarios/scenario_movie_table.go @@ -4,6 +4,7 @@ package scenarios import ( + "context" "fmt" "log" "strings" @@ -36,7 +37,7 @@ import ( // The specified movie sampler is used to get sample data from a URL that is loaded // into the named table. func RunMovieScenario( - sdkConfig aws.Config, questioner demotools.IQuestioner, tableName string, + ctx context.Context, sdkConfig aws.Config, questioner demotools.IQuestioner, tableName string, movieSampler actions.IMovieSampler) { defer func() { if r := recover(); r != nil { @@ -51,13 +52,13 @@ func RunMovieScenario( tableBasics := actions.TableBasics{TableName: tableName, DynamoDbClient: dynamodb.NewFromConfig(sdkConfig)} - exists, err := tableBasics.TableExists() + exists, err := tableBasics.TableExists(ctx) if err != nil { panic(err) } if !exists { log.Printf("Creating table %v...\n", tableName) - _, err = tableBasics.CreateMovieTable() + _, err = tableBasics.CreateMovieTable(ctx) if err != nil { panic(err) } else { @@ -78,7 +79,7 @@ func RunMovieScenario( demotools.NotEmpty{}, demotools.InFloatRange{Lower: 1, Upper: 10}) customMovie.Info["plot"] = questioner.Ask("What's the plot? ", demotools.NotEmpty{}) - err = tableBasics.AddMovie(customMovie) + err = tableBasics.AddMovie(ctx, customMovie) if err == nil { log.Printf("Added %v to the movie table.\n", customMovie.Title) } @@ -91,7 +92,7 @@ func RunMovieScenario( log.Printf("You summarized the plot as '%v'.\n", customMovie.Info["plot"]) customMovie.Info["plot"] = questioner.Ask("What would you say now?", demotools.NotEmpty{}) - attributes, err := tableBasics.UpdateMovie(customMovie) + attributes, err := tableBasics.UpdateMovie(ctx, customMovie) if err == nil { log.Printf("Updated %v with new values.\n", customMovie.Title) for _, attVal := range attributes { @@ -105,7 +106,7 @@ func RunMovieScenario( log.Printf("Getting movie data from %v and adding 250 movies to the table...\n", movieSampler.GetURL()) movies := movieSampler.GetSampleMovies() - written, err := tableBasics.AddMovieBatch(movies, 250) + written, err := tableBasics.AddMovieBatch(ctx, movies, 250) if err != nil { panic(err) } else { @@ -124,7 +125,7 @@ func RunMovieScenario( "Enter the number of a movie to get info about it: ", demotools.InIntRange{Lower: 1, Upper: show}, ) - movie, err := tableBasics.GetMovie(movies[movieIndex-1].Title, movies[movieIndex-1].Year) + movie, err := tableBasics.GetMovie(ctx, movies[movieIndex-1].Title, movies[movieIndex-1].Year) if err == nil { log.Println(movie) } @@ -134,7 +135,7 @@ func RunMovieScenario( releaseYear := questioner.AskInt("Enter a year between 1972 and 2018: ", demotools.InIntRange{Lower: 1972, Upper: 2018}, ) - releases, err := tableBasics.Query(releaseYear) + releases, err := tableBasics.Query(ctx, releaseYear) if err == nil { if len(releases) == 0 { log.Printf("I couldn't find any movies released in %v!\n", releaseYear) @@ -151,7 +152,7 @@ func RunMovieScenario( demotools.InIntRange{Lower: 1972, Upper: 2018}) endYear := questioner.AskInt("Enter another year: ", demotools.InIntRange{Lower: 1972, Upper: 2018}) - releases, err = tableBasics.Scan(startYear, endYear) + releases, err = tableBasics.Scan(ctx, startYear, endYear) if err == nil { if len(releases) == 0 { log.Printf("I couldn't find any movies released between %v and %v!\n", startYear, endYear) @@ -168,7 +169,7 @@ func RunMovieScenario( var tables []string if questioner.AskBool("Do you want to list all of your tables? (y/n) ", "y") { - tables, err = tableBasics.ListTables() + tables, err = tableBasics.ListTables(ctx) if err == nil { log.Printf("Found %v tables:", len(tables)) for _, table := range tables { @@ -180,14 +181,14 @@ func RunMovieScenario( log.Printf("Let's remove your movie '%v'.\n", customMovie.Title) if questioner.AskBool("Do you want to delete it from the table? (y/n) ", "y") { - err = tableBasics.DeleteMovie(customMovie) + err = tableBasics.DeleteMovie(ctx, customMovie) } if err == nil { log.Printf("Deleted %v.\n", customMovie.Title) } if questioner.AskBool("Delete the table, too? (y/n)", "y") { - err = tableBasics.DeleteTable() + err = tableBasics.DeleteTable(ctx) } else { log.Println("Don't forget to delete the table when you're done or you might " + "incur charges on your account.") diff --git a/gov2/dynamodb/scenarios/scenario_movie_table_integ_test.go b/gov2/dynamodb/scenarios/scenario_movie_table_integ_test.go index eae28cff91f..c364dcbebf4 100644 --- a/gov2/dynamodb/scenarios/scenario_movie_table_integ_test.go +++ b/gov2/dynamodb/scenarios/scenario_movie_table_integ_test.go @@ -41,7 +41,8 @@ func TestRunMovieScenario_Integration(t *testing.T) { }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } diff --git a/gov2/dynamodb/scenarios/scenario_movie_table_test.go b/gov2/dynamodb/scenarios/scenario_movie_table_test.go index ba5dd54fb49..1340de2cd04 100644 --- a/gov2/dynamodb/scenarios/scenario_movie_table_test.go +++ b/gov2/dynamodb/scenarios/scenario_movie_table_test.go @@ -6,6 +6,7 @@ package scenarios import ( + "context" "fmt" "strconv" "testing" @@ -138,7 +139,7 @@ func (scenTest *MovieScenarioTest) SetupDataAndStubs() []testtools.Stub { // or without errors. func (scenTest *MovieScenarioTest) RunSubTest(stubber *testtools.AwsmStubber) { mockQuestioner := demotools.MockQuestioner{Answers: scenTest.Answers} - RunMovieScenario(*stubber.SdkConfig, &mockQuestioner, scenTest.TableName, scenTest.Sampler) + RunMovieScenario(context.Background(), *stubber.SdkConfig, &mockQuestioner, scenTest.TableName, scenTest.Sampler) } func (scenTest *MovieScenarioTest) Cleanup() {} diff --git a/gov2/dynamodb/scenarios/scenario_partiql_batch.go b/gov2/dynamodb/scenarios/scenario_partiql_batch.go index 5019f79c609..ea564aac384 100644 --- a/gov2/dynamodb/scenarios/scenario_partiql_batch.go +++ b/gov2/dynamodb/scenarios/scenario_partiql_batch.go @@ -4,6 +4,7 @@ package scenarios import ( + "context" "fmt" "log" "strings" @@ -26,7 +27,7 @@ import ( // you can replace it with a mocked or stubbed config for unit testing. // // This example creates and deletes a DynamoDB table to use during the scenario. -func RunPartiQLBatchScenario(sdkConfig aws.Config, tableName string) { +func RunPartiQLBatchScenario(ctx context.Context, sdkConfig aws.Config, tableName string) { defer func() { if r := recover(); r != nil { fmt.Printf("Something went wrong with the demo.") @@ -46,13 +47,13 @@ func RunPartiQLBatchScenario(sdkConfig aws.Config, tableName string) { TableName: tableName, } - exists, err := tableBasics.TableExists() + exists, err := tableBasics.TableExists(ctx) if err != nil { panic(err) } if !exists { log.Printf("Creating table %v...\n", tableName) - _, err = tableBasics.CreateMovieTable() + _, err = tableBasics.CreateMovieTable(ctx) if err != nil { panic(err) } else { @@ -84,14 +85,14 @@ func RunPartiQLBatchScenario(sdkConfig aws.Config, tableName string) { } log.Printf("Inserting a batch of movies into table '%v'.\n", tableName) - err = runner.AddMovieBatch(customMovies) + err = runner.AddMovieBatch(ctx, customMovies) if err == nil { log.Printf("Added %v movies to the table.\n", len(customMovies)) } log.Println(strings.Repeat("-", 88)) log.Println("Getting data for a batch of movies.") - movies, err := runner.GetMovieBatch(customMovies) + movies, err := runner.GetMovieBatch(ctx, customMovies) if err == nil { for _, movie := range movies { log.Println(movie) @@ -101,7 +102,7 @@ func RunPartiQLBatchScenario(sdkConfig aws.Config, tableName string) { newRatings := []float64{7.7, 4.4, 1.1} log.Println("Updating a batch of movies with new ratings.") - err = runner.UpdateMovieBatch(customMovies, newRatings) + err = runner.UpdateMovieBatch(ctx, customMovies, newRatings) if err == nil { log.Printf("Updated %v movies with new ratings.\n", len(customMovies)) } @@ -109,7 +110,7 @@ func RunPartiQLBatchScenario(sdkConfig aws.Config, tableName string) { log.Println("Getting projected data from the table to verify our update.") log.Println("Using a page size of 2 to demonstrate paging.") - projections, err := runner.GetAllMovies(2) + projections, err := runner.GetAllMovies(ctx, 2) if err == nil { log.Println("All movies:") for _, projection := range projections { @@ -119,12 +120,12 @@ func RunPartiQLBatchScenario(sdkConfig aws.Config, tableName string) { log.Println(strings.Repeat("-", 88)) log.Println("Deleting a batch of movies.") - err = runner.DeleteMovieBatch(customMovies) + err = runner.DeleteMovieBatch(ctx, customMovies) if err == nil { log.Printf("Deleted %v movies.\n", len(customMovies)) } - err = tableBasics.DeleteTable() + err = tableBasics.DeleteTable(ctx) if err == nil { log.Printf("Deleted table %v.\n", tableBasics.TableName) } diff --git a/gov2/dynamodb/scenarios/scenario_partiql_batch_integ_test.go b/gov2/dynamodb/scenarios/scenario_partiql_batch_integ_test.go index 0cda5d5f22e..a25681a9eb5 100644 --- a/gov2/dynamodb/scenarios/scenario_partiql_batch_integ_test.go +++ b/gov2/dynamodb/scenarios/scenario_partiql_batch_integ_test.go @@ -21,7 +21,8 @@ import ( ) func TestRunPartiQLBatchScenario_Integration(t *testing.T) { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } diff --git a/gov2/dynamodb/scenarios/scenario_partiql_batch_test.go b/gov2/dynamodb/scenarios/scenario_partiql_batch_test.go index ed1b63d2839..1a52b387ddf 100644 --- a/gov2/dynamodb/scenarios/scenario_partiql_batch_test.go +++ b/gov2/dynamodb/scenarios/scenario_partiql_batch_test.go @@ -6,6 +6,7 @@ package scenarios import ( + "context" "fmt" "testing" "time" @@ -105,7 +106,7 @@ func (scenTest *PartiQLBatchScenarioTest) SetupDataAndStubs() []testtools.Stub { // RunSubTest performs a batch test run with a set of stubs that are set up to run with // or without errors. func (scenTest *PartiQLBatchScenarioTest) RunSubTest(stubber *testtools.AwsmStubber) { - RunPartiQLBatchScenario(*stubber.SdkConfig, scenTest.TableName) + RunPartiQLBatchScenario(context.Background(), *stubber.SdkConfig, scenTest.TableName) } func (scenTest *PartiQLBatchScenarioTest) Cleanup() {} diff --git a/gov2/dynamodb/scenarios/scenario_partiql_single.go b/gov2/dynamodb/scenarios/scenario_partiql_single.go index de35cc4750c..f180be84587 100644 --- a/gov2/dynamodb/scenarios/scenario_partiql_single.go +++ b/gov2/dynamodb/scenarios/scenario_partiql_single.go @@ -4,6 +4,7 @@ package scenarios import ( + "context" "fmt" "log" "strings" @@ -25,7 +26,7 @@ import ( // you can replace it with a mocked or stubbed config for unit testing. // // This example creates and deletes a DynamoDB table to use during the scenario. -func RunPartiQLSingleScenario(sdkConfig aws.Config, tableName string) { +func RunPartiQLSingleScenario(ctx context.Context, sdkConfig aws.Config, tableName string) { defer func() { if r := recover(); r != nil { fmt.Printf("Something went wrong with the demo.") @@ -45,13 +46,13 @@ func RunPartiQLSingleScenario(sdkConfig aws.Config, tableName string) { TableName: tableName, } - exists, err := tableBasics.TableExists() + exists, err := tableBasics.TableExists(ctx) if err != nil { panic(err) } if !exists { log.Printf("Creating table %v...\n", tableName) - _, err = tableBasics.CreateMovieTable() + _, err = tableBasics.CreateMovieTable(ctx) if err != nil { panic(err) } else { @@ -73,14 +74,14 @@ func RunPartiQLSingleScenario(sdkConfig aws.Config, tableName string) { } log.Printf("Inserting movie '%v' released in %v.", customMovie.Title, customMovie.Year) - err = runner.AddMovie(customMovie) + err = runner.AddMovie(ctx, customMovie) if err == nil { log.Printf("Added %v to the movie table.\n", customMovie.Title) } log.Println(strings.Repeat("-", 88)) log.Printf("Getting data for movie '%v' released in %v.", customMovie.Title, customMovie.Year) - movie, err := runner.GetMovie(customMovie.Title, customMovie.Year) + movie, err := runner.GetMovie(ctx, customMovie.Title, customMovie.Year) if err == nil { log.Println(movie) } @@ -88,26 +89,26 @@ func RunPartiQLSingleScenario(sdkConfig aws.Config, tableName string) { newRating := 6.6 log.Printf("Updating movie '%v' with a rating of %v.", customMovie.Title, newRating) - err = runner.UpdateMovie(customMovie, newRating) + err = runner.UpdateMovie(ctx, customMovie, newRating) if err == nil { log.Printf("Updated %v with a new rating.\n", customMovie.Title) } log.Println(strings.Repeat("-", 88)) log.Printf("Getting data again to verify the update.") - movie, err = runner.GetMovie(customMovie.Title, customMovie.Year) + movie, err = runner.GetMovie(ctx, customMovie.Title, customMovie.Year) if err == nil { log.Println(movie) } log.Println(strings.Repeat("-", 88)) log.Printf("Deleting movie '%v'.\n", customMovie.Title) - err = runner.DeleteMovie(customMovie) + err = runner.DeleteMovie(ctx, customMovie) if err == nil { log.Printf("Deleted %v.\n", customMovie.Title) } - err = tableBasics.DeleteTable() + err = tableBasics.DeleteTable(ctx) if err == nil { log.Printf("Deleted table %v.\n", tableBasics.TableName) } diff --git a/gov2/dynamodb/scenarios/scenario_partiql_single_integ_test.go b/gov2/dynamodb/scenarios/scenario_partiql_single_integ_test.go index 16accb880f4..00bd0dbbe00 100644 --- a/gov2/dynamodb/scenarios/scenario_partiql_single_integ_test.go +++ b/gov2/dynamodb/scenarios/scenario_partiql_single_integ_test.go @@ -21,7 +21,8 @@ import ( ) func TestRunPartiQLSingleScenario_Integration(t *testing.T) { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } diff --git a/gov2/dynamodb/scenarios/scenario_partiql_single_test.go b/gov2/dynamodb/scenarios/scenario_partiql_single_test.go index cd3f20cac74..6f133ca771a 100644 --- a/gov2/dynamodb/scenarios/scenario_partiql_single_test.go +++ b/gov2/dynamodb/scenarios/scenario_partiql_single_test.go @@ -6,6 +6,7 @@ package scenarios import ( + "context" "fmt" "testing" "time" @@ -72,7 +73,7 @@ func (scenTest *PartiQLSingleScenarioTest) SetupDataAndStubs() []testtools.Stub // RunSubTest performs a single test run with a set of stubs that are set up to run with // or without errors. func (scenTest *PartiQLSingleScenarioTest) RunSubTest(stubber *testtools.AwsmStubber) { - RunPartiQLSingleScenario(*stubber.SdkConfig, scenTest.TableName) + RunPartiQLSingleScenario(context.Background(), *stubber.SdkConfig, scenTest.TableName) } func (scenTest *PartiQLSingleScenarioTest) Cleanup() {} diff --git a/gov2/iam/README.md b/gov2/iam/README.md index 3b8f22ab39b..7b2fb7ec6bf 100644 --- a/gov2/iam/README.md +++ b/gov2/iam/README.md @@ -45,7 +45,7 @@ Code examples that show you how to perform the essential operations within a ser Code excerpts that show you how to call individual service functions. -- [AttachRolePolicy](actions/roles.go#L132) +- [AttachRolePolicy](actions/roles.go#L133) - [CreateAccessKey](actions/users.go#L175) - [CreatePolicy](actions/policies.go#L65) - [CreateRole](actions/roles.go#L46) @@ -53,20 +53,20 @@ Code excerpts that show you how to call individual service functions. - [CreateUser](actions/users.go#L74) - [DeleteAccessKey](actions/users.go#L193) - [DeletePolicy](actions/policies.go#L118) -- [DeleteRole](actions/roles.go#L200) -- [DeleteServiceLinkedRole](actions/roles.go#L117) +- [DeleteRole](actions/roles.go#L201) +- [DeleteServiceLinkedRole](actions/roles.go#L118) - [DeleteUser](actions/users.go#L160) - [DeleteUserPolicy](actions/users.go#L144) -- [DetachRolePolicy](actions/roles.go#L166) +- [DetachRolePolicy](actions/roles.go#L167) - [GetAccountPasswordPolicy](actions/account.go#L26) - [GetPolicy](actions/policies.go#L100) - [GetRole](actions/roles.go#L81) - [GetUser](actions/users.go#L47) - [ListAccessKeys](actions/users.go#L209) -- [ListAttachedRolePolicies](actions/roles.go#L148) +- [ListAttachedRolePolicies](actions/roles.go#L149) - [ListGroups](actions/groups.go#L27) - [ListPolicies](actions/policies.go#L47) -- [ListRolePolicies](actions/roles.go#L182) +- [ListRolePolicies](actions/roles.go#L183) - [ListRoles](actions/roles.go#L28) - [ListSAMLProviders](actions/account.go#L44) - [ListUserPolicies](actions/users.go#L126) diff --git a/gov2/iam/actions/account.go b/gov2/iam/actions/account.go index e92a0c8babb..c5c2047b891 100644 --- a/gov2/iam/actions/account.go +++ b/gov2/iam/actions/account.go @@ -27,9 +27,9 @@ type AccountWrapper struct { // GetAccountPasswordPolicy gets the account password policy for the current account. // If no policy has been set, a NoSuchEntityException is error is returned. -func (wrapper AccountWrapper) GetAccountPasswordPolicy() (*types.PasswordPolicy, error) { +func (wrapper AccountWrapper) GetAccountPasswordPolicy(ctx context.Context) (*types.PasswordPolicy, error) { var pwPolicy *types.PasswordPolicy - result, err := wrapper.IamClient.GetAccountPasswordPolicy(context.TODO(), + result, err := wrapper.IamClient.GetAccountPasswordPolicy(ctx, &iam.GetAccountPasswordPolicyInput{}) if err != nil { log.Printf("Couldn't get account password policy. Here's why: %v\n", err) @@ -44,9 +44,9 @@ func (wrapper AccountWrapper) GetAccountPasswordPolicy() (*types.PasswordPolicy, // snippet-start:[gov2.iam.ListSAMLProviders] // ListSAMLProviders gets the SAML providers for the account. -func (wrapper AccountWrapper) ListSAMLProviders() ([]types.SAMLProviderListEntry, error) { +func (wrapper AccountWrapper) ListSAMLProviders(ctx context.Context) ([]types.SAMLProviderListEntry, error) { var providers []types.SAMLProviderListEntry - result, err := wrapper.IamClient.ListSAMLProviders(context.TODO(), &iam.ListSAMLProvidersInput{}) + result, err := wrapper.IamClient.ListSAMLProviders(ctx, &iam.ListSAMLProvidersInput{}) if err != nil { log.Printf("Couldn't list SAML providers. Here's why: %v\n", err) } else { diff --git a/gov2/iam/actions/groups.go b/gov2/iam/actions/groups.go index 64997902e6f..aee6780e3d9 100644 --- a/gov2/iam/actions/groups.go +++ b/gov2/iam/actions/groups.go @@ -27,9 +27,9 @@ type GroupWrapper struct { // snippet-start:[gov2.iam.ListGroups] // ListGroups lists up to maxGroups number of groups. -func (wrapper GroupWrapper) ListGroups(maxGroups int32) ([]types.Group, error) { +func (wrapper GroupWrapper) ListGroups(ctx context.Context, maxGroups int32) ([]types.Group, error) { var groups []types.Group - result, err := wrapper.IamClient.ListGroups(context.TODO(), &iam.ListGroupsInput{ + result, err := wrapper.IamClient.ListGroups(ctx, &iam.ListGroupsInput{ MaxItems: aws.Int32(maxGroups), }) if err != nil { diff --git a/gov2/iam/actions/non_scenario_action_integ_test.go b/gov2/iam/actions/non_scenario_action_integ_test.go index 25ff3e05e49..a3e794ae67a 100644 --- a/gov2/iam/actions/non_scenario_action_integ_test.go +++ b/gov2/iam/actions/non_scenario_action_integ_test.go @@ -22,7 +22,8 @@ import ( // live AWS services. This test is used to verify that the actions not used in scenarios // run successfully when making calls to AWS. func TestCallNonScenarioActions_Integration(t *testing.T) { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -31,7 +32,7 @@ func TestCallNonScenarioActions_Integration(t *testing.T) { var buf bytes.Buffer log.SetOutput(&buf) - CallNonScenarioActions(sdkConfig) + CallNonScenarioActions(ctx, sdkConfig) log.SetOutput(os.Stderr) if !strings.Contains(buf.String(), "Thanks for watching") { diff --git a/gov2/iam/actions/non_scenario_action_test.go b/gov2/iam/actions/non_scenario_action_test.go index c9aece12ad2..513d6012d53 100644 --- a/gov2/iam/actions/non_scenario_action_test.go +++ b/gov2/iam/actions/non_scenario_action_test.go @@ -6,6 +6,7 @@ package actions import ( + "context" "errors" "log" "testing" @@ -26,7 +27,7 @@ const linkService = "batch.amazonaws.com" // CallNonScenarioActions calls the actions not used in scenarios to verify that they // run as expected. This script can be run as a unit test using stubs so that AWS // is not called, or as an integration test to verify it works when calling live AWS services. -func CallNonScenarioActions(sdkConfig aws.Config, ) { +func CallNonScenarioActions(ctx context.Context, sdkConfig aws.Config) { defer func() { if r := recover(); r != nil { log.Println(r) @@ -37,10 +38,10 @@ func CallNonScenarioActions(sdkConfig aws.Config, ) { accountWrapper := AccountWrapper{IamClient: iamClient} groupWrapper := GroupWrapper{IamClient: iamClient} policyWrapper := PolicyWrapper{IamClient: iamClient} - roleWrapper:= RoleWrapper{IamClient: iamClient} - userWrapper:= UserWrapper{IamClient: iamClient} + roleWrapper := RoleWrapper{IamClient: iamClient} + userWrapper := UserWrapper{IamClient: iamClient} - pwPolicy, err := accountWrapper.GetAccountPasswordPolicy() + pwPolicy, err := accountWrapper.GetAccountPasswordPolicy(ctx) if err != nil { var apiError smithy.APIError if errors.As(err, &apiError) { @@ -55,53 +56,73 @@ func CallNonScenarioActions(sdkConfig aws.Config, ) { log.Printf("Policy min length: %v\n", pwPolicy.MinimumPasswordLength) } - providers, err := accountWrapper.ListSAMLProviders() - if err != nil {panic(err)} + providers, err := accountWrapper.ListSAMLProviders(ctx) + if err != nil { + panic(err) + } for _, prov := range providers { log.Println(*prov.Arn) } - groups, err := groupWrapper.ListGroups(maxThings) - if err != nil {panic(err)} + groups, err := groupWrapper.ListGroups(ctx, maxThings) + if err != nil { + panic(err) + } for _, group := range groups { log.Println(*group.GroupName) } - policies, err := policyWrapper.ListPolicies(maxThings) - if err != nil {panic(err)} + policies, err := policyWrapper.ListPolicies(ctx, maxThings) + if err != nil { + panic(err) + } for _, pol := range policies { log.Println(*pol.PolicyName) } - policy, err := policyWrapper.GetPolicy(policyArn) - if err != nil {panic(err)} + policy, err := policyWrapper.GetPolicy(ctx, policyArn) + if err != nil { + panic(err) + } log.Println(*policy.Arn) - roles, err := roleWrapper.ListRoles(maxThings) - if err != nil {panic(err)} + roles, err := roleWrapper.ListRoles(ctx, maxThings) + if err != nil { + panic(err) + } for _, r := range roles { log.Println(*r.RoleName) } - role, err := roleWrapper.GetRole(roleName) - if err != nil {panic(err)} + role, err := roleWrapper.GetRole(ctx, roleName) + if err != nil { + panic(err) + } log.Println(*role.RoleName) - svcRole, err := roleWrapper.CreateServiceLinkedRole("batch.amazonaws.com", "test") - if err != nil {panic(err)} + svcRole, err := roleWrapper.CreateServiceLinkedRole(ctx, "batch.amazonaws.com", "test") + if err != nil { + panic(err) + } log.Println(*svcRole.RoleName) - err = roleWrapper.DeleteServiceLinkedRole(*svcRole.RoleName) - if err != nil {panic(err)} + err = roleWrapper.DeleteServiceLinkedRole(ctx, *svcRole.RoleName) + if err != nil { + panic(err) + } - rPols, err := roleWrapper.ListRolePolicies(roleName) - if err != nil {panic(err)} + rPols, err := roleWrapper.ListRolePolicies(ctx, roleName) + if err != nil { + panic(err) + } for _, rPol := range rPols { log.Println(rPol) } - users, err := userWrapper.ListUsers(maxThings) - if err != nil {panic(err)} + users, err := userWrapper.ListUsers(ctx, maxThings) + if err != nil { + panic(err) + } for _, user := range users { log.Println(user.UserName) } @@ -118,7 +139,7 @@ func TestCallNonScenarioActions(t *testing.T) { } // NonScenarioActionsTest encapsulates data for a scenario test. -type NonScenarioActionsTest struct {} +type NonScenarioActionsTest struct{} // SetupDataAndStubs sets up test data and builds the stubs that are used to return // mocked data. @@ -142,7 +163,7 @@ func (scenTest *NonScenarioActionsTest) SetupDataAndStubs() []testtools.Stub { // RunSubTest performs a single test run with a set of stubs set up to run with // or without errors. func (scenTest *NonScenarioActionsTest) RunSubTest(stubber *testtools.AwsmStubber) { - CallNonScenarioActions(*stubber.SdkConfig) + CallNonScenarioActions(context.Background(), *stubber.SdkConfig) } func (scenTest *NonScenarioActionsTest) Cleanup() {} diff --git a/gov2/iam/actions/policies.go b/gov2/iam/actions/policies.go index 36383721c8e..0c945973b4a 100644 --- a/gov2/iam/actions/policies.go +++ b/gov2/iam/actions/policies.go @@ -19,16 +19,16 @@ import ( // PolicyDocument defines a policy document as a Go struct that can be serialized // to JSON. type PolicyDocument struct { - Version string + Version string Statement []PolicyStatement } // PolicyStatement defines a statement in a policy document. type PolicyStatement struct { - Effect string - Action []string + Effect string + Action []string Principal map[string]string `json:",omitempty"` - Resource *string `json:",omitempty"` + Resource *string `json:",omitempty"` } // snippet-end:[gov2.iam.PolicyDocument.struct] @@ -47,9 +47,9 @@ type PolicyWrapper struct { // snippet-start:[gov2.iam.ListPolicies] // ListPolicies gets up to maxPolicies policies. -func (wrapper PolicyWrapper) ListPolicies(maxPolicies int32) ([]types.Policy, error) { +func (wrapper PolicyWrapper) ListPolicies(ctx context.Context, maxPolicies int32) ([]types.Policy, error) { var policies []types.Policy - result, err := wrapper.IamClient.ListPolicies(context.TODO(), &iam.ListPoliciesInput{ + result, err := wrapper.IamClient.ListPolicies(ctx, &iam.ListPoliciesInput{ MaxItems: aws.Int32(maxPolicies), }) if err != nil { @@ -67,14 +67,14 @@ func (wrapper PolicyWrapper) ListPolicies(maxPolicies int32) ([]types.Policy, er // CreatePolicy creates a policy that grants a list of actions to the specified resource. // PolicyDocument shows how to work with a policy document as a data structure and // serialize it to JSON by using Go's JSON marshaler. -func (wrapper PolicyWrapper) CreatePolicy(policyName string, actions []string, - resourceArn string) (*types.Policy, error) { +func (wrapper PolicyWrapper) CreatePolicy(ctx context.Context, policyName string, actions []string, + resourceArn string) (*types.Policy, error) { var policy *types.Policy policyDoc := PolicyDocument{ - Version: "2012-10-17", + Version: "2012-10-17", Statement: []PolicyStatement{{ - Effect: "Allow", - Action: actions, + Effect: "Allow", + Action: actions, Resource: aws.String(resourceArn), }}, } @@ -83,7 +83,7 @@ func (wrapper PolicyWrapper) CreatePolicy(policyName string, actions []string, log.Printf("Couldn't create policy document for %v. Here's why: %v\n", resourceArn, err) return nil, err } - result, err := wrapper.IamClient.CreatePolicy(context.TODO(), &iam.CreatePolicyInput{ + result, err := wrapper.IamClient.CreatePolicy(ctx, &iam.CreatePolicyInput{ PolicyDocument: aws.String(string(policyBytes)), PolicyName: aws.String(policyName), }) @@ -100,9 +100,9 @@ func (wrapper PolicyWrapper) CreatePolicy(policyName string, actions []string, // snippet-start:[gov2.iam.GetPolicy] // GetPolicy gets data about a policy. -func (wrapper PolicyWrapper) GetPolicy(policyArn string) (*types.Policy, error) { +func (wrapper PolicyWrapper) GetPolicy(ctx context.Context, policyArn string) (*types.Policy, error) { var policy *types.Policy - result, err := wrapper.IamClient.GetPolicy(context.TODO(), &iam.GetPolicyInput{ + result, err := wrapper.IamClient.GetPolicy(ctx, &iam.GetPolicyInput{ PolicyArn: aws.String(policyArn), }) if err != nil { @@ -118,8 +118,8 @@ func (wrapper PolicyWrapper) GetPolicy(policyArn string) (*types.Policy, error) // snippet-start:[gov2.iam.DeletePolicy] // DeletePolicy deletes a policy. -func (wrapper PolicyWrapper) DeletePolicy(policyArn string) error { - _, err := wrapper.IamClient.DeletePolicy(context.TODO(), &iam.DeletePolicyInput{ +func (wrapper PolicyWrapper) DeletePolicy(ctx context.Context, policyArn string) error { + _, err := wrapper.IamClient.DeletePolicy(ctx, &iam.DeletePolicyInput{ PolicyArn: aws.String(policyArn), }) if err != nil { diff --git a/gov2/iam/actions/roles.go b/gov2/iam/actions/roles.go index 48dbfd2a676..596d916a4f7 100644 --- a/gov2/iam/actions/roles.go +++ b/gov2/iam/actions/roles.go @@ -28,9 +28,9 @@ type RoleWrapper struct { // snippet-start:[gov2.iam.ListRoles] // ListRoles gets up to maxRoles roles. -func (wrapper RoleWrapper) ListRoles(maxRoles int32) ([]types.Role, error) { +func (wrapper RoleWrapper) ListRoles(ctx context.Context, maxRoles int32) ([]types.Role, error) { var roles []types.Role - result, err := wrapper.IamClient.ListRoles(context.TODO(), + result, err := wrapper.IamClient.ListRoles(ctx, &iam.ListRolesInput{MaxItems: aws.Int32(maxRoles)}, ) if err != nil { @@ -49,14 +49,14 @@ func (wrapper RoleWrapper) ListRoles(maxRoles int32) ([]types.Role, error) { // the role to acquire its permissions. // PolicyDocument shows how to work with a policy document as a data structure and // serialize it to JSON by using Go's JSON marshaler. -func (wrapper RoleWrapper) CreateRole(roleName string, trustedUserArn string) (*types.Role, error) { +func (wrapper RoleWrapper) CreateRole(ctx context.Context, roleName string, trustedUserArn string) (*types.Role, error) { var role *types.Role trustPolicy := PolicyDocument{ - Version: "2012-10-17", + Version: "2012-10-17", Statement: []PolicyStatement{{ - Effect: "Allow", + Effect: "Allow", Principal: map[string]string{"AWS": trustedUserArn}, - Action: []string{"sts:AssumeRole"}, + Action: []string{"sts:AssumeRole"}, }}, } policyBytes, err := json.Marshal(trustPolicy) @@ -64,7 +64,7 @@ func (wrapper RoleWrapper) CreateRole(roleName string, trustedUserArn string) (* log.Printf("Couldn't create trust policy for %v. Here's why: %v\n", trustedUserArn, err) return nil, err } - result, err := wrapper.IamClient.CreateRole(context.TODO(), &iam.CreateRoleInput{ + result, err := wrapper.IamClient.CreateRole(ctx, &iam.CreateRoleInput{ AssumeRolePolicyDocument: aws.String(string(policyBytes)), RoleName: aws.String(roleName), }) @@ -81,9 +81,9 @@ func (wrapper RoleWrapper) CreateRole(roleName string, trustedUserArn string) (* // snippet-start:[gov2.iam.GetRole] // GetRole gets data about a role. -func (wrapper RoleWrapper) GetRole(roleName string) (*types.Role, error) { +func (wrapper RoleWrapper) GetRole(ctx context.Context, roleName string) (*types.Role, error) { var role *types.Role - result, err := wrapper.IamClient.GetRole(context.TODO(), + result, err := wrapper.IamClient.GetRole(ctx, &iam.GetRoleInput{RoleName: aws.String(roleName)}) if err != nil { log.Printf("Couldn't get role %v. Here's why: %v\n", roleName, err) @@ -98,9 +98,10 @@ func (wrapper RoleWrapper) GetRole(roleName string) (*types.Role, error) { // snippet-start:[gov2.iam.CreateServiceLinkedRole] // CreateServiceLinkedRole creates a service-linked role that is owned by the specified service. -func (wrapper RoleWrapper) CreateServiceLinkedRole(serviceName string, description string) (*types.Role, error) { +func (wrapper RoleWrapper) CreateServiceLinkedRole(ctx context.Context, serviceName string, description string) ( + *types.Role, error) { var role *types.Role - result, err := wrapper.IamClient.CreateServiceLinkedRole(context.TODO(), &iam.CreateServiceLinkedRoleInput{ + result, err := wrapper.IamClient.CreateServiceLinkedRole(ctx, &iam.CreateServiceLinkedRoleInput{ AWSServiceName: aws.String(serviceName), Description: aws.String(description), }) @@ -117,8 +118,8 @@ func (wrapper RoleWrapper) CreateServiceLinkedRole(serviceName string, descripti // snippet-start:[gov2.iam.DeleteServiceLinkedRole] // DeleteServiceLinkedRole deletes a service-linked role. -func (wrapper RoleWrapper) DeleteServiceLinkedRole(roleName string) error { - _, err := wrapper.IamClient.DeleteServiceLinkedRole(context.TODO(), &iam.DeleteServiceLinkedRoleInput{ +func (wrapper RoleWrapper) DeleteServiceLinkedRole(ctx context.Context, roleName string) error { + _, err := wrapper.IamClient.DeleteServiceLinkedRole(ctx, &iam.DeleteServiceLinkedRoleInput{ RoleName: aws.String(roleName)}, ) if err != nil { @@ -132,8 +133,8 @@ func (wrapper RoleWrapper) DeleteServiceLinkedRole(roleName string) error { // snippet-start:[gov2.iam.AttachRolePolicy] // AttachRolePolicy attaches a policy to a role. -func (wrapper RoleWrapper) AttachRolePolicy(policyArn string, roleName string) error { - _, err := wrapper.IamClient.AttachRolePolicy(context.TODO(), &iam.AttachRolePolicyInput{ +func (wrapper RoleWrapper) AttachRolePolicy(ctx context.Context, policyArn string, roleName string) error { + _, err := wrapper.IamClient.AttachRolePolicy(ctx, &iam.AttachRolePolicyInput{ PolicyArn: aws.String(policyArn), RoleName: aws.String(roleName), }) @@ -148,9 +149,9 @@ func (wrapper RoleWrapper) AttachRolePolicy(policyArn string, roleName string) e // snippet-start:[gov2.iam.ListAttachedRolePolicies] // ListAttachedRolePolicies lists the policies that are attached to the specified role. -func (wrapper RoleWrapper) ListAttachedRolePolicies(roleName string) ([]types.AttachedPolicy, error) { +func (wrapper RoleWrapper) ListAttachedRolePolicies(ctx context.Context, roleName string) ([]types.AttachedPolicy, error) { var policies []types.AttachedPolicy - result, err := wrapper.IamClient.ListAttachedRolePolicies(context.TODO(), &iam.ListAttachedRolePoliciesInput{ + result, err := wrapper.IamClient.ListAttachedRolePolicies(ctx, &iam.ListAttachedRolePoliciesInput{ RoleName: aws.String(roleName), }) if err != nil { @@ -166,8 +167,8 @@ func (wrapper RoleWrapper) ListAttachedRolePolicies(roleName string) ([]types.At // snippet-start:[gov2.iam.DetachRolePolicy] // DetachRolePolicy detaches a policy from a role. -func (wrapper RoleWrapper) DetachRolePolicy(roleName string, policyArn string) error { - _, err := wrapper.IamClient.DetachRolePolicy(context.TODO(), &iam.DetachRolePolicyInput{ +func (wrapper RoleWrapper) DetachRolePolicy(ctx context.Context, roleName string, policyArn string) error { + _, err := wrapper.IamClient.DetachRolePolicy(ctx, &iam.DetachRolePolicyInput{ PolicyArn: aws.String(policyArn), RoleName: aws.String(roleName), }) @@ -182,9 +183,9 @@ func (wrapper RoleWrapper) DetachRolePolicy(roleName string, policyArn string) e // snippet-start:[gov2.iam.ListRolePolicies] // ListRolePolicies lists the inline policies for a role. -func (wrapper RoleWrapper) ListRolePolicies(roleName string) ([]string, error) { +func (wrapper RoleWrapper) ListRolePolicies(ctx context.Context, roleName string) ([]string, error) { var policies []string - result, err := wrapper.IamClient.ListRolePolicies(context.TODO(), &iam.ListRolePoliciesInput{ + result, err := wrapper.IamClient.ListRolePolicies(ctx, &iam.ListRolePoliciesInput{ RoleName: aws.String(roleName), }) if err != nil { @@ -201,8 +202,8 @@ func (wrapper RoleWrapper) ListRolePolicies(roleName string) ([]string, error) { // DeleteRole deletes a role. All attached policies must be detached before a // role can be deleted. -func (wrapper RoleWrapper) DeleteRole(roleName string) error { - _, err := wrapper.IamClient.DeleteRole(context.TODO(), &iam.DeleteRoleInput{ +func (wrapper RoleWrapper) DeleteRole(ctx context.Context, roleName string) error { + _, err := wrapper.IamClient.DeleteRole(ctx, &iam.DeleteRoleInput{ RoleName: aws.String(roleName), }) if err != nil { diff --git a/gov2/iam/actions/users.go b/gov2/iam/actions/users.go index 41b32f37aa1..ff7477a97d1 100644 --- a/gov2/iam/actions/users.go +++ b/gov2/iam/actions/users.go @@ -29,9 +29,9 @@ type UserWrapper struct { // snippet-start:[gov2.iam.ListUsers] // ListUsers gets up to maxUsers number of users. -func (wrapper UserWrapper) ListUsers(maxUsers int32) ([]types.User, error) { +func (wrapper UserWrapper) ListUsers(ctx context.Context, maxUsers int32) ([]types.User, error) { var users []types.User - result, err := wrapper.IamClient.ListUsers(context.TODO(), &iam.ListUsersInput{ + result, err := wrapper.IamClient.ListUsers(ctx, &iam.ListUsersInput{ MaxItems: aws.Int32(maxUsers), }) if err != nil { @@ -47,9 +47,9 @@ func (wrapper UserWrapper) ListUsers(maxUsers int32) ([]types.User, error) { // snippet-start:[gov2.iam.GetUser] // GetUser gets data about a user. -func (wrapper UserWrapper) GetUser(userName string) (*types.User, error) { +func (wrapper UserWrapper) GetUser(ctx context.Context, userName string) (*types.User, error) { var user *types.User - result, err := wrapper.IamClient.GetUser(context.TODO(), &iam.GetUserInput{ + result, err := wrapper.IamClient.GetUser(ctx, &iam.GetUserInput{ UserName: aws.String(userName), }) if err != nil { @@ -74,9 +74,9 @@ func (wrapper UserWrapper) GetUser(userName string) (*types.User, error) { // snippet-start:[gov2.iam.CreateUser] // CreateUser creates a new user with the specified name. -func (wrapper UserWrapper) CreateUser(userName string) (*types.User, error) { +func (wrapper UserWrapper) CreateUser(ctx context.Context, userName string) (*types.User, error) { var user *types.User - result, err := wrapper.IamClient.CreateUser(context.TODO(), &iam.CreateUserInput{ + result, err := wrapper.IamClient.CreateUser(ctx, &iam.CreateUserInput{ UserName: aws.String(userName), }) if err != nil { @@ -95,13 +95,13 @@ func (wrapper UserWrapper) CreateUser(userName string) (*types.User, error) { // grants a list of actions on a specified role. // PolicyDocument shows how to work with a policy document as a data structure and // serialize it to JSON by using Go's JSON marshaler. -func (wrapper UserWrapper) CreateUserPolicy(userName string, policyName string, actions []string, - roleArn string) error { +func (wrapper UserWrapper) CreateUserPolicy(ctx context.Context, userName string, policyName string, actions []string, + roleArn string) error { policyDoc := PolicyDocument{ - Version: "2012-10-17", + Version: "2012-10-17", Statement: []PolicyStatement{{ - Effect: "Allow", - Action: actions, + Effect: "Allow", + Action: actions, Resource: aws.String(roleArn), }}, } @@ -110,7 +110,7 @@ func (wrapper UserWrapper) CreateUserPolicy(userName string, policyName string, log.Printf("Couldn't create policy document for %v. Here's why: %v\n", roleArn, err) return err } - _, err = wrapper.IamClient.PutUserPolicy(context.TODO(), &iam.PutUserPolicyInput{ + _, err = wrapper.IamClient.PutUserPolicy(ctx, &iam.PutUserPolicyInput{ PolicyDocument: aws.String(string(policyBytes)), PolicyName: aws.String(policyName), UserName: aws.String(userName), @@ -126,9 +126,9 @@ func (wrapper UserWrapper) CreateUserPolicy(userName string, policyName string, // snippet-start:[gov2.iam.ListUserPolicies] // ListUserPolicies lists the inline policies for the specified user. -func (wrapper UserWrapper) ListUserPolicies(userName string) ([]string, error) { +func (wrapper UserWrapper) ListUserPolicies(ctx context.Context, userName string) ([]string, error) { var policies []string - result, err := wrapper.IamClient.ListUserPolicies(context.TODO(), &iam.ListUserPoliciesInput{ + result, err := wrapper.IamClient.ListUserPolicies(ctx, &iam.ListUserPoliciesInput{ UserName: aws.String(userName), }) if err != nil { @@ -144,8 +144,8 @@ func (wrapper UserWrapper) ListUserPolicies(userName string) ([]string, error) { // snippet-start:[gov2.iam.DeleteUserPolicy] // DeleteUserPolicy deletes an inline policy from a user. -func (wrapper UserWrapper) DeleteUserPolicy(userName string, policyName string) error { - _, err := wrapper.IamClient.DeleteUserPolicy(context.TODO(), &iam.DeleteUserPolicyInput{ +func (wrapper UserWrapper) DeleteUserPolicy(ctx context.Context, userName string, policyName string) error { + _, err := wrapper.IamClient.DeleteUserPolicy(ctx, &iam.DeleteUserPolicyInput{ PolicyName: aws.String(policyName), UserName: aws.String(userName), }) @@ -160,8 +160,8 @@ func (wrapper UserWrapper) DeleteUserPolicy(userName string, policyName string) // snippet-start:[gov2.iam.DeleteUser] // DeleteUser deletes a user. -func (wrapper UserWrapper) DeleteUser(userName string) error { - _, err := wrapper.IamClient.DeleteUser(context.TODO(), &iam.DeleteUserInput{ +func (wrapper UserWrapper) DeleteUser(ctx context.Context, userName string) error { + _, err := wrapper.IamClient.DeleteUser(ctx, &iam.DeleteUserInput{ UserName: aws.String(userName), }) if err != nil { @@ -176,9 +176,9 @@ func (wrapper UserWrapper) DeleteUser(userName string) error { // CreateAccessKeyPair creates an access key for a user. The returned access key contains // the ID and secret credentials needed to use the key. -func (wrapper UserWrapper) CreateAccessKeyPair(userName string) (*types.AccessKey, error) { +func (wrapper UserWrapper) CreateAccessKeyPair(ctx context.Context, userName string) (*types.AccessKey, error) { var key *types.AccessKey - result, err := wrapper.IamClient.CreateAccessKey(context.TODO(), &iam.CreateAccessKeyInput{ + result, err := wrapper.IamClient.CreateAccessKey(ctx, &iam.CreateAccessKeyInput{ UserName: aws.String(userName)}) if err != nil { log.Printf("Couldn't create access key pair for user %v. Here's why: %v\n", userName, err) @@ -193,8 +193,8 @@ func (wrapper UserWrapper) CreateAccessKeyPair(userName string) (*types.AccessKe // snippet-start:[gov2.iam.DeleteAccessKey] // DeleteAccessKey deletes an access key from a user. -func (wrapper UserWrapper) DeleteAccessKey(userName string, keyId string) error { - _, err := wrapper.IamClient.DeleteAccessKey(context.TODO(), &iam.DeleteAccessKeyInput{ +func (wrapper UserWrapper) DeleteAccessKey(ctx context.Context, userName string, keyId string) error { + _, err := wrapper.IamClient.DeleteAccessKey(ctx, &iam.DeleteAccessKeyInput{ AccessKeyId: aws.String(keyId), UserName: aws.String(userName), }) @@ -209,9 +209,9 @@ func (wrapper UserWrapper) DeleteAccessKey(userName string, keyId string) error // snippet-start:[gov2.iam.ListAccessKeys] // ListAccessKeys lists the access keys for the specified user. -func (wrapper UserWrapper) ListAccessKeys(userName string) ([]types.AccessKeyMetadata, error) { +func (wrapper UserWrapper) ListAccessKeys(ctx context.Context, userName string) ([]types.AccessKeyMetadata, error) { var keys []types.AccessKeyMetadata - result, err := wrapper.IamClient.ListAccessKeys(context.TODO(), &iam.ListAccessKeysInput{ + result, err := wrapper.IamClient.ListAccessKeys(ctx, &iam.ListAccessKeysInput{ UserName: aws.String(userName), }) if err != nil { diff --git a/gov2/iam/cmd/main.go b/gov2/iam/cmd/main.go index e25c02fe3f9..dce56e70aee 100644 --- a/gov2/iam/cmd/main.go +++ b/gov2/iam/cmd/main.go @@ -22,11 +22,11 @@ import ( // // `-scenario` can be one of the following: // -// * `assumerole` - Runs an interactive scenario that shows you how to assume a role -// with limited permissions and perform actions on AWS services. +// - `assumerole` - Runs an interactive scenario that shows you how to assume a role +// with limited permissions and perform actions on AWS services. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config){ - "assumerole": runAssumeRoleScenario, + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config){ + "assumerole": runAssumeRoleScenario, } choices := make([]string, len(scenarioMap)) choiceIndex := 0 @@ -43,21 +43,22 @@ func main() { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } log.SetFlags(0) - runScenario(sdkConfig) + runScenario(ctx, sdkConfig) } } -func runAssumeRoleScenario(sdkConfig aws.Config) { +func runAssumeRoleScenario(ctx context.Context, sdkConfig aws.Config) { helper := scenarios.ScenarioHelper{ Prefix: "doc-example-assumerole-", Random: rand.New(rand.NewSource(time.Now().Unix())), } scenario := scenarios.NewAssumeRoleScenario(sdkConfig, demotools.NewQuestioner(), &helper) - scenario.Run() + scenario.Run(ctx) } diff --git a/gov2/iam/hello/hello.go b/gov2/iam/hello/hello.go index a689fc22405..25f39d9b345 100644 --- a/gov2/iam/hello/hello.go +++ b/gov2/iam/hello/hello.go @@ -19,7 +19,8 @@ import ( // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -28,7 +29,7 @@ func main() { iamClient := iam.NewFromConfig(sdkConfig) const maxPols = 10 fmt.Printf("Let's list up to %v policies for your account.\n", maxPols) - result, err := iamClient.ListPolicies(context.TODO(), &iam.ListPoliciesInput{ + result, err := iamClient.ListPolicies(ctx, &iam.ListPoliciesInput{ MaxItems: aws.Int32(maxPols), }) if err != nil { diff --git a/gov2/iam/scenarios/scenario_assume_role.go b/gov2/iam/scenarios/scenario_assume_role.go index 845d61328b4..703bbe8e31e 100644 --- a/gov2/iam/scenarios/scenario_assume_role.go +++ b/gov2/iam/scenarios/scenario_assume_role.go @@ -32,6 +32,7 @@ type IScenarioHelper interface { } const rMax = 100000 + type ScenarioHelper struct { Prefix string Random *rand.Rand @@ -52,32 +53,32 @@ func (helper ScenarioHelper) Pause(secs int) { // AssumeRoleScenario shows you how to use the AWS Identity and Access Management (IAM) // service to perform the following actions: // -// 1. Create a user who has no permissions. -// 2. Create a role that grants permission to list Amazon Simple Storage Service -// (Amazon S3) buckets for the account. -// 3. Add a policy to let the user assume the role. -// 4. Try and fail to list buckets without permissions. -// 5. Assume the role and list S3 buckets using temporary credentials. -// 6. Delete the policy, role, and user. +// 1. Create a user who has no permissions. +// 2. Create a role that grants permission to list Amazon Simple Storage Service +// (Amazon S3) buckets for the account. +// 3. Add a policy to let the user assume the role. +// 4. Try and fail to list buckets without permissions. +// 5. Assume the role and list S3 buckets using temporary credentials. +// 6. Delete the policy, role, and user. type AssumeRoleScenario struct { - sdkConfig aws.Config + sdkConfig aws.Config accountWrapper actions.AccountWrapper - policyWrapper actions.PolicyWrapper - roleWrapper actions.RoleWrapper - userWrapper actions.UserWrapper - questioner demotools.IQuestioner - helper IScenarioHelper - isTestRun bool + policyWrapper actions.PolicyWrapper + roleWrapper actions.RoleWrapper + userWrapper actions.UserWrapper + questioner demotools.IQuestioner + helper IScenarioHelper + isTestRun bool } // NewAssumeRoleScenario constructs an AssumeRoleScenario instance from a configuration. // It uses the specified config to get an IAM client and create wrappers for the actions // used in the scenario. func NewAssumeRoleScenario(sdkConfig aws.Config, questioner demotools.IQuestioner, - helper IScenarioHelper) AssumeRoleScenario { + helper IScenarioHelper) AssumeRoleScenario { iamClient := iam.NewFromConfig(sdkConfig) return AssumeRoleScenario{ - sdkConfig: sdkConfig, + sdkConfig: sdkConfig, accountWrapper: actions.AccountWrapper{IamClient: iamClient}, policyWrapper: actions.PolicyWrapper{IamClient: iamClient}, roleWrapper: actions.RoleWrapper{IamClient: iamClient}, @@ -97,7 +98,7 @@ func (scenario AssumeRoleScenario) addTestOptions(scenarioConfig *aws.Config) { } // Run runs the interactive scenario. -func (scenario AssumeRoleScenario) Run() { +func (scenario AssumeRoleScenario) Run(ctx context.Context) { defer func() { if r := recover(); r != nil { log.Printf("Something went wrong with the demo.\n") @@ -109,12 +110,12 @@ func (scenario AssumeRoleScenario) Run() { log.Println("Welcome to the AWS Identity and Access Management (IAM) assume role demo.") log.Println(strings.Repeat("-", 88)) - user := scenario.CreateUser() - accessKey := scenario.CreateAccessKey(user) - role := scenario.CreateRoleAndPolicies(user) - noPermsConfig := scenario.ListBucketsWithoutPermissions(accessKey) - scenario.ListBucketsWithAssumedRole(noPermsConfig, role) - scenario.Cleanup(user, role) + user := scenario.CreateUser(ctx) + accessKey := scenario.CreateAccessKey(ctx, user) + role := scenario.CreateRoleAndPolicies(ctx, user) + noPermsConfig := scenario.ListBucketsWithoutPermissions(ctx, accessKey) + scenario.ListBucketsWithAssumedRole(ctx, noPermsConfig, role) + scenario.Cleanup(ctx, user, role) log.Println(strings.Repeat("-", 88)) log.Println("Thanks for watching!") @@ -122,15 +123,15 @@ func (scenario AssumeRoleScenario) Run() { } // CreateUser creates a new IAM user. This user has no permissions. -func (scenario AssumeRoleScenario) CreateUser() *types.User { +func (scenario AssumeRoleScenario) CreateUser(ctx context.Context) *types.User { log.Println("Let's create an example user with no permissions.") userName := scenario.questioner.Ask("Enter a name for the example user:", demotools.NotEmpty{}) - user, err := scenario.userWrapper.GetUser(userName) + user, err := scenario.userWrapper.GetUser(ctx, userName) if err != nil { panic(err) } if user == nil { - user, err = scenario.userWrapper.CreateUser(userName) + user, err = scenario.userWrapper.CreateUser(ctx, userName) if err != nil { panic(err) } @@ -143,8 +144,8 @@ func (scenario AssumeRoleScenario) CreateUser() *types.User { } // CreateAccessKey creates an access key for the user. -func (scenario AssumeRoleScenario) CreateAccessKey(user *types.User) *types.AccessKey { - accessKey, err := scenario.userWrapper.CreateAccessKeyPair(*user.UserName) +func (scenario AssumeRoleScenario) CreateAccessKey(ctx context.Context, user *types.User) *types.AccessKey { + accessKey, err := scenario.userWrapper.CreateAccessKeyPair(ctx, *user.UserName) if err != nil { panic(err) } @@ -158,23 +159,31 @@ func (scenario AssumeRoleScenario) CreateAccessKey(user *types.User) *types.Acce // CreateRoleAndPolicies creates a policy that grants permission to list S3 buckets for // the current account and attaches the policy to a newly created role. It also adds an // inline policy to the specified user that grants the user permission to assume the role. -func (scenario AssumeRoleScenario) CreateRoleAndPolicies(user *types.User) *types.Role { +func (scenario AssumeRoleScenario) CreateRoleAndPolicies(ctx context.Context, user *types.User) *types.Role { log.Println("Let's create a role and policy that grant permission to list S3 buckets.") scenario.questioner.Ask("Press Enter when you're ready.") - listBucketsRole, err := scenario.roleWrapper.CreateRole(scenario.helper.GetName(), *user.Arn) - if err != nil {panic(err)} + listBucketsRole, err := scenario.roleWrapper.CreateRole(ctx, scenario.helper.GetName(), *user.Arn) + if err != nil { + panic(err) + } log.Printf("Created role %v.\n", *listBucketsRole.RoleName) listBucketsPolicy, err := scenario.policyWrapper.CreatePolicy( - scenario.helper.GetName(), []string{"s3:ListAllMyBuckets"}, "arn:aws:s3:::*") - if err != nil {panic(err)} + ctx, scenario.helper.GetName(), []string{"s3:ListAllMyBuckets"}, "arn:aws:s3:::*") + if err != nil { + panic(err) + } log.Printf("Created policy %v.\n", *listBucketsPolicy.PolicyName) - err = scenario.roleWrapper.AttachRolePolicy(*listBucketsPolicy.Arn, *listBucketsRole.RoleName) - if err != nil {panic(err)} + err = scenario.roleWrapper.AttachRolePolicy(ctx, *listBucketsPolicy.Arn, *listBucketsRole.RoleName) + if err != nil { + panic(err) + } log.Printf("Attached policy %v to role %v.\n", *listBucketsPolicy.PolicyName, *listBucketsRole.RoleName) - err = scenario.userWrapper.CreateUserPolicy(*user.UserName, scenario.helper.GetName(), + err = scenario.userWrapper.CreateUserPolicy(ctx, *user.UserName, scenario.helper.GetName(), []string{"sts:AssumeRole"}, *listBucketsRole.Arn) - if err != nil {panic(err)} + if err != nil { + panic(err) + } log.Printf("Created an inline policy for user %v that lets the user assume the role.\n", *user.UserName) log.Println("Let's give AWS a few seconds to propagate these new resources and connections...") @@ -186,22 +195,24 @@ func (scenario AssumeRoleScenario) CreateRoleAndPolicies(user *types.User) *type // ListBucketsWithoutPermissions creates an Amazon S3 client from the user's access key // credentials and tries to list buckets for the account. Because the user does not have // permission to perform this action, the action fails. -func (scenario AssumeRoleScenario) ListBucketsWithoutPermissions(accessKey *types.AccessKey) *aws.Config { - log.Println("Let's try to list buckets without permissions. This should return an AccessDenied error.") - scenario.questioner.Ask("Press Enter when you're ready.") - noPermsConfig, err := config.LoadDefaultConfig(context.TODO(), +func (scenario AssumeRoleScenario) ListBucketsWithoutPermissions(ctx context.Context, accessKey *types.AccessKey) *aws.Config { + log.Println("Let's try to list buckets without permissions. This should return an AccessDenied error.") + scenario.questioner.Ask("Press Enter when you're ready.") + noPermsConfig, err := config.LoadDefaultConfig(ctx, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( *accessKey.AccessKeyId, *accessKey.SecretAccessKey, ""), - )) - if err != nil {panic(err)} + )) + if err != nil { + panic(err) + } - // Add test options if this is a test run. This is needed only for testing purposes. + // Add test options if this is a test run. This is needed only for testing purposes. scenario.addTestOptions(&noPermsConfig) - s3Client := s3.NewFromConfig(noPermsConfig) - _, err = s3Client.ListBuckets(context.TODO(), &s3.ListBucketsInput{}) - if err != nil { - // The SDK for Go does not model the AccessDenied error, so check ErrorCode directly. + s3Client := s3.NewFromConfig(noPermsConfig) + _, err = s3Client.ListBuckets(ctx, &s3.ListBucketsInput{}) + if err != nil { + // The SDK for Go does not model the AccessDenied error, so check ErrorCode directly. var ae smithy.APIError if errors.As(err, &ae) { switch ae.ErrorCode() { @@ -213,51 +224,53 @@ func (scenario AssumeRoleScenario) ListBucketsWithoutPermissions(accessKey *type panic(err) } } - } else { - log.Println("Expected AccessDenied error when calling ListBuckets without permissions,\n" + - "but the call succeeded. Continuing the example anyway...") + } else { + log.Println("Expected AccessDenied error when calling ListBuckets without permissions,\n" + + "but the call succeeded. Continuing the example anyway...") } log.Println(strings.Repeat("-", 88)) - return &noPermsConfig + return &noPermsConfig } // ListBucketsWithAssumedRole performs the following actions: // -// 1. Creates an AWS Security Token Service (AWS STS) client from the config created from -// the user's access key credentials. -// 2. Gets temporary credentials by assuming the role that grants permission to list the -// buckets. -// 3. Creates an Amazon S3 client from the temporary credentials. -// 4. Lists buckets for the account. Because the temporary credentials are generated by -// assuming the role that grants permission, the action succeeds. -func (scenario AssumeRoleScenario) ListBucketsWithAssumedRole(noPermsConfig *aws.Config, role *types.Role) { +// 1. Creates an AWS Security Token Service (AWS STS) client from the config created from +// the user's access key credentials. +// 2. Gets temporary credentials by assuming the role that grants permission to list the +// buckets. +// 3. Creates an Amazon S3 client from the temporary credentials. +// 4. Lists buckets for the account. Because the temporary credentials are generated by +// assuming the role that grants permission, the action succeeds. +func (scenario AssumeRoleScenario) ListBucketsWithAssumedRole(ctx context.Context, noPermsConfig *aws.Config, role *types.Role) { log.Println("Let's assume the role that grants permission to list buckets and try again.") scenario.questioner.Ask("Press Enter when you're ready.") stsClient := sts.NewFromConfig(*noPermsConfig) - tempCredentials, err := stsClient.AssumeRole(context.TODO(), &sts.AssumeRoleInput{ - RoleArn: role.Arn, - RoleSessionName: aws.String("AssumeRoleExampleSession"), - DurationSeconds: aws.Int32(900), + tempCredentials, err := stsClient.AssumeRole(ctx, &sts.AssumeRoleInput{ + RoleArn: role.Arn, + RoleSessionName: aws.String("AssumeRoleExampleSession"), + DurationSeconds: aws.Int32(900), }) if err != nil { log.Printf("Couldn't assume role %v.\n", *role.RoleName) panic(err) } log.Printf("Assumed role %v, got temporary credentials.\n", *role.RoleName) - assumeRoleConfig, err := config.LoadDefaultConfig(context.TODO(), + assumeRoleConfig, err := config.LoadDefaultConfig(ctx, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( *tempCredentials.Credentials.AccessKeyId, *tempCredentials.Credentials.SecretAccessKey, *tempCredentials.Credentials.SessionToken), ), ) - if err != nil {panic(err)} + if err != nil { + panic(err) + } // Add test options if this is a test run. This is needed only for testing purposes. scenario.addTestOptions(&assumeRoleConfig) s3Client := s3.NewFromConfig(assumeRoleConfig) - result, err := s3Client.ListBuckets(context.TODO(), &s3.ListBucketsInput{}) + result, err := s3Client.ListBuckets(ctx, &s3.ListBucketsInput{}) if err != nil { log.Println("Couldn't list buckets with assumed role credentials.") panic(err) @@ -271,42 +284,60 @@ func (scenario AssumeRoleScenario) ListBucketsWithAssumedRole(noPermsConfig *aws } // Cleanup deletes all resources created for the scenario. -func (scenario AssumeRoleScenario) Cleanup(user *types.User, role *types.Role) { +func (scenario AssumeRoleScenario) Cleanup(ctx context.Context, user *types.User, role *types.Role) { if scenario.questioner.AskBool( "Do you want to delete the resources created for this example? (y/n)", "y", ) { - policies, err := scenario.roleWrapper.ListAttachedRolePolicies(*role.RoleName) - if err != nil {panic(err)} - for _, policy := range policies { - err = scenario.roleWrapper.DetachRolePolicy(*role.RoleName, *policy.PolicyArn) - if err != nil {panic(err)} - err = scenario.policyWrapper.DeletePolicy(*policy.PolicyArn) - if err != nil {panic(err)} - log.Printf("Detached policy %v from role %v and deleted the policy.\n", - *policy.PolicyName, *role.RoleName) - } - err = scenario.roleWrapper.DeleteRole(*role.RoleName) - if err != nil {panic(err)} - log.Printf("Deleted role %v.\n", *role.RoleName) + policies, err := scenario.roleWrapper.ListAttachedRolePolicies(ctx, *role.RoleName) + if err != nil { + panic(err) + } + for _, policy := range policies { + err = scenario.roleWrapper.DetachRolePolicy(ctx, *role.RoleName, *policy.PolicyArn) + if err != nil { + panic(err) + } + err = scenario.policyWrapper.DeletePolicy(ctx, *policy.PolicyArn) + if err != nil { + panic(err) + } + log.Printf("Detached policy %v from role %v and deleted the policy.\n", + *policy.PolicyName, *role.RoleName) + } + err = scenario.roleWrapper.DeleteRole(ctx, *role.RoleName) + if err != nil { + panic(err) + } + log.Printf("Deleted role %v.\n", *role.RoleName) - userPols, err := scenario.userWrapper.ListUserPolicies(*user.UserName) - if err != nil {panic(err)} - for _, userPol := range userPols { - err = scenario.userWrapper.DeleteUserPolicy(*user.UserName, userPol) - if err != nil {panic(err)} - log.Printf("Deleted policy %v from user %v.\n", userPol, *user.UserName) - } - keys, err := scenario.userWrapper.ListAccessKeys(*user.UserName) - if err != nil {panic(err)} - for _, key := range keys { - err = scenario.userWrapper.DeleteAccessKey(*user.UserName, *key.AccessKeyId) - if err != nil {panic(err)} - log.Printf("Deleted access key %v from user %v.\n", *key.AccessKeyId, *user.UserName) - } - err = scenario.userWrapper.DeleteUser(*user.UserName) - if err != nil {panic(err)} - log.Printf("Deleted user %v.\n", *user.UserName) - log.Println(strings.Repeat("-", 88)) + userPols, err := scenario.userWrapper.ListUserPolicies(ctx, *user.UserName) + if err != nil { + panic(err) + } + for _, userPol := range userPols { + err = scenario.userWrapper.DeleteUserPolicy(ctx, *user.UserName, userPol) + if err != nil { + panic(err) + } + log.Printf("Deleted policy %v from user %v.\n", userPol, *user.UserName) + } + keys, err := scenario.userWrapper.ListAccessKeys(ctx, *user.UserName) + if err != nil { + panic(err) + } + for _, key := range keys { + err = scenario.userWrapper.DeleteAccessKey(ctx, *user.UserName, *key.AccessKeyId) + if err != nil { + panic(err) + } + log.Printf("Deleted access key %v from user %v.\n", *key.AccessKeyId, *user.UserName) + } + err = scenario.userWrapper.DeleteUser(ctx, *user.UserName) + if err != nil { + panic(err) + } + log.Printf("Deleted user %v.\n", *user.UserName) + log.Println(strings.Repeat("-", 88)) } } diff --git a/gov2/iam/scenarios/scenario_assume_role_integ_test.go b/gov2/iam/scenarios/scenario_assume_role_integ_test.go index 02537a97a28..2cd8963a17f 100644 --- a/gov2/iam/scenarios/scenario_assume_role_integ_test.go +++ b/gov2/iam/scenarios/scenario_assume_role_integ_test.go @@ -45,7 +45,8 @@ func TestRunAssumeRoleScenario_Integration(t *testing.T) { Answers: []string{helper.GetName(), "", "", "", "y"}, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -55,7 +56,7 @@ func TestRunAssumeRoleScenario_Integration(t *testing.T) { log.SetOutput(&buf) scenario := NewAssumeRoleScenario(sdkConfig, mockQuestioner, &helper) - scenario.Run() + scenario.Run(ctx) log.SetOutput(os.Stderr) if !strings.Contains(buf.String(), "Thanks for watching") { diff --git a/gov2/iam/scenarios/scenario_assume_role_test.go b/gov2/iam/scenarios/scenario_assume_role_test.go index dc3556edcbd..c1ce2fdb856 100644 --- a/gov2/iam/scenarios/scenario_assume_role_test.go +++ b/gov2/iam/scenarios/scenario_assume_role_test.go @@ -6,6 +6,7 @@ package scenarios import ( + "context" "testing" "github.com/aws/aws-sdk-go-v2/service/iam/types" @@ -25,8 +26,8 @@ func TestRunAssumeRoleScenario(t *testing.T) { // AssumeRoleScenarioTest encapsulates data for a scenario test. type AssumeRoleScenarioTest struct { - Answers []string - helper testHelper + Answers []string + helper testHelper } // SetupDataAndStubs sets up test data and builds the stubs that are used to return @@ -43,7 +44,6 @@ func (scenTest *AssumeRoleScenarioTest) SetupDataAndStubs() []testtools.Stub { listBucketsPolicyArn := listBucketsPolicy + "-arn" userPolicy := "test-user-policy" - scenTest.helper = testHelper{names: []string{roleName, listBucketsPolicy, userPolicy}} scenTest.Answers = []string{userName, "", "", "", "y"} @@ -81,7 +81,7 @@ func (scenTest *AssumeRoleScenarioTest) RunSubTest(stubber *testtools.AwsmStubbe mockQuestioner := demotools.MockQuestioner{Answers: scenTest.Answers} scenario := NewAssumeRoleScenario(*stubber.SdkConfig, &mockQuestioner, &scenTest.helper) scenario.isTestRun = true - scenario.Run() + scenario.Run(context.Background()) } func (scenTest *AssumeRoleScenarioTest) Cleanup() {} diff --git a/gov2/lambda/actions/functions.go b/gov2/lambda/actions/functions.go index 248af33b308..55ff0bc461e 100644 --- a/gov2/lambda/actions/functions.go +++ b/gov2/lambda/actions/functions.go @@ -29,9 +29,9 @@ type FunctionWrapper struct { // snippet-start:[gov2.lambda.GetFunction] // GetFunction gets data about the Lambda function specified by functionName. -func (wrapper FunctionWrapper) GetFunction(functionName string) types.State { +func (wrapper FunctionWrapper) GetFunction(ctx context.Context, functionName string) types.State { var state types.State - funcOutput, err := wrapper.LambdaClient.GetFunction(context.TODO(), &lambda.GetFunctionInput{ + funcOutput, err := wrapper.LambdaClient.GetFunction(ctx, &lambda.GetFunctionInput{ FunctionName: aws.String(functionName), }) if err != nil { @@ -53,10 +53,10 @@ func (wrapper FunctionWrapper) GetFunction(functionName string) types.State { // When the function already exists, types.StateActive is returned. // When the function is created, a lambda.FunctionActiveV2Waiter is used to wait until the // function is active. -func (wrapper FunctionWrapper) CreateFunction(functionName string, handlerName string, +func (wrapper FunctionWrapper) CreateFunction(ctx context.Context, functionName string, handlerName string, iamRoleArn *string, zipPackage *bytes.Buffer) types.State { var state types.State - _, err := wrapper.LambdaClient.CreateFunction(context.TODO(), &lambda.CreateFunctionInput{ + _, err := wrapper.LambdaClient.CreateFunction(ctx, &lambda.CreateFunctionInput{ Code: &types.FunctionCode{ZipFile: zipPackage.Bytes()}, FunctionName: aws.String(functionName), Role: iamRoleArn, @@ -74,7 +74,7 @@ func (wrapper FunctionWrapper) CreateFunction(functionName string, handlerName s } } else { waiter := lambda.NewFunctionActiveV2Waiter(wrapper.LambdaClient) - funcOutput, err := waiter.WaitForOutput(context.TODO(), &lambda.GetFunctionInput{ + funcOutput, err := waiter.WaitForOutput(ctx, &lambda.GetFunctionInput{ FunctionName: aws.String(functionName)}, 1*time.Minute) if err != nil { log.Panicf("Couldn't wait for function %v to be active. Here's why: %v\n", functionName, err) @@ -93,16 +93,16 @@ func (wrapper FunctionWrapper) CreateFunction(functionName string, handlerName s // The existing code for the Lambda function is entirely replaced by the code in the // zipPackage buffer. After the update action is called, a lambda.FunctionUpdatedV2Waiter // is used to wait until the update is successful. -func (wrapper FunctionWrapper) UpdateFunctionCode(functionName string, zipPackage *bytes.Buffer) types.State { +func (wrapper FunctionWrapper) UpdateFunctionCode(ctx context.Context, functionName string, zipPackage *bytes.Buffer) types.State { var state types.State - _, err := wrapper.LambdaClient.UpdateFunctionCode(context.TODO(), &lambda.UpdateFunctionCodeInput{ + _, err := wrapper.LambdaClient.UpdateFunctionCode(ctx, &lambda.UpdateFunctionCodeInput{ FunctionName: aws.String(functionName), ZipFile: zipPackage.Bytes(), }) if err != nil { log.Panicf("Couldn't update code for function %v. Here's why: %v\n", functionName, err) } else { waiter := lambda.NewFunctionUpdatedV2Waiter(wrapper.LambdaClient) - funcOutput, err := waiter.WaitForOutput(context.TODO(), &lambda.GetFunctionInput{ + funcOutput, err := waiter.WaitForOutput(ctx, &lambda.GetFunctionInput{ FunctionName: aws.String(functionName)}, 1*time.Minute) if err != nil { log.Panicf("Couldn't wait for function %v to be active. Here's why: %v\n", functionName, err) @@ -119,8 +119,8 @@ func (wrapper FunctionWrapper) UpdateFunctionCode(functionName string, zipPackag // UpdateFunctionConfiguration updates a map of environment variables configured for // the Lambda function specified by functionName. -func (wrapper FunctionWrapper) UpdateFunctionConfiguration(functionName string, envVars map[string]string) { - _, err := wrapper.LambdaClient.UpdateFunctionConfiguration(context.TODO(), &lambda.UpdateFunctionConfigurationInput{ +func (wrapper FunctionWrapper) UpdateFunctionConfiguration(ctx context.Context, functionName string, envVars map[string]string) { + _, err := wrapper.LambdaClient.UpdateFunctionConfiguration(ctx, &lambda.UpdateFunctionConfigurationInput{ FunctionName: aws.String(functionName), Environment: &types.Environment{Variables: envVars}, }) @@ -135,13 +135,13 @@ func (wrapper FunctionWrapper) UpdateFunctionConfiguration(functionName string, // ListFunctions lists up to maxItems functions for the account. This function uses a // lambda.ListFunctionsPaginator to paginate the results. -func (wrapper FunctionWrapper) ListFunctions(maxItems int) []types.FunctionConfiguration { +func (wrapper FunctionWrapper) ListFunctions(ctx context.Context, maxItems int) []types.FunctionConfiguration { var functions []types.FunctionConfiguration paginator := lambda.NewListFunctionsPaginator(wrapper.LambdaClient, &lambda.ListFunctionsInput{ MaxItems: aws.Int32(int32(maxItems)), }) for paginator.HasMorePages() && len(functions) < maxItems { - pageOutput, err := paginator.NextPage(context.TODO()) + pageOutput, err := paginator.NextPage(ctx) if err != nil { log.Panicf("Couldn't list functions for your account. Here's why: %v\n", err) } @@ -155,8 +155,8 @@ func (wrapper FunctionWrapper) ListFunctions(maxItems int) []types.FunctionConfi // snippet-start:[gov2.lambda.DeleteFunction] // DeleteFunction deletes the Lambda function specified by functionName. -func (wrapper FunctionWrapper) DeleteFunction(functionName string) { - _, err := wrapper.LambdaClient.DeleteFunction(context.TODO(), &lambda.DeleteFunctionInput{ +func (wrapper FunctionWrapper) DeleteFunction(ctx context.Context, functionName string) { + _, err := wrapper.LambdaClient.DeleteFunction(ctx, &lambda.DeleteFunctionInput{ FunctionName: aws.String(functionName), }) if err != nil { @@ -171,7 +171,7 @@ func (wrapper FunctionWrapper) DeleteFunction(functionName string) { // Invoke invokes the Lambda function specified by functionName, passing the parameters // as a JSON payload. When getLog is true, types.LogTypeTail is specified, which tells // Lambda to include the last few log lines in the returned result. -func (wrapper FunctionWrapper) Invoke(functionName string, parameters any, getLog bool) *lambda.InvokeOutput { +func (wrapper FunctionWrapper) Invoke(ctx context.Context, functionName string, parameters any, getLog bool) *lambda.InvokeOutput { logType := types.LogTypeNone if getLog { logType = types.LogTypeTail @@ -180,7 +180,7 @@ func (wrapper FunctionWrapper) Invoke(functionName string, parameters any, getLo if err != nil { log.Panicf("Couldn't marshal parameters to JSON. Here's why %v\n", err) } - invokeOutput, err := wrapper.LambdaClient.Invoke(context.TODO(), &lambda.InvokeInput{ + invokeOutput, err := wrapper.LambdaClient.Invoke(ctx, &lambda.InvokeInput{ FunctionName: aws.String(functionName), LogType: logType, Payload: payload, diff --git a/gov2/lambda/cmd/main.go b/gov2/lambda/cmd/main.go index 12d42a85975..103f7f2b7a5 100644 --- a/gov2/lambda/cmd/main.go +++ b/gov2/lambda/cmd/main.go @@ -23,7 +23,7 @@ import ( // - `functions` - Runs an interactive scenario that shows you how to create and // invoke an AWS Lambda function, then update the code and invoke it again. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config){ + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config){ "functions": runGetStartedFunctionsScenario, } choices := make([]string, len(scenarioMap)) @@ -41,18 +41,19 @@ func main() { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } log.SetFlags(0) - runScenario(sdkConfig) + runScenario(ctx, sdkConfig) } } -func runGetStartedFunctionsScenario(sdkConfig aws.Config) { +func runGetStartedFunctionsScenario(ctx context.Context, sdkConfig aws.Config) { helper := scenarios.ScenarioHelper{HandlerPath: "handlers/"} scenario := scenarios.NewGetStartedFunctionsScenario(sdkConfig, demotools.NewQuestioner(), &helper) - scenario.Run() + scenario.Run(ctx) } diff --git a/gov2/lambda/hello/hello.go b/gov2/lambda/hello/hello.go index f5bc6b907bd..c7ebd7083ae 100644 --- a/gov2/lambda/hello/hello.go +++ b/gov2/lambda/hello/hello.go @@ -19,7 +19,8 @@ import ( // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -29,7 +30,7 @@ func main() { maxItems := 10 fmt.Printf("Let's list up to %v functions for your account.\n", maxItems) - result, err := lambdaClient.ListFunctions(context.TODO(), &lambda.ListFunctionsInput{ + result, err := lambdaClient.ListFunctions(ctx, &lambda.ListFunctionsInput{ MaxItems: aws.Int32(int32(maxItems)), }) if err != nil { diff --git a/gov2/lambda/scenarios/scenario_get_started_functions.go b/gov2/lambda/scenarios/scenario_get_started_functions.go index f5cad638750..3193e145306 100644 --- a/gov2/lambda/scenarios/scenario_get_started_functions.go +++ b/gov2/lambda/scenarios/scenario_get_started_functions.go @@ -108,7 +108,7 @@ func NewGetStartedFunctionsScenario(sdkConfig aws.Config, questioner demotools.I } // Run runs the interactive scenario. -func (scenario GetStartedFunctionsScenario) Run() { +func (scenario GetStartedFunctionsScenario) Run(ctx context.Context) { defer func() { if r := recover(); r != nil { log.Printf("Something went wrong with the demo.\n") @@ -119,13 +119,13 @@ func (scenario GetStartedFunctionsScenario) Run() { log.Println("Welcome to the AWS Lambda get started with functions demo.") log.Println(strings.Repeat("-", 88)) - role := scenario.GetOrCreateRole() - funcName := scenario.CreateFunction(role) - scenario.InvokeIncrement(funcName) - scenario.UpdateFunction(funcName) - scenario.InvokeCalculator(funcName) - scenario.ListFunctions() - scenario.Cleanup(role, funcName) + role := scenario.GetOrCreateRole(ctx) + funcName := scenario.CreateFunction(ctx, role) + scenario.InvokeIncrement(ctx, funcName) + scenario.UpdateFunction(ctx, funcName) + scenario.InvokeCalculator(ctx, funcName) + scenario.ListFunctions(ctx) + scenario.Cleanup(ctx, role, funcName) log.Println(strings.Repeat("-", 88)) log.Println("Thanks for watching!") @@ -136,12 +136,12 @@ func (scenario GetStartedFunctionsScenario) Run() { // Otherwise, a role is created that specifies Lambda as a trusted principal. // The AWSLambdaBasicExecutionRole managed policy is attached to the role and the role // is returned. -func (scenario GetStartedFunctionsScenario) GetOrCreateRole() *iamtypes.Role { +func (scenario GetStartedFunctionsScenario) GetOrCreateRole(ctx context.Context) *iamtypes.Role { var role *iamtypes.Role iamClient := iam.NewFromConfig(scenario.sdkConfig) log.Println("First, we need an IAM role that Lambda can assume.") roleName := scenario.questioner.Ask("Enter a name for the role:", demotools.NotEmpty{}) - getOutput, err := iamClient.GetRole(context.TODO(), &iam.GetRoleInput{ + getOutput, err := iamClient.GetRole(ctx, &iam.GetRoleInput{ RoleName: aws.String(roleName)}) if err != nil { var noSuch *iamtypes.NoSuchEntityException @@ -165,7 +165,7 @@ func (scenario GetStartedFunctionsScenario) GetOrCreateRole() *iamtypes.Role { }}, } policyArn := "arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole" - createOutput, err := iamClient.CreateRole(context.TODO(), &iam.CreateRoleInput{ + createOutput, err := iamClient.CreateRole(ctx, &iam.CreateRoleInput{ AssumeRolePolicyDocument: aws.String(trustPolicy.String()), RoleName: aws.String(roleName), }) @@ -173,7 +173,7 @@ func (scenario GetStartedFunctionsScenario) GetOrCreateRole() *iamtypes.Role { log.Panicf("Couldn't create role %v. Here's why: %v\n", roleName, err) } role = createOutput.Role - _, err = iamClient.AttachRolePolicy(context.TODO(), &iam.AttachRolePolicyInput{ + _, err = iamClient.AttachRolePolicy(ctx, &iam.AttachRolePolicyInput{ PolicyArn: aws.String(policyArn), RoleName: aws.String(roleName), }) @@ -190,14 +190,14 @@ func (scenario GetStartedFunctionsScenario) GetOrCreateRole() *iamtypes.Role { // CreateFunction creates a Lambda function and uploads a handler written in Python. // The code for the Python handler is packaged as a []byte in .zip format. -func (scenario GetStartedFunctionsScenario) CreateFunction(role *iamtypes.Role) string { +func (scenario GetStartedFunctionsScenario) CreateFunction(ctx context.Context, role *iamtypes.Role) string { log.Println("Let's create a function that increments a number.\n" + "The function uses the 'lambda_handler_basic.py' script found in the \n" + "'handlers' directory of this project.") funcName := scenario.questioner.Ask("Enter a name for the Lambda function:", demotools.NotEmpty{}) zipPackage := scenario.helper.CreateDeploymentPackage("lambda_handler_basic.py", fmt.Sprintf("%v.py", funcName)) log.Printf("Creating function %v and waiting for it to be ready.", funcName) - funcState := scenario.functionWrapper.CreateFunction(funcName, fmt.Sprintf("%v.lambda_handler", funcName), + funcState := scenario.functionWrapper.CreateFunction(ctx, funcName, fmt.Sprintf("%v.lambda_handler", funcName), role.Arn, zipPackage) log.Printf("Your function is %v.", funcState) log.Println(strings.Repeat("-", 88)) @@ -208,12 +208,12 @@ func (scenario GetStartedFunctionsScenario) CreateFunction(role *iamtypes.Role) // parameters are contained in a Go struct that is used to serialize the parameters to // a JSON payload that is passed to the function. // The result payload is deserialized into a Go struct that contains an int value. -func (scenario GetStartedFunctionsScenario) InvokeIncrement(funcName string) { +func (scenario GetStartedFunctionsScenario) InvokeIncrement(ctx context.Context, funcName string) { parameters := actions.IncrementParameters{Action: "increment"} log.Println("Let's invoke our function. This function increments a number.") parameters.Number = scenario.questioner.AskInt("Enter a number to increment:", demotools.NotEmpty{}) log.Printf("Invoking %v with %v...\n", funcName, parameters.Number) - invokeOutput := scenario.functionWrapper.Invoke(funcName, parameters, false) + invokeOutput := scenario.functionWrapper.Invoke(ctx, funcName, parameters, false) var payload actions.LambdaResultInt err := json.Unmarshal(invokeOutput.Payload, &payload) if err != nil { @@ -229,7 +229,7 @@ func (scenario GetStartedFunctionsScenario) InvokeIncrement(funcName string) { // []byte in .zip format. // After the code is updated, the configuration is also updated with a new log // level that instructs the handler to log additional information. -func (scenario GetStartedFunctionsScenario) UpdateFunction(funcName string) { +func (scenario GetStartedFunctionsScenario) UpdateFunction(ctx context.Context, funcName string) { log.Println("Let's update the function to an arithmetic calculator.\n" + "The function uses the 'lambda_handler_calculator.py' script found in the \n" + "'handlers' directory of this project.") @@ -238,11 +238,11 @@ func (scenario GetStartedFunctionsScenario) UpdateFunction(funcName string) { zipPackage := scenario.helper.CreateDeploymentPackage("lambda_handler_calculator.py", fmt.Sprintf("%v.py", funcName)) log.Println("...and updating the Lambda function and waiting for it to be ready.") - funcState := scenario.functionWrapper.UpdateFunctionCode(funcName, zipPackage) + funcState := scenario.functionWrapper.UpdateFunctionCode(ctx, funcName, zipPackage) log.Printf("Updated function %v. Its current state is %v.", funcName, funcState) log.Println("This function uses an environment variable to control logging level.") log.Println("Let's set it to DEBUG to get the most logging.") - scenario.functionWrapper.UpdateFunctionConfiguration(funcName, + scenario.functionWrapper.UpdateFunctionConfiguration(ctx, funcName, map[string]string{"LOG_LEVEL": "DEBUG"}) log.Println(strings.Repeat("-", 88)) } @@ -252,7 +252,7 @@ func (scenario GetStartedFunctionsScenario) UpdateFunction(funcName string) { // to the function. // The result payload is deserialized to a Go struct that stores the result as either an // int or float32, depending on the kind of operation that was specified. -func (scenario GetStartedFunctionsScenario) InvokeCalculator(funcName string) { +func (scenario GetStartedFunctionsScenario) InvokeCalculator(ctx context.Context, funcName string) { wantInvoke := true choices := []string{"plus", "minus", "times", "divided-by"} for wantInvoke { @@ -265,7 +265,7 @@ func (scenario GetStartedFunctionsScenario) InvokeCalculator(funcName string) { X: x, Y: y, } - invokeOutput := scenario.functionWrapper.Invoke(funcName, calcParameters, true) + invokeOutput := scenario.functionWrapper.Invoke(ctx, funcName, calcParameters, true) var payload any if choice == 3 { // divide-by results in a float. payload = actions.LambdaResultFloat{} @@ -291,10 +291,10 @@ func (scenario GetStartedFunctionsScenario) InvokeCalculator(funcName string) { } // ListFunctions lists up to the specified number of functions for your account. -func (scenario GetStartedFunctionsScenario) ListFunctions() { +func (scenario GetStartedFunctionsScenario) ListFunctions(ctx context.Context) { count := scenario.questioner.AskInt( "Let's list functions for your account. How many do you want to see?", demotools.NotEmpty{}) - functions := scenario.functionWrapper.ListFunctions(count) + functions := scenario.functionWrapper.ListFunctions(ctx, count) log.Printf("Found %v functions:", len(functions)) for _, function := range functions { log.Printf("\t%v", *function.FunctionName) @@ -303,18 +303,18 @@ func (scenario GetStartedFunctionsScenario) ListFunctions() { } // Cleanup removes the IAM and Lambda resources created by the example. -func (scenario GetStartedFunctionsScenario) Cleanup(role *iamtypes.Role, funcName string) { +func (scenario GetStartedFunctionsScenario) Cleanup(ctx context.Context, role *iamtypes.Role, funcName string) { if scenario.questioner.AskBool("Do you want to clean up resources created for this example? (y/n)", "y") { iamClient := iam.NewFromConfig(scenario.sdkConfig) - policiesOutput, err := iamClient.ListAttachedRolePolicies(context.TODO(), + policiesOutput, err := iamClient.ListAttachedRolePolicies(ctx, &iam.ListAttachedRolePoliciesInput{RoleName: role.RoleName}) if err != nil { log.Panicf("Couldn't get policies attached to role %v. Here's why: %v\n", *role.RoleName, err) } for _, policy := range policiesOutput.AttachedPolicies { - _, err = iamClient.DetachRolePolicy(context.TODO(), &iam.DetachRolePolicyInput{ + _, err = iamClient.DetachRolePolicy(ctx, &iam.DetachRolePolicyInput{ PolicyArn: policy.PolicyArn, RoleName: role.RoleName, }) if err != nil { @@ -322,13 +322,13 @@ func (scenario GetStartedFunctionsScenario) Cleanup(role *iamtypes.Role, funcNam *policy.PolicyArn, *role.RoleName, err) } } - _, err = iamClient.DeleteRole(context.TODO(), &iam.DeleteRoleInput{RoleName: role.RoleName}) + _, err = iamClient.DeleteRole(ctx, &iam.DeleteRoleInput{RoleName: role.RoleName}) if err != nil { log.Panicf("Couldn't delete role %v. Here's why: %v\n", *role.RoleName, err) } log.Printf("Deleted role %v.\n", *role.RoleName) - scenario.functionWrapper.DeleteFunction(funcName) + scenario.functionWrapper.DeleteFunction(ctx, funcName) log.Printf("Deleted function %v.\n", funcName) } else { log.Println("Okay. Don't forget to delete the resources when you're done with them.") diff --git a/gov2/lambda/scenarios/scenario_get_started_functions_integ_test.go b/gov2/lambda/scenarios/scenario_get_started_functions_integ_test.go index faf3f81aa25..6be410a8aab 100644 --- a/gov2/lambda/scenarios/scenario_get_started_functions_integ_test.go +++ b/gov2/lambda/scenarios/scenario_get_started_functions_integ_test.go @@ -47,7 +47,8 @@ func TestRunGetStartedFunctionsScenario_Integration(t *testing.T) { }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -57,7 +58,7 @@ func TestRunGetStartedFunctionsScenario_Integration(t *testing.T) { log.SetOutput(&buf) scenario := NewGetStartedFunctionsScenario(sdkConfig, mockQuestioner, &helper) - scenario.Run() + scenario.Run(ctx) log.SetOutput(os.Stderr) if !strings.Contains(buf.String(), "Thanks for watching") { diff --git a/gov2/lambda/scenarios/scenario_get_started_functions_test.go b/gov2/lambda/scenarios/scenario_get_started_functions_test.go index 0f79852f275..3ecc4d91be0 100644 --- a/gov2/lambda/scenarios/scenario_get_started_functions_test.go +++ b/gov2/lambda/scenarios/scenario_get_started_functions_test.go @@ -5,6 +5,7 @@ package scenarios import ( "bytes" + "context" "encoding/json" "fmt" "testing" @@ -122,7 +123,7 @@ func (scenTest *GetStartedFunctionsScenarioTest) RunSubTest(stubber *testtools.A mockQuestioner := demotools.MockQuestioner{Answers: scenTest.Answers} scenario := NewGetStartedFunctionsScenario(*stubber.SdkConfig, &mockQuestioner, &scenTest.helper) scenario.isTestRun = true - scenario.Run() + scenario.Run(context.Background()) } func (scenTest *GetStartedFunctionsScenarioTest) Cleanup() {} diff --git a/gov2/rds/actions/instances.go b/gov2/rds/actions/instances.go index 539781d9b42..04aeab1041a 100644 --- a/gov2/rds/actions/instances.go +++ b/gov2/rds/actions/instances.go @@ -24,10 +24,10 @@ type DbInstances struct { // snippet-start:[gov2.rds.DescribeDBParameterGroups] // GetParameterGroup gets a DB parameter group by name. -func (instances *DbInstances) GetParameterGroup(parameterGroupName string) ( +func (instances *DbInstances) GetParameterGroup(ctx context.Context, parameterGroupName string) ( *types.DBParameterGroup, error) { output, err := instances.RdsClient.DescribeDBParameterGroups( - context.TODO(), &rds.DescribeDBParameterGroupsInput{ + ctx, &rds.DescribeDBParameterGroupsInput{ DBParameterGroupName: aws.String(parameterGroupName), }) if err != nil { @@ -51,10 +51,10 @@ func (instances *DbInstances) GetParameterGroup(parameterGroupName string) ( // CreateParameterGroup creates a DB parameter group that is based on the specified // parameter group family. func (instances *DbInstances) CreateParameterGroup( - parameterGroupName string, parameterGroupFamily string, description string) ( + ctx context.Context, parameterGroupName string, parameterGroupFamily string, description string) ( *types.DBParameterGroup, error) { - output, err := instances.RdsClient.CreateDBParameterGroup(context.TODO(), + output, err := instances.RdsClient.CreateDBParameterGroup(ctx, &rds.CreateDBParameterGroupInput{ DBParameterGroupName: aws.String(parameterGroupName), DBParameterGroupFamily: aws.String(parameterGroupFamily), @@ -73,8 +73,8 @@ func (instances *DbInstances) CreateParameterGroup( // snippet-start:[gov2.rds.DeleteDBParameterGroup] // DeleteParameterGroup deletes the named DB parameter group. -func (instances *DbInstances) DeleteParameterGroup(parameterGroupName string) error { - _, err := instances.RdsClient.DeleteDBParameterGroup(context.TODO(), +func (instances *DbInstances) DeleteParameterGroup(ctx context.Context, parameterGroupName string) error { + _, err := instances.RdsClient.DeleteDBParameterGroup(ctx, &rds.DeleteDBParameterGroupInput{ DBParameterGroupName: aws.String(parameterGroupName), }) @@ -91,7 +91,7 @@ func (instances *DbInstances) DeleteParameterGroup(parameterGroupName string) er // snippet-start:[gov2.rds.DescribeDBParameters] // GetParameters gets the parameters that are contained in a DB parameter group. -func (instances *DbInstances) GetParameters(parameterGroupName string, source string) ( +func (instances *DbInstances) GetParameters(ctx context.Context, parameterGroupName string, source string) ( []types.Parameter, error) { var output *rds.DescribeDBParametersOutput @@ -103,7 +103,7 @@ func (instances *DbInstances) GetParameters(parameterGroupName string, source st Source: aws.String(source), }) for parameterPaginator.HasMorePages() { - output, err = parameterPaginator.NextPage(context.TODO()) + output, err = parameterPaginator.NextPage(ctx) if err != nil { log.Printf("Couldn't get parameters for %v: %v\n", parameterGroupName, err) break @@ -119,8 +119,8 @@ func (instances *DbInstances) GetParameters(parameterGroupName string, source st // snippet-start:[gov2.rds.ModifyDBParameterGroup] // UpdateParameters updates parameters in a named DB parameter group. -func (instances *DbInstances) UpdateParameters(parameterGroupName string, params []types.Parameter) error { - _, err := instances.RdsClient.ModifyDBParameterGroup(context.TODO(), +func (instances *DbInstances) UpdateParameters(ctx context.Context, parameterGroupName string, params []types.Parameter) error { + _, err := instances.RdsClient.ModifyDBParameterGroup(ctx, &rds.ModifyDBParameterGroupInput{ DBParameterGroupName: aws.String(parameterGroupName), Parameters: params, @@ -138,9 +138,9 @@ func (instances *DbInstances) UpdateParameters(parameterGroupName string, params // snippet-start:[gov2.rds.CreateDBSnapshot] // CreateSnapshot creates a snapshot of a DB instance. -func (instances *DbInstances) CreateSnapshot(instanceName string, snapshotName string) ( +func (instances *DbInstances) CreateSnapshot(ctx context.Context, instanceName string, snapshotName string) ( *types.DBSnapshot, error) { - output, err := instances.RdsClient.CreateDBSnapshot(context.TODO(), &rds.CreateDBSnapshotInput{ + output, err := instances.RdsClient.CreateDBSnapshot(ctx, &rds.CreateDBSnapshotInput{ DBInstanceIdentifier: aws.String(instanceName), DBSnapshotIdentifier: aws.String(snapshotName), }) @@ -157,8 +157,8 @@ func (instances *DbInstances) CreateSnapshot(instanceName string, snapshotName s // snippet-start:[gov2.rds.DescribeDBSnapshots] // GetSnapshot gets a DB instance snapshot. -func (instances *DbInstances) GetSnapshot(snapshotName string) (*types.DBSnapshot, error) { - output, err := instances.RdsClient.DescribeDBSnapshots(context.TODO(), +func (instances *DbInstances) GetSnapshot(ctx context.Context, snapshotName string) (*types.DBSnapshot, error) { + output, err := instances.RdsClient.DescribeDBSnapshots(ctx, &rds.DescribeDBSnapshotsInput{ DBSnapshotIdentifier: aws.String(snapshotName), }) @@ -175,11 +175,11 @@ func (instances *DbInstances) GetSnapshot(snapshotName string) (*types.DBSnapsho // snippet-start:[gov2.rds.CreateDBInstance] // CreateInstance creates a DB instance. -func (instances *DbInstances) CreateInstance(instanceName string, dbName string, +func (instances *DbInstances) CreateInstance(ctx context.Context, instanceName string, dbName string, dbEngine string, dbEngineVersion string, parameterGroupName string, dbInstanceClass string, storageType string, allocatedStorage int32, adminName string, adminPassword string) ( *types.DBInstance, error) { - output, err := instances.RdsClient.CreateDBInstance(context.TODO(), &rds.CreateDBInstanceInput{ + output, err := instances.RdsClient.CreateDBInstance(ctx, &rds.CreateDBInstanceInput{ DBInstanceIdentifier: aws.String(instanceName), DBName: aws.String(dbName), DBParameterGroupName: aws.String(parameterGroupName), @@ -204,9 +204,9 @@ func (instances *DbInstances) CreateInstance(instanceName string, dbName string, // snippet-start:[gov2.rds.DescribeDBInstances] // GetInstance gets data about a DB instance. -func (instances *DbInstances) GetInstance(instanceName string) ( +func (instances *DbInstances) GetInstance(ctx context.Context, instanceName string) ( *types.DBInstance, error) { - output, err := instances.RdsClient.DescribeDBInstances(context.TODO(), + output, err := instances.RdsClient.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{ DBInstanceIdentifier: aws.String(instanceName), }) @@ -229,8 +229,8 @@ func (instances *DbInstances) GetInstance(instanceName string) ( // snippet-start:[gov2.rds.DeleteDBInstance] // DeleteInstance deletes a DB instance. -func (instances *DbInstances) DeleteInstance(instanceName string) error { - _, err := instances.RdsClient.DeleteDBInstance(context.TODO(), &rds.DeleteDBInstanceInput{ +func (instances *DbInstances) DeleteInstance(ctx context.Context, instanceName string) error { + _, err := instances.RdsClient.DeleteDBInstance(ctx, &rds.DeleteDBInstanceInput{ DBInstanceIdentifier: aws.String(instanceName), SkipFinalSnapshot: aws.Bool(true), DeleteAutomatedBackups: aws.Bool(true), @@ -249,9 +249,9 @@ func (instances *DbInstances) DeleteInstance(instanceName string) error { // GetEngineVersions gets database engine versions that are available for the specified engine // and parameter group family. -func (instances *DbInstances) GetEngineVersions(engine string, parameterGroupFamily string) ( +func (instances *DbInstances) GetEngineVersions(ctx context.Context, engine string, parameterGroupFamily string) ( []types.DBEngineVersion, error) { - output, err := instances.RdsClient.DescribeDBEngineVersions(context.TODO(), + output, err := instances.RdsClient.DescribeDBEngineVersions(ctx, &rds.DescribeDBEngineVersionsInput{ Engine: aws.String(engine), DBParameterGroupFamily: aws.String(parameterGroupFamily), @@ -270,7 +270,7 @@ func (instances *DbInstances) GetEngineVersions(engine string, parameterGroupFam // GetOrderableInstances uses a paginator to get DB instance options that can be used to create DB instances that are // compatible with a set of specifications. -func (instances *DbInstances) GetOrderableInstances(engine string, engineVersion string) ( +func (instances *DbInstances) GetOrderableInstances(ctx context.Context, engine string, engineVersion string) ( []types.OrderableDBInstanceOption, error) { var output *rds.DescribeOrderableDBInstanceOptionsOutput @@ -282,7 +282,7 @@ func (instances *DbInstances) GetOrderableInstances(engine string, engineVersion EngineVersion: aws.String(engineVersion), }) for orderablePaginator.HasMorePages() { - output, err = orderablePaginator.NextPage(context.TODO()) + output, err = orderablePaginator.NextPage(ctx) if err != nil { log.Printf("Couldn't get orderable DB instance options: %v\n", err) break diff --git a/gov2/rds/cmd/main.go b/gov2/rds/cmd/main.go index 68af89cfa31..718c9157103 100644 --- a/gov2/rds/cmd/main.go +++ b/gov2/rds/cmd/main.go @@ -23,7 +23,7 @@ import ( // - `instances` - Runs the interactive DB instances scenario that shows you how to use // Amazon Relational Database Service (Amazon RDS) commands to work with DB instances and databases. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config){ + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config){ "instances": runInstanceScenario, } choices := make([]string, len(scenarioMap)) @@ -41,17 +41,18 @@ func main() { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } log.SetFlags(0) - runScenario(sdkConfig) + runScenario(ctx, sdkConfig) } } -func runInstanceScenario(sdkConfig aws.Config) { +func runInstanceScenario(ctx context.Context, sdkConfig aws.Config) { scenario := scenarios.NewGetStartedInstances(sdkConfig, demotools.NewQuestioner(), scenarios.ScenarioHelper{}) - scenario.Run("mysql", "doc-example-parameter-group", "doc-example-instance", "docexampledb") + scenario.Run(ctx, "mysql", "doc-example-parameter-group", "doc-example-instance", "docexampledb") } diff --git a/gov2/rds/hello/hello.go b/gov2/rds/hello/hello.go index 0e2bacdc8d9..5bacce68b56 100644 --- a/gov2/rds/hello/hello.go +++ b/gov2/rds/hello/hello.go @@ -19,7 +19,8 @@ import ( // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -28,7 +29,7 @@ func main() { rdsClient := rds.NewFromConfig(sdkConfig) const maxInstances = 20 fmt.Printf("Let's list up to %v DB instances.\n", maxInstances) - output, err := rdsClient.DescribeDBInstances(context.TODO(), + output, err := rdsClient.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{MaxRecords: aws.Int32(maxInstances)}) if err != nil { fmt.Printf("Couldn't list DB instances: %v\n", err) diff --git a/gov2/rds/scenarios/get_started_instances.go b/gov2/rds/scenarios/get_started_instances.go index 9707f8d7a8b..062c43b0bef 100644 --- a/gov2/rds/scenarios/get_started_instances.go +++ b/gov2/rds/scenarios/get_started_instances.go @@ -4,6 +4,7 @@ package scenarios import ( + "context" "fmt" "log" "sort" @@ -70,7 +71,7 @@ func NewGetStartedInstances(sdkConfig aws.Config, questioner demotools.IQuestion } // Run runs the interactive scenario. -func (scenario GetStartedInstances) Run(dbEngine string, parameterGroupName string, +func (scenario GetStartedInstances) Run(ctx context.Context, dbEngine string, parameterGroupName string, instanceName string, dbName string) { defer func() { if r := recover(); r != nil { @@ -82,12 +83,12 @@ func (scenario GetStartedInstances) Run(dbEngine string, parameterGroupName stri log.Println("Welcome to the Amazon Relational Database Service (Amazon RDS) DB Instance demo.") log.Println(strings.Repeat("-", 88)) - parameterGroup := scenario.CreateParameterGroup(dbEngine, parameterGroupName) - scenario.SetUserParameters(parameterGroupName) - instance := scenario.CreateInstance(instanceName, dbEngine, dbName, parameterGroup) + parameterGroup := scenario.CreateParameterGroup(ctx, dbEngine, parameterGroupName) + scenario.SetUserParameters(ctx, parameterGroupName) + instance := scenario.CreateInstance(ctx, instanceName, dbEngine, dbName, parameterGroup) scenario.DisplayConnection(instance) - scenario.CreateSnapshot(instance) - scenario.Cleanup(instance, parameterGroup) + scenario.CreateSnapshot(ctx, instance) + scenario.Cleanup(ctx, instance, parameterGroup) log.Println(strings.Repeat("-", 88)) log.Println("Thanks for watching!") @@ -97,18 +98,18 @@ func (scenario GetStartedInstances) Run(dbEngine string, parameterGroupName stri // CreateParameterGroup shows how to get available engine versions for a specified // database engine and create a DB parameter group that is compatible with a // selected engine family. -func (scenario GetStartedInstances) CreateParameterGroup(dbEngine string, +func (scenario GetStartedInstances) CreateParameterGroup(ctx context.Context, dbEngine string, parameterGroupName string) *types.DBParameterGroup { log.Printf("Checking for an existing DB parameter group named %v.\n", parameterGroupName) - parameterGroup, err := scenario.instances.GetParameterGroup(parameterGroupName) + parameterGroup, err := scenario.instances.GetParameterGroup(ctx, parameterGroupName) if err != nil { panic(err) } if parameterGroup == nil { log.Printf("Getting available database engine versions for %v.\n", dbEngine) - engineVersions, err := scenario.instances.GetEngineVersions(dbEngine, "") + engineVersions, err := scenario.instances.GetEngineVersions(ctx, dbEngine, "") if err != nil { panic(err) } @@ -125,11 +126,11 @@ func (scenario GetStartedInstances) CreateParameterGroup(dbEngine string, familyIndex := scenario.questioner.AskChoice("Which family do you want to use?\n", families) log.Println("Creating a DB parameter group.") _, err = scenario.instances.CreateParameterGroup( - parameterGroupName, families[familyIndex], "Example parameter group.") + ctx, parameterGroupName, families[familyIndex], "Example parameter group.") if err != nil { panic(err) } - parameterGroup, err = scenario.instances.GetParameterGroup(parameterGroupName) + parameterGroup, err = scenario.instances.GetParameterGroup(ctx, parameterGroupName) if err != nil { panic(err) } @@ -145,9 +146,9 @@ func (scenario GetStartedInstances) CreateParameterGroup(dbEngine string, // SetUserParameters shows how to get the parameters contained in a custom parameter // group and update some of the parameter values in the group. -func (scenario GetStartedInstances) SetUserParameters(parameterGroupName string) { +func (scenario GetStartedInstances) SetUserParameters(ctx context.Context, parameterGroupName string) { log.Println("Let's set some parameter values in your parameter group.") - dbParameters, err := scenario.instances.GetParameters(parameterGroupName, "") + dbParameters, err := scenario.instances.GetParameters(ctx, parameterGroupName, "") if err != nil { panic(err) } @@ -167,12 +168,12 @@ func (scenario GetStartedInstances) SetUserParameters(parameterGroupName string) updateParams = append(updateParams, dbParam) } } - err = scenario.instances.UpdateParameters(parameterGroupName, updateParams) + err = scenario.instances.UpdateParameters(ctx, parameterGroupName, updateParams) if err != nil { panic(err) } log.Println("To get a list of parameters that you set previously, specify a source of 'user'.") - userParameters, err := scenario.instances.GetParameters(parameterGroupName, "user") + userParameters, err := scenario.instances.GetParameters(ctx, parameterGroupName, "user") if err != nil { panic(err) } @@ -185,11 +186,11 @@ func (scenario GetStartedInstances) SetUserParameters(parameterGroupName string) // CreateInstance shows how to create a DB instance that contains a database of a // specified type. The database is also configured to use a custom DB parameter group. -func (scenario GetStartedInstances) CreateInstance(instanceName string, dbEngine string, +func (scenario GetStartedInstances) CreateInstance(ctx context.Context, instanceName string, dbEngine string, dbName string, parameterGroup *types.DBParameterGroup) *types.DBInstance { log.Println("Checking for an existing DB instance.") - instance, err := scenario.instances.GetInstance(instanceName) + instance, err := scenario.instances.GetInstance(ctx, instanceName) if err != nil { panic(err) } @@ -198,7 +199,7 @@ func (scenario GetStartedInstances) CreateInstance(instanceName string, dbEngine "Enter an administrator username for the database: ", demotools.NotEmpty{}) adminPassword := scenario.questioner.AskPassword( "Enter a password for the administrator (at least 8 characters): ", 7) - engineVersions, err := scenario.instances.GetEngineVersions(dbEngine, + engineVersions, err := scenario.instances.GetEngineVersions(ctx, dbEngine, *parameterGroup.DBParameterGroupFamily) if err != nil { panic(err) @@ -210,7 +211,7 @@ func (scenario GetStartedInstances) CreateInstance(instanceName string, dbEngine engineIndex := scenario.questioner.AskChoice( "The available engines for your parameter group are:\n", engineChoices) engineSelection := engineVersions[engineIndex] - instOpts, err := scenario.instances.GetOrderableInstances(*engineSelection.Engine, + instOpts, err := scenario.instances.GetOrderableInstances(ctx, *engineSelection.Engine, *engineSelection.EngineVersion) if err != nil { panic(err) @@ -239,7 +240,7 @@ func (scenario GetStartedInstances) CreateInstance(instanceName string, dbEngine instanceName, dbName, *parameterGroup.DBParameterGroupName, *engineSelection.EngineVersion, optChoices[optIndex], allocatedStorage, storageType) instance, err = scenario.instances.CreateInstance( - instanceName, dbName, *engineSelection.Engine, *engineSelection.EngineVersion, + ctx, instanceName, dbName, *engineSelection.Engine, *engineSelection.EngineVersion, *parameterGroup.DBParameterGroupName, optChoices[optIndex], storageType, allocatedStorage, adminUsername, adminPassword) if err != nil { @@ -247,7 +248,7 @@ func (scenario GetStartedInstances) CreateInstance(instanceName string, dbEngine } for *instance.DBInstanceStatus != "available" { scenario.helper.Pause(30) - instance, err = scenario.instances.GetInstance(instanceName) + instance, err = scenario.instances.GetInstance(ctx, instanceName) if err != nil { panic(err) } @@ -281,18 +282,18 @@ func (scenario GetStartedInstances) DisplayConnection(instance *types.DBInstance } // CreateSnapshot shows how to create a DB instance snapshot and wait until it's available. -func (scenario GetStartedInstances) CreateSnapshot(instance *types.DBInstance) { +func (scenario GetStartedInstances) CreateSnapshot(ctx context.Context, instance *types.DBInstance) { if scenario.questioner.AskBool( "Do you want to create a snapshot of your DB instance (y/n)? ", "y") { snapshotId := fmt.Sprintf("%v-%v", *instance.DBInstanceIdentifier, scenario.helper.UniqueId()) log.Printf("Creating a snapshot named %v. This typically takes a few minutes.\n", snapshotId) - snapshot, err := scenario.instances.CreateSnapshot(*instance.DBInstanceIdentifier, snapshotId) + snapshot, err := scenario.instances.CreateSnapshot(ctx, *instance.DBInstanceIdentifier, snapshotId) if err != nil { panic(err) } for *snapshot.Status != "available" { scenario.helper.Pause(30) - snapshot, err = scenario.instances.GetSnapshot(snapshotId) + snapshot, err = scenario.instances.GetSnapshot(ctx, snapshotId) if err != nil { panic(err) } @@ -312,12 +313,12 @@ func (scenario GetStartedInstances) CreateSnapshot(instance *types.DBInstance) { // Cleanup shows how to clean up a DB instance and DB parameter group. // Before the DB parameter group can be deleted, all associated DB instances must first be deleted. func (scenario GetStartedInstances) Cleanup( - instance *types.DBInstance, parameterGroup *types.DBParameterGroup) { + ctx context.Context, instance *types.DBInstance, parameterGroup *types.DBParameterGroup) { if scenario.questioner.AskBool( "\nDo you want to delete the database instance and parameter group (y/n)? ", "y") { log.Printf("Deleting database instance %v.\n", *instance.DBInstanceIdentifier) - err := scenario.instances.DeleteInstance(*instance.DBInstanceIdentifier) + err := scenario.instances.DeleteInstance(ctx, *instance.DBInstanceIdentifier) if err != nil { panic(err) } @@ -325,13 +326,13 @@ func (scenario GetStartedInstances) Cleanup( "Waiting for the DB instance to delete. This typically takes several minutes.") for instance != nil { scenario.helper.Pause(30) - instance, err = scenario.instances.GetInstance(*instance.DBInstanceIdentifier) + instance, err = scenario.instances.GetInstance(ctx, *instance.DBInstanceIdentifier) if err != nil { panic(err) } } log.Printf("Deleting parameter group %v.", *parameterGroup.DBParameterGroupName) - err = scenario.instances.DeleteParameterGroup(*parameterGroup.DBParameterGroupName) + err = scenario.instances.DeleteParameterGroup(ctx, *parameterGroup.DBParameterGroupName) if err != nil { panic(err) } diff --git a/gov2/rds/scenarios/get_started_instances_integ_test.go b/gov2/rds/scenarios/get_started_instances_integ_test.go index c44d30f9cc7..c6f85c6300e 100644 --- a/gov2/rds/scenarios/get_started_instances_integ_test.go +++ b/gov2/rds/scenarios/get_started_instances_integ_test.go @@ -51,7 +51,8 @@ func TestRunGetStartedClustersScenario_Integration(t *testing.T) { }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -63,7 +64,7 @@ func TestRunGetStartedClustersScenario_Integration(t *testing.T) { scenario := NewGetStartedInstances(sdkConfig, mockQuestioner, &helper) testId := time.Now().Unix() scenario.Run( - "mysql", + ctx, "mysql", fmt.Sprintf("doc-example-parameter-group-%v", testId), fmt.Sprintf("doc-example-instance-%v", testId), "docexampledbinteg") diff --git a/gov2/rds/scenarios/get_started_instances_test.go b/gov2/rds/scenarios/get_started_instances_test.go index 91440cf6da3..13a4e8b2a64 100644 --- a/gov2/rds/scenarios/get_started_instances_test.go +++ b/gov2/rds/scenarios/get_started_instances_test.go @@ -3,6 +3,7 @@ package scenarios import ( + "context" "fmt" "strconv" "testing" @@ -127,7 +128,7 @@ func (scenTest *GetStartedInstancesTest) RunSubTest(stubber *testtools.AwsmStubb mockQuestioner := demotools.MockQuestioner{Answers: scenTest.Answers} scenario := NewGetStartedInstances(*stubber.SdkConfig, &mockQuestioner, &scenTest.helper) scenario.isTestRun = true - scenario.Run(scenTest.dbEngine, scenTest.parameterGroupName, scenTest.instanceName, scenTest.dbName) + scenario.Run(context.Background(), scenTest.dbEngine, scenTest.parameterGroupName, scenTest.instanceName, scenTest.dbName) } func (scenTest *GetStartedInstancesTest) Cleanup() {} diff --git a/gov2/redshift/README.md b/gov2/redshift/README.md index 554c5a724fc..28b7417b7ee 100644 --- a/gov2/redshift/README.md +++ b/gov2/redshift/README.md @@ -45,10 +45,10 @@ Code examples that show you how to perform the essential operations within a ser Code excerpts that show you how to call individual service functions. -- [CreateCluster](actions/redshift_actions.go#L29) -- [DeleteCluster](actions/redshift_actions.go#L84) -- [DescribeClusters](actions/redshift_actions.go#L108) -- [ModifyCluster](actions/redshift_actions.go#L58) +- [CreateCluster](actions/redshift_actions.go#L30) +- [DeleteCluster](actions/redshift_actions.go#L85) +- [DescribeClusters](actions/redshift_actions.go#L115) +- [ModifyCluster](actions/redshift_actions.go#L59) diff --git a/gov2/redshift/actions/redshift_actions.go b/gov2/redshift/actions/redshift_actions.go index cdc0239caa5..0a9b3ca390a 100644 --- a/gov2/redshift/actions/redshift_actions.go +++ b/gov2/redshift/actions/redshift_actions.go @@ -8,11 +8,12 @@ package actions import ( "context" "errors" + "log" + "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/aws/aws-sdk-go-v2/service/redshift/types" - "log" - "time" ) // snippet-end:[gov2.redshift.Imports] @@ -85,12 +86,11 @@ func (actor RedshiftActions) ModifyCluster(ctx context.Context, clusterId string // DeleteCluster deletes the given cluster. func (actor RedshiftActions) DeleteCluster(ctx context.Context, clusterId string) (bool, error) { - // Delete the specified Redshift cluster - - waiter := redshift.NewClusterDeletedWaiter(actor.RedshiftClient) - err := waiter.Wait(ctx, &redshift.DescribeClustersInput{ - ClusterIdentifier: aws.String(clusterId), - }, 5*time.Minute) + input := redshift.DeleteClusterInput{ + ClusterIdentifier: aws.String(clusterId), + SkipFinalClusterSnapshot: aws.Bool(true), + } + _, err := actor.RedshiftClient.DeleteCluster(ctx, &input) var opErr *types.ClusterNotFoundFault if err != nil && errors.As(err, &opErr) { log.Println("Cluster was not found. Where could it be?") @@ -99,7 +99,14 @@ func (actor RedshiftActions) DeleteCluster(ctx context.Context, clusterId string log.Printf("Failed to delete Redshift cluster: %v\n", err) return false, err } - log.Printf("The %s was deleted\n", clusterId) + waiter := redshift.NewClusterDeletedWaiter(actor.RedshiftClient) + err = waiter.Wait(ctx, &redshift.DescribeClustersInput{ + ClusterIdentifier: aws.String(clusterId), + }, 5*time.Minute) + if err != nil { + log.Printf("Wait time exceeded for deleting cluster, continuing: %v\n", err) + } + log.Printf("The cluster %s was deleted\n", clusterId) return true, nil } diff --git a/gov2/redshift/cmd/main.go b/gov2/redshift/cmd/main.go index bdf6faedf02..90ae69f4d6e 100644 --- a/gov2/redshift/cmd/main.go +++ b/gov2/redshift/cmd/main.go @@ -7,12 +7,13 @@ import ( "context" "flag" "fmt" - "github.com/awsdocs/aws-doc-sdk-examples/gov2/demotools" - "github.com/awsdocs/aws-doc-sdk-examples/gov2/redshift/scenarios" "log" "math/rand" "time" + "github.com/awsdocs/aws-doc-sdk-examples/gov2/demotools" + "github.com/awsdocs/aws-doc-sdk-examples/gov2/redshift/scenarios" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" ) @@ -24,7 +25,7 @@ import ( // // - `basics` - Runs the interactive Basics scenario to show core Redshift actions. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config, helper scenarios.IScenarioHelper){ + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config, helper scenarios.IScenarioHelper){ "basics": runRedshiftBasicsScenario, } choices := make([]string, len(scenarioMap)) @@ -42,7 +43,8 @@ func main() { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -52,12 +54,12 @@ func main() { Prefix: "redshift_basics", Random: rand.New(rand.NewSource(time.Now().UnixNano())), } - runScenario(sdkConfig, helper) + runScenario(ctx, sdkConfig, helper) } } -func runRedshiftBasicsScenario(sdkConfig aws.Config, helper scenarios.IScenarioHelper) { +func runRedshiftBasicsScenario(ctx context.Context, sdkConfig aws.Config, helper scenarios.IScenarioHelper) { pauser := demotools.Pauser{} scenario := scenarios.RedshiftBasics(sdkConfig, demotools.NewQuestioner(), pauser, demotools.NewStandardFileSystem(), helper) - scenario.Run() + scenario.Run(ctx) } diff --git a/gov2/redshift/hello/hello.go b/gov2/redshift/hello/hello.go index b690c0c6e95..adaf70c0722 100644 --- a/gov2/redshift/hello/hello.go +++ b/gov2/redshift/hello/hello.go @@ -8,6 +8,7 @@ package main import ( "context" "fmt" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/redshift" @@ -18,7 +19,8 @@ import ( // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -27,7 +29,7 @@ func main() { redshiftClient := redshift.NewFromConfig(sdkConfig) count := 20 fmt.Printf("Let's list up to %v clusters for your account.\n", count) - result, err := redshiftClient.DescribeClusters(context.TODO(), &redshift.DescribeClustersInput{ + result, err := redshiftClient.DescribeClusters(ctx, &redshift.DescribeClustersInput{ MaxRecords: aws.Int32(int32(count)), }) if err != nil { diff --git a/gov2/redshift/scenarios/redshift_basics.go b/gov2/redshift/scenarios/redshift_basics.go index 0bba40177dc..b1fb44f2ad5 100644 --- a/gov2/redshift/scenarios/redshift_basics.go +++ b/gov2/redshift/scenarios/redshift_basics.go @@ -10,16 +10,17 @@ import ( "encoding/json" "errors" "fmt" + "log" + "math/rand" + "strings" + "time" + "github.com/aws/aws-sdk-go-v2/aws" redshift_types "github.com/aws/aws-sdk-go-v2/service/redshift/types" redshiftdata_types "github.com/aws/aws-sdk-go-v2/service/redshiftdata/types" "github.com/aws/aws-sdk-go-v2/service/secretsmanager" "github.com/awsdocs/aws-doc-sdk-examples/gov2/demotools" "github.com/awsdocs/aws-doc-sdk-examples/gov2/redshift/actions" - "log" - "math/rand" - "strings" - "time" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/aws/aws-sdk-go-v2/service/redshiftdata" @@ -111,7 +112,7 @@ type User struct { // // It uses a questioner from the `demotools` package to get input during the example. // This package can be found in the ..\..\demotools folder of this repo. -func (runner *RedshiftBasicsScenario) Run() { +func (runner *RedshiftBasicsScenario) Run(ctx context.Context) { user := User{} secretId := "s3express/basics/secrets" @@ -122,7 +123,6 @@ func (runner *RedshiftBasicsScenario) Run() { fileName := "Movies.json" nodeType := "ra3.xlplus" clusterType := "single-node" - ctx := context.TODO() defer func() { if r := recover(); r != nil { diff --git a/gov2/redshift/scenarios/redshift_basics_integration_test.go b/gov2/redshift/scenarios/redshift_basics_integration_test.go index e8debb238b0..9f2a6ddadca 100644 --- a/gov2/redshift/scenarios/redshift_basics_integration_test.go +++ b/gov2/redshift/scenarios/redshift_basics_integration_test.go @@ -12,13 +12,14 @@ package scenarios import ( "bytes" "context" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/awsdocs/aws-doc-sdk-examples/gov2/demotools" "log" "math/rand" "os" "strings" "testing" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/awsdocs/aws-doc-sdk-examples/gov2/demotools" ) // MockPauser holds the pausable object. @@ -35,7 +36,8 @@ func TestBasicsScenario_Integration(t *testing.T) { }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -51,7 +53,7 @@ func TestBasicsScenario_Integration(t *testing.T) { Random: rand.New(rand.NewSource(0)), } scenario := RedshiftBasics(sdkConfig, mockQuestioner, demotools.Pauser{}, demotools.NewMockFileSystem(file), helper) - scenario.Run() + scenario.Run(ctx) _ = os.Remove(outFile) diff --git a/gov2/redshift/scenarios/redshift_basics_test.go b/gov2/redshift/scenarios/redshift_basics_test.go index 28a3281c852..58917477d35 100644 --- a/gov2/redshift/scenarios/redshift_basics_test.go +++ b/gov2/redshift/scenarios/redshift_basics_test.go @@ -6,14 +6,17 @@ package scenarios import ( + "context" "encoding/json" "fmt" - "github.com/awsdocs/aws-doc-sdk-examples/gov2/redshift/stubs" "io" "math/rand" "testing" "time" + "github.com/aws/aws-sdk-go-v2/service/redshift/types" + "github.com/awsdocs/aws-doc-sdk-examples/gov2/redshift/stubs" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/awsdocs/aws-doc-sdk-examples/gov2/demotools" "github.com/awsdocs/aws-doc-sdk-examples/gov2/testtools" @@ -110,18 +113,23 @@ func (scenarioTest *BasicsScenarioTest) SetupDataAndStubs() []testtools.Stub { stubList = append(stubList, stubs.StubDescribeStatement(testId, nil)) stubList = append(stubList, stubs.StubGetStatementResult(nil)) stubList = append(stubList, stubs.StubModifyCluster(nil)) - stubList = append(stubList, stubs.StubDeleteCluster(clusterId)) + stubList = append(stubList, stubs.StubDeleteCluster(clusterId, nil)) + stubList = append(stubList, stubs.StubDescribeClusters(clusterId, &testtools.StubError{Err: &types.ClusterNotFoundFault{}})) return stubList } +type TestPauser struct{} + +func (tp TestPauser) Pause(secs int) {} + // RunSubTest performs a single test run with a set of stubs set up to run with // or without errors. func (scenarioTest *BasicsScenarioTest) RunSubTest(stubber *testtools.AwsmStubber) { mockQuestioner := demotools.MockQuestioner{Answers: scenarioTest.Answers} - scenario := RedshiftBasics(*stubber.SdkConfig, &mockQuestioner, demotools.Pauser{}, demotools.NewMockFileSystem(scenarioTest.File), scenarioTest.Helper) + scenario := RedshiftBasics(*stubber.SdkConfig, &mockQuestioner, TestPauser{}, demotools.NewMockFileSystem(scenarioTest.File), scenarioTest.Helper) - scenario.Run() + scenario.Run(context.Background()) } func (scenarioTest *BasicsScenarioTest) Cleanup() { diff --git a/gov2/redshift/stubs/redshift_stubs.go b/gov2/redshift/stubs/redshift_stubs.go index ccade45569d..afd17159dfd 100644 --- a/gov2/redshift/stubs/redshift_stubs.go +++ b/gov2/redshift/stubs/redshift_stubs.go @@ -4,12 +4,12 @@ package stubs import ( + "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/aws/aws-sdk-go-v2/service/redshift/types" - "github.com/aws/smithy-go" "github.com/awsdocs/aws-doc-sdk-examples/gov2/testtools" - "time" ) func StubDescribeClusters(clusterId string, raiseErr *testtools.StubError) testtools.Stub { @@ -27,9 +27,7 @@ func StubDescribeClusters(clusterId string, raiseErr *testtools.StubError) testt Output: &redshift.DescribeClustersOutput{ Clusters: clusters, }, - SkipErrorTest: false, - IgnoreFields: nil, - Error: raiseErr, + Error: raiseErr, } } @@ -71,18 +69,11 @@ func StubModifyCluster(raiseErr *testtools.StubError) testtools.Stub { } } -func StubDeleteCluster(clusterId string) testtools.Stub { +func StubDeleteCluster(clusterId string, raiseErr *testtools.StubError) testtools.Stub { return testtools.Stub{ - OperationName: "DescribeClusters", // Because a waiter is used, this is the actual called mocked. - Input: &redshift.DescribeClustersInput{ - ClusterIdentifier: &clusterId, - }, - SkipErrorTest: true, - Error: &testtools.StubError{ - Err: &smithy.GenericAPIError{ - Code: "ClusterNotFound", - Message: "ClusterNotFound", - }, - }, + OperationName: "DeleteCluster", + Input: &redshift.DeleteClusterInput{ClusterIdentifier: aws.String(clusterId), SkipFinalClusterSnapshot: aws.Bool(true)}, + Output: &redshift.DeleteClusterOutput{}, + Error: raiseErr, } } diff --git a/gov2/s3/actions/bucket_basics.go b/gov2/s3/actions/bucket_basics.go index b54ab9aa62b..c65f7d08316 100644 --- a/gov2/s3/actions/bucket_basics.go +++ b/gov2/s3/actions/bucket_basics.go @@ -35,8 +35,8 @@ type BucketBasics struct { // snippet-start:[gov2.s3.ListBuckets] // ListBuckets lists the buckets in the current account. -func (basics BucketBasics) ListBuckets() ([]types.Bucket, error) { - result, err := basics.S3Client.ListBuckets(context.TODO(), &s3.ListBucketsInput{}) +func (basics BucketBasics) ListBuckets(ctx context.Context) ([]types.Bucket, error) { + result, err := basics.S3Client.ListBuckets(ctx, &s3.ListBucketsInput{}) var buckets []types.Bucket if err != nil { log.Printf("Couldn't list buckets for your account. Here's why: %v\n", err) @@ -51,8 +51,8 @@ func (basics BucketBasics) ListBuckets() ([]types.Bucket, error) { // snippet-start:[gov2.s3.HeadBucket] // BucketExists checks whether a bucket exists in the current account. -func (basics BucketBasics) BucketExists(bucketName string) (bool, error) { - _, err := basics.S3Client.HeadBucket(context.TODO(), &s3.HeadBucketInput{ +func (basics BucketBasics) BucketExists(ctx context.Context, bucketName string) (bool, error) { + _, err := basics.S3Client.HeadBucket(ctx, &s3.HeadBucketInput{ Bucket: aws.String(bucketName), }) exists := true @@ -81,8 +81,8 @@ func (basics BucketBasics) BucketExists(bucketName string) (bool, error) { // snippet-start:[gov2.s3.CreateBucket] // CreateBucket creates a bucket with the specified name in the specified Region. -func (basics BucketBasics) CreateBucket(name string, region string) error { - _, err := basics.S3Client.CreateBucket(context.TODO(), &s3.CreateBucketInput{ +func (basics BucketBasics) CreateBucket(ctx context.Context, name string, region string) error { + _, err := basics.S3Client.CreateBucket(ctx, &s3.CreateBucketInput{ Bucket: aws.String(name), CreateBucketConfiguration: &types.CreateBucketConfiguration{ LocationConstraint: types.BucketLocationConstraint(region), @@ -100,13 +100,13 @@ func (basics BucketBasics) CreateBucket(name string, region string) error { // snippet-start:[gov2.s3.PutObject] // UploadFile reads from a file and puts the data into an object in a bucket. -func (basics BucketBasics) UploadFile(bucketName string, objectKey string, fileName string) error { +func (basics BucketBasics) UploadFile(ctx context.Context, bucketName string, objectKey string, fileName string) error { file, err := os.Open(fileName) if err != nil { log.Printf("Couldn't open file %v to upload. Here's why: %v\n", fileName, err) } else { defer file.Close() - _, err = basics.S3Client.PutObject(context.TODO(), &s3.PutObjectInput{ + _, err = basics.S3Client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), Body: file, @@ -125,13 +125,13 @@ func (basics BucketBasics) UploadFile(bucketName string, objectKey string, fileN // UploadLargeObject uses an upload manager to upload data to an object in a bucket. // The upload manager breaks large data into parts and uploads the parts concurrently. -func (basics BucketBasics) UploadLargeObject(bucketName string, objectKey string, largeObject []byte) error { +func (basics BucketBasics) UploadLargeObject(ctx context.Context, bucketName string, objectKey string, largeObject []byte) error { largeBuffer := bytes.NewReader(largeObject) var partMiBs int64 = 10 uploader := manager.NewUploader(basics.S3Client, func(u *manager.Uploader) { u.PartSize = partMiBs * 1024 * 1024 }) - _, err := uploader.Upload(context.TODO(), &s3.PutObjectInput{ + _, err := uploader.Upload(ctx, &s3.PutObjectInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), Body: largeBuffer, @@ -149,8 +149,8 @@ func (basics BucketBasics) UploadLargeObject(bucketName string, objectKey string // snippet-start:[gov2.s3.GetObject] // DownloadFile gets an object from a bucket and stores it in a local file. -func (basics BucketBasics) DownloadFile(bucketName string, objectKey string, fileName string) error { - result, err := basics.S3Client.GetObject(context.TODO(), &s3.GetObjectInput{ +func (basics BucketBasics) DownloadFile(ctx context.Context, bucketName string, objectKey string, fileName string) error { + result, err := basics.S3Client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), }) @@ -180,13 +180,13 @@ func (basics BucketBasics) DownloadFile(bucketName string, objectKey string, fil // DownloadLargeObject uses a download manager to download an object from a bucket. // The download manager gets the data in parts and writes them to a buffer until all of // the data has been downloaded. -func (basics BucketBasics) DownloadLargeObject(bucketName string, objectKey string) ([]byte, error) { +func (basics BucketBasics) DownloadLargeObject(ctx context.Context, bucketName string, objectKey string) ([]byte, error) { var partMiBs int64 = 10 downloader := manager.NewDownloader(basics.S3Client, func(d *manager.Downloader) { d.PartSize = partMiBs * 1024 * 1024 }) buffer := manager.NewWriteAtBuffer([]byte{}) - _, err := downloader.Download(context.TODO(), buffer, &s3.GetObjectInput{ + _, err := downloader.Download(ctx, buffer, &s3.GetObjectInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), }) @@ -202,8 +202,8 @@ func (basics BucketBasics) DownloadLargeObject(bucketName string, objectKey stri // snippet-start:[gov2.s3.CopyObject] // CopyToFolder copies an object in a bucket to a subfolder in the same bucket. -func (basics BucketBasics) CopyToFolder(bucketName string, objectKey string, folderName string) error { - _, err := basics.S3Client.CopyObject(context.TODO(), &s3.CopyObjectInput{ +func (basics BucketBasics) CopyToFolder(ctx context.Context, bucketName string, objectKey string, folderName string) error { + _, err := basics.S3Client.CopyObject(ctx, &s3.CopyObjectInput{ Bucket: aws.String(bucketName), CopySource: aws.String(fmt.Sprintf("%v/%v", bucketName, objectKey)), Key: aws.String(fmt.Sprintf("%v/%v", folderName, objectKey)), @@ -220,8 +220,8 @@ func (basics BucketBasics) CopyToFolder(bucketName string, objectKey string, fol // snippet-start:[gov2.s3.CopyObject.ToBucket] // CopyToBucket copies an object in a bucket to another bucket. -func (basics BucketBasics) CopyToBucket(sourceBucket string, destinationBucket string, objectKey string) error { - _, err := basics.S3Client.CopyObject(context.TODO(), &s3.CopyObjectInput{ +func (basics BucketBasics) CopyToBucket(ctx context.Context, sourceBucket string, destinationBucket string, objectKey string) error { + _, err := basics.S3Client.CopyObject(ctx, &s3.CopyObjectInput{ Bucket: aws.String(destinationBucket), CopySource: aws.String(fmt.Sprintf("%v/%v", sourceBucket, objectKey)), Key: aws.String(objectKey), @@ -238,8 +238,8 @@ func (basics BucketBasics) CopyToBucket(sourceBucket string, destinationBucket s // snippet-start:[gov2.s3.ListObjectsV2] // ListObjects lists the objects in a bucket. -func (basics BucketBasics) ListObjects(bucketName string) ([]types.Object, error) { - result, err := basics.S3Client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{ +func (basics BucketBasics) ListObjects(ctx context.Context, bucketName string) ([]types.Object, error) { + result, err := basics.S3Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ Bucket: aws.String(bucketName), }) var contents []types.Object @@ -256,12 +256,12 @@ func (basics BucketBasics) ListObjects(bucketName string) ([]types.Object, error // snippet-start:[gov2.s3.DeleteObjects] // DeleteObjects deletes a list of objects from a bucket. -func (basics BucketBasics) DeleteObjects(bucketName string, objectKeys []string) error { +func (basics BucketBasics) DeleteObjects(ctx context.Context, bucketName string, objectKeys []string) error { var objectIds []types.ObjectIdentifier for _, key := range objectKeys { objectIds = append(objectIds, types.ObjectIdentifier{Key: aws.String(key)}) } - output, err := basics.S3Client.DeleteObjects(context.TODO(), &s3.DeleteObjectsInput{ + output, err := basics.S3Client.DeleteObjects(ctx, &s3.DeleteObjectsInput{ Bucket: aws.String(bucketName), Delete: &types.Delete{Objects: objectIds}, }) @@ -278,8 +278,8 @@ func (basics BucketBasics) DeleteObjects(bucketName string, objectKeys []string) // snippet-start:[gov2.s3.DeleteBucket] // DeleteBucket deletes a bucket. The bucket must be empty or an error is returned. -func (basics BucketBasics) DeleteBucket(bucketName string) error { - _, err := basics.S3Client.DeleteBucket(context.TODO(), &s3.DeleteBucketInput{ +func (basics BucketBasics) DeleteBucket(ctx context.Context, bucketName string) error { + _, err := basics.S3Client.DeleteBucket(ctx, &s3.DeleteBucketInput{ Bucket: aws.String(bucketName)}) if err != nil { log.Printf("Couldn't delete bucket %v. Here's why: %v\n", bucketName, err) diff --git a/gov2/s3/actions/bucket_basics_test.go b/gov2/s3/actions/bucket_basics_test.go index 41d9e91069d..46bd481d6d1 100644 --- a/gov2/s3/actions/bucket_basics_test.go +++ b/gov2/s3/actions/bucket_basics_test.go @@ -6,6 +6,7 @@ package actions import ( + "context" "errors" "testing" @@ -28,8 +29,9 @@ func TestBucketBasics_CopyToBucket(t *testing.T) { func CopyToBucket(raiseErr *testtools.StubError, t *testing.T) { stubber, basics := enterTest() stubber.Add(stubs.StubCopyObject("source-bucket", "object-key", "dest-bucket", "object-key", raiseErr)) + ctx := context.Background() - err := basics.CopyToBucket("source-bucket", "dest-bucket", "object-key") + err := basics.CopyToBucket(ctx, "source-bucket", "dest-bucket", "object-key") testtools.VerifyError(err, raiseErr, t) testtools.ExitTest(stubber, t) diff --git a/gov2/s3/actions/presigner.go b/gov2/s3/actions/presigner.go index 40bd93d1a11..24a8d4719f3 100644 --- a/gov2/s3/actions/presigner.go +++ b/gov2/s3/actions/presigner.go @@ -31,8 +31,8 @@ type Presigner struct { // GetObject makes a presigned request that can be used to get an object from a bucket. // The presigned request is valid for the specified number of seconds. func (presigner Presigner) GetObject( - bucketName string, objectKey string, lifetimeSecs int64) (*v4.PresignedHTTPRequest, error) { - request, err := presigner.PresignClient.PresignGetObject(context.TODO(), &s3.GetObjectInput{ + ctx context.Context, bucketName string, objectKey string, lifetimeSecs int64) (*v4.PresignedHTTPRequest, error) { + request, err := presigner.PresignClient.PresignGetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), }, func(opts *s3.PresignOptions) { @@ -52,8 +52,8 @@ func (presigner Presigner) GetObject( // PutObject makes a presigned request that can be used to put an object in a bucket. // The presigned request is valid for the specified number of seconds. func (presigner Presigner) PutObject( - bucketName string, objectKey string, lifetimeSecs int64) (*v4.PresignedHTTPRequest, error) { - request, err := presigner.PresignClient.PresignPutObject(context.TODO(), &s3.PutObjectInput{ + ctx context.Context, bucketName string, objectKey string, lifetimeSecs int64) (*v4.PresignedHTTPRequest, error) { + request, err := presigner.PresignClient.PresignPutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), }, func(opts *s3.PresignOptions) { @@ -71,8 +71,8 @@ func (presigner Presigner) PutObject( // snippet-start:[gov2.s3.PresignDeleteObject] // DeleteObject makes a presigned request that can be used to delete an object from a bucket. -func (presigner Presigner) DeleteObject(bucketName string, objectKey string) (*v4.PresignedHTTPRequest, error) { - request, err := presigner.PresignClient.PresignDeleteObject(context.TODO(), &s3.DeleteObjectInput{ +func (presigner Presigner) DeleteObject(ctx context.Context, bucketName string, objectKey string) (*v4.PresignedHTTPRequest, error) { + request, err := presigner.PresignClient.PresignDeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), }) @@ -86,8 +86,8 @@ func (presigner Presigner) DeleteObject(bucketName string, objectKey string) (*v // snippet-start:[gov2.s3.PresignPostObject] -func (presigner Presigner) PresignPostObject(bucketName string, objectKey string, lifetimeSecs int64) (*s3.PresignedPostRequest, error) { - request, err := presigner.PresignClient.PresignPostObject(context.TODO(), &s3.PutObjectInput{ +func (presigner Presigner) PresignPostObject(ctx context.Context, bucketName string, objectKey string, lifetimeSecs int64) (*s3.PresignedPostRequest, error) { + request, err := presigner.PresignClient.PresignPostObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), }, func(options *s3.PresignPostOptions) { diff --git a/gov2/s3/cmd/main.go b/gov2/s3/cmd/main.go index ff5fd603b0b..ba7c6114459 100644 --- a/gov2/s3/cmd/main.go +++ b/gov2/s3/cmd/main.go @@ -27,7 +27,7 @@ import ( // get presigned requests that contain temporary credentials // and can be used to make requests from any HTTP client. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config){ + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config){ "getstarted": runGetStartedScenario, "presigning": runPresigningScenario, } @@ -46,20 +46,21 @@ func main() { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } log.SetFlags(0) - runScenario(sdkConfig) + runScenario(ctx, sdkConfig) } } -func runGetStartedScenario(sdkConfig aws.Config) { - scenarios.RunGetStartedScenario(sdkConfig, demotools.NewQuestioner()) +func runGetStartedScenario(ctx context.Context, sdkConfig aws.Config) { + scenarios.RunGetStartedScenario(ctx, sdkConfig, demotools.NewQuestioner()) } -func runPresigningScenario(sdkConfig aws.Config) { - scenarios.RunPresigningScenario(sdkConfig, demotools.NewQuestioner(), scenarios.HttpRequester{}) +func runPresigningScenario(ctx context.Context, sdkConfig aws.Config) { + scenarios.RunPresigningScenario(ctx, sdkConfig, demotools.NewQuestioner(), scenarios.HttpRequester{}) } diff --git a/gov2/s3/hello/hello.go b/gov2/s3/hello/hello.go index 0ff6db165d3..073a9825a83 100644 --- a/gov2/s3/hello/hello.go +++ b/gov2/s3/hello/hello.go @@ -18,7 +18,8 @@ import ( // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -27,7 +28,7 @@ func main() { s3Client := s3.NewFromConfig(sdkConfig) count := 10 fmt.Printf("Let's list up to %v buckets for your account.\n", count) - result, err := s3Client.ListBuckets(context.TODO(), &s3.ListBucketsInput{}) + result, err := s3Client.ListBuckets(ctx, &s3.ListBucketsInput{}) if err != nil { fmt.Printf("Couldn't list buckets for your account. Here's why: %v\n", err) return diff --git a/gov2/s3/scenarios/scenario_get_started.go b/gov2/s3/scenarios/scenario_get_started.go index 54695790d0d..3b04d861344 100644 --- a/gov2/s3/scenarios/scenario_get_started.go +++ b/gov2/s3/scenarios/scenario_get_started.go @@ -4,6 +4,7 @@ package scenarios import ( + "context" "crypto/rand" "fmt" "log" @@ -36,7 +37,7 @@ import ( // // It uses a questioner from the `demotools` package to get input during the example. // This package can be found in the ..\..\demotools folder of this repo. -func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestioner) { +func RunGetStartedScenario(ctx context.Context, sdkConfig aws.Config, questioner demotools.IQuestioner) { defer func() { if r := recover(); r != nil { fmt.Println("Something went wrong with the demo.\n", r) @@ -52,7 +53,7 @@ func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestione count := 10 log.Printf("Let's list up to %v buckets for your account:", count) - buckets, err := bucketBasics.ListBuckets() + buckets, err := bucketBasics.ListBuckets(ctx) if err != nil { panic(err) } @@ -69,12 +70,12 @@ func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestione bucketName := questioner.Ask("Let's create a bucket. Enter a name for your bucket:", demotools.NotEmpty{}) - bucketExists, err := bucketBasics.BucketExists(bucketName) + bucketExists, err := bucketBasics.BucketExists(ctx, bucketName) if err != nil { panic(err) } if !bucketExists { - err = bucketBasics.CreateBucket(bucketName, sdkConfig.Region) + err = bucketBasics.CreateBucket(ctx, bucketName, sdkConfig.Region) if err != nil { panic(err) } else { @@ -87,7 +88,7 @@ func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestione smallFile := questioner.Ask("Enter the path to a file you want to upload:", demotools.NotEmpty{}) const smallKey = "doc-example-key" - err = bucketBasics.UploadFile(bucketName, smallKey, smallFile) + err = bucketBasics.UploadFile(ctx, bucketName, smallKey, smallFile) if err != nil { panic(err) } @@ -101,7 +102,7 @@ func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestione _, _ = rand.Read(largeBytes) largeKey := "doc-example-large" log.Println("Uploading...") - err = bucketBasics.UploadLargeObject(bucketName, largeKey, largeBytes) + err = bucketBasics.UploadLargeObject(ctx, bucketName, largeKey, largeBytes) if err != nil { panic(err) } @@ -110,7 +111,7 @@ func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestione log.Printf("Let's download %v to a file.", smallKey) downloadFileName := questioner.Ask("Enter a name for the downloaded file:", demotools.NotEmpty{}) - err = bucketBasics.DownloadFile(bucketName, smallKey, downloadFileName) + err = bucketBasics.DownloadFile(ctx, bucketName, smallKey, downloadFileName) if err != nil { panic(err) } @@ -120,7 +121,7 @@ func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestione log.Printf("Let's download the %v MiB object.", mibs) questioner.Ask("Press Enter when you're ready.") log.Println("Downloading...") - largeDownload, err := bucketBasics.DownloadLargeObject(bucketName, largeKey) + largeDownload, err := bucketBasics.DownloadLargeObject(ctx, bucketName, largeKey) if err != nil { panic(err) } @@ -129,7 +130,7 @@ func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestione log.Printf("Let's copy %v to a folder in the same bucket.", smallKey) folderName := questioner.Ask("Enter a folder name: ", demotools.NotEmpty{}) - err = bucketBasics.CopyToFolder(bucketName, smallKey, folderName) + err = bucketBasics.CopyToFolder(ctx, bucketName, smallKey, folderName) if err != nil { panic(err) } @@ -138,7 +139,7 @@ func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestione log.Println("Let's list the objects in your bucket.") questioner.Ask("Press Enter when you're ready.") - objects, err := bucketBasics.ListObjects(bucketName) + objects, err := bucketBasics.ListObjects(ctx, bucketName) if err != nil { panic(err) } @@ -153,12 +154,12 @@ func RunGetStartedScenario(sdkConfig aws.Config, questioner demotools.IQuestione if questioner.AskBool("Do you want to delete your bucket and all of its "+ "contents? (y/n)", "y") { log.Println("Deleting objects.") - err = bucketBasics.DeleteObjects(bucketName, objKeys) + err = bucketBasics.DeleteObjects(ctx, bucketName, objKeys) if err != nil { panic(err) } log.Println("Deleting bucket.") - err = bucketBasics.DeleteBucket(bucketName) + err = bucketBasics.DeleteBucket(ctx, bucketName) if err != nil { panic(err) } diff --git a/gov2/s3/scenarios/scenario_get_started_integ_test.go b/gov2/s3/scenarios/scenario_get_started_integ_test.go index 39426531d24..03c01ae9cf1 100644 --- a/gov2/s3/scenarios/scenario_get_started_integ_test.go +++ b/gov2/s3/scenarios/scenario_get_started_integ_test.go @@ -29,7 +29,8 @@ func TestGetStartedScenario_Integration(t *testing.T) { }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -38,7 +39,7 @@ func TestGetStartedScenario_Integration(t *testing.T) { var buf bytes.Buffer log.SetOutput(&buf) - RunGetStartedScenario(sdkConfig, mockQuestioner) + RunGetStartedScenario(ctx, sdkConfig, mockQuestioner) _ = os.Remove(outFile) diff --git a/gov2/s3/scenarios/scenario_get_started_test.go b/gov2/s3/scenarios/scenario_get_started_test.go index 9f197491e14..ee9d1ceb8b4 100644 --- a/gov2/s3/scenarios/scenario_get_started_test.go +++ b/gov2/s3/scenarios/scenario_get_started_test.go @@ -102,7 +102,7 @@ func (scenTest *GetStartedScenarioTest) SetupDataAndStubs() []testtools.Stub { // or without errors. func (scenTest *GetStartedScenarioTest) RunSubTest(stubber *testtools.AwsmStubber) { mockQuestioner := demotools.MockQuestioner{Answers: scenTest.Answers} - RunGetStartedScenario(*stubber.SdkConfig, &mockQuestioner) + RunGetStartedScenario(context.Background(), *stubber.SdkConfig, &mockQuestioner) } // Cleanup deletes the output file created by the download test. diff --git a/gov2/s3/scenarios/scenario_presigning.go b/gov2/s3/scenarios/scenario_presigning.go index 3a1398be85b..976386658a5 100644 --- a/gov2/s3/scenarios/scenario_presigning.go +++ b/gov2/s3/scenarios/scenario_presigning.go @@ -5,6 +5,7 @@ package scenarios import ( "bytes" + "context" "fmt" "io" "log" @@ -122,7 +123,7 @@ func sendMultipartRequest(url string, fields map[string]string, file *os.File, f // // It uses an IHttpRequester interface to abstract HTTP requests so they can be mocked // during testing. -func RunPresigningScenario(sdkConfig aws.Config, questioner demotools.IQuestioner, httpRequester IHttpRequester) { +func RunPresigningScenario(ctx context.Context, sdkConfig aws.Config, questioner demotools.IQuestioner, httpRequester IHttpRequester) { defer func() { if r := recover(); r != nil { fmt.Printf("Something went wrong with the demo") @@ -140,12 +141,12 @@ func RunPresigningScenario(sdkConfig aws.Config, questioner demotools.IQuestione bucketName := questioner.Ask("We'll need a bucket. Enter a name for a bucket "+ "you own or one you want to create:", demotools.NotEmpty{}) - bucketExists, err := bucketBasics.BucketExists(bucketName) + bucketExists, err := bucketBasics.BucketExists(ctx, bucketName) if err != nil { panic(err) } if !bucketExists { - err = bucketBasics.CreateBucket(bucketName, sdkConfig.Region) + err = bucketBasics.CreateBucket(ctx, bucketName, sdkConfig.Region) if err != nil { panic(err) } else { @@ -164,7 +165,7 @@ func RunPresigningScenario(sdkConfig aws.Config, questioner demotools.IQuestione panic(err) } defer uploadFile.Close() - presignedPutRequest, err := presigner.PutObject(bucketName, uploadKey, 60) + presignedPutRequest, err := presigner.PutObject(ctx, bucketName, uploadKey, 60) if err != nil { panic(err) } @@ -185,7 +186,7 @@ func RunPresigningScenario(sdkConfig aws.Config, questioner demotools.IQuestione log.Printf("Let's presign a request to download the object.") questioner.Ask("Press Enter when you're ready.") - presignedGetRequest, err := presigner.GetObject(bucketName, uploadKey, 60) + presignedGetRequest, err := presigner.GetObject(ctx, bucketName, uploadKey, 60) if err != nil { panic(err) } @@ -209,7 +210,8 @@ func RunPresigningScenario(sdkConfig aws.Config, questioner demotools.IQuestione log.Println(strings.Repeat("-", 88)) log.Println("Now we'll create a new request to put the same object using a presigned post request") - presignPostRequest, err := presigner.PresignPostObject(bucketName, uploadKey, 60) + questioner.Ask("Press Enter when you're ready.") + presignPostRequest, err := presigner.PresignPostObject(ctx, bucketName, uploadKey, 60) if err != nil { panic(err) } @@ -228,7 +230,7 @@ func RunPresigningScenario(sdkConfig aws.Config, questioner demotools.IQuestione log.Println("Let's presign a request to delete the object.") questioner.Ask("Press Enter when you're ready.") - presignedDelRequest, err := presigner.DeleteObject(bucketName, uploadKey) + presignedDelRequest, err := presigner.DeleteObject(ctx, bucketName, uploadKey) if err != nil { panic(err) } diff --git a/gov2/s3/scenarios/scenario_presigning_integ_test.go b/gov2/s3/scenarios/scenario_presigning_integ_test.go index 19f04104fd6..0c33543a18d 100644 --- a/gov2/s3/scenarios/scenario_presigning_integ_test.go +++ b/gov2/s3/scenarios/scenario_presigning_integ_test.go @@ -24,11 +24,12 @@ import ( func TestRunPresigningScenario_Integration(t *testing.T) { mockQuestioner := &demotools.MockQuestioner{ Answers: []string{ - "doc-example-go-test-bucket", "../README.md", "test-object", "", "", + "doc-example-go-test-bucket", "../README.md", "test-object", "", "", "", }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -37,7 +38,7 @@ func TestRunPresigningScenario_Integration(t *testing.T) { var buf bytes.Buffer log.SetOutput(&buf) - RunPresigningScenario(sdkConfig, mockQuestioner, HttpRequester{}) + RunPresigningScenario(ctx, sdkConfig, mockQuestioner, HttpRequester{}) log.SetOutput(os.Stderr) if !strings.Contains(buf.String(), "Thanks for watching") { diff --git a/gov2/s3/scenarios/scenario_presigning_test.go b/gov2/s3/scenarios/scenario_presigning_test.go index 26331aa3b68..1aaa4b35b5b 100644 --- a/gov2/s3/scenarios/scenario_presigning_test.go +++ b/gov2/s3/scenarios/scenario_presigning_test.go @@ -62,7 +62,7 @@ func (scenTest *PresigningScenarioTest) SetupDataAndStubs() []testtools.Stub { objectKey := "doc-example-key" scenTest.TestBody = io.NopCloser(strings.NewReader("Test data!")) scenTest.Answers = []string{ - bucketName, "../README.md", objectKey, "", "", + bucketName, "../README.md", objectKey, "", "", "", } var stubList []testtools.Stub @@ -81,7 +81,7 @@ func (scenTest *PresigningScenarioTest) SetupDataAndStubs() []testtools.Stub { // or without errors. func (scenTest *PresigningScenarioTest) RunSubTest(stubber *testtools.AwsmStubber) { mockQuestioner := demotools.MockQuestioner{Answers: scenTest.Answers} - RunPresigningScenario(*stubber.SdkConfig, &mockQuestioner, MockHttpRequester{GetBody: scenTest.TestBody}) + RunPresigningScenario(context.Background(), *stubber.SdkConfig, &mockQuestioner, MockHttpRequester{GetBody: scenTest.TestBody}) } func (scenTest *PresigningScenarioTest) Cleanup() {} diff --git a/gov2/sns/hello/hello.go b/gov2/sns/hello/hello.go index a7b8432f205..ac159947444 100644 --- a/gov2/sns/hello/hello.go +++ b/gov2/sns/hello/hello.go @@ -20,7 +20,8 @@ import ( // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -31,7 +32,7 @@ func main() { var topics []types.Topic paginator := sns.NewListTopicsPaginator(snsClient, &sns.ListTopicsInput{}) for paginator.HasMorePages() { - output, err := paginator.NextPage(context.TODO()) + output, err := paginator.NextPage(ctx) if err != nil { log.Printf("Couldn't get topics. Here's why: %v\n", err) break diff --git a/gov2/sqs/hello/hello.go b/gov2/sqs/hello/hello.go index 2415c61bc5e..5b565d5d39f 100644 --- a/gov2/sqs/hello/hello.go +++ b/gov2/sqs/hello/hello.go @@ -19,7 +19,8 @@ import ( // This example uses the default settings specified in your shared credentials // and config files. func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { fmt.Println("Couldn't load default configuration. Have you set up your AWS account?") fmt.Println(err) @@ -30,7 +31,7 @@ func main() { var queueUrls []string paginator := sqs.NewListQueuesPaginator(sqsClient, &sqs.ListQueuesInput{}) for paginator.HasMorePages() { - output, err := paginator.NextPage(context.TODO()) + output, err := paginator.NextPage(ctx) if err != nil { log.Printf("Couldn't get queues. Here's why: %v\n", err) break diff --git a/gov2/workflows/topics_and_queues/actions/sns_actions.go b/gov2/workflows/topics_and_queues/actions/sns_actions.go index 10977068cf1..dba7ed9dc82 100644 --- a/gov2/workflows/topics_and_queues/actions/sns_actions.go +++ b/gov2/workflows/topics_and_queues/actions/sns_actions.go @@ -29,7 +29,7 @@ type SnsActions struct { // CreateTopic creates an Amazon SNS topic with the specified name. You can optionally // specify that the topic is created as a FIFO topic and whether it uses content-based // deduplication instead of ID-based deduplication. -func (actor SnsActions) CreateTopic(topicName string, isFifoTopic bool, contentBasedDeduplication bool) (string, error) { +func (actor SnsActions) CreateTopic(ctx context.Context, topicName string, isFifoTopic bool, contentBasedDeduplication bool) (string, error) { var topicArn string topicAttributes := map[string]string{} if isFifoTopic { @@ -38,7 +38,7 @@ func (actor SnsActions) CreateTopic(topicName string, isFifoTopic bool, contentB if contentBasedDeduplication { topicAttributes["ContentBasedDeduplication"] = "true" } - topic, err := actor.SnsClient.CreateTopic(context.TODO(), &sns.CreateTopicInput{ + topic, err := actor.SnsClient.CreateTopic(ctx, &sns.CreateTopicInput{ Name: aws.String(topicName), Attributes: topicAttributes, }) @@ -56,8 +56,8 @@ func (actor SnsActions) CreateTopic(topicName string, isFifoTopic bool, contentB // snippet-start:[gov2.sns.DeleteTopic] // DeleteTopic delete an Amazon SNS topic. -func (actor SnsActions) DeleteTopic(topicArn string) error { - _, err := actor.SnsClient.DeleteTopic(context.TODO(), &sns.DeleteTopicInput{ +func (actor SnsActions) DeleteTopic(ctx context.Context, topicArn string) error { + _, err := actor.SnsClient.DeleteTopic(ctx, &sns.DeleteTopicInput{ TopicArn: aws.String(topicArn)}) if err != nil { log.Printf("Couldn't delete topic %v. Here's why: %v\n", topicArn, err) @@ -72,7 +72,7 @@ func (actor SnsActions) DeleteTopic(topicArn string) error { // SubscribeQueue subscribes an Amazon Simple Queue Service (Amazon SQS) queue to an // Amazon SNS topic. When filterMap is not nil, it is used to specify a filter policy // so that messages are only sent to the queue when the message has the specified attributes. -func (actor SnsActions) SubscribeQueue(topicArn string, queueArn string, filterMap map[string][]string) (string, error) { +func (actor SnsActions) SubscribeQueue(ctx context.Context, topicArn string, queueArn string, filterMap map[string][]string) (string, error) { var subscriptionArn string var attributes map[string]string if filterMap != nil { @@ -83,7 +83,7 @@ func (actor SnsActions) SubscribeQueue(topicArn string, queueArn string, filterM } attributes = map[string]string{"FilterPolicy": string(filterBytes)} } - output, err := actor.SnsClient.Subscribe(context.TODO(), &sns.SubscribeInput{ + output, err := actor.SnsClient.Subscribe(ctx, &sns.SubscribeInput{ Protocol: aws.String("sqs"), TopicArn: aws.String(topicArn), Attributes: attributes, @@ -109,7 +109,7 @@ func (actor SnsActions) SubscribeQueue(topicArn string, queueArn string, filterM // and, when ID-based deduplication is used, a deduplication ID. An optional key-value // filter attribute can be specified so that the message can be filtered according to // a filter policy. -func (actor SnsActions) Publish(topicArn string, message string, groupId string, dedupId string, filterKey string, filterValue string) error { +func (actor SnsActions) Publish(ctx context.Context, topicArn string, message string, groupId string, dedupId string, filterKey string, filterValue string) error { publishInput := sns.PublishInput{TopicArn: aws.String(topicArn), Message: aws.String(message)} if groupId != "" { publishInput.MessageGroupId = aws.String(groupId) @@ -122,7 +122,7 @@ func (actor SnsActions) Publish(topicArn string, message string, groupId string, filterKey: {DataType: aws.String("String"), StringValue: aws.String(filterValue)}, } } - _, err := actor.SnsClient.Publish(context.TODO(), &publishInput) + _, err := actor.SnsClient.Publish(ctx, &publishInput) if err != nil { log.Printf("Couldn't publish message to topic %v. Here's why: %v", topicArn, err) } diff --git a/gov2/workflows/topics_and_queues/actions/sqs_actions.go b/gov2/workflows/topics_and_queues/actions/sqs_actions.go index 3c432b27b0b..dc6a9016e5d 100644 --- a/gov2/workflows/topics_and_queues/actions/sqs_actions.go +++ b/gov2/workflows/topics_and_queues/actions/sqs_actions.go @@ -29,13 +29,13 @@ type SqsActions struct { // CreateQueue creates an Amazon SQS queue with the specified name. You can specify // whether the queue is created as a FIFO queue. -func (actor SqsActions) CreateQueue(queueName string, isFifoQueue bool) (string, error) { +func (actor SqsActions) CreateQueue(ctx context.Context, queueName string, isFifoQueue bool) (string, error) { var queueUrl string queueAttributes := map[string]string{} if isFifoQueue { queueAttributes["FifoQueue"] = "true" } - queue, err := actor.SqsClient.CreateQueue(context.TODO(), &sqs.CreateQueueInput{ + queue, err := actor.SqsClient.CreateQueue(ctx, &sqs.CreateQueueInput{ QueueName: aws.String(queueName), Attributes: queueAttributes, }) @@ -54,10 +54,10 @@ func (actor SqsActions) CreateQueue(queueName string, isFifoQueue bool) (string, // GetQueueArn uses the GetQueueAttributes action to get the Amazon Resource Name (ARN) // of an Amazon SQS queue. -func (actor SqsActions) GetQueueArn(queueUrl string) (string, error) { +func (actor SqsActions) GetQueueArn(ctx context.Context, queueUrl string) (string, error) { var queueArn string arnAttributeName := types.QueueAttributeNameQueueArn - attribute, err := actor.SqsClient.GetQueueAttributes(context.TODO(), &sqs.GetQueueAttributesInput{ + attribute, err := actor.SqsClient.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{ QueueUrl: aws.String(queueUrl), AttributeNames: []types.QueueAttributeName{arnAttributeName}, }) @@ -76,7 +76,7 @@ func (actor SqsActions) GetQueueArn(queueUrl string) (string, error) { // AttachSendMessagePolicy uses the SetQueueAttributes action to attach a policy to an // Amazon SQS queue that allows the specified Amazon SNS topic to send messages to the // queue. -func (actor SqsActions) AttachSendMessagePolicy(queueUrl string, queueArn string, topicArn string) error { +func (actor SqsActions) AttachSendMessagePolicy(ctx context.Context, queueUrl string, queueArn string, topicArn string) error { policyDoc := PolicyDocument{ Version: "2012-10-17", Statement: []PolicyStatement{{ @@ -92,7 +92,7 @@ func (actor SqsActions) AttachSendMessagePolicy(queueUrl string, queueArn string log.Printf("Couldn't create policy document. Here's why: %v\n", err) return err } - _, err = actor.SqsClient.SetQueueAttributes(context.TODO(), &sqs.SetQueueAttributesInput{ + _, err = actor.SqsClient.SetQueueAttributes(ctx, &sqs.SetQueueAttributesInput{ Attributes: map[string]string{ string(types.QueueAttributeNamePolicy): string(policyBytes), }, @@ -128,9 +128,9 @@ type PolicyCondition map[string]map[string]string // snippet-start:[gov2.sqs.ReceiveMessage] // GetMessages uses the ReceiveMessage action to get messages from an Amazon SQS queue. -func (actor SqsActions) GetMessages(queueUrl string, maxMessages int32, waitTime int32) ([]types.Message, error) { +func (actor SqsActions) GetMessages(ctx context.Context, queueUrl string, maxMessages int32, waitTime int32) ([]types.Message, error) { var messages []types.Message - result, err := actor.SqsClient.ReceiveMessage(context.TODO(), &sqs.ReceiveMessageInput{ + result, err := actor.SqsClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(queueUrl), MaxNumberOfMessages: maxMessages, WaitTimeSeconds: waitTime, @@ -149,13 +149,13 @@ func (actor SqsActions) GetMessages(queueUrl string, maxMessages int32, waitTime // DeleteMessages uses the DeleteMessageBatch action to delete a batch of messages from // an Amazon SQS queue. -func (actor SqsActions) DeleteMessages(queueUrl string, messages []types.Message) error { +func (actor SqsActions) DeleteMessages(ctx context.Context, queueUrl string, messages []types.Message) error { entries := make([]types.DeleteMessageBatchRequestEntry, len(messages)) for msgIndex := range messages { entries[msgIndex].Id = aws.String(fmt.Sprintf("%v", msgIndex)) entries[msgIndex].ReceiptHandle = messages[msgIndex].ReceiptHandle } - _, err := actor.SqsClient.DeleteMessageBatch(context.TODO(), &sqs.DeleteMessageBatchInput{ + _, err := actor.SqsClient.DeleteMessageBatch(ctx, &sqs.DeleteMessageBatchInput{ Entries: entries, QueueUrl: aws.String(queueUrl), }) @@ -170,8 +170,8 @@ func (actor SqsActions) DeleteMessages(queueUrl string, messages []types.Message // snippet-start:[gov2.sqs.DeleteQueue] // DeleteQueue deletes an Amazon SQS queue. -func (actor SqsActions) DeleteQueue(queueUrl string) error { - _, err := actor.SqsClient.DeleteQueue(context.TODO(), &sqs.DeleteQueueInput{ +func (actor SqsActions) DeleteQueue(ctx context.Context, queueUrl string) error { + _, err := actor.SqsClient.DeleteQueue(ctx, &sqs.DeleteQueueInput{ QueueUrl: aws.String(queueUrl)}) if err != nil { log.Printf("Couldn't delete queue %v. Here's why: %v\n", queueUrl, err) diff --git a/gov2/workflows/topics_and_queues/cmd/main.go b/gov2/workflows/topics_and_queues/cmd/main.go index 7910e7b1e5c..413d5709d7c 100644 --- a/gov2/workflows/topics_and_queues/cmd/main.go +++ b/gov2/workflows/topics_and_queues/cmd/main.go @@ -24,7 +24,7 @@ import ( // you how to create an Amazon SNS topic and Amazon SQS queues and publish messages // to the topic that are forwarded to the subscribed queues. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config){ + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config){ "topics_and_queues": runTopicsAndQueuesScenario, } choices := make([]string, len(scenarioMap)) @@ -42,16 +42,17 @@ func main() { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } log.SetFlags(0) - runScenario(sdkConfig) + runScenario(ctx, sdkConfig) } } -func runTopicsAndQueuesScenario(sdkConfig aws.Config) { - workflows.RunTopicsAndQueuesScenario(sdkConfig, demotools.NewQuestioner()) +func runTopicsAndQueuesScenario(ctx context.Context, sdkConfig aws.Config) { + workflows.RunTopicsAndQueuesScenario(ctx, sdkConfig, demotools.NewQuestioner()) } diff --git a/gov2/workflows/topics_and_queues/workflows/resources.go b/gov2/workflows/topics_and_queues/workflows/resources.go index 2cc176c4655..b473b73b634 100644 --- a/gov2/workflows/topics_and_queues/workflows/resources.go +++ b/gov2/workflows/topics_and_queues/workflows/resources.go @@ -4,6 +4,7 @@ package workflows import ( + "context" "fmt" "log" "topics_and_queues/actions" @@ -19,7 +20,7 @@ type Resources struct { } // Cleanup deletes all AWS resources created during an example. -func (resources Resources) Cleanup() { +func (resources Resources) Cleanup(ctx context.Context) { defer func() { if r := recover(); r != nil { fmt.Println("Something went wrong during cleanup. Use the AWS Management Console\n" + @@ -30,7 +31,7 @@ func (resources Resources) Cleanup() { var err error if resources.topicArn != "" { log.Printf("Deleting topic %v.\n", resources.topicArn) - err = resources.snsActor.DeleteTopic(resources.topicArn) + err = resources.snsActor.DeleteTopic(ctx, resources.topicArn) if err != nil { panic(err) } @@ -38,7 +39,7 @@ func (resources Resources) Cleanup() { for _, queueUrl := range resources.queueUrls { log.Printf("Deleting queue %v.\n", queueUrl) - err = resources.sqsActor.DeleteQueue(queueUrl) + err = resources.sqsActor.DeleteQueue(ctx, queueUrl) if err != nil { panic(err) } diff --git a/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues.go b/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues.go index 835402a8e30..c28536c8b2f 100644 --- a/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues.go +++ b/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues.go @@ -4,6 +4,7 @@ package workflows import ( + "context" "encoding/json" "fmt" "log" @@ -37,7 +38,7 @@ type ScenarioRunner struct { sqsActor *actions.SqsActions } -func (runner ScenarioRunner) CreateTopic() (string, string, bool, bool) { +func (runner ScenarioRunner) CreateTopic(ctx context.Context) (string, string, bool, bool) { log.Println("SNS topics can be configured as FIFO (First-In-First-Out) or standard.\n" + "FIFO topics deliver messages in order and support deduplication and message filtering.") isFifoTopic := runner.questioner.AskBool("\nWould you like to work with FIFO topics? (y/n) ", "y") @@ -64,7 +65,7 @@ func (runner ScenarioRunner) CreateTopic() (string, string, bool, bool) { "the topic name.", FIFO_SUFFIX) } - topicArn, err := runner.snsActor.CreateTopic(topicName, isFifoTopic, contentBasedDeduplication) + topicArn, err := runner.snsActor.CreateTopic(ctx, topicName, isFifoTopic, contentBasedDeduplication) if err != nil { panic(err) } @@ -74,7 +75,7 @@ func (runner ScenarioRunner) CreateTopic() (string, string, bool, bool) { return topicName, topicArn, isFifoTopic, contentBasedDeduplication } -func (runner ScenarioRunner) CreateQueue(ordinal string, isFifoTopic bool) (string, string) { +func (runner ScenarioRunner) CreateQueue(ctx context.Context, ordinal string, isFifoTopic bool) (string, string) { queueName := runner.questioner.Ask(fmt.Sprintf("Enter a name for the %v SQS queue. ", ordinal)) if isFifoTopic { queueName = fmt.Sprintf("%v%v", queueName, FIFO_SUFFIX) @@ -83,7 +84,7 @@ func (runner ScenarioRunner) CreateQueue(ordinal string, isFifoTopic bool) (stri "be appended to the queue name.\n", FIFO_SUFFIX) } } - queueUrl, err := runner.sqsActor.CreateQueue(queueName, isFifoTopic) + queueUrl, err := runner.sqsActor.CreateQueue(ctx, queueName, isFifoTopic) if err != nil { panic(err) } @@ -94,16 +95,16 @@ func (runner ScenarioRunner) CreateQueue(ordinal string, isFifoTopic bool) (stri } func (runner ScenarioRunner) SubscribeQueueToTopic( - queueName string, queueUrl string, topicName string, topicArn string, ordinal string, + ctx context.Context, queueName string, queueUrl string, topicName string, topicArn string, ordinal string, isFifoTopic bool) (string, bool) { - queueArn, err := runner.sqsActor.GetQueueArn(queueUrl) + queueArn, err := runner.sqsActor.GetQueueArn(ctx, queueUrl) if err != nil { panic(err) } log.Printf("The ARN of your queue is: %v.\n", queueArn) - err = runner.sqsActor.AttachSendMessagePolicy(queueUrl, queueArn, topicArn) + err = runner.sqsActor.AttachSendMessagePolicy(ctx, queueUrl, queueArn, topicArn) if err != nil { panic(err) } @@ -141,7 +142,7 @@ func (runner ScenarioRunner) SubscribeQueueToTopic( } } - subscriptionArn, err := runner.snsActor.SubscribeQueue(topicArn, queueArn, filterPolicy) + subscriptionArn, err := runner.snsActor.SubscribeQueue(ctx, topicArn, queueArn, filterPolicy) if err != nil { panic(err) } @@ -151,7 +152,7 @@ func (runner ScenarioRunner) SubscribeQueueToTopic( return subscriptionArn, filterPolicy != nil } -func (runner ScenarioRunner) PublishMessages(topicArn string, isFifoTopic bool, contentBasedDeduplication bool, usingFilters bool) { +func (runner ScenarioRunner) PublishMessages(ctx context.Context, topicArn string, isFifoTopic bool, contentBasedDeduplication bool, usingFilters bool) { var message string var groupId string var dedupId string @@ -180,7 +181,7 @@ func (runner ScenarioRunner) PublishMessages(topicArn string, isFifoTopic bool, } } - err := runner.snsActor.Publish(topicArn, message, groupId, dedupId, TONE_KEY, toneSelection) + err := runner.snsActor.Publish(ctx, topicArn, message, groupId, dedupId, TONE_KEY, toneSelection) if err != nil { panic(err) } @@ -190,12 +191,12 @@ func (runner ScenarioRunner) PublishMessages(topicArn string, isFifoTopic bool, } } -func (runner ScenarioRunner) PollForMessages(queueUrls []string) { +func (runner ScenarioRunner) PollForMessages(ctx context.Context, queueUrls []string) { log.Println("Polling queues for messages...") for _, queueUrl := range queueUrls { var messages []types.Message for { - currentMsgs, err := runner.sqsActor.GetMessages(queueUrl, 10, 1) + currentMsgs, err := runner.sqsActor.GetMessages(ctx, queueUrl, 10, 1) if err != nil { panic(err) } @@ -223,7 +224,7 @@ func (runner ScenarioRunner) PollForMessages(queueUrls []string) { if len(messages) > 0 { log.Printf("Deleting %v messages from queue %v.\n", len(messages), queueUrl) - err := runner.sqsActor.DeleteMessages(queueUrl, messages) + err := runner.sqsActor.DeleteMessages(ctx, queueUrl, messages) if err != nil { panic(err) } @@ -246,13 +247,13 @@ func (runner ScenarioRunner) PollForMessages(queueUrls []string) { // It uses a questioner from the `demotools` package to get input during the example. // This package can be found in the ..\..\demotools folder of this repo. func RunTopicsAndQueuesScenario( - sdkConfig aws.Config, questioner demotools.IQuestioner) { + ctx context.Context, sdkConfig aws.Config, questioner demotools.IQuestioner) { resources := Resources{} defer func() { if r := recover(); r != nil { log.Println("Something went wrong with the demo.\n" + "Cleaning up any resources that were created...") - resources.Cleanup() + resources.Cleanup(ctx) } }() queueCount := 2 @@ -274,7 +275,7 @@ func RunTopicsAndQueuesScenario( resources.snsActor = runner.snsActor resources.sqsActor = runner.sqsActor - topicName, topicArn, isFifoTopic, contentBasedDeduplication := runner.CreateTopic() + topicName, topicArn, isFifoTopic, contentBasedDeduplication := runner.CreateTopic(ctx) resources.topicArn = topicArn log.Println(strings.Repeat("-", 88)) @@ -282,24 +283,24 @@ func RunTopicsAndQueuesScenario( ordinals := []string{"first", "next"} usingFilters := false for _, ordinal := range ordinals { - queueName, queueUrl := runner.CreateQueue(ordinal, isFifoTopic) + queueName, queueUrl := runner.CreateQueue(ctx, ordinal, isFifoTopic) resources.queueUrls = append(resources.queueUrls, queueUrl) - _, filtering := runner.SubscribeQueueToTopic(queueName, queueUrl, topicName, topicArn, ordinal, isFifoTopic) + _, filtering := runner.SubscribeQueueToTopic(ctx, queueName, queueUrl, topicName, topicArn, ordinal, isFifoTopic) usingFilters = usingFilters || filtering } log.Println(strings.Repeat("-", 88)) - runner.PublishMessages(topicArn, isFifoTopic, contentBasedDeduplication, usingFilters) + runner.PublishMessages(ctx, topicArn, isFifoTopic, contentBasedDeduplication, usingFilters) log.Println(strings.Repeat("-", 88)) - runner.PollForMessages(resources.queueUrls) + runner.PollForMessages(ctx, resources.queueUrls) log.Println(strings.Repeat("-", 88)) wantCleanup := questioner.AskBool("Do you want to remove all AWS resources created for this scenario? (y/n) ", "y") if wantCleanup { log.Println("Cleaning up resources...") - resources.Cleanup() + resources.Cleanup(ctx) } log.Println(strings.Repeat("-", 88)) diff --git a/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues_integ_test.go b/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues_integ_test.go index 32b216d082c..23d9287f2f2 100644 --- a/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues_integ_test.go +++ b/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues_integ_test.go @@ -36,7 +36,8 @@ func TestRunTopicsAndQueuesScenario_Integration(t *testing.T) { }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -45,7 +46,7 @@ func TestRunTopicsAndQueuesScenario_Integration(t *testing.T) { var buf bytes.Buffer log.SetOutput(&buf) - RunTopicsAndQueuesScenario(sdkConfig, mockQuestioner) + RunTopicsAndQueuesScenario(ctx, sdkConfig, mockQuestioner) log.SetOutput(os.Stderr) if !strings.Contains(buf.String(), "Thanks for watching") { diff --git a/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues_test.go b/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues_test.go index 8e4fe149bce..b9f92537eae 100644 --- a/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues_test.go +++ b/gov2/workflows/topics_and_queues/workflows/scenario_topics_and_queues_test.go @@ -6,6 +6,7 @@ package workflows import ( + "context" "fmt" "testing" "topics_and_queues/stubs" @@ -103,7 +104,7 @@ func (scenTest *TopicsAndQueuesScenarioTest) SetupDataAndStubs() []testtools.Stu // or without errors. func (scenTest *TopicsAndQueuesScenarioTest) RunSubTest(stubber *testtools.AwsmStubber) { mockQuestioner := demotools.MockQuestioner{Answers: scenTest.Answers} - RunTopicsAndQueuesScenario(*stubber.SdkConfig, &mockQuestioner) + RunTopicsAndQueuesScenario(context.Background(), *stubber.SdkConfig, &mockQuestioner) } func (scenTest *TopicsAndQueuesScenarioTest) Cleanup() {} diff --git a/gov2/workflows/user_pools_and_lambda_triggers/actions/cloud_formation_actions.go b/gov2/workflows/user_pools_and_lambda_triggers/actions/cloud_formation_actions.go index dfcd1f2df02..78cd90443e8 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/actions/cloud_formation_actions.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/actions/cloud_formation_actions.go @@ -21,8 +21,8 @@ type CloudFormationActions struct { } // GetOutputs gets the outputs from a CloudFormation stack and puts them into a structured format. -func (actor CloudFormationActions) GetOutputs(stackName string) StackOutputs { - output, err := actor.CfnClient.DescribeStacks(context.TODO(), &cloudformation.DescribeStacksInput{ +func (actor CloudFormationActions) GetOutputs(ctx context.Context, stackName string) StackOutputs { + output, err := actor.CfnClient.DescribeStacks(ctx, &cloudformation.DescribeStacksInput{ StackName: aws.String(stackName), }) if err != nil || len(output.Stacks) == 0 { diff --git a/gov2/workflows/user_pools_and_lambda_triggers/actions/cloudwatch_logs_actions.go b/gov2/workflows/user_pools_and_lambda_triggers/actions/cloudwatch_logs_actions.go index 57c746146bd..08b233bb624 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/actions/cloudwatch_logs_actions.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/actions/cloudwatch_logs_actions.go @@ -20,10 +20,10 @@ type CloudWatchLogsActions struct { } // GetLatestLogStream gets the most recent log stream for a Lambda function. -func (actor CloudWatchLogsActions) GetLatestLogStream(functionName string) (types.LogStream, error) { +func (actor CloudWatchLogsActions) GetLatestLogStream(ctx context.Context, functionName string) (types.LogStream, error) { var logStream types.LogStream logGroupName := fmt.Sprintf("/aws/lambda/%s", functionName) - output, err := actor.CwlClient.DescribeLogStreams(context.TODO(), &cloudwatchlogs.DescribeLogStreamsInput{ + output, err := actor.CwlClient.DescribeLogStreams(ctx, &cloudwatchlogs.DescribeLogStreamsInput{ Descending: aws.Bool(true), Limit: aws.Int32(1), LogGroupName: aws.String(logGroupName), @@ -38,11 +38,11 @@ func (actor CloudWatchLogsActions) GetLatestLogStream(functionName string) (type } // GetLogEvents gets the most recent eventCount events from the specified log stream. -func (actor CloudWatchLogsActions) GetLogEvents(functionName string, logStreamName string, eventCount int32) ( +func (actor CloudWatchLogsActions) GetLogEvents(ctx context.Context, functionName string, logStreamName string, eventCount int32) ( []types.OutputLogEvent, error) { var events []types.OutputLogEvent logGroupName := fmt.Sprintf("/aws/lambda/%s", functionName) - output, err := actor.CwlClient.GetLogEvents(context.TODO(), &cloudwatchlogs.GetLogEventsInput{ + output, err := actor.CwlClient.GetLogEvents(ctx, &cloudwatchlogs.GetLogEventsInput{ LogStreamName: aws.String(logStreamName), Limit: aws.Int32(eventCount), LogGroupName: aws.String(logGroupName), diff --git a/gov2/workflows/user_pools_and_lambda_triggers/actions/cognito_actions.go b/gov2/workflows/user_pools_and_lambda_triggers/actions/cognito_actions.go index 474e79616c0..58e08741e99 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/actions/cognito_actions.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/actions/cognito_actions.go @@ -41,8 +41,8 @@ type TriggerInfo struct { // UpdateTriggers adds or removes Lambda triggers for a user pool. When a trigger is specified with a `nil` value, // it is removed from the user pool. -func (actor CognitoActions) UpdateTriggers(userPoolId string, triggers ...TriggerInfo) error { - output, err := actor.CognitoClient.DescribeUserPool(context.TODO(), &cognitoidentityprovider.DescribeUserPoolInput{ +func (actor CognitoActions) UpdateTriggers(ctx context.Context, userPoolId string, triggers ...TriggerInfo) error { + output, err := actor.CognitoClient.DescribeUserPool(ctx, &cognitoidentityprovider.DescribeUserPoolInput{ UserPoolId: aws.String(userPoolId), }) if err != nil { @@ -60,7 +60,7 @@ func (actor CognitoActions) UpdateTriggers(userPoolId string, triggers ...Trigge lambdaConfig.PostAuthentication = trigger.HandlerArn } } - _, err = actor.CognitoClient.UpdateUserPool(context.TODO(), &cognitoidentityprovider.UpdateUserPoolInput{ + _, err = actor.CognitoClient.UpdateUserPool(ctx, &cognitoidentityprovider.UpdateUserPoolInput{ UserPoolId: aws.String(userPoolId), LambdaConfig: lambdaConfig, }) @@ -75,9 +75,9 @@ func (actor CognitoActions) UpdateTriggers(userPoolId string, triggers ...Trigge // snippet-start:[gov2.cognito-identity-provider.SignUp] // SignUp signs up a user with Amazon Cognito. -func (actor CognitoActions) SignUp(clientId string, userName string, password string, userEmail string) (bool, error) { +func (actor CognitoActions) SignUp(ctx context.Context, clientId string, userName string, password string, userEmail string) (bool, error) { confirmed := false - output, err := actor.CognitoClient.SignUp(context.TODO(), &cognitoidentityprovider.SignUpInput{ + output, err := actor.CognitoClient.SignUp(ctx, &cognitoidentityprovider.SignUpInput{ ClientId: aws.String(clientId), Password: aws.String(password), Username: aws.String(userName), @@ -103,9 +103,9 @@ func (actor CognitoActions) SignUp(clientId string, userName string, password st // snippet-start:[gov2.cognito-identity-provider.InitiateAuth] // SignIn signs in a user to Amazon Cognito using a username and password authentication flow. -func (actor CognitoActions) SignIn(clientId string, userName string, password string) (*types.AuthenticationResultType, error) { +func (actor CognitoActions) SignIn(ctx context.Context, clientId string, userName string, password string) (*types.AuthenticationResultType, error) { var authResult *types.AuthenticationResultType - output, err := actor.CognitoClient.InitiateAuth(context.TODO(), &cognitoidentityprovider.InitiateAuthInput{ + output, err := actor.CognitoClient.InitiateAuth(ctx, &cognitoidentityprovider.InitiateAuthInput{ AuthFlow: "USER_PASSWORD_AUTH", ClientId: aws.String(clientId), AuthParameters: map[string]string{"USERNAME": userName, "PASSWORD": password}, @@ -129,8 +129,8 @@ func (actor CognitoActions) SignIn(clientId string, userName string, password st // ForgotPassword starts a password recovery flow for a user. This flow typically sends a confirmation code // to the user's configured notification destination, such as email. -func (actor CognitoActions) ForgotPassword(clientId string, userName string) (*types.CodeDeliveryDetailsType, error) { - output, err := actor.CognitoClient.ForgotPassword(context.TODO(), &cognitoidentityprovider.ForgotPasswordInput{ +func (actor CognitoActions) ForgotPassword(ctx context.Context, clientId string, userName string) (*types.CodeDeliveryDetailsType, error) { + output, err := actor.CognitoClient.ForgotPassword(ctx, &cognitoidentityprovider.ForgotPasswordInput{ ClientId: aws.String(clientId), Username: aws.String(userName), }) @@ -145,8 +145,8 @@ func (actor CognitoActions) ForgotPassword(clientId string, userName string) (*t // snippet-start:[gov2.cognito-identity-provider.ConfirmForgotPassword] // ConfirmForgotPassword confirms a user with a confirmation code and a new password. -func (actor CognitoActions) ConfirmForgotPassword(clientId string, code string, userName string, password string) error { - _, err := actor.CognitoClient.ConfirmForgotPassword(context.TODO(), &cognitoidentityprovider.ConfirmForgotPasswordInput{ +func (actor CognitoActions) ConfirmForgotPassword(ctx context.Context, clientId string, code string, userName string, password string) error { + _, err := actor.CognitoClient.ConfirmForgotPassword(ctx, &cognitoidentityprovider.ConfirmForgotPasswordInput{ ClientId: aws.String(clientId), ConfirmationCode: aws.String(code), Password: aws.String(password), @@ -168,8 +168,8 @@ func (actor CognitoActions) ConfirmForgotPassword(clientId string, code string, // snippet-start:[gov2.cognito-identity-provider.DeleteUser] // DeleteUser removes a user from the user pool. -func (actor CognitoActions) DeleteUser(userAccessToken string) error { - _, err := actor.CognitoClient.DeleteUser(context.TODO(), &cognitoidentityprovider.DeleteUserInput{ +func (actor CognitoActions) DeleteUser(ctx context.Context, userAccessToken string) error { + _, err := actor.CognitoClient.DeleteUser(ctx, &cognitoidentityprovider.DeleteUserInput{ AccessToken: aws.String(userAccessToken), }) if err != nil { @@ -184,8 +184,8 @@ func (actor CognitoActions) DeleteUser(userAccessToken string) error { // AdminCreateUser uses administrator credentials to add a user to a user pool. This method leaves the user // in a state that requires they enter a new password next time they sign in. -func (actor CognitoActions) AdminCreateUser(userPoolId string, userName string, userEmail string) error { - _, err := actor.CognitoClient.AdminCreateUser(context.TODO(), &cognitoidentityprovider.AdminCreateUserInput{ +func (actor CognitoActions) AdminCreateUser(ctx context.Context, userPoolId string, userName string, userEmail string) error { + _, err := actor.CognitoClient.AdminCreateUser(ctx, &cognitoidentityprovider.AdminCreateUserInput{ UserPoolId: aws.String(userPoolId), Username: aws.String(userName), MessageAction: types.MessageActionTypeSuppress, @@ -209,8 +209,8 @@ func (actor CognitoActions) AdminCreateUser(userPoolId string, userName string, // AdminSetUserPassword uses administrator credentials to set a password for a user without requiring a // temporary password. -func (actor CognitoActions) AdminSetUserPassword(userPoolId string, userName string, password string) error { - _, err := actor.CognitoClient.AdminSetUserPassword(context.TODO(), &cognitoidentityprovider.AdminSetUserPasswordInput{ +func (actor CognitoActions) AdminSetUserPassword(ctx context.Context, userPoolId string, userName string, password string) error { + _, err := actor.CognitoClient.AdminSetUserPassword(ctx, &cognitoidentityprovider.AdminSetUserPasswordInput{ Password: aws.String(password), UserPoolId: aws.String(userPoolId), Username: aws.String(userName), diff --git a/gov2/workflows/user_pools_and_lambda_triggers/actions/dynamo_actions.go b/gov2/workflows/user_pools_and_lambda_triggers/actions/dynamo_actions.go index faea9823bd4..09fe45c9554 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/actions/dynamo_actions.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/actions/dynamo_actions.go @@ -51,7 +51,7 @@ func (users *UserList) UserNameList() []string { } // PopulateTable adds a set of test users to the table. -func (actor DynamoActions) PopulateTable(tableName string) error { +func (actor DynamoActions) PopulateTable(ctx context.Context, tableName string) error { var err error var item map[string]types.AttributeValue var writeReqs []types.WriteRequest @@ -63,7 +63,7 @@ func (actor DynamoActions) PopulateTable(tableName string) error { } writeReqs = append(writeReqs, types.WriteRequest{PutRequest: &types.PutRequest{Item: item}}) } - _, err = actor.DynamoClient.BatchWriteItem(context.TODO(), &dynamodb.BatchWriteItemInput{ + _, err = actor.DynamoClient.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{ RequestItems: map[string][]types.WriteRequest{tableName: writeReqs}, }) if err != nil { @@ -73,9 +73,9 @@ func (actor DynamoActions) PopulateTable(tableName string) error { } // Scan scans the table for all items. -func (actor DynamoActions) Scan(tableName string) (UserList, error) { +func (actor DynamoActions) Scan(ctx context.Context, tableName string) (UserList, error) { var userList UserList - output, err := actor.DynamoClient.Scan(context.TODO(), &dynamodb.ScanInput{ + output, err := actor.DynamoClient.Scan(ctx, &dynamodb.ScanInput{ TableName: aws.String(tableName), }) if err != nil { @@ -90,12 +90,12 @@ func (actor DynamoActions) Scan(tableName string) (UserList, error) { } // AddUser adds a user item to a table. -func (actor DynamoActions) AddUser(tableName string, user User) error { +func (actor DynamoActions) AddUser(ctx context.Context, tableName string, user User) error { userItem, err := attributevalue.MarshalMap(user) if err != nil { log.Printf("Couldn't marshall user to item. Here's why: %v\n", err) } - _, err = actor.DynamoClient.PutItem(context.TODO(), &dynamodb.PutItemInput{ + _, err = actor.DynamoClient.PutItem(ctx, &dynamodb.PutItemInput{ Item: userItem, TableName: aws.String(tableName), }) diff --git a/gov2/workflows/user_pools_and_lambda_triggers/cmd/main.go b/gov2/workflows/user_pools_and_lambda_triggers/cmd/main.go index f7e4b83d845..1c351042252 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/cmd/main.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/cmd/main.go @@ -27,7 +27,7 @@ import ( // - `activity_log` - Runs an interactive scenario that shows you how to use an Amazon Cognito // Lambda trigger to log custom activity data. func main() { - scenarioMap := map[string]func(sdkConfig aws.Config, questioner demotools.IQuestioner, helper workflows.IScenarioHelper, stack string){ + scenarioMap := map[string]func(ctx context.Context, sdkConfig aws.Config, questioner demotools.IQuestioner, helper workflows.IScenarioHelper, stack string){ "auto_confirm": runAutoConfirmScenario, "migrate_user": runMigrateUserScenario, "activity_log": runActivityLogScenario, @@ -50,7 +50,8 @@ func main() { fmt.Printf("'%v' is not a valid scenario.\n", *scenario) flag.Usage() } else { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -58,24 +59,24 @@ func main() { log.SetFlags(0) questioner := demotools.NewQuestioner() helper := workflows.NewScenarioHelper(sdkConfig, questioner) - runScenario(sdkConfig, questioner, helper, *stack) + runScenario(ctx, sdkConfig, questioner, helper, *stack) } } -func runAutoConfirmScenario(sdkConfig aws.Config, questioner demotools.IQuestioner, helper workflows.IScenarioHelper, +func runAutoConfirmScenario(ctx context.Context, sdkConfig aws.Config, questioner demotools.IQuestioner, helper workflows.IScenarioHelper, stack string) { workflow := workflows.NewAutoConfirm(sdkConfig, questioner, helper) - workflow.Run(stack) + workflow.Run(ctx, stack) } -func runMigrateUserScenario(sdkConfig aws.Config, questioner demotools.IQuestioner, helper workflows.IScenarioHelper, +func runMigrateUserScenario(ctx context.Context, sdkConfig aws.Config, questioner demotools.IQuestioner, helper workflows.IScenarioHelper, stack string) { workflow := workflows.NewMigrateUser(sdkConfig, questioner, helper) - workflow.Run(stack) + workflow.Run(ctx, stack) } -func runActivityLogScenario(sdkConfig aws.Config, questioner demotools.IQuestioner, helper workflows.IScenarioHelper, +func runActivityLogScenario(ctx context.Context, sdkConfig aws.Config, questioner demotools.IQuestioner, helper workflows.IScenarioHelper, stack string) { workflow := workflows.NewActivityLog(sdkConfig, questioner, helper) - workflow.Run(stack) + workflow.Run(ctx, stack) } diff --git a/gov2/workflows/user_pools_and_lambda_triggers/handlers/activity_log/activity_log_handler.go b/gov2/workflows/user_pools_and_lambda_triggers/handlers/activity_log/activity_log_handler.go index f8d6268455d..d81649749a5 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/handlers/activity_log/activity_log_handler.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/handlers/activity_log/activity_log_handler.go @@ -89,7 +89,8 @@ func (h *handler) HandleRequest(ctx context.Context, event events.CognitoEventUs } func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Panicln(err) } diff --git a/gov2/workflows/user_pools_and_lambda_triggers/handlers/auto_confirm/auto_confirm_handler.go b/gov2/workflows/user_pools_and_lambda_triggers/handlers/auto_confirm/auto_confirm_handler.go index 7f55b711d62..c451709e0fa 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/handlers/auto_confirm/auto_confirm_handler.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/handlers/auto_confirm/auto_confirm_handler.go @@ -85,7 +85,8 @@ func (h *handler) HandleRequest(ctx context.Context, event events.CognitoEventUs } func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Panicln(err) } diff --git a/gov2/workflows/user_pools_and_lambda_triggers/handlers/migrate_user/migrate_user_handler.go b/gov2/workflows/user_pools_and_lambda_triggers/handlers/migrate_user/migrate_user_handler.go index 4290ad2112e..38684015d80 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/handlers/migrate_user/migrate_user_handler.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/handlers/migrate_user/migrate_user_handler.go @@ -84,7 +84,8 @@ func (h *handler) HandleRequest(ctx context.Context, event events.CognitoEventUs } func main() { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Panicln(err) } diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/resources.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/resources.go index 1fbeb59517e..76ac23af69a 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/resources.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/resources.go @@ -4,6 +4,7 @@ package workflows import ( + "context" "log" "user_pools_and_lambda_triggers/actions" @@ -31,7 +32,7 @@ func (resources *Resources) init(cognitoActor *actions.CognitoActions, questione } // Cleanup deletes all AWS resources created during an example. -func (resources *Resources) Cleanup() { +func (resources *Resources) Cleanup(ctx context.Context) { defer func() { if r := recover(); r != nil { log.Printf("Something went wrong during cleanup.\n%v\n", r) @@ -44,7 +45,7 @@ func (resources *Resources) Cleanup() { "during this demo (y/n)?", "y") if wantDelete { for _, accessToken := range resources.userAccessTokens { - err := resources.cognitoActor.DeleteUser(accessToken) + err := resources.cognitoActor.DeleteUser(ctx, accessToken) if err != nil { log.Println("Couldn't delete user during cleanup.") panic(err) @@ -55,7 +56,7 @@ func (resources *Resources) Cleanup() { for i := 0; i < len(resources.triggers); i++ { triggerList[i] = actions.TriggerInfo{Trigger: resources.triggers[i], HandlerArn: nil} } - err := resources.cognitoActor.UpdateTriggers(resources.userPoolId, triggerList...) + err := resources.cognitoActor.UpdateTriggers(ctx, resources.userPoolId, triggerList...) if err != nil { log.Println("Couldn't update Cognito triggers during cleanup.") panic(err) diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log.go index b73554b7d39..403d0cb8354 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log.go @@ -4,6 +4,7 @@ package workflows import ( + "context" "errors" "log" "strings" @@ -39,15 +40,15 @@ func NewActivityLog(sdkConfig aws.Config, questioner demotools.IQuestioner, help } // AddUserToPool selects a user from the known users table and uses administrator credentials to add the user to the user pool. -func (runner *ActivityLog) AddUserToPool(userPoolId string, tableName string) (string, string) { +func (runner *ActivityLog) AddUserToPool(ctx context.Context, userPoolId string, tableName string) (string, string) { log.Println("To facilitate this example, let's add a user to the user pool using administrator privileges.") - users, err := runner.helper.GetKnownUsers(tableName) + users, err := runner.helper.GetKnownUsers(ctx, tableName) if err != nil { panic(err) } user := users.Users[0] log.Printf("Adding known user %v to the user pool.\n", user.UserName) - err = runner.cognitoActor.AdminCreateUser(userPoolId, user.UserName, user.UserEmail) + err = runner.cognitoActor.AdminCreateUser(ctx, userPoolId, user.UserName, user.UserEmail) if err != nil { panic(err) } @@ -56,7 +57,7 @@ func (runner *ActivityLog) AddUserToPool(userPoolId string, tableName string) (s "(the password will not display as you type):", 8) for !pwSet { log.Printf("\nSetting password for user '%v'.\n", user.UserName) - err = runner.cognitoActor.AdminSetUserPassword(userPoolId, user.UserName, password) + err = runner.cognitoActor.AdminSetUserPassword(ctx, userPoolId, user.UserName, password) if err != nil { var invalidPassword *types.InvalidPasswordException if errors.As(err, &invalidPassword) { @@ -75,12 +76,12 @@ func (runner *ActivityLog) AddUserToPool(userPoolId string, tableName string) (s } // AddActivityLogTrigger adds a Lambda handler as an invocation target for the PostAuthentication trigger. -func (runner *ActivityLog) AddActivityLogTrigger(userPoolId string, activityLogArn string) { +func (runner *ActivityLog) AddActivityLogTrigger(ctx context.Context, userPoolId string, activityLogArn string) { log.Println("Let's add a Lambda function to handle the PostAuthentication trigger from Cognito.\n" + "This trigger happens after a user is authenticated, and lets your function take action, such as logging\n" + "the outcome.") err := runner.cognitoActor.UpdateTriggers( - userPoolId, + ctx, userPoolId, actions.TriggerInfo{Trigger: actions.PostAuthentication, HandlerArn: aws.String(activityLogArn)}) if err != nil { panic(err) @@ -93,10 +94,10 @@ func (runner *ActivityLog) AddActivityLogTrigger(userPoolId string, activityLogA } // SignInUser signs in as the specified user. -func (runner *ActivityLog) SignInUser(clientId string, userName string, password string) { +func (runner *ActivityLog) SignInUser(ctx context.Context, clientId string, userName string, password string) { log.Printf("Now we'll sign in user %v and check the results in the logs and the DynamoDB table.", userName) runner.questioner.Ask("Press Enter when you're ready.") - authResult, err := runner.cognitoActor.SignIn(clientId, userName, password) + authResult, err := runner.cognitoActor.SignIn(ctx, clientId, userName, password) if err != nil { panic(err) } @@ -107,10 +108,10 @@ func (runner *ActivityLog) SignInUser(clientId string, userName string, password } // GetKnownUserLastLogin gets the login info for a user from the Amazon DynamoDB table and displays it. -func (runner *ActivityLog) GetKnownUserLastLogin(tableName string, userName string) { +func (runner *ActivityLog) GetKnownUserLastLogin(ctx context.Context, tableName string, userName string) { log.Println("The PostAuthentication handler also writes login data to the DynamoDB table.") runner.questioner.Ask("Press Enter when you're ready to continue.") - users, err := runner.helper.GetKnownUsers(tableName) + users, err := runner.helper.GetKnownUsers(ctx, tableName) if err != nil { panic(err) } @@ -124,11 +125,11 @@ func (runner *ActivityLog) GetKnownUserLastLogin(tableName string, userName stri } // Run runs the scenario. -func (runner *ActivityLog) Run(stackName string) { +func (runner *ActivityLog) Run(ctx context.Context, stackName string) { defer func() { if r := recover(); r != nil { log.Println("Something went wrong with the demo.") - runner.resources.Cleanup() + runner.resources.Cleanup(ctx) } }() @@ -137,20 +138,20 @@ func (runner *ActivityLog) Run(stackName string) { log.Println(strings.Repeat("-", 88)) - stackOutputs, err := runner.helper.GetStackOutputs(stackName) + stackOutputs, err := runner.helper.GetStackOutputs(ctx, stackName) if err != nil { panic(err) } runner.resources.userPoolId = stackOutputs["UserPoolId"] - runner.helper.PopulateUserTable(stackOutputs["TableName"]) - userName, password := runner.AddUserToPool(stackOutputs["UserPoolId"], stackOutputs["TableName"]) + runner.helper.PopulateUserTable(ctx, stackOutputs["TableName"]) + userName, password := runner.AddUserToPool(ctx, stackOutputs["UserPoolId"], stackOutputs["TableName"]) - runner.AddActivityLogTrigger(stackOutputs["UserPoolId"], stackOutputs["ActivityLogFunctionArn"]) - runner.SignInUser(stackOutputs["UserPoolClientId"], userName, password) - runner.helper.ListRecentLogEvents(stackOutputs["ActivityLogFunction"]) - runner.GetKnownUserLastLogin(stackOutputs["TableName"], userName) + runner.AddActivityLogTrigger(ctx, stackOutputs["UserPoolId"], stackOutputs["ActivityLogFunctionArn"]) + runner.SignInUser(ctx, stackOutputs["UserPoolClientId"], userName, password) + runner.helper.ListRecentLogEvents(ctx, stackOutputs["ActivityLogFunction"]) + runner.GetKnownUserLastLogin(ctx, stackOutputs["TableName"], userName) - runner.resources.Cleanup() + runner.resources.Cleanup(ctx) log.Println(strings.Repeat("-", 88)) log.Println("Thanks for watching!") diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log_integ_test.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log_integ_test.go index 0130051422e..f0f8c2b90b5 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log_integ_test.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log_integ_test.go @@ -31,7 +31,8 @@ func TestRunActivityLogScenario_Integration(t *testing.T) { }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -44,7 +45,7 @@ func TestRunActivityLogScenario_Integration(t *testing.T) { NewScenarioHelper(sdkConfig, mockQuestioner), } scenario := NewActivityLog(sdkConfig, mockQuestioner, &helper) - scenario.Run("PoolsAndTriggersStackForGo") + scenario.Run(ctx, "PoolsAndTriggersStackForGo") log.SetOutput(os.Stderr) if !strings.Contains(buf.String(), "Thanks for watching") { diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log_test.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log_test.go index d83dc7ee9a5..0f9a0bd0001 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log_test.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_activity_log_test.go @@ -4,6 +4,7 @@ package workflows import ( + "context" "fmt" "testing" "user_pools_and_lambda_triggers/stubs" @@ -130,7 +131,7 @@ func (scenTest *ActivityLogScenarioTest) RunSubTest(stubber *testtools.AwsmStubb helper := NewScenarioHelper(*stubber.SdkConfig, &mockQuestioner) helper.isTestRun = true scenario := NewActivityLog(*stubber.SdkConfig, &mockQuestioner, &helper) - scenario.Run(scenTest.stackName) + scenario.Run(context.Background(), scenTest.stackName) } func (scenTest *ActivityLogScenarioTest) Cleanup() {} diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts.go index 96bc2293689..bf2c8a8b46d 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts.go @@ -4,6 +4,7 @@ package workflows import ( + "context" "errors" "log" "strings" @@ -39,12 +40,12 @@ func NewAutoConfirm(sdkConfig aws.Config, questioner demotools.IQuestioner, help } // AddPreSignUpTrigger adds a Lambda handler as an invocation target for the PreSignUp trigger. -func (runner *AutoConfirm) AddPreSignUpTrigger(userPoolId string, functionArn string) { +func (runner *AutoConfirm) AddPreSignUpTrigger(ctx context.Context, userPoolId string, functionArn string) { log.Printf("Let's add a Lambda function to handle the PreSignUp trigger from Cognito.\n" + "This trigger happens when a user signs up, and lets your function take action before the main Cognito\n" + "sign up processing occurs.\n") err := runner.cognitoActor.UpdateTriggers( - userPoolId, + ctx, userPoolId, actions.TriggerInfo{Trigger: actions.PreSignUp, HandlerArn: aws.String(functionArn)}) if err != nil { panic(err) @@ -54,11 +55,11 @@ func (runner *AutoConfirm) AddPreSignUpTrigger(userPoolId string, functionArn st } // SignUpUser signs up a user from the known user table with a password you specify. -func (runner *AutoConfirm) SignUpUser(clientId string, usersTable string) (string, string) { +func (runner *AutoConfirm) SignUpUser(ctx context.Context, clientId string, usersTable string) (string, string) { log.Println("Let's sign up a user to your Cognito user pool. When the user's email matches an email in the\n" + "DynamoDB known users table, it is automatically verified and the user is confirmed.") - knownUsers, err := runner.helper.GetKnownUsers(usersTable) + knownUsers, err := runner.helper.GetKnownUsers(ctx, usersTable) if err != nil { panic(err) } @@ -71,7 +72,7 @@ func (runner *AutoConfirm) SignUpUser(clientId string, usersTable string) (strin "(the password will not display as you type):", 8) for !signedUp { log.Printf("Signing up user '%v' with email '%v' to Cognito.\n", user.UserName, user.UserEmail) - userConfirmed, err = runner.cognitoActor.SignUp(clientId, user.UserName, password, user.UserEmail) + userConfirmed, err = runner.cognitoActor.SignUp(ctx, clientId, user.UserName, password, user.UserEmail) if err != nil { var invalidPassword *types.InvalidPasswordException if errors.As(err, &invalidPassword) { @@ -91,10 +92,10 @@ func (runner *AutoConfirm) SignUpUser(clientId string, usersTable string) (strin } // SignInUser signs in a user. -func (runner *AutoConfirm) SignInUser(clientId string, userName string, password string) string { +func (runner *AutoConfirm) SignInUser(ctx context.Context, clientId string, userName string, password string) string { runner.questioner.Ask("Press Enter when you're ready to continue.") log.Printf("Let's sign in as %v...\n", userName) - authResult, err := runner.cognitoActor.SignIn(clientId, userName, password) + authResult, err := runner.cognitoActor.SignIn(ctx, clientId, userName, password) if err != nil { panic(err) } @@ -104,11 +105,11 @@ func (runner *AutoConfirm) SignInUser(clientId string, userName string, password } // Run runs the scenario. -func (runner *AutoConfirm) Run(stackName string) { +func (runner *AutoConfirm) Run(ctx context.Context, stackName string) { defer func() { if r := recover(); r != nil { log.Println("Something went wrong with the demo.") - runner.resources.Cleanup() + runner.resources.Cleanup(ctx) } }() @@ -117,21 +118,21 @@ func (runner *AutoConfirm) Run(stackName string) { log.Println(strings.Repeat("-", 88)) - stackOutputs, err := runner.helper.GetStackOutputs(stackName) + stackOutputs, err := runner.helper.GetStackOutputs(ctx, stackName) if err != nil { panic(err) } runner.resources.userPoolId = stackOutputs["UserPoolId"] - runner.helper.PopulateUserTable(stackOutputs["TableName"]) + runner.helper.PopulateUserTable(ctx, stackOutputs["TableName"]) - runner.AddPreSignUpTrigger(stackOutputs["UserPoolId"], stackOutputs["AutoConfirmFunctionArn"]) + runner.AddPreSignUpTrigger(ctx, stackOutputs["UserPoolId"], stackOutputs["AutoConfirmFunctionArn"]) runner.resources.triggers = append(runner.resources.triggers, actions.PreSignUp) - userName, password := runner.SignUpUser(stackOutputs["UserPoolClientId"], stackOutputs["TableName"]) - runner.helper.ListRecentLogEvents(stackOutputs["AutoConfirmFunction"]) + userName, password := runner.SignUpUser(ctx, stackOutputs["UserPoolClientId"], stackOutputs["TableName"]) + runner.helper.ListRecentLogEvents(ctx, stackOutputs["AutoConfirmFunction"]) runner.resources.userAccessTokens = append(runner.resources.userAccessTokens, - runner.SignInUser(stackOutputs["UserPoolClientId"], userName, password)) + runner.SignInUser(ctx, stackOutputs["UserPoolClientId"], userName, password)) - runner.resources.Cleanup() + runner.resources.Cleanup(ctx) log.Println(strings.Repeat("-", 88)) log.Println("Thanks for watching!") diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts_integ_test.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts_integ_test.go index b9573b9304a..38628c2c2e6 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts_integ_test.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts_integ_test.go @@ -43,7 +43,8 @@ func TestRunAutoConfirmScenario_Integration(t *testing.T) { }, } - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) + ctx := context.Background() + sdkConfig, err := config.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("unable to load SDK config, %v", err) } @@ -56,7 +57,7 @@ func TestRunAutoConfirmScenario_Integration(t *testing.T) { NewScenarioHelper(sdkConfig, mockQuestioner), } scenario := NewAutoConfirm(sdkConfig, mockQuestioner, &helper) - scenario.Run("PoolsAndTriggersStackForGo") + scenario.Run(ctx, "PoolsAndTriggersStackForGo") log.SetOutput(os.Stderr) if !strings.Contains(buf.String(), "Thanks for watching") { diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts_test.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts_test.go index ff69c4fb90d..6c196eb2e05 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts_test.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_auto_confirm_trusted_accounts_test.go @@ -4,6 +4,7 @@ package workflows import ( + "context" "fmt" "testing" "user_pools_and_lambda_triggers/stubs" @@ -114,7 +115,7 @@ func (scenTest *AutoConfirmScenarioTest) RunSubTest(stubber *testtools.AwsmStubb helper := NewScenarioHelper(*stubber.SdkConfig, &mockQuestioner) helper.isTestRun = true scenario := NewAutoConfirm(*stubber.SdkConfig, &mockQuestioner, &helper) - scenario.Run(scenTest.stackName) + scenario.Run(context.Background(), scenTest.stackName) } func (scenTest *AutoConfirmScenarioTest) Cleanup() {} diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_common.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_common.go index 6a2dd06fa80..6dbc76fc937 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_common.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_common.go @@ -4,6 +4,7 @@ package workflows import ( + "context" "log" "strings" "time" @@ -21,11 +22,11 @@ import ( // IScenarioHelper defines common functions used by the workflows in this example. type IScenarioHelper interface { Pause(secs int) - GetStackOutputs(stackName string) (actions.StackOutputs, error) - PopulateUserTable(tableName string) - GetKnownUsers(tableName string) (actions.UserList, error) - AddKnownUser(tableName string, user actions.User) - ListRecentLogEvents(functionName string) + GetStackOutputs(ctx context.Context, stackName string) (actions.StackOutputs, error) + PopulateUserTable(ctx context.Context, tableName string) + GetKnownUsers(ctx context.Context, tableName string) (actions.UserList, error) + AddKnownUser(ctx context.Context, tableName string, user actions.User) + ListRecentLogEvents(ctx context.Context, functionName string) } // ScenarioHelper contains AWS wrapper structs used by the workflows in this example. @@ -56,22 +57,22 @@ func (helper ScenarioHelper) Pause(secs int) { } // GetStackOutputs gets the outputs from the specified CloudFormation stack in a structured format. -func (helper ScenarioHelper) GetStackOutputs(stackName string) (actions.StackOutputs, error) { - return helper.cfnActor.GetOutputs(stackName), nil +func (helper ScenarioHelper) GetStackOutputs(ctx context.Context, stackName string) (actions.StackOutputs, error) { + return helper.cfnActor.GetOutputs(ctx, stackName), nil } // PopulateUserTable fills the known user table with example data. -func (helper ScenarioHelper) PopulateUserTable(tableName string) { +func (helper ScenarioHelper) PopulateUserTable(ctx context.Context, tableName string) { log.Printf("First, let's add some users to the DynamoDB %v table we'll use for this example.\n", tableName) - err := helper.dynamoActor.PopulateTable(tableName) + err := helper.dynamoActor.PopulateTable(ctx, tableName) if err != nil { panic(err) } } // GetKnownUsers gets the users from the known users table in a structured format. -func (helper ScenarioHelper) GetKnownUsers(tableName string) (actions.UserList, error) { - knownUsers, err := helper.dynamoActor.Scan(tableName) +func (helper ScenarioHelper) GetKnownUsers(ctx context.Context, tableName string) (actions.UserList, error) { + knownUsers, err := helper.dynamoActor.Scan(ctx, tableName) if err != nil { log.Printf("Couldn't get known users from table %v. Here's why: %v\n", tableName, err) } @@ -79,26 +80,26 @@ func (helper ScenarioHelper) GetKnownUsers(tableName string) (actions.UserList, } // AddKnownUser adds a user to the known users table. -func (helper ScenarioHelper) AddKnownUser(tableName string, user actions.User) { +func (helper ScenarioHelper) AddKnownUser(ctx context.Context, tableName string, user actions.User) { log.Printf("Adding user '%v' with email '%v' to the DynamoDB known users table...\n", user.UserName, user.UserEmail) - err := helper.dynamoActor.AddUser(tableName, user) + err := helper.dynamoActor.AddUser(ctx, tableName, user) if err != nil { panic(err) } } // ListRecentLogEvents gets the most recent log stream and events for the specified Lambda function and displays them. -func (helper ScenarioHelper) ListRecentLogEvents(functionName string) { +func (helper ScenarioHelper) ListRecentLogEvents(ctx context.Context, functionName string) { log.Println("Waiting a few seconds to let Lambda write to CloudWatch Logs...") helper.Pause(10) log.Println("Okay, let's check the logs to find what's happened recently with your Lambda function.") - logStream, err := helper.cwlActor.GetLatestLogStream(functionName) + logStream, err := helper.cwlActor.GetLatestLogStream(ctx, functionName) if err != nil { panic(err) } log.Printf("Getting some recent events from log stream %v\n", *logStream.LogStreamName) - events, err := helper.cwlActor.GetLogEvents(functionName, *logStream.LogStreamName, 10) + events, err := helper.cwlActor.GetLogEvents(ctx, functionName, *logStream.LogStreamName, 10) if err != nil { panic(err) } diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_migrate_user.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_migrate_user.go index 283d1e01f56..b69bbc5e82b 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_migrate_user.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_migrate_user.go @@ -6,6 +6,7 @@ package workflows // snippet-start:[gov2.workflows.PoolsAndTriggers.MigrateUser] import ( + "context" "errors" "fmt" "log" @@ -40,12 +41,12 @@ func NewMigrateUser(sdkConfig aws.Config, questioner demotools.IQuestioner, help } // AddMigrateUserTrigger adds a Lambda handler as an invocation target for the MigrateUser trigger. -func (runner *MigrateUser) AddMigrateUserTrigger(userPoolId string, functionArn string) { +func (runner *MigrateUser) AddMigrateUserTrigger(ctx context.Context, userPoolId string, functionArn string) { log.Printf("Let's add a Lambda function to handle the MigrateUser trigger from Cognito.\n" + "This trigger happens when an unknown user signs in, and lets your function take action before Cognito\n" + "rejects the user.\n\n") err := runner.cognitoActor.UpdateTriggers( - userPoolId, + ctx, userPoolId, actions.TriggerInfo{Trigger: actions.UserMigration, HandlerArn: aws.String(functionArn)}) if err != nil { panic(err) @@ -57,7 +58,7 @@ func (runner *MigrateUser) AddMigrateUserTrigger(userPoolId string, functionArn } // SignInUser adds a new user to the known users table and signs that user in to Amazon Cognito. -func (runner *MigrateUser) SignInUser(usersTable string, clientId string) (bool, actions.User) { +func (runner *MigrateUser) SignInUser(ctx context.Context, usersTable string, clientId string) (bool, actions.User) { log.Println("Let's sign in a user to your Cognito user pool. When the username and email matches an entry in the\n" + "DynamoDB known users table, the email is automatically verified and the user is migrated to the Cognito user pool.") @@ -66,7 +67,7 @@ func (runner *MigrateUser) SignInUser(usersTable string, clientId string) (bool, user.UserEmail = runner.questioner.Ask("\nEnter an email that you own. This email will be used to confirm user migration\n" + "during this example:") - runner.helper.AddKnownUser(usersTable, user) + runner.helper.AddKnownUser(ctx, usersTable, user) var err error var resetRequired *types.PasswordResetRequiredException @@ -74,7 +75,7 @@ func (runner *MigrateUser) SignInUser(usersTable string, clientId string) (bool, signedIn := false for !signedIn && resetRequired == nil { log.Printf("Signing in to Cognito as user '%v'. The expected result is a PasswordResetRequiredException.\n\n", user.UserName) - authResult, err = runner.cognitoActor.SignIn(clientId, user.UserName, "_") + authResult, err = runner.cognitoActor.SignIn(ctx, clientId, user.UserName, "_") if err != nil { if errors.As(err, &resetRequired) { log.Printf("\nUser '%v' is not in the Cognito user pool but was found in the DynamoDB known users table.\n"+ @@ -97,7 +98,7 @@ func (runner *MigrateUser) SignInUser(usersTable string, clientId string) (bool, } // ResetPassword starts a password recovery flow. -func (runner *MigrateUser) ResetPassword(clientId string, user actions.User) { +func (runner *MigrateUser) ResetPassword(ctx context.Context, clientId string, user actions.User) { wantCode := runner.questioner.AskBool(fmt.Sprintf("In order to migrate the user to Cognito, you must be able to receive a confirmation\n"+ "code by email at %v. Do you want to send a code (y/n)?", user.UserEmail), "y") if !wantCode { @@ -105,7 +106,7 @@ func (runner *MigrateUser) ResetPassword(clientId string, user actions.User) { "you own that can receive a confirmation code.") return } - codeDelivery, err := runner.cognitoActor.ForgotPassword(clientId, user.UserName) + codeDelivery, err := runner.cognitoActor.ForgotPassword(ctx, clientId, user.UserName) if err != nil { panic(err) } @@ -117,7 +118,7 @@ func (runner *MigrateUser) ResetPassword(clientId string, user actions.User) { "(the password will not display as you type):", 8) for !confirmed { log.Printf("\nConfirming password reset for user '%v'.\n", user.UserName) - err = runner.cognitoActor.ConfirmForgotPassword(clientId, code, user.UserName, password) + err = runner.cognitoActor.ConfirmForgotPassword(ctx, clientId, code, user.UserName, password) if err != nil { var invalidPassword *types.InvalidPasswordException if errors.As(err, &invalidPassword) { @@ -131,7 +132,7 @@ func (runner *MigrateUser) ResetPassword(clientId string, user actions.User) { } log.Printf("User '%v' successfully confirmed and migrated.\n", user.UserName) log.Println("Signing in with your username and password...") - authResult, err := runner.cognitoActor.SignIn(clientId, user.UserName, password) + authResult, err := runner.cognitoActor.SignIn(ctx, clientId, user.UserName, password) if err != nil { panic(err) } @@ -142,11 +143,11 @@ func (runner *MigrateUser) ResetPassword(clientId string, user actions.User) { } // Run runs the scenario. -func (runner *MigrateUser) Run(stackName string) { +func (runner *MigrateUser) Run(ctx context.Context, stackName string) { defer func() { if r := recover(); r != nil { log.Println("Something went wrong with the demo.") - runner.resources.Cleanup() + runner.resources.Cleanup(ctx) } }() @@ -155,21 +156,21 @@ func (runner *MigrateUser) Run(stackName string) { log.Println(strings.Repeat("-", 88)) - stackOutputs, err := runner.helper.GetStackOutputs(stackName) + stackOutputs, err := runner.helper.GetStackOutputs(ctx, stackName) if err != nil { panic(err) } runner.resources.userPoolId = stackOutputs["UserPoolId"] - runner.AddMigrateUserTrigger(stackOutputs["UserPoolId"], stackOutputs["MigrateUserFunctionArn"]) + runner.AddMigrateUserTrigger(ctx, stackOutputs["UserPoolId"], stackOutputs["MigrateUserFunctionArn"]) runner.resources.triggers = append(runner.resources.triggers, actions.UserMigration) - resetNeeded, user := runner.SignInUser(stackOutputs["TableName"], stackOutputs["UserPoolClientId"]) + resetNeeded, user := runner.SignInUser(ctx, stackOutputs["TableName"], stackOutputs["UserPoolClientId"]) if resetNeeded { - runner.helper.ListRecentLogEvents(stackOutputs["MigrateUserFunction"]) - runner.ResetPassword(stackOutputs["UserPoolClientId"], user) + runner.helper.ListRecentLogEvents(ctx, stackOutputs["MigrateUserFunction"]) + runner.ResetPassword(ctx, stackOutputs["UserPoolClientId"], user) } - runner.resources.Cleanup() + runner.resources.Cleanup(ctx) log.Println(strings.Repeat("-", 88)) log.Println("Thanks for watching!") diff --git a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_migrate_user_test.go b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_migrate_user_test.go index c68f2a006bb..8f9e924b7dc 100644 --- a/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_migrate_user_test.go +++ b/gov2/workflows/user_pools_and_lambda_triggers/workflows/scenario_migrate_user_test.go @@ -4,6 +4,7 @@ package workflows import ( + "context" "testing" "user_pools_and_lambda_triggers/stubs" @@ -106,7 +107,7 @@ func (scenTest *MigrateUserScenarioTest) RunSubTest(stubber *testtools.AwsmStubb helper := NewScenarioHelper(*stubber.SdkConfig, &mockQuestioner) helper.isTestRun = true scenario := NewMigrateUser(*stubber.SdkConfig, &mockQuestioner, &helper) - scenario.Run(scenTest.stackName) + scenario.Run(context.Background(), scenTest.stackName) } func (scenTest *MigrateUserScenarioTest) Cleanup() {}