diff --git a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts index f04842249e24b..ff2ac0c76e94e 100644 --- a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts +++ b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -4,7 +4,7 @@ import { renderEnvironment, renderTags } from './private/utils'; import * as ec2 from '../../../aws-ec2'; import * as iam from '../../../aws-iam'; import * as sfn from '../../../aws-stepfunctions'; -import { Duration, Lazy, Size, Stack } from '../../../core'; +import { Duration, Lazy, Size, Stack, Token } from '../../../core'; import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; /** @@ -163,6 +163,14 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam throw new Error('Must define either an algorithm name or training image URI in the algorithm specification'); } + // check that both algorithm name and image are not defined + if (props.algorithmSpecification.algorithmName && props.algorithmSpecification.trainingImage) { + throw new Error('Cannot define both an algorithm name and training image URI in the algorithm specification'); + } + + // validate algorithm name + this.validateAlgorithmName(props.algorithmSpecification.algorithmName); + // set the input mode to 'File' if not defined this.algorithmSpecification = props.algorithmSpecification.trainingInputMode ? props.algorithmSpecification @@ -324,6 +332,21 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam : {}; } + private validateAlgorithmName(algorithmName?: string): void { + if (algorithmName === undefined || Token.isUnresolved(algorithmName)) { + return; + } + + if (algorithmName.length < 1 || 170 < algorithmName.length) { + throw new Error(`Algorithm name length must be between 1 and 170, but got ${algorithmName.length}`); + } + + const regex = /^(arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:[a-z\-]*\/)?([a-zA-Z0-9]([a-zA-Z0-9-]){0,62})(? { + + expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + algorithmName: 'BlazingText', + trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')), + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + })) + .toThrowError(/Cannot define both an algorithm name and training image URI in the algorithm specification/); +}); + +test('create a SageMaker train task with trainingImage', () => { + + const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')), + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toMatchObject({ + Parameters: { + AlgorithmSpecification: { + 'TrainingImage.$': '$.Training.imageName', + 'TrainingInputMode': 'File', + }, + }, + }); +}); + +test('create a SageMaker train task with image URI algorithmName', () => { + + const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + algorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees', + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toMatchObject({ + Parameters: { + AlgorithmSpecification: { + AlgorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees', + }, + }, + }); +}); + +test('Cannot create a SageMaker train task when algorithmName length is 171 or more', () => { + + expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + algorithmName: 'a'.repeat(171), // maximum length is 170 + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + })) + .toThrowError(/Algorithm name length must be between 1 and 170, but got 171/); +}); + +test('Cannot create a SageMaker train task with incorrect algorithmName', () => { + + expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + algorithmName: 'Blazing_Text', // underscores are not allowed + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + })) + .toThrowError(/Expected algorithm name to match pattern/); +});