Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(aws-stepfunctions): refactor sagemaker tasks and fix default role issue #3014

Merged
merged 17 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,15 @@ export interface ResourceConfig {
* @experimental
*/
export interface VpcConfig {
/**
* VPC security groups.
*/
readonly securityGroups: ec2.ISecurityGroup[];

/**
* VPC id
*/
readonly vpc: ec2.Vpc;
readonly vpc: ec2.IVpc;

/**
* VPC subnets.
*/
readonly subnets: ec2.ISubnet[];
readonly subnets?: ec2.SubnetSelection;
}

/**
Expand Down
126 changes: 84 additions & 42 deletions packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ec2 = require('@aws-cdk/aws-ec2');
import iam = require('@aws-cdk/aws-iam');
import sfn = require('@aws-cdk/aws-stepfunctions');
import { Construct, Duration, Stack } from '@aws-cdk/core';
import { Duration, Lazy, Stack } from '@aws-cdk/core';
import { resourceArnSuffix } from './resource-arn-suffix';
import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig,
S3DataType, StoppingCondition, VpcConfig, } from './sagemaker-task-base-types';
Expand Down Expand Up @@ -53,7 +53,7 @@ export interface SagemakerTrainTaskProps {
/**
* Tags to be applied to the train job.
*/
readonly tags?: {[key: string]: any};
readonly tags?: {[key: string]: string};

/**
* Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training.
Expand Down Expand Up @@ -88,15 +88,6 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
*/
public readonly connections: ec2.Connections = new ec2.Connections();

/**
* The execution role for the Sagemaker training job.
*
* @default new role for Amazon SageMaker to assume is automatically created.
*/
public readonly role: iam.IRole;

public readonly grantPrincipal: iam.IPrincipal;

/**
* The Algorithm Specification
*/
Expand All @@ -117,9 +108,15 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
*/
private readonly stoppingCondition: StoppingCondition;

private readonly vpc: ec2.IVpc;
private securityGroup: ec2.ISecurityGroup;
private readonly securityGroups: ec2.ISecurityGroup[] = [];
private readonly subnets: string[];
private readonly integrationPattern: sfn.ServiceIntegrationPattern;
private _role?: iam.IRole;
private _grantPrincipal?: iam.IPrincipal;

constructor(scope: Construct, private readonly props: SagemakerTrainTaskProps) {
constructor(private readonly props: SagemakerTrainTaskProps) {
this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET;

const supportedPatterns = [
Expand All @@ -143,8 +140,66 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
maxRuntime: Duration.hours(1)
};

// check that either algorithm name or image is defined
if ((!props.algorithmSpecification.algorithmName) && (!props.algorithmSpecification.trainingImage)) {
throw new Error("Must define either an algorithm name or training image URI in the algorithm specification");
}

// set the input mode to 'File' if not defined
this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ?
( props.algorithmSpecification ) :
( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } );

// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
this.inputDataConfig = props.inputDataConfig.map(config => {
if (!config.dataSource.s3DataSource.s3DataType) {
return Object.assign({}, config, { dataSource: { s3DataSource:
{ ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } });
} else {
return config;
}
});

// add the security groups to the connections object
if (props.vpcConfig) {
mattmcclean marked this conversation as resolved.
Show resolved Hide resolved
this.vpc = props.vpcConfig.vpc;
this.subnets = (props.vpcConfig.subnets) ?
(this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds) : this.vpc.selectSubnets().subnetIds;
}
}

/**
* The execution role for the Sagemaker training job.
*
* Only available after task has been added to a state machine.
*/
public get role(): iam.IRole {
if (this._role === undefined) {
throw new Error(`role not available yet--use the object in a Task first`);
}
return this._role;
}

public get grantPrincipal(): iam.IPrincipal {
if (this._grantPrincipal === undefined) {
throw new Error(`Principal not available yet--use the object in a Task first`);
}
return this._grantPrincipal;
}

/**
* Add the security group to all instances via the launch configuration
* security groups array.
*
* @param securityGroup: The security group to add
*/
public addSecurityGroup(securityGroup: ec2.ISecurityGroup): void {
this.securityGroups.push(securityGroup);
mattmcclean marked this conversation as resolved.
Show resolved Hide resolved
}

public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
// set the sagemaker role or create new one
this.grantPrincipal = this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
this._grantPrincipal = this._role = this.props.role || new iam.Role(task, 'SagemakerRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
inlinePolicies: {
CreateTrainingJob: new iam.PolicyDocument({
Expand All @@ -157,7 +212,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
'logs:CreateLogGroup',
'logs:DescribeLogStreams',
'ecr:GetAuthorizationToken',
...props.vpcConfig
...this.props.vpcConfig
? [
'ec2:CreateNetworkInterface',
'ec2:CreateNetworkInterfacePermission',
Expand All @@ -178,36 +233,23 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
}
});

if (props.outputDataConfig.encryptionKey) {
props.outputDataConfig.encryptionKey.grantEncrypt(this.role);
if (this.props.outputDataConfig.encryptionKey) {
this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role);
}

if (props.resourceConfig && props.resourceConfig.volumeEncryptionKey) {
props.resourceConfig.volumeEncryptionKey.grant(this.role, 'kms:CreateGrant');
if (this.props.resourceConfig && this.props.resourceConfig.volumeEncryptionKey) {
this.props.resourceConfig.volumeEncryptionKey.grant(this._role, 'kms:CreateGrant');
}

// set the input mode to 'File' if not defined
this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ?
( props.algorithmSpecification ) :
( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } );

// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
this.inputDataConfig = props.inputDataConfig.map(config => {
if (!config.dataSource.s3DataSource.s3DataType) {
return Object.assign({}, config, { dataSource: { s3DataSource:
{ ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } });
} else {
return config;
}
});

// add the security groups to the connections object
if (this.props.vpcConfig) {
this.props.vpcConfig.securityGroups.forEach(sg => this.connections.addSecurityGroup(sg));
// create a security group if not defined
if (this.vpc && this.securityGroup === undefined) {
this.securityGroup = new ec2.SecurityGroup(task, 'TrainJobSecurityGroup', {
vpc: this.vpc
});
this.connections.addSecurityGroup(this.securityGroup);
this.securityGroups.push(this.securityGroup);
}
}

public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
return {
resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + resourceArnSuffix.get(this.integrationPattern),
parameters: this.renderParameters(),
Expand All @@ -218,7 +260,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
private renderParameters(): {[key: string]: any} {
return {
TrainingJobName: this.props.trainingJobName,
RoleArn: this.role.roleArn,
RoleArn: this._role!.roleArn,
...(this.renderAlgorithmSpecification(this.algorithmSpecification)),
...(this.renderInputDataConfig(this.inputDataConfig)),
...(this.renderOutputDataConfig(this.props.outputDataConfig)),
Expand Down Expand Up @@ -303,8 +345,8 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn

private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} {
return (config) ? { VpcConfig: {
SecurityGroupIds: config.securityGroups.map(sg => ( sg.securityGroupId )),
Subnets: config.subnets.map(subnet => ( subnet.subnetId )),
SecurityGroupIds: Lazy.listValue({ produce: () => (this.securityGroups.map(sg => (sg.securityGroupId))) }),
Subnets: this.subnets,
}} : {};
}

Expand All @@ -330,7 +372,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
}),
new iam.PolicyStatement({
actions: ['iam:PassRole'],
resources: [this.role.roleArn],
resources: [this._role!.roleArn],
conditions: {
StringEquals: { "iam:PassedToService": "sagemaker.amazonaws.com" }
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ec2 = require('@aws-cdk/aws-ec2');
import iam = require('@aws-cdk/aws-iam');
import sfn = require('@aws-cdk/aws-stepfunctions');
import { Construct, Stack } from '@aws-cdk/core';
import { Stack } from '@aws-cdk/core';
import { resourceArnSuffix } from './resource-arn-suffix';
import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types';

Expand Down Expand Up @@ -37,7 +37,7 @@ export interface SagemakerTransformProps {
/**
* Environment variables to set in the Docker container.
*/
readonly environment?: {[key: string]: any};
readonly environment?: {[key: string]: string};

/**
* Maximum number of parallel requests that can be sent to each instance in a transform job.
Expand All @@ -57,7 +57,7 @@ export interface SagemakerTransformProps {
/**
* Tags to be applied to the train job.
*/
readonly tags?: {[key: string]: any};
readonly tags?: {[key: string]: string};

/**
* Dataset to be transformed and the Amazon S3 location where it is stored.
Expand All @@ -82,13 +82,6 @@ export interface SagemakerTransformProps {
*/
export class SagemakerTransformTask implements sfn.IStepFunctionsTask {

/**
* The execution role for the Sagemaker training job.
*
* @default new role for Amazon SageMaker to assume is automatically created.
*/
public readonly role: iam.IRole;

/**
* Dataset to be transformed and the Amazon S3 location where it is stored.
*/
Expand All @@ -98,10 +91,10 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
* ML compute instances for the transform job.
*/
private readonly transformResources: TransformResources;

private readonly integrationPattern: sfn.ServiceIntegrationPattern;
private _role?: iam.IRole;

constructor(scope: Construct, private readonly props: SagemakerTransformProps) {
constructor(private readonly props: SagemakerTransformProps) {
this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET;

const supportedPatterns = [
Expand All @@ -114,12 +107,9 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
}

// set the sagemaker role or create new one
this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
]
});
if (props.role) {
this._role = props.role;
}

// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) :
Expand All @@ -140,13 +130,35 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
}

public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
// create new role if doesn't exist
if (this._role === undefined) {
this._role = new iam.Role(task, 'SagemakerTransformRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
]
});
}

return {
resourceArn: 'arn:aws:states:::sagemaker:createTransformJob' + resourceArnSuffix.get(this.integrationPattern),
parameters: this.renderParameters(),
policyStatements: this.makePolicyStatements(task),
};
}

/**
* The execution role for the Sagemaker training job.
*
* Only available after task has been added to a state machine.
*/
public get role(): iam.IRole {
if (this._role === undefined) {
throw new Error(`role not available yet--use the object in a Task first`);
}
return this._role;
}

private renderParameters(): {[key: string]: any} {
return {
...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {},
Expand Down
Loading