diff --git a/.doc_gen/metadata/sagemaker_metadata.yaml b/.doc_gen/metadata/sagemaker_metadata.yaml index a95467b8ce0..5053c248210 100644 --- a/.doc_gen/metadata/sagemaker_metadata.yaml +++ b/.doc_gen/metadata/sagemaker_metadata.yaml @@ -5,6 +5,15 @@ sagemaker_Hello: synopsis: get started using &SM;. category: Hello languages: + Java: + versions: + - sdk_version: 2 + github: javav2/example_code/sagemaker + sdkguide: + excerpts: + - description: + snippet_tags: + - sagemaker.java2.list_books.main .NET: versions: - sdk_version: 3 @@ -22,6 +31,15 @@ sagemaker_CreatePipeline: synopsis: create or update a pipeline in &SM;. category: languages: + Java: + versions: + - sdk_version: 2 + github: javav2/usecases/workflow_sagemaker_pipes + sdkguide: + excerpts: + - description: + snippet_tags: + - sagemaker.java2.create_pipeline.main .NET: versions: - sdk_version: 3 @@ -39,6 +57,15 @@ sagemaker_ExecutePipeline: synopsis: start a pipeline execution in &SM;. category: languages: + Java: + versions: + - sdk_version: 2 + github: javav2/usecases/workflow_sagemaker_pipes + sdkguide: + excerpts: + - description: + snippet_tags: + - sagemaker.java2.execute_pipeline.main .NET: versions: - sdk_version: 3 @@ -56,6 +83,15 @@ sagemaker_DeletePipeline: synopsis: delete a pipeline in &SM;. category: languages: + Java: + versions: + - sdk_version: 2 + github: javav2/usecases/workflow_sagemaker_pipes + sdkguide: + excerpts: + - description: + snippet_tags: + - sagemaker.java2.delete_pipeline.main .NET: versions: - sdk_version: 3 @@ -73,6 +109,15 @@ sagemaker_DescribePipelineExecution: synopsis: describe a pipeline execution in &SM;. category: languages: + Java: + versions: + - sdk_version: 2 + github: javav2/usecases/workflow_sagemaker_pipes + sdkguide: + excerpts: + - description: + snippet_tags: + - sagemaker.java2.describe_pipeline_execution.main .NET: versions: - sdk_version: 3 @@ -307,6 +352,15 @@ sagemaker_Scenario_Pipelines: - Clean up resources. category: Scenarios languages: + Java: + versions: + - sdk_version: 2 + github: javav2/usecases/workflow_sagemaker_pipes + sdkguide: + excerpts: + - description: + snippet_tags: + - sagemaker.java2.sc.main .NET: versions: - sdk_version: 3 diff --git a/.github/pre_validate/pre_validate.py b/.github/pre_validate/pre_validate.py index ccb62ce3300..8c73132e0ef 100644 --- a/.github/pre_validate/pre_validate.py +++ b/.github/pre_validate/pre_validate.py @@ -197,6 +197,8 @@ 'aws/dynamodb/model/BatchWriteItemRequest', 'aws/rds/model/DescribeDBInstancesRequest', 'aws/rds/model/DescribeDBSnapshotsRequest', + 'role/AmazonSageMakerGeospatialFullAccess', + 'VectorEnrichmentJobDataSourceConfigInput', } def check_files(root, quiet): diff --git a/javav2/example_code/sagemaker/src/main/java/com/example/sage/ListNotebooks.java b/javav2/example_code/sagemaker/src/main/java/com/example/sage/HelloSageMaker.java similarity index 90% rename from javav2/example_code/sagemaker/src/main/java/com/example/sage/ListNotebooks.java rename to javav2/example_code/sagemaker/src/main/java/com/example/sage/HelloSageMaker.java index cd72d7854d8..fec6243ad39 100644 --- a/javav2/example_code/sagemaker/src/main/java/com/example/sage/ListNotebooks.java +++ b/javav2/example_code/sagemaker/src/main/java/com/example/sage/HelloSageMaker.java @@ -1,58 +1,56 @@ -//snippet-sourcedescription:[ListNotebooks.java demonstrates how to list notebooks.] -//snippet-keyword:[AWS SDK for Java v2] -//snippet-keyword:[Amazon SageMaker] - -/* - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - SPDX-License-Identifier: Apache-2.0 -*/ - -package com.example.sage; - -//snippet-start:[sagemaker.java2.list_books.import] -import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.sagemaker.SageMakerClient; -import software.amazon.awssdk.services.sagemaker.model.ListNotebookInstancesResponse; -import software.amazon.awssdk.services.sagemaker.model.NotebookInstanceSummary; -import software.amazon.awssdk.services.sagemaker.model.SageMakerException; -import java.util.List; -//snippet-end:[sagemaker.java2.list_books.import] - -/** - * Before running this Java V2 code example, set up your development environment, including your credentials. - * - * For more information, see the following documentation topic: - * - * https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/get-started.html - */ -public class ListNotebooks { - - public static void main(String[] args) { - - Region region = Region.US_WEST_2; - SageMakerClient sageMakerClient = SageMakerClient.builder() - .region(region) - .credentialsProvider(ProfileCredentialsProvider.create()) - .build(); - - listBooks(sageMakerClient); - sageMakerClient.close(); - } - - //snippet-start:[sagemaker.java2.list_books.main] - public static void listBooks(SageMakerClient sageMakerClient) { - try { - ListNotebookInstancesResponse notebookInstancesResponse = sageMakerClient.listNotebookInstances(); - List items = notebookInstancesResponse.notebookInstances(); - for (NotebookInstanceSummary item: items) { - System.out.println("The notebook name is: "+item.notebookInstanceName()); - } - - } catch (SageMakerException e) { - System.err.println(e.awsErrorDetails().errorMessage()); - System.exit(1); - } - } - //snippet-end:[sagemaker.java2.list_books.main] -} +//snippet-sourcedescription:[HelloSageMaker.java demonstrates how to list notebooks.] +//snippet-keyword:[AWS SDK for Java v2] +//snippet-keyword:[Amazon SageMaker] + +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + SPDX-License-Identifier: Apache-2.0 +*/ + +package com.example.sage; + +//snippet-start:[sagemaker.java2.list_books.import] +import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sagemaker.SageMakerClient; +import software.amazon.awssdk.services.sagemaker.model.ListNotebookInstancesResponse; +import software.amazon.awssdk.services.sagemaker.model.NotebookInstanceSummary; +import software.amazon.awssdk.services.sagemaker.model.SageMakerException; +import java.util.List; +//snippet-end:[sagemaker.java2.list_books.import] + +//snippet-start:[sagemaker.java2.list_books.main] +/** + * Before running this Java V2 code example, set up your development environment, including your credentials. + * + * For more information, see the following documentation topic: + * + * https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/get-started.html + */ +public class HelloSageMaker { + public static void main(String[] args) { + Region region = Region.US_WEST_2; + SageMakerClient sageMakerClient = SageMakerClient.builder() + .region(region) + .credentialsProvider(ProfileCredentialsProvider.create()) + .build(); + + listBooks(sageMakerClient); + sageMakerClient.close(); + } + + public static void listBooks(SageMakerClient sageMakerClient) { + try { + ListNotebookInstancesResponse notebookInstancesResponse = sageMakerClient.listNotebookInstances(); + List items = notebookInstancesResponse.notebookInstances(); + for (NotebookInstanceSummary item: items) { + System.out.println("The notebook name is: "+item.notebookInstanceName()); + } + + } catch (SageMakerException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } + } + //snippet-end:[sagemaker.java2.list_books.main] +} diff --git a/javav2/usecases/workflow_sagemaker_lambda/Readme.md b/javav2/usecases/workflow_sagemaker_lambda/Readme.md new file mode 100644 index 00000000000..bc4236d454b --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_lambda/Readme.md @@ -0,0 +1,37 @@ +# Create the SageMaker geospatial Lambda function using the Lambda Java rumtime API + +This example demonstrates how to create a Lambda function for the Amazon SageMaker pipeline and geospatial job example. + +A [SageMaker pipeline](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines.html) is a series of +interconnected steps that can be used to automate machine learning workflows. You can create and run pipelines from SageMaker Studio by using Python, but you can also do this by using AWS SDKs in other +languages. Using the SDKs, you can create and run SageMaker pipelines and also monitor operations for them. + +You need to build this Lambda function in order to successfully complete the Java example. You can find the full example under **workflow_sagemaker_pipes**. + +### Prerequisites + +To use this tutorial, you need the following: + ++ An AWS account. ++ A Java IDE. ++ Java 1.8 JDK or later. ++ Maven 3.6 or later. ++ Set up your development environment. For more information, see [Get started with the SDK for Java](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/setup-basics.html). + +### Create a .jar file + +You can compile the project into a .jar file, which will serve as input for [Create and run a SageMaker geospatial pipeline using the SDK for Java V2](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/usecases/workflow_sagemaker_pipes). This can be achieved by using the following Maven command. + + mvn package + +The .jar file is located in the target folder. + +## Additional resources + +* [SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/whatis.html) +* [SageMaker API Reference](https://docs.aws.amazon.com/sagemaker/latest/APIReference/Welcome.html) +* [Java Developer Guide](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/home.html) + +--- + +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. \ No newline at end of file diff --git a/javav2/usecases/workflow_sagemaker_lambda/pom.xml b/javav2/usecases/workflow_sagemaker_lambda/pom.xml new file mode 100644 index 00000000000..d3b0a455468 --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_lambda/pom.xml @@ -0,0 +1,176 @@ + + + 4.0.0 + org.example + SageMakerLambda + 1.0-SNAPSHOT + + 17 + 17 + UTF-8 + + + + + software.amazon.awssdk + bom + 2.20.45 + pom + import + + + + + + com.amazonaws + aws-lambda-java-core + 1.2.1 + + + + com.fasterxml.jackson.core + jackson-core + 2.15.1 + + + com.fasterxml.jackson.core + jackson-databind + 2.15.1 + + + + com.google.code.gson + gson + 2.8.6 + + + com.fasterxml.jackson.core + jackson-core + 2.15.1 + + + org.apache.logging.log4j + log4j-api + 2.10.0 + + + org.apache.logging.log4j + log4j-core + 2.13.0 + test + + + software.amazon.awssdk + sagemaker + + + software.amazon.awssdk + sagemakergeospatial + 2.20.78 + + + com.amazonaws + aws-lambda-java-events + 3.11.2 + + + software.amazon.awssdk + sagemakerruntime + 2.20.26 + + + org.apache.logging.log4j + log4j-slf4j18-impl + 2.13.3 + test + + + org.junit.jupiter + junit-jupiter-api + 5.6.0 + test + + + org.junit.jupiter + junit-jupiter-engine + 5.6.0 + test + + + com.googlecode.json-simple + json-simple + 1.1.1 + + + software.amazon.awssdk + s3 + + + software.amazon.awssdk + s3 + + + software.amazon.awssdk + dynamodb + 2.5.10 + + + software.amazon.awssdk + dynamodb-enhanced + 2.11.4-PREVIEW + + + software.amazon.awssdk + rekognition + + + javax.mail + javax.mail-api + 1.5.5 + + + com.sun.mail + javax.mail + 1.5.5 + + + software.amazon.awssdk + ses + + + + + + maven-surefire-plugin + 2.22.2 + + + org.apache.maven.plugins + maven-shade-plugin + 3.2.2 + + false + + + + package + + shade + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.1 + + 1.8 + 1.8 + + + + + \ No newline at end of file diff --git a/javav2/usecases/workflow_sagemaker_lambda/src/main/java/org/example/QueuePayload.java b/javav2/usecases/workflow_sagemaker_lambda/src/main/java/org/example/QueuePayload.java new file mode 100644 index 00000000000..91c360a44d6 --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_lambda/src/main/java/org/example/QueuePayload.java @@ -0,0 +1,57 @@ +package org.example; + +import java.util.HashMap; + +public class QueuePayload { + // The payload job token. + private String token; + + // The Amazon Resource Name (ARN) of the pipeline run. + private String pipelineExecutionArn; + + // The status of the job. + private String status; + + // A dictionary of payload arguments. + private HashMap arguments; + + // Constructor + public QueuePayload() { + } + + // Getter and Setter methods for token + public String getToken() { + return token; + } + + public void setToken(String token) { + this.token = token; + } + + // Getter and Setter methods for pipelineExecutionArn + public String getPipelineExecutionArn() { + return pipelineExecutionArn; + } + + public void setPipelineExecutionArn(String pipelineExecutionArn) { + this.pipelineExecutionArn = pipelineExecutionArn; + } + + // Getter and Setter methods for status + public String getStatus() { + return status; + } + + public void setStatus(String status) { + this.status = status; + } + + // Getter and Setter methods for arguments + public HashMap getArguments() { + return arguments; + } + + public void setArguments(HashMap arguments) { + this.arguments = arguments; + } +} diff --git a/javav2/usecases/workflow_sagemaker_lambda/src/main/java/org/example/SageMakerLambdaFunction.java b/javav2/usecases/workflow_sagemaker_lambda/src/main/java/org/example/SageMakerLambdaFunction.java new file mode 100644 index 00000000000..705ac85697a --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_lambda/src/main/java/org/example/SageMakerLambdaFunction.java @@ -0,0 +1,260 @@ +package org.example; + +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.LambdaLogger; +import com.amazonaws.services.lambda.runtime.RequestHandler; +import org.json.simple.JSONObject; +import org.json.simple.parser.JSONParser; +import org.json.simple.parser.ParseException; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sagemaker.model.OutputParameter; +import software.amazon.awssdk.services.sagemaker.model.SendPipelineExecutionStepFailureRequest; +import software.amazon.awssdk.services.sagemaker.model.SendPipelineExecutionStepSuccessRequest; +import software.amazon.awssdk.services.sagemakergeospatial.SageMakerGeospatialAsyncClient; +import software.amazon.awssdk.services.sagemaker.SageMakerClient; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import com.google.gson.Gson; +import software.amazon.awssdk.services.sagemakergeospatial.model.ExportVectorEnrichmentJobOutputConfig; +import software.amazon.awssdk.services.sagemakergeospatial.model.ExportVectorEnrichmentJobRequest; +import software.amazon.awssdk.services.sagemakergeospatial.model.ExportVectorEnrichmentJobResponse; +import software.amazon.awssdk.services.sagemakergeospatial.model.GetVectorEnrichmentJobRequest; +import software.amazon.awssdk.services.sagemakergeospatial.model.GetVectorEnrichmentJobResponse; +import software.amazon.awssdk.services.sagemakergeospatial.model.ReverseGeocodingConfig; +import software.amazon.awssdk.services.sagemakergeospatial.model.StartVectorEnrichmentJobRequest; +import software.amazon.awssdk.services.sagemakergeospatial.model.StartVectorEnrichmentJobResponse; +import software.amazon.awssdk.services.sagemakergeospatial.model.VectorEnrichmentJobConfig; +import software.amazon.awssdk.services.sagemakergeospatial.model.VectorEnrichmentJobDataSourceConfigInput; +import software.amazon.awssdk.services.sagemakergeospatial.model.VectorEnrichmentJobDocumentType; +import software.amazon.awssdk.services.sagemakergeospatial.model.VectorEnrichmentJobInputConfig; +import software.amazon.awssdk.services.sagemakergeospatial.model.VectorEnrichmentJobS3Data; +import software.amazon.awssdk.services.sagemakergeospatial.model.VectorEnrichmentJobStatus; + +import com.amazonaws.services.lambda.runtime.events.SQSEvent; + +// The AWS Lambda function handler for the Amazon SageMaker pipeline. +public class SageMakerLambdaFunction implements RequestHandler, Map>{ + @Override + public Map handleRequest(HashMap requestObject, Context context) throws RuntimeException { + LambdaLogger logger = context.getLogger(); + Region region = Region.US_WEST_2; + + SageMakerClient sageMakerClient = SageMakerClient.builder() + .region(region) + .build(); + + SageMakerGeospatialAsyncClient asyncClient = SageMakerGeospatialAsyncClient.builder() + .region(region) + .build(); + Gson gson = new Gson(); + if (requestObject == null) { + logger.log("*** Request is Null"); + } else { + logger.log("*** Request is NOT Null"); + logger.log("*** REQUEST: " + requestObject); + } + + // Log out the values from the request. The request object is a HashMap. + logger.log("*** vej_export_config: "+ requestObject.get("vej_export_config")); + logger.log("*** vej_name: "+ requestObject.get("vej_name")); + logger.log("*** vej_config: "+ requestObject.get("vej_config")); + logger.log("*** vej_input_config: "+ requestObject.get("vej_input_config")); + logger.log("*** role: "+ requestObject.get("role")); + + // The Records array will be populated if this request came from the queue. + logger.log("*** records: "+ requestObject.get("Records")); + + // The response dictionary. + Map responseDictionary = new HashMap<>(); + + if (requestObject.get("Records") != null ) { + logger.log("Records found, this is a queue event. Processing the queue records."); + ArrayList> queueMessages = (ArrayList>)requestObject.get("Records"); + for (HashMap message : queueMessages) { + processMessage(asyncClient, sageMakerClient, message.get("body"), context); + } + } else if (requestObject.get("vej_export_config") != null) { + logger.log("*** Export configuration found. Start the Vector Enrichment Job (VEJ) export."); + + JSONObject jsonObject = null; + JSONParser parser = new JSONParser(); + + try { + jsonObject = (JSONObject) parser.parse((String)requestObject.get("vej_export_config")); + } catch (ParseException e) { + throw new RuntimeException("Problem parsing export config."); + } + + JSONObject s3Data = (JSONObject) jsonObject.get("S3Data"); + String s3Uri = (String) s3Data.get("S3Uri"); + System.out.println("**** NEW S3URI: " + s3Uri); + + + VectorEnrichmentJobS3Data jobS3Data = VectorEnrichmentJobS3Data.builder() + .s3Uri(s3Uri) + .build(); + + ExportVectorEnrichmentJobOutputConfig jobOutputConfig = ExportVectorEnrichmentJobOutputConfig.builder() + .s3Data(jobS3Data) + .build(); + + ExportVectorEnrichmentJobRequest exportRequest = ExportVectorEnrichmentJobRequest.builder() + .arn((String)requestObject.get("vej_arn")) + .executionRoleArn((String)requestObject.get("role")) + .outputConfig(jobOutputConfig) + .build(); + + CompletableFuture futureResponse = asyncClient.exportVectorEnrichmentJob(exportRequest); + futureResponse.whenComplete((response, exception) -> { + logger.log("*** IN whenComplete BLOCK"); + if (exception != null) { + // Handle the exception here + LambdaLogger logger2 = context.getLogger(); + logger2.log("Error occurred during the asynchronous operation: " + exception.getMessage()); + } else { + // Process the response here + LambdaLogger logger2 = context.getLogger(); + logger2.log("Export response: " + response.toString()); + responseDictionary.put("export_eoj_status", response.exportStatusAsString()); + responseDictionary.put("vej_arn", response.arn()); + } + }); + + /* + By adding futureResponse.join(), + you ensure that the main thread will wait for the CompletableFuture to complete before the Lambda function terminates. This should allow the "whenComplete" block to be executed properly and reach the Complete When block. + */ + futureResponse.join(); + } else if (requestObject.get("vej_name") != null ) { + logger.log("*** NEW Vector Enrichment Job name found, starting the job."); + + JSONObject jsonObject = null; + JSONParser parser = new JSONParser(); + + try { + jsonObject = (JSONObject) parser.parse((String)requestObject.get("vej_input_config")); + } catch (ParseException e) { + throw new RuntimeException("Problem parsing input config."); + } + + JSONObject dataSourceConfig = (JSONObject) jsonObject.get("DataSourceConfig"); + JSONObject s3Data = (JSONObject) dataSourceConfig.get("S3Data"); + String s3Uri = (String) s3Data.get("S3Uri"); + System.out.println("**** NEW S3URI: " + s3Uri); + + VectorEnrichmentJobS3Data s3DataOb = VectorEnrichmentJobS3Data.builder() + .s3Uri(s3Uri) + .build(); + + VectorEnrichmentJobInputConfig inputConfig = VectorEnrichmentJobInputConfig.builder() + .documentType(VectorEnrichmentJobDocumentType.CSV) + .dataSourceConfig(VectorEnrichmentJobDataSourceConfigInput.fromS3Data(s3DataOb)) + .build(); + + ReverseGeocodingConfig geocodingConfig = ReverseGeocodingConfig.builder() + .xAttributeName("Longitude") + .yAttributeName("Latitude") + .build(); + + VectorEnrichmentJobConfig jobConfig = VectorEnrichmentJobConfig.builder() + .reverseGeocodingConfig(geocodingConfig) + .build(); + + StartVectorEnrichmentJobRequest jobRequest = StartVectorEnrichmentJobRequest.builder() + .inputConfig(inputConfig) + .executionRoleArn((String)requestObject.get("role")) + .name((String)requestObject.get("vej_name")) + .jobConfig(jobConfig) + .build(); + + logger.log("*** INVOKE geoSpatialClient.startVectorEnrichmentJob with asyncClient"); + CompletableFuture futureResponse = asyncClient.startVectorEnrichmentJob(jobRequest); + futureResponse.whenComplete((response, exception) -> { + logger.log("*** IN whenComplete BLOCK"); + if (exception != null) { + // Handle the exception here + logger.log("Error occurred during the asynchronous operation: " + exception.getMessage()); + } else { + // Process the response here + logger.log("Asynchronous job started successfully. Job Status is: " + response.toString()); + String vej_arnValue = response.arn(); + logger.log("vej_arn: " + vej_arnValue); + String status = response.statusAsString(); + logger.log("STATUS: " + status); + + responseDictionary.put("statusCode", status); + responseDictionary.put("vej_arn", vej_arnValue); + } + }); + + /* + By adding futureResponse.join(), + you ensure that the main thread will wait for the CompletableFuture to complete before the Lambda function terminates. This should allow the "whenComplete" block to be executed properly and reach the Complete When block. + */ + futureResponse.join(); + logger.log("*** OUT OF whenComplete BLOCK"); + } + logger.log("Returning:" + responseDictionary); + return responseDictionary; + } + + private void processMessage(SageMakerGeospatialAsyncClient asyncClient, SageMakerClient sageMakerClient, String messageBody, Context context ) throws RuntimeException { + Gson gson = new Gson(); + LambdaLogger logger = context.getLogger(); + logger.log("Processing message with body:" + messageBody); + + QueuePayload queuePayload = gson.fromJson(messageBody, QueuePayload.class); + String token = queuePayload.getToken(); + logger.log("Payload token " + token); + + if (queuePayload.getArguments().containsKey("vej_arn")) { + // Use the job ARN and the token to get the job status. + String job_arn = queuePayload.getArguments().get("vej_arn"); + logger.log("Token: " + token + ", arn " + job_arn); + + GetVectorEnrichmentJobRequest jobInfoRequest = GetVectorEnrichmentJobRequest.builder() + .arn(job_arn) + .build(); + + CompletableFuture futureResponse = asyncClient.getVectorEnrichmentJob(jobInfoRequest); + + /* + By adding futureResponse.join(), + you ensure that the main thread will wait for the CompletableFuture to complete before the Lambda function terminates. This should allow the "whenComplete" block to be executed properly and reach the Complete When block. + */ + GetVectorEnrichmentJobResponse jobResponse = futureResponse.join(); + + logger.log("Job info: " + jobResponse.toString()); + + if (jobResponse.status().equals(VectorEnrichmentJobStatus.COMPLETED)) { + logger.log("Status completed, resuming pipeline..."); + + OutputParameter out = OutputParameter.builder() + .name("export_status") + .value(String.valueOf(jobResponse.status())) + .build(); + + SendPipelineExecutionStepSuccessRequest successRequest = SendPipelineExecutionStepSuccessRequest.builder() + .callbackToken(token) + .outputParameters(Collections.singletonList(out)) + .build(); + + sageMakerClient.sendPipelineExecutionStepSuccess(successRequest); + + } else if (jobResponse.status().equals(VectorEnrichmentJobStatus.FAILED)) { + logger.log("Status failed, stopping pipeline..."); + SendPipelineExecutionStepFailureRequest failureRequest = SendPipelineExecutionStepFailureRequest.builder() + .callbackToken(token) + .failureReason(jobResponse.errorDetails().errorMessage()) + .build(); + + sageMakerClient.sendPipelineExecutionStepFailure(failureRequest); + + } else if (jobResponse.status().equals(VectorEnrichmentJobStatus.IN_PROGRESS)) { + // Put this message back in the queue to reprocess later. + logger.log("Status still in progress, check back later."); + throw new RuntimeException("Job still running."); + } + } + } +} \ No newline at end of file diff --git a/javav2/usecases/workflow_sagemaker_lambda/src/main/java/test/TestJSON.java b/javav2/usecases/workflow_sagemaker_lambda/src/main/java/test/TestJSON.java new file mode 100644 index 00000000000..6991df84b52 --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_lambda/src/main/java/test/TestJSON.java @@ -0,0 +1,32 @@ +package test; + +import org.json.simple.JSONObject; +import org.json.simple.parser.JSONParser; +import org.json.simple.parser.ParseException; + +public class TestJSON { + + public static void main(String[] args) { + String inputConfigJSON = "{\"DataSourceConfig\":{\"Type\":\"S3_DATA\",\"S3Data\":{\"KmsKeyId\":\"\",\"S3Uri\":\"s3:\\/\\/sagemaker-sdk-example-bucket\\/samplefiles\\/latlongtest.csv\"}},\"DocumentType\":{\"Value\":\"CSV\"}}"; + + + JSONParser parser = new JSONParser(); + JSONObject jsonObject = null; + try { + jsonObject = (JSONObject) parser.parse(inputConfigJSON); + } catch (ParseException e) { + throw new RuntimeException(e); + } + + // Get the DataSourceConfig object + JSONObject dataSourceConfig = (JSONObject) jsonObject.get("DataSourceConfig"); + + // Get the S3Data object + JSONObject s3DataOb = (JSONObject) dataSourceConfig.get("S3Data"); + + // Extract the S3URI + String s3Uri = (String) s3DataOb.get("S3Uri"); + System.out.println("**** NEW S3URI: " + s3Uri); + + } +} \ No newline at end of file diff --git a/javav2/usecases/workflow_sagemaker_pipes/Readme.md b/javav2/usecases/workflow_sagemaker_pipes/Readme.md new file mode 100644 index 00000000000..0e8f662bf23 --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_pipes/Readme.md @@ -0,0 +1,92 @@ +# Create and run a SageMaker geospatial pipeline using the SDK for Java V2 + +## Overview + +This scenario demonstrates how to work with Amazon SageMaker pipelines and geospatial jobs. + +A [SageMaker pipeline](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines.html) is a series of +interconnected steps that can be used to automate machine learning workflows. You can create and run pipelines from SageMaker Studio by using Python, but you can also do this by using AWS SDKs in other +languages. Using the SDKs, you can create and run SageMaker pipelines and also monitor operations for them. + +### Pipeline steps +This example pipeline includes an [AWS Lambda step](https://docs.aws.amazon.com/sagemaker/latest/dg/build-and-manage-steps.html#step-type-lambda) +and a [callback step](https://docs.aws.amazon.com/sagemaker/latest/dg/build-and-manage-steps.html#step-type-callback). +Both steps are processed by the same example Lambda function. + +This Lambda code is included as part of this example, with the following functionality: +- Starts the SageMaker Vector Enrichment Job with the provided job configuration. +- Starts the export function with the provided export configuration. +- Processes Amazon Simple Queue Service (Amazon SQS) messages from the SageMaker pipeline. + +![AWS Tracking Application](images/pipes.png) + +### Pipeline parameters +The example pipeline uses [parameters](https://docs.aws.amazon.com/sagemaker/latest/dg/build-and-manage-parameters.html) that you can reference throughout the steps. You can also use the parameters to change +values between runs. In this example, the parameters are used to set the Amazon Simple Storage Service (Amazon S3) +locations for the input and output files, along with the identifiers for the role and queue to use in the pipeline. +The example demonstrates how to set and access these parameters. + +### Geospatial jobs +A SageMaker pipeline can be used for model training, setup, testing, or validation. This example uses a simple job +for demonstration purposes: a [Vector Enrichment Job (VEJ)](https://docs.aws.amazon.com/sagemaker/latest/dg/geospatial-vej.html) that processes a set of coordinates to produce human-readable +addresses powered by Amazon Location Service. Other types of jobs could be substituted in the pipeline instead. + +## ⚠ Important + +* Running this code might result in charges to your AWS account. +* Running the tests might result in charges to your AWS account. +* We recommend that you grant your code least privilege. At most, grant only the minimum permissions required to perform the task. For more information, see [Grant least privilege](https://docs.aws.amazon.com/IAM/latest/UserGuide/best-practices.html#grant-least-privilege). +* This code is not tested in every AWS Region. For more information, see [AWS Regional Services](https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services). + +## Scenario + +### Prerequisites + +To use this tutorial, you need the following: + ++ An AWS account. ++ A Java IDE. ++ Java 1.8 JDK or later. ++ Maven 3.6 or later. ++ Set up your development environment. For more information, see [Get started with the SDK for Java](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/setup-basics.html). + +To view pipelines in SageMaker Studio, you need to [set up an Amazon SageMaker Domain](https://docs.aws.amazon.com/sagemaker/latest/dg/gs-studio-onboard.html). +To use geospatial capabilities, [you need to use a supported Region](https://docs.aws.amazon.com/sagemaker/latest/dg/geospatial.html). + +You must download and use these files to successfully run this code example: + ++ GeoSpatialPipeline.json ++ latlongtest.csv + +These files are located on GitHub in [AWS SDK for .NET](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/dotnetv3/SageMaker/Scenarios). + +### Java Lambda Function + +To successfully run this example, you need to create the Java Sagemaker Lambda function. This Lambda function is required. You can find this project here: [Create the SageMaker geospatial Lambda function using the Lambda Java rumtime API](https://github.com/awsdocs/aws-doc-sdk-examples/tree/main/javav2/usecases/workflow_sagemaker_lambda). This project creates a JAR file that is input to this code example. + +Once you create the Java Lambda project, you can build the required JAR file using the **mvn package** command. This will create the JAR file in the target folder. You can use this JAR file as input to this code example. + +### Instructions + +You can run this Java code example from within your Java IDE. + +#### Get started with geospatial jobs and pipelines + +This example shows you how to do the following: + +* Set up resources for a pipeline. +* Set up a pipeline that runs a geospatial job. +* Start a pipeline run. +* Monitor the status of the run. +* View the output of the pipeline. +* Clean up resources. + +## Additional resources + +* [SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/whatis.html) +* [SageMaker API Reference](https://docs.aws.amazon.com/sagemaker/latest/APIReference/Welcome.html) +* [Java Developer Guide](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/home.html) + +--- + +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. \ No newline at end of file diff --git a/javav2/usecases/workflow_sagemaker_pipes/images/pipes.png b/javav2/usecases/workflow_sagemaker_pipes/images/pipes.png new file mode 100644 index 00000000000..f3215db0099 Binary files /dev/null and b/javav2/usecases/workflow_sagemaker_pipes/images/pipes.png differ diff --git a/javav2/usecases/workflow_sagemaker_pipes/pom.xml b/javav2/usecases/workflow_sagemaker_pipes/pom.xml new file mode 100644 index 00000000000..ca3f7189ccb --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_pipes/pom.xml @@ -0,0 +1,152 @@ + + + 4.0.0 + + org.example + SageMakerPipelines + 1.0-SNAPSHOT + + + UTF-8 + 1.8 + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + ${java.version} + ${java.version} + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 3.1.2 + + checkstyle.xml + UTF-8 + true + true + false + + + + validate + validate + + check + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.22.1 + + IntegrationTest + + + + + + + + software.amazon.awssdk + bom + 2.20.45 + pom + import + + + + + + software.amazon.awssdk + dynamodb-enhanced + 2.20.26 + + + software.amazon.awssdk + sagemakerruntime + 2.20.26 + + + com.googlecode.json-simple + json-simple + 1.1.1 + + + org.json + json + 20210307 + + + org.junit.jupiter + junit-jupiter-api + 5.9.2 + test + + + com.google.code.gson + gson + 2.10 + + + software.amazon.awssdk + s3 + + + software.amazon.awssdk + lambda + + + software.amazon.awssdk + sqs + + + software.amazon.awssdk + iam + + + software.amazon.awssdk + sagemakergeospatial + 2.20.78 + + + software.amazon.awssdk + secretsmanager + + + com.google.code.gson + gson + 2.10.1 + + + org.junit.jupiter + junit-jupiter-engine + 5.9.2 + test + + + org.junit.platform + junit-platform-commons + 1.9.2 + + + org.junit.platform + junit-platform-launcher + 1.9.2 + test + + + software.amazon.awssdk + sagemaker + + + \ No newline at end of file diff --git a/javav2/usecases/workflow_sagemaker_pipes/src/main/java/com/example/sage/SagemakerWorkflow.java b/javav2/usecases/workflow_sagemaker_pipes/src/main/java/com/example/sage/SagemakerWorkflow.java new file mode 100644 index 00000000000..e1416287f17 --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_pipes/src/main/java/com/example/sage/SagemakerWorkflow.java @@ -0,0 +1,939 @@ +//snippet-sourcedescription:[SagemakerWorkflow.java is a multiple service example that demonstrates how to set up and run an Amazon SageMaker pipeline.] +//snippet-keyword:[AWS SDK for Java v2] +//snippet-keyword:[Amazon SageMaker] + +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + SPDX-License-Identifier: Apache-2.0 +*/ + +package com.example.sage; + +import com.google.gson.FieldNamingPolicy; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import org.json.simple.JSONArray; +import org.json.simple.JSONObject; +import org.json.simple.parser.JSONParser; +import org.json.simple.parser.ParseException; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.core.waiters.WaiterResponse; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.iam.IamClient; +import software.amazon.awssdk.services.iam.model.AttachRolePolicyRequest; +import software.amazon.awssdk.services.iam.model.CreateRoleRequest; +import software.amazon.awssdk.services.iam.model.CreateRoleResponse; +import software.amazon.awssdk.services.iam.model.DeleteRoleRequest; +import software.amazon.awssdk.services.iam.model.DetachRolePolicyRequest; +import software.amazon.awssdk.services.iam.model.GetRoleRequest; +import software.amazon.awssdk.services.iam.model.GetRoleResponse; +import software.amazon.awssdk.services.iam.model.IamException; +import software.amazon.awssdk.services.lambda.LambdaClient; +import software.amazon.awssdk.services.lambda.model.CreateEventSourceMappingRequest; +import software.amazon.awssdk.services.lambda.model.CreateEventSourceMappingResponse; +import software.amazon.awssdk.services.lambda.model.CreateFunctionRequest; +import software.amazon.awssdk.services.lambda.model.CreateFunctionResponse; +import software.amazon.awssdk.services.lambda.model.DeleteEventSourceMappingRequest; +import software.amazon.awssdk.services.lambda.model.DeleteFunctionRequest; +import software.amazon.awssdk.services.lambda.model.FunctionCode; +import software.amazon.awssdk.services.lambda.model.GetFunctionRequest; +import software.amazon.awssdk.services.lambda.model.GetFunctionResponse; +import software.amazon.awssdk.services.lambda.model.LambdaException; +import software.amazon.awssdk.services.lambda.model.Runtime; +import software.amazon.awssdk.services.lambda.waiters.LambdaWaiter; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.CreateBucketRequest; +import software.amazon.awssdk.services.s3.model.Delete; +import software.amazon.awssdk.services.s3.model.DeleteBucketRequest; +import software.amazon.awssdk.services.s3.model.DeleteObjectsRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadBucketRequest; +import software.amazon.awssdk.services.s3.model.HeadBucketResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.ObjectIdentifier; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.services.s3.model.S3Object; +import software.amazon.awssdk.services.s3.waiters.S3Waiter; +import software.amazon.awssdk.services.sagemaker.SageMakerClient; +import software.amazon.awssdk.services.sagemaker.model.CreatePipelineRequest; +import software.amazon.awssdk.services.sagemaker.model.DeletePipelineRequest; +import software.amazon.awssdk.services.sagemaker.model.DescribePipelineExecutionRequest; +import software.amazon.awssdk.services.sagemaker.model.DescribePipelineExecutionResponse; +import software.amazon.awssdk.services.sagemaker.model.StartPipelineExecutionRequest; +import software.amazon.awssdk.services.sagemaker.model.StartPipelineExecutionResponse; +import software.amazon.awssdk.services.sagemakergeospatial.model.ExportVectorEnrichmentJobOutputConfig; +import software.amazon.awssdk.services.sagemakergeospatial.model.ReverseGeocodingConfig; +import software.amazon.awssdk.services.sagemakergeospatial.model.VectorEnrichmentJobConfig; +import software.amazon.awssdk.services.sagemakergeospatial.model.VectorEnrichmentJobS3Data; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.CreateQueueRequest; +import software.amazon.awssdk.services.sqs.model.DeleteQueueRequest; +import software.amazon.awssdk.services.sqs.model.GetQueueAttributesRequest; +import software.amazon.awssdk.services.sqs.model.GetQueueAttributesResponse; +import software.amazon.awssdk.services.sqs.model.GetQueueUrlRequest; +import software.amazon.awssdk.services.sqs.model.GetQueueUrlResponse; +import software.amazon.awssdk.services.sqs.model.QueueAttributeName; +import software.amazon.awssdk.services.sqs.model.SqsException; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Scanner; +import java.util.concurrent.TimeUnit; +import software.amazon.awssdk.services.sagemaker.model.Parameter; + +/** + * Before running this Java V2 code example, set up your development environment, including your credentials. + * + * For more information, see the following documentation topic: + * + * ... + * + * Before running this code example, read the corresponding Readme for instructions on + * where to get the required input files. You need the two files (latlongtest.csv and GeoSpatialPipeline.json) and + * the Lambda JAR file to successfully run this example. + * + * This example shows you how to do the following: + * + * 1. Set up resources for a pipeline. + * 2. Set up a pipeline that runs a geospatial job. + * 3. Start a pipeline run. + * 4. Monitor the status of the run. + * 5. View the output of the pipeline. + * 6. Clean up resources. + */ + +//snippet-start:[sagemaker.java2.sc.main] +public class SagemakerWorkflow { + public static final String DASHES = new String(new char[80]).replace("\0", "-"); + private static String eventSourceMapping = ""; + + public static void main(String[] args) throws InterruptedException { + final String usage = "\n" + + "Usage:\n" + + " \n\n" + + "Where:\n" + + " sageMakerRoleName - The name of the Amazon SageMaker role.\n\n"+ + " lambdaRoleName - The name of the AWS Lambda role.\n\n"+ + " functionFileLocation - The file location where the JAR file that represents the AWS Lambda function is located.\n\n"+ + " functionName - The name of the AWS Lambda function (for example,SageMakerExampleFunction).\n\n"+ + " queueName - The name of the Amazon Simple Queue Service (Amazon SQS) queue.\n\n"+ + " bucketName - The name of the Amazon Simple Storage Service (Amazon S3) bucket.\n\n"+ + " lnglatData - The file location of the latlongtest.csv file required for this use case.\n\n"+ + " spatialPipelinePath - The file location of the GeoSpatialPipeline.json file required for this use case.\n\n"+ + " pipelineName - The name of the pipeline to create (for example, sagemaker-sdk-example-pipeline).\n\n" ; + + if (args.length != 9) { + System.out.println(usage); + System.exit(1); + } + + String sageMakerRoleName = args[0]; + String lambdaRoleName = args[1]; + String functionFileLocation = args[2]; + String functionName = args[3]; + String queueName = args[4]; + String bucketName = args[5]; + String lnglatData = args[6]; + String spatialPipelinePath = args[7]; + String pipelineName = args[8]; + String handlerName = "org.example.SageMakerLambdaFunction::handleRequest"; + + Region region = Region.US_WEST_2; + SageMakerClient sageMakerClient = SageMakerClient.builder() + .region(region) + .build(); + + IamClient iam = IamClient.builder() + .region(region) + .build(); + + LambdaClient lambdaClient = LambdaClient.builder() + .region(region) + .build(); + + SqsClient sqsClient = SqsClient.builder() + .region(region) + .build(); + + S3Client s3Client = S3Client.builder() + .region(region) + .build(); + + System.out.println(DASHES); + System.out.println("Welcome to the Amazon SageMaker pipeline example scenario."); + System.out.println( + "\nThis example workflow will guide you through setting up and running an" + + "\nAmazon SageMaker pipeline. The pipeline uses an AWS Lambda function and an" + + "\nAmazon SQS Queue. It runs a vector enrichment reverse geocode job to" + + "\nreverse geocode addresses in an input file and store the results in an export file."); + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("First, we will set up the roles, functions, and queue needed by the SageMaker pipeline."); + String lambdaRoleArn = checkLambdaRole(iam, lambdaRoleName); + String sageMakerRoleArn = checkSageMakerRole(iam, sageMakerRoleName); + + String functionArn = checkFunction(lambdaClient, functionName, functionFileLocation, lambdaRoleArn, handlerName); + String queueUrl = checkQueue(sqsClient, lambdaClient, queueName, functionName); + System.out.println("The queue URL is "+queueUrl); + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("Setting up bucket "+bucketName); + if (!checkBucket(s3Client, bucketName)) { + setupBucket(s3Client, bucketName); + System.out.println("Put "+lnglatData +" into "+bucketName); + putS3Object(s3Client, bucketName, "latlongtest.csv", lnglatData); + } + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("Now we can create and run our pipeline."); + setupPipeline(sageMakerClient, spatialPipelinePath, sageMakerRoleArn, functionArn, pipelineName); + String pipelineExecutionARN = executePipeline(sageMakerClient, bucketName, queueUrl, sageMakerRoleArn, pipelineName); + System.out.println("The pipeline execution ARN value is "+pipelineExecutionARN); + waitForPipelineExecution(sageMakerClient, pipelineExecutionARN); + System.out.println("Getting output results "+bucketName); + getOutputResults(s3Client, bucketName); + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("The pipeline has completed. To view the pipeline and runs " + + "in SageMaker Studio, follow these instructions:" + + "\nhttps://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-studio.html"); + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("Do you want to delete the AWS resources used in this Workflow? (y/n)"); + Scanner in = new Scanner(System.in); + String delResources = in.nextLine(); + if (delResources.compareTo("y") == 0) { + System.out.println("Lets clean up the AWS resources. Wait 30 seconds"); + TimeUnit.SECONDS.sleep(30); + deleteEventSourceMapping(lambdaClient); + deleteSQSQueue(sqsClient, queueName); + listBucketObjects(s3Client, bucketName); + deleteBucket(s3Client, bucketName); + deleteLambdaFunction(lambdaClient, functionName); + deleteLambdaRole(iam, lambdaRoleName); + deleteSagemakerRole(iam, sageMakerRoleName); + deletePipeline(sageMakerClient, pipelineName); + } else { + System.out.println("The AWS Resources were not deleted!"); + } + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("SageMaker pipeline scenario is complete."); + System.out.println(DASHES); + } + + private static void readObject(S3Client s3Client, String bucketName, String key) { + System.out.println("Output file contents: \n"); + GetObjectRequest objectRequest = GetObjectRequest.builder() + .bucket(bucketName) + .key(key) + .build(); + + ResponseBytes objectBytes = s3Client.getObjectAsBytes(objectRequest); + byte[] byteArray = objectBytes.asByteArray(); + String text = new String(byteArray, StandardCharsets.UTF_8); + System.out.println("Text output: " + text); + } + + // Display some results from the output directory. + public static void getOutputResults(S3Client s3Client, String bucketName) { + System.out.println("Getting output results {bucketName}."); + ListObjectsRequest listObjectsRequest = ListObjectsRequest.builder() + .bucket(bucketName) + .prefix("outputfiles/") + .build(); + + ListObjectsResponse response = s3Client.listObjects(listObjectsRequest); + List s3Objects = response.contents(); + for (S3Object object: s3Objects) { + readObject(s3Client, bucketName, object.key()); + } + } + + //snippet-start:[sagemaker.java2.describe_pipeline_execution.main] + // Check the status of a pipeline execution. + public static void waitForPipelineExecution(SageMakerClient sageMakerClient, String executionArn) throws InterruptedException { + String status; + int index = 0; + do { + DescribePipelineExecutionRequest pipelineExecutionRequest = DescribePipelineExecutionRequest.builder() + .pipelineExecutionArn(executionArn) + .build(); + + DescribePipelineExecutionResponse response = sageMakerClient.describePipelineExecution(pipelineExecutionRequest); + status = response.pipelineExecutionStatusAsString(); + System.out.println(index +". The Status of the pipeline is "+status); + TimeUnit.SECONDS.sleep(4); + index ++; + } while ("Executing".equals(status)); + System.out.println("Pipeline finished with status "+ status); + } + //snippet-end:[sagemaker.java2.describe_pipeline_execution.main] + + //snippet-start:[sagemaker.java2.delete_pipeline.main] + // Delete a SageMaker pipeline by name. + public static void deletePipeline(SageMakerClient sageMakerClient, String pipelineName) { + DeletePipelineRequest pipelineRequest = DeletePipelineRequest.builder() + .pipelineName(pipelineName) + .build(); + + sageMakerClient.deletePipeline(pipelineRequest); + System.out.println("*** Successfully deleted "+pipelineName); + } + //snippet-end:[sagemaker.java2.delete_pipeline.main] + + //snippet-start:[sagemaker.java2.create_pipeline.main] + // Create a pipeline from the example pipeline JSON. + public static void setupPipeline(SageMakerClient sageMakerClient, String filePath, String roleArn, String functionArn, String pipelineName) { + System.out.println("Setting up the pipeline."); + JSONParser parser = new JSONParser(); + + // Read JSON and get pipeline definition. + try (FileReader reader = new FileReader(filePath)) { + Object obj = parser.parse(reader); + JSONObject jsonObject = (JSONObject) obj; + JSONArray stepsArray = (JSONArray) jsonObject.get("Steps"); + for (Object stepObj : stepsArray) { + JSONObject step = (JSONObject) stepObj; + if (step.containsKey("FunctionArn")) { + step.put("FunctionArn", functionArn); + } + } + System.out.println(jsonObject); + + // Create the pipeline. + CreatePipelineRequest pipelineRequest = CreatePipelineRequest.builder() + .pipelineDescription("Java SDK example pipeline") + .roleArn(roleArn) + .pipelineName(pipelineName) + .pipelineDefinition(jsonObject.toString()) + .build(); + + sageMakerClient.createPipeline(pipelineRequest); + + } catch (IamException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } catch (IOException | ParseException e) { + throw new RuntimeException(e); + } + } + //snippet-end:[sagemaker.java2.create_pipeline.main] + + //snippet-start:[sagemaker.java2.execute_pipeline.main] + // Start a pipeline run with job configurations. + public static String executePipeline(SageMakerClient sageMakerClient, String bucketName,String queueUrl, String roleArn, String pipelineName) { + System.out.println("Starting pipeline execution."); + String inputBucketLocation = "s3://"+bucketName+"/samplefiles/latlongtest.csv"; + String output = "s3://"+bucketName+"/outputfiles/"; + Gson gson = new GsonBuilder() + .setFieldNamingPolicy(FieldNamingPolicy.UPPER_CAMEL_CASE) + .setPrettyPrinting().create(); + + // Set up all parameters required to start the pipeline. + List parameters = new ArrayList<>(); + Parameter para1 = Parameter.builder() + .name("parameter_execution_role") + .value(roleArn) + .build(); + + Parameter para2 = Parameter.builder() + .name("parameter_queue_url") + .value(queueUrl) + .build(); + + String inputJSON = "{\n" + + " \"DataSourceConfig\": {\n" + + " \"S3Data\": {\n" + + " \"S3Uri\": \"s3://"+bucketName+"/samplefiles/latlongtest.csv\"\n" + + " },\n" + + " \"Type\": \"S3_DATA\"\n" + + " },\n" + + " \"DocumentType\": \"CSV\"\n" + + "}"; + + System.out.println(inputJSON); + + Parameter para3 = Parameter.builder() + .name("parameter_vej_input_config") + .value(inputJSON) + .build(); + + // Create an ExportVectorEnrichmentJobOutputConfig object. + VectorEnrichmentJobS3Data jobS3Data = VectorEnrichmentJobS3Data.builder() + .s3Uri(output) + .build(); + + ExportVectorEnrichmentJobOutputConfig outputConfig = ExportVectorEnrichmentJobOutputConfig.builder() + .s3Data(jobS3Data) + .build(); + + String gson4 = gson.toJson(outputConfig); + Parameter para4 = Parameter.builder() + .name("parameter_vej_export_config") + .value(gson4) + .build(); + System.out.println("parameter_vej_export_config:"+gson.toJson(outputConfig)); + + // Create a VectorEnrichmentJobConfig object. + ReverseGeocodingConfig reverseGeocodingConfig = ReverseGeocodingConfig.builder() + .xAttributeName("Longitude") + .yAttributeName("Latitude") + .build(); + + VectorEnrichmentJobConfig jobConfig = VectorEnrichmentJobConfig.builder() + .reverseGeocodingConfig(reverseGeocodingConfig) + .build(); + + String para5JSON = "{\"MapMatchingConfig\":null,\"ReverseGeocodingConfig\":{\"XAttributeName\":\"Longitude\",\"YAttributeName\":\"Latitude\"}}"; + Parameter para5 = Parameter.builder() + .name("parameter_step_1_vej_config") + .value(para5JSON) + .build(); + + System.out.println("parameter_step_1_vej_config:"+gson.toJson(jobConfig)); + parameters.add(para1); + parameters.add(para2); + parameters.add(para3); + parameters.add(para4); + parameters.add(para5); + + StartPipelineExecutionRequest pipelineExecutionRequest = StartPipelineExecutionRequest.builder() + .pipelineExecutionDescription("Created using Java SDK") + .pipelineExecutionDisplayName(pipelineName + "-example-execution") + .pipelineParameters(parameters) + .pipelineName(pipelineName) + .build(); + + StartPipelineExecutionResponse response = sageMakerClient.startPipelineExecution(pipelineExecutionRequest); + return response.pipelineExecutionArn(); + } + //snippet-end:[sagemaker.java2.execute_pipeline.main] + + public static void deleteEventSourceMapping(LambdaClient lambdaClient){ + DeleteEventSourceMappingRequest eventSourceMappingRequest = DeleteEventSourceMappingRequest.builder() + .uuid(eventSourceMapping) + .build(); + + lambdaClient.deleteEventSourceMapping(eventSourceMappingRequest); + } + + public static void deleteSagemakerRole(IamClient iam, String roleName) { + String[] sageMakerRolePolicies = getSageMakerRolePolicies(); + try { + for (String policy : sageMakerRolePolicies) { + // First the policy needs to be detached. + DetachRolePolicyRequest rolePolicyRequest = DetachRolePolicyRequest.builder() + .policyArn(policy) + .roleName(roleName) + .build(); + + iam.detachRolePolicy(rolePolicyRequest); + } + + // Delete the role. + DeleteRoleRequest roleRequest = DeleteRoleRequest.builder() + .roleName(roleName) + .build(); + + iam.deleteRole(roleRequest); + System.out.println("*** Successfully deleted " + roleName); + + } catch (IamException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } + } + + public static void deleteLambdaRole(IamClient iam, String roleName) { + String[] lambdaRolePolicies = getLambdaRolePolicies(); + try { + for (String policy : lambdaRolePolicies) { + // First the policy needs to be detached. + DetachRolePolicyRequest rolePolicyRequest = DetachRolePolicyRequest.builder() + .policyArn(policy) + .roleName(roleName) + .build(); + + iam.detachRolePolicy(rolePolicyRequest); + } + + // Delete the role. + DeleteRoleRequest roleRequest = DeleteRoleRequest.builder() + .roleName(roleName) + .build(); + + iam.deleteRole(roleRequest); + System.out.println("*** Successfully deleted " + roleName); + + } catch (IamException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } + } + + // Delete the specific AWS Lambda function. + public static void deleteLambdaFunction(LambdaClient awsLambda, String functionName) { + try { + DeleteFunctionRequest request = DeleteFunctionRequest.builder() + .functionName(functionName) + .build(); + + awsLambda.deleteFunction(request); + System.out.println("*** "+functionName +" was deleted"); + + } catch(LambdaException e) { + System.err.println(e.getMessage()); + System.exit(1); + } + } + + // Delete the specific S3 bucket. + public static void deleteBucket(S3Client s3Client, String bucketName) { + DeleteBucketRequest deleteBucketRequest = DeleteBucketRequest.builder() + .bucket(bucketName) + .build(); + s3Client.deleteBucket(deleteBucketRequest); + System.out.println("*** "+bucketName +" was deleted."); + } + + public static void listBucketObjects(S3Client s3, String bucketName ) { + try { + ListObjectsRequest listObjects = ListObjectsRequest + .builder() + .bucket(bucketName) + .build(); + + ListObjectsResponse res = s3.listObjects(listObjects); + List objects = res.contents(); + for (S3Object myValue : objects) { + System.out.print("\n The name of the key is " + myValue.key()); + deleteBucketObjects(s3, bucketName, myValue.key()); + } + + } catch (S3Exception e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } + } + + public static void deleteBucketObjects(S3Client s3, String bucketName, String objectName) { + ArrayList toDelete = new ArrayList<>(); + toDelete.add(ObjectIdentifier.builder() + .key(objectName) + .build()); + try { + DeleteObjectsRequest dor = DeleteObjectsRequest.builder() + .bucket(bucketName) + .delete(Delete.builder() + .objects(toDelete).build()) + .build(); + + s3.deleteObjects(dor); + System.out.println("*** "+bucketName +" objects were deleted."); + + } catch (S3Exception e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } + } + + // Delete the specific Amazon SQS queue. + public static void deleteSQSQueue(SqsClient sqsClient, String queueName) { + try { + GetQueueUrlRequest getQueueRequest = GetQueueUrlRequest.builder() + .queueName(queueName) + .build(); + + String queueUrl = sqsClient.getQueueUrl(getQueueRequest).queueUrl(); + DeleteQueueRequest deleteQueueRequest = DeleteQueueRequest.builder() + .queueUrl(queueUrl) + .build(); + + sqsClient.deleteQueue(deleteQueueRequest); + + } catch (SqsException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } + } + + public static void putS3Object(S3Client s3, String bucketName, String objectKey, String objectPath) { + try { + Map metadata = new HashMap<>(); + metadata.put("x-amz-meta-myVal", "test"); + PutObjectRequest putOb = PutObjectRequest.builder() + .bucket(bucketName) + .key("samplefiles/"+objectKey) + .metadata(metadata) + .build(); + + s3.putObject(putOb, RequestBody.fromFile(new File(objectPath))); + System.out.println("Successfully placed " + objectKey +" into bucket "+bucketName); + + } catch (S3Exception e) { + System.err.println(e.getMessage()); + System.exit(1); + } + } + + public static void setupBucket(S3Client s3Client, String bucketName) { + try { + S3Waiter s3Waiter = s3Client.waiter(); + CreateBucketRequest bucketRequest = CreateBucketRequest.builder() + .bucket(bucketName) + .build(); + + s3Client.createBucket(bucketRequest); + HeadBucketRequest bucketRequestWait = HeadBucketRequest.builder() + .bucket(bucketName) + .build(); + + // Wait until the bucket is created and print out the response. + WaiterResponse waiterResponse = s3Waiter.waitUntilBucketExists(bucketRequestWait); + waiterResponse.matched().response().ifPresent(System.out::println); + System.out.println(bucketName +" is ready"); + + } catch (S3Exception e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } + } + + // Set up the SQS queue to use with the pipeline. + public static String setupQueue(SqsClient sqsClient, LambdaClient lambdaClient, String queueName, String lambdaName) { + System.out.println("Setting up queue named "+queueName); + try { + Map queueAtt = new HashMap<>(); + queueAtt.put(QueueAttributeName.DELAY_SECONDS, "5"); + queueAtt.put( QueueAttributeName.RECEIVE_MESSAGE_WAIT_TIME_SECONDS, "5"); + queueAtt.put( QueueAttributeName.VISIBILITY_TIMEOUT, "300"); + CreateQueueRequest createQueueRequest = CreateQueueRequest.builder() + .queueName(queueName) + .attributes(queueAtt) + .build(); + + sqsClient.createQueue(createQueueRequest); + System.out.println("\nGet queue url"); + GetQueueUrlResponse getQueueUrlResponse = sqsClient.getQueueUrl(GetQueueUrlRequest.builder().queueName(queueName).build()); + TimeUnit.SECONDS.sleep(15); + + connectLambda(sqsClient, lambdaClient, getQueueUrlResponse.queueUrl(), lambdaName); + System.out.println("Queue ready with Url "+ getQueueUrlResponse.queueUrl()); + return getQueueUrlResponse.queueUrl(); + + } catch (SqsException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return ""; + } + + // Connect the queue to the Lambda function as an event source. + public static void connectLambda(SqsClient sqsClient, LambdaClient lambdaClient, String queueUrl, String lambdaName) { + System.out.println("Connecting the Lambda function and queue for the pipeline."); + String queueArn=""; + + // Specify the attributes to retrieve. + List atts = new ArrayList<>(); + atts.add(QueueAttributeName.QUEUE_ARN); + GetQueueAttributesRequest attributesRequest= GetQueueAttributesRequest.builder() + .queueUrl(queueUrl) + .attributeNames(atts) + .build(); + + GetQueueAttributesResponse response = sqsClient.getQueueAttributes(attributesRequest); + Map queueAtts = response.attributesAsStrings(); + for (Map.Entry queueAtt : queueAtts.entrySet()) { + System.out.println("Key = " + queueAtt.getKey() + ", Value = " + queueAtt.getValue()); + queueArn = queueAtt.getValue(); + } + + CreateEventSourceMappingRequest eventSourceMappingRequest = CreateEventSourceMappingRequest.builder() + .eventSourceArn(queueArn) + .functionName(lambdaName) + .build(); + + CreateEventSourceMappingResponse response1 = lambdaClient.createEventSourceMapping(eventSourceMappingRequest); + eventSourceMapping = response1.uuid(); + System.out.println("The mapping between the event source and Lambda function was successful"); + } + + // Create an AWS Lambda function. + public static String createLambdaFunction(LambdaClient awsLambda, String functionName, String filePath, String role, String handler) { + try { + LambdaWaiter waiter = awsLambda.waiter(); + InputStream is = new FileInputStream(filePath); + SdkBytes fileToUpload = SdkBytes.fromInputStream(is); + FunctionCode code = FunctionCode.builder() + .zipFile(fileToUpload) + .build(); + + CreateFunctionRequest functionRequest = CreateFunctionRequest.builder() + .functionName(functionName) + .description("SageMaker example function.") + .code(code) + .handler(handler) + .runtime(Runtime.JAVA11) + .timeout(200) + .memorySize(1024) + .role(role) + .build(); + + // Create a Lambda function using a waiter. + CreateFunctionResponse functionResponse = awsLambda.createFunction(functionRequest); + GetFunctionRequest getFunctionRequest = GetFunctionRequest.builder() + .functionName(functionName) + .build(); + WaiterResponse waiterResponse = waiter.waitUntilFunctionExists(getFunctionRequest); + waiterResponse.matched().response().ifPresent(System.out::println); + System.out.println("The function ARN is " + functionResponse.functionArn()); + return functionResponse.functionArn(); + + } catch(LambdaException | FileNotFoundException e) { + System.err.println(e.getMessage()); + System.exit(1); + } + return ""; + } + + public static String createSageMakerRole(IamClient iam, String roleName) { + String[] sageMakerRolePolicies = getSageMakerRolePolicies(); + System.out.println("Creating a role to use with SageMaker."); + String assumeRolePolicy = "{" + + "\"Version\": \"2012-10-17\"," + + "\"Statement\": [{" + + "\"Effect\": \"Allow\"," + + "\"Principal\": {" + + "\"Service\": [" + + "\"sagemaker.amazonaws.com\"," + + "\"sagemaker-geospatial.amazonaws.com\"," + + "\"lambda.amazonaws.com\"," + + "\"s3.amazonaws.com\"" + + "]" + + "}," + + "\"Action\": \"sts:AssumeRole\"" + + "}]" + + "}"; + + try { + CreateRoleRequest request = CreateRoleRequest.builder() + .roleName(roleName) + .assumeRolePolicyDocument(assumeRolePolicy) + .description("Created using the AWS SDK for Java") + .build(); + + CreateRoleResponse roleResult = iam.createRole(request); + + // Attach the policies to the role. + for (String policy : sageMakerRolePolicies) { + AttachRolePolicyRequest attachRequest = AttachRolePolicyRequest.builder() + .roleName(roleName) + .policyArn(policy) + .build(); + + iam.attachRolePolicy(attachRequest); + } + + // Allow time for the role to be ready. + TimeUnit.SECONDS.sleep(15); + System.out.println("Role ready with ARN "+roleResult.role().arn()); + return roleResult.role().arn() ; + + } catch (IamException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + System.exit(1); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return "" ; + } + + private static String createLambdaRole(IamClient iam, String roleName) { + String [] lambdaRolePolicies = getLambdaRolePolicies(); + String assumeRolePolicy = "{" + + "\"Version\": \"2012-10-17\"," + + "\"Statement\": [{" + + "\"Effect\": \"Allow\"," + + "\"Principal\": {" + + "\"Service\": [" + + "\"sagemaker.amazonaws.com\"," + + "\"sagemaker-geospatial.amazonaws.com\"," + + "\"lambda.amazonaws.com\"," + + "\"s3.amazonaws.com\"" + + "]" + + "}," + + "\"Action\": \"sts:AssumeRole\"" + + "}]" + + "}"; + + try { + CreateRoleRequest request = CreateRoleRequest.builder() + .roleName(roleName) + .assumeRolePolicyDocument(assumeRolePolicy) + .description("Created using the AWS SDK for Java") + .build(); + + CreateRoleResponse roleResult = iam.createRole(request); + + // Attach the policies to the role. + for (String policy : lambdaRolePolicies) { + AttachRolePolicyRequest attachRequest = AttachRolePolicyRequest.builder() + .roleName(roleName) + .policyArn(policy) + .build(); + + iam.attachRolePolicy(attachRequest); + } + + // Allow time for the role to be ready. + TimeUnit.SECONDS.sleep(15); + System.out.println("Role ready with ARN "+roleResult.role().arn()); + return roleResult.role().arn() ; + + } catch (IamException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return ""; + } + + public static String checkFunction(LambdaClient lambdaClient, String functionName, String filePath, String role, String handler) { + System.out.println("Create an AWS Lambda function used in this workflow."); + String functionArn; + try { + // Does this function already exist. + GetFunctionRequest functionRequest = GetFunctionRequest.builder() + .functionName(functionName) + .build(); + + GetFunctionResponse response = lambdaClient.getFunction(functionRequest); + functionArn = response.configuration().functionArn(); + + } catch (LambdaException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + functionArn = createLambdaFunction(lambdaClient, functionName, filePath, role, handler); + } + return functionArn; + } + + // Check to see if the specific S3 bucket exists. If the S3 bucket exists, this method returns true. + public static boolean checkBucket(S3Client s3, String bucketName) { + try { + HeadBucketRequest headBucketRequest = HeadBucketRequest.builder() + .bucket(bucketName) + .build(); + + s3.headBucket(headBucketRequest); + System.out.println(bucketName +" exists"); + return true ; + + } catch (S3Exception e) { + System.err.println(e.awsErrorDetails().errorMessage()); + } + return false; + } + + // Checks to see if the Amazon SQS queue exists. If not, this method creates a new queue + // and returns the ARN value. + public static String checkQueue(SqsClient sqsClient, LambdaClient lambdaClient, String queueName, String lambdaName) { + System.out.println("Creating a queue for this use case."); + String queueUrl; + try { + GetQueueUrlRequest request = GetQueueUrlRequest.builder() + .queueName(queueName) + .build(); + + GetQueueUrlResponse response = sqsClient.getQueueUrl(request); + queueUrl = response.queueUrl(); + System.out.println(queueUrl); + + } catch (SqsException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + queueUrl = setupQueue(sqsClient, lambdaClient, queueName, lambdaName); + } + return queueUrl; + } + + // Checks to see if the Lambda role exists. If not, this method creates it. + public static String checkLambdaRole(IamClient iam, String roleName) { + System.out.println("Creating a role to for AWS Lambda to use."); + String roleArn; + try { + GetRoleRequest roleRequest = GetRoleRequest.builder() + .roleName(roleName) + .build(); + + GetRoleResponse response = iam.getRole(roleRequest); + roleArn = response.role().arn(); + System.out.println(roleArn); + + } catch (IamException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + roleArn = createLambdaRole(iam, roleName); + } + return roleArn; + } + + // Checks to see if the SageMaker role exists. If not, this method creates it. + public static String checkSageMakerRole(IamClient iam, String roleName) { + System.out.println("Creating a role to for AWS SageMaker to use."); + String roleArn; + try { + GetRoleRequest roleRequest = GetRoleRequest.builder() + .roleName(roleName) + .build(); + + GetRoleResponse response = iam.getRole(roleRequest); + roleArn = response.role().arn(); + System.out.println(roleArn); + + } catch (IamException e) { + System.err.println(e.awsErrorDetails().errorMessage()); + roleArn = createSageMakerRole(iam, roleName); + } + return roleArn; + } + + private static String[] getSageMakerRolePolicies() { + String[] sageMakerRolePolicies = new String[3]; + sageMakerRolePolicies[0] = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess"; + sageMakerRolePolicies[1] = "arn:aws:iam::aws:policy/" + "AmazonSageMakerGeospatialFullAccess"; + sageMakerRolePolicies[2] = "arn:aws:iam::aws:policy/AmazonSQSFullAccess"; + return sageMakerRolePolicies; + } + + private static String[] getLambdaRolePolicies() { + String[] lambdaRolePolicies = new String[5]; + lambdaRolePolicies[0] = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess"; + lambdaRolePolicies[1] = "arn:aws:iam::aws:policy/AmazonSQSFullAccess" ; + lambdaRolePolicies[2] = "arn:aws:iam::aws:policy/service-role/"+"AmazonSageMakerGeospatialFullAccess"; + lambdaRolePolicies[3] = "arn:aws:iam::aws:policy/service-role/"+"AmazonSageMakerServiceCatalogProductsLambdaServiceRolePolicy"; + lambdaRolePolicies[4] = "arn:aws:iam::aws:policy/service-role/"+"AWSLambdaSQSQueueExecutionRole"; + return lambdaRolePolicies; + } +} +//snippet-end:[sagemaker.java2.sc.main] \ No newline at end of file diff --git a/javav2/usecases/workflow_sagemaker_pipes/src/main/resources/config.properties b/javav2/usecases/workflow_sagemaker_pipes/src/main/resources/config.properties new file mode 100644 index 00000000000..1ae198c14dd --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_pipes/src/main/resources/config.properties @@ -0,0 +1,9 @@ +sageMakerRoleName = +lambdaRoleName = +functionFileLocation = +functionName = +queueName = +bucketName = +lnglatData = +spatialPipelinePath = +pipelineName = \ No newline at end of file diff --git a/javav2/usecases/workflow_sagemaker_pipes/src/test/java/SageMakerpipelineTest.java b/javav2/usecases/workflow_sagemaker_pipes/src/test/java/SageMakerpipelineTest.java new file mode 100644 index 00000000000..b1307547f78 --- /dev/null +++ b/javav2/usecases/workflow_sagemaker_pipes/src/test/java/SageMakerpipelineTest.java @@ -0,0 +1,175 @@ +import com.google.gson.Gson; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestMethodOrder; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.iam.IamClient; +import software.amazon.awssdk.services.lambda.LambdaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.sagemaker.SageMakerClient; +import software.amazon.awssdk.services.sqs.SqsClient; +import com.example.sage.*; +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; +import java.util.Random; +import java.util.Scanner; +import java.util.concurrent.TimeUnit; + +@TestInstance(TestInstance.Lifecycle.PER_METHOD) +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +public class SageMakerpipelineTest { + public static final String DASHES = new String(new char[80]).replace("\0", "-"); + + private static SageMakerClient sageMakerClient ; + private static IamClient iam ; + + private static LambdaClient lambdaClient ; + + private static SqsClient sqsClient ; + + private static S3Client s3Client ; + + private static String sageMakerRoleName = ""; + private static String lambdaRoleName = ""; + private static String functionFileLocation = ""; + private static String functionName = ""; + private static String queueName = "" ; + private static String bucketName = ""; + private static String lnglatData = "" ; + private static String spatialPipelinePath = ""; + private static String pipelineName = ""; + + @BeforeAll + public static void setUp() throws IOException { + Random random = new Random(); + int randomNum = random.nextInt((10000 - 1) + 1) + 1; + Region region = Region.US_WEST_2; + sageMakerClient = SageMakerClient.builder() + .region(region) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .build(); + + iam = IamClient.builder() + .region(region) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .build(); + + lambdaClient = LambdaClient.builder() + .region(region) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .build(); + + sqsClient = SqsClient.builder() + .region(region) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .build(); + + s3Client = S3Client.builder() + .region(region) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .build(); + + try (InputStream input = SageMakerpipelineTest.class.getClassLoader().getResourceAsStream("config.properties")) { + Properties prop = new Properties(); + if (input == null) { + System.out.println("Sorry, unable to find config.properties"); + return; + } + + // Populate the data members required for all tests. + prop.load(input); + sageMakerRoleName = prop.getProperty("sageMakerRoleName"); + lambdaRoleName = prop.getProperty("lambdaRoleName"); + functionFileLocation = prop.getProperty("functionFileLocation"); + functionName = prop.getProperty("functionName"); + queueName = prop.getProperty("queueName"); + bucketName = prop.getProperty("bucketName"); + lnglatData = prop.getProperty("lnglatData"); + spatialPipelinePath = prop.getProperty("spatialPipelinePath"); + pipelineName = prop.getProperty("pipelineName")+randomNum; + + } catch (IOException ex) { + ex.printStackTrace(); + } + } + + @Test + @Tag("IntegrationTest") + @Order(1) + public void testSagemakerWorkflow() throws InterruptedException { + String handlerName = "SageMakerLambda::SageMakerLambda.SageMakerLambdaFunction::FunctionHandler"; + System.out.println(DASHES); + System.out.println("Welcome to the Amazon SageMaker pipeline example scenario."); + System.out.println( + "\nThis example workflow will guide you through setting up and running an" + + "\nAmazon SageMaker pipeline. The pipeline uses an AWS Lambda function and an" + + "\nAmazon SQS Queue. It runs a vector enrichment reverse geocode job to" + + "\nreverse geocode addresses in an input file and store the results in an export file."); + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("First, we will set up the roles, functions, and queue needed by the SageMaker pipeline."); + String lambdaRoleArn = SagemakerWorkflow.checkSageMakerRole(iam, sageMakerRoleName); + String sageMakerRoleArn = SagemakerWorkflow.checkLambdaRole(iam, lambdaRoleName); + + String functionArn = SagemakerWorkflow.checkFunction(lambdaClient, functionName, functionFileLocation, lambdaRoleArn, handlerName); + String queueUrl = SagemakerWorkflow.checkQueue(sqsClient, lambdaClient, queueName, functionName); + System.out.println("The queue URL is "+queueUrl); + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("Setting up bucket "+bucketName); + if (!SagemakerWorkflow.checkBucket(s3Client, bucketName)) { + SagemakerWorkflow.setupBucket(s3Client, bucketName); + System.out.println("Put "+lnglatData +" into "+bucketName); + SagemakerWorkflow.putS3Object(s3Client, bucketName, "latlongtest.csv", lnglatData); + } + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("Now we can create and run our pipeline."); + SagemakerWorkflow.setupPipeline(sageMakerClient, spatialPipelinePath, sageMakerRoleArn, functionArn, pipelineName); + String pipelineExecutionARN = SagemakerWorkflow.executePipeline(sageMakerClient, bucketName, queueUrl, sageMakerRoleArn, pipelineName); + System.out.println("The pipeline execution ARN value is "+pipelineExecutionARN); + SagemakerWorkflow.waitForPipelineExecution(sageMakerClient, pipelineExecutionARN); + System.out.println("Getting output results "+bucketName); + SagemakerWorkflow.getOutputResults(s3Client, bucketName); + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("The pipeline has completed. To view the pipeline and runs " + + "in SageMaker Studio, follow these instructions:" + + "\nhttps://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-studio.html"); + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("Do you want to delete the AWS resources used in this Workflow? (y/n)"); + Scanner in = new Scanner(System.in); + String delResources = in.nextLine(); + if (delResources.compareTo("y") == 0) { + System.out.println("Lets clean up the AWS resources. Wait 30 seconds"); + TimeUnit.SECONDS.sleep(30); + SagemakerWorkflow.deleteEventSourceMapping(lambdaClient); + SagemakerWorkflow.deleteSQSQueue(sqsClient, queueName); + SagemakerWorkflow.listBucketObjects(s3Client, bucketName); + SagemakerWorkflow.deleteBucket(s3Client, bucketName); + SagemakerWorkflow.deleteLambdaFunction(lambdaClient, functionName); + SagemakerWorkflow.deleteLambdaRole(iam, lambdaRoleName); + SagemakerWorkflow.deleteSagemakerRole(iam, sageMakerRoleName); + SagemakerWorkflow.deletePipeline(sageMakerClient, pipelineName); + } else { + System.out.println("The AWS Resources were not deleted!"); + } + System.out.println(DASHES); + + System.out.println(DASHES); + System.out.println("SageMaker pipeline scenario is complete."); + System.out.println(DASHES); + } +}