Skip to content

Commit

Permalink
add more unit tests for ground truth
Browse files Browse the repository at this point in the history
  • Loading branch information
akartsky committed May 26, 2020
1 parent 49b9c1a commit f367652
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,61 @@ def test_get_labeling_job_output_from_job(self):

output_manifest, active_learning_model_arn = _utils.get_labeling_job_outputs(mock_client, 'labeling-job', True)
self.assertEqual(output_manifest, 's3://path/')
self.assertEqual(active_learning_model_arn, 'fake-arn')
self.assertEqual(active_learning_model_arn, 'fake-arn')

def test_pass_most_args(self):
required_args = [
'--region', 'us-west-2',
'--role', 'arn:aws:iam::123456789012:user/Development/product_1234/*',
'--job_name', 'test_job',
'--manifest_location', 's3://fake-bucket/manifest',
'--output_location', 's3://fake-bucket/output',
'--task_type', 'image classification',
'--worker_type', 'fake_worker',
'--ui_template', 's3://fake-bucket/ui_template',
'--title', 'fake-image-labelling-work',
'--description', 'fake job',
'--num_workers_per_object', '1',
'--time_limit', '180',
]
arguments = required_args + ['--label_attribute_name', 'fake-attribute',
'--max_human_labeled_objects', '10',
'--max_percent_objects', '50',
'--enable_auto_labeling', 'True',
'--initial_model_arn', 'fake-model-arn',
'--task_availibility', '30',
'--max_concurrent_tasks', '10',
'--task_keywords', 'fake-keyword',
'--worker_type', 'public',
'--no_adult_content', 'True',
'--no_ppi', 'True',
'--tags', '{"fake_key": "fake_value"}'
]
response = _utils.create_labeling_job_request(vars(self.parser.parse_args(arguments)))
print(response)
self.assertEqual(response, {'LabelingJobName': 'test_job',
'LabelAttributeName': 'fake-attribute',
'InputConfig': {'DataSource': {'S3DataSource': {'ManifestS3Uri': 's3://fake-bucket/manifest'}},
'DataAttributes': {'ContentClassifiers': ['FreeOfAdultContent', 'FreeOfPersonallyIdentifiableInformation']}},
'OutputConfig': {'S3OutputPath': 's3://fake-bucket/output', 'KmsKeyId': ''},
'RoleArn': 'arn:aws:iam::123456789012:user/Development/product_1234/*',
'LabelCategoryConfigS3Uri': '',
'StoppingConditions': {'MaxHumanLabeledObjectCount': 10, 'MaxPercentageOfInputDatasetLabeled': 50},
'LabelingJobAlgorithmsConfig': {'LabelingJobAlgorithmSpecificationArn': 'arn:aws:sagemaker:us-west-2:027400017018:labeling-job-algorithm-specification/image-classification',
'InitialActiveLearningModelArn': 'fake-model-arn',
'LabelingJobResourceConfig': {'VolumeKmsKeyId': ''}},
'HumanTaskConfig': {'WorkteamArn': 'arn:aws:sagemaker:us-west-2:394669845002:workteam/public-crowd/default',
'UiConfig': {'UiTemplateS3Uri': 's3://fake-bucket/ui_template'},
'PreHumanTaskLambdaArn': 'arn:aws:lambda:us-west-2:081040173940:function:PRE-ImageMultiClass',
'TaskKeywords': ['fake-keyword'],
'TaskTitle': 'fake-image-labelling-work',
'TaskDescription': 'fake job',
'NumberOfHumanWorkersPerDataObject': 1,
'TaskTimeLimitInSeconds': 180,
'TaskAvailabilityLifetimeInSeconds': 30,
'MaxConcurrentTaskCount': 10,
'AnnotationConsolidationConfig': {'AnnotationConsolidationLambdaArn': 'arn:aws:lambda:us-west-2:081040173940:function:ACS-ImageMultiClass'},
'PublicWorkforceTaskPrice': {'AmountInUsd': {'Dollars': 0, 'Cents': 0, 'TenthFractionsOfACent': 0}}},
'Tags': [{'Key': 'fake_key', 'Value': 'fake_value'}]}
)

8 changes: 1 addition & 7 deletions components/aws/sagemaker/tests/unit_tests/tests/test_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_create_hyperparameter_tuning_job(self):
'RecordWrapperType': 'None',
'InputMode': 'File'}],
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 'test-output-location'},
'ResourceConfig': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': 1, 'VolumeSizeInGB': 1, 'VolumeKmsKeyId': ''},
'ResourceConfig': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': 1, 'VolumeSizeInGB': 30, 'VolumeKmsKeyId': ''},
'StoppingCondition': {'MaxRuntimeInSeconds': 86400},
'EnableNetworkIsolation': True,
'EnableInterContainerTrafficEncryption': False,
Expand Down Expand Up @@ -328,12 +328,6 @@ def test_tags(self):
self.assertIn({'Key': 'key1', 'Value': 'val1'}, response['Tags'])
self.assertIn({'Key': 'key2', 'Value': 'val2'}, response['Tags'])

def test_invalid_instance_type(self):
invalid_instance_args = required_args + ['--instance_type', 'invalid-instance']

with self.assertRaises(SystemExit):
self.parser.parse_args(invalid_instance_args)

def test_valid_hyperparameters(self):
hyperparameters_str = '{"hp1": "val1", "hp2": "val2", "hp3": "val3"}'
categorical_params = '[{"Name" : "categorical", "Values": ["A", "B"]}]'
Expand Down

0 comments on commit f367652

Please sign in to comment.