diff --git a/pom.xml b/pom.xml index 6be7b06..396f780 100644 --- a/pom.xml +++ b/pom.xml @@ -154,7 +154,13 @@ ml.combust.mleap mleap-runtime_2.11 - 0.14.0 + 0.15.0 + + + + org.apache.spark + spark-mllib-local_2.11 + 2.4.5 org.apache.commons diff --git a/src/main/java/com/amazonaws/sagemaker/helper/DataConversionHelper.java b/src/main/java/com/amazonaws/sagemaker/helper/DataConversionHelper.java index b412685..af8e517 100644 --- a/src/main/java/com/amazonaws/sagemaker/helper/DataConversionHelper.java +++ b/src/main/java/com/amazonaws/sagemaker/helper/DataConversionHelper.java @@ -16,24 +16,14 @@ package com.amazonaws.sagemaker.helper; -import com.amazonaws.sagemaker.dto.DataSchema; import com.amazonaws.sagemaker.dto.ColumnSchema; +import com.amazonaws.sagemaker.dto.DataSchema; import com.amazonaws.sagemaker.type.BasicDataType; import com.amazonaws.sagemaker.type.DataStructureType; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; -import java.io.IOException; -import java.io.StringReader; -import java.util.List; -import java.util.stream.Collectors; -import ml.combust.mleap.core.types.BasicType; -import ml.combust.mleap.core.types.DataType; -import ml.combust.mleap.core.types.ListType; -import ml.combust.mleap.core.types.ScalarType; -import ml.combust.mleap.core.types.StructField; -import ml.combust.mleap.core.types.StructType; -import ml.combust.mleap.core.types.TensorType; +import ml.combust.mleap.core.types.*; import ml.combust.mleap.runtime.frame.ArrayRow; import ml.combust.mleap.runtime.frame.DefaultLeapFrame; import ml.combust.mleap.runtime.frame.Row; @@ -43,9 +33,15 @@ import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVRecord; import org.apache.commons.lang3.StringUtils; +import org.apache.spark.ml.linalg.Vectors; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import java.io.IOException; +import java.io.StringReader; +import java.util.List; +import java.util.stream.Collectors; + /** * Converter class to convert data between input to MLeap expected types and convert back MLeap helper to Java types * for output. @@ -168,12 +164,12 @@ protected Object convertInputDataToJavaType(final String type, final String stru default: throw new IllegalArgumentException("Given type is not supported"); } - } else { + } else if (!StringUtils.isBlank(structure) && StringUtils.equals(structure, DataStructureType.ARRAY)) { List listOfObjects; try { listOfObjects = (List) value; } catch (ClassCastException cce) { - throw new IllegalArgumentException("Input val is not a list but struct passed is vector or array"); + throw new IllegalArgumentException("Input val is not a list but struct passed is array"); } switch (type) { case BasicDataType.INTEGER: @@ -194,7 +190,17 @@ protected Object convertInputDataToJavaType(final String type, final String stru default: throw new IllegalArgumentException("Given type is not supported"); } - + } else { + if(!type.equals(BasicDataType.DOUBLE)) + throw new IllegalArgumentException("Only Double type is supported for vector"); + List vectorValues; + try { + vectorValues = (List)value; + } catch (ClassCastException cce) { + throw new IllegalArgumentException("Input val is not a list but struct passed is vector"); + } + double[] primitiveVectorValues = vectorValues.stream().mapToDouble(d -> d).toArray(); + return Vectors.dense(primitiveVectorValues); } } diff --git a/src/test/java/com/amazonaws/sagemaker/dto/SageMakerRequestObjectTest.java b/src/test/java/com/amazonaws/sagemaker/dto/SageMakerRequestObjectTest.java index f4a350e..1096044 100644 --- a/src/test/java/com/amazonaws/sagemaker/dto/SageMakerRequestObjectTest.java +++ b/src/test/java/com/amazonaws/sagemaker/dto/SageMakerRequestObjectTest.java @@ -18,11 +18,12 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; -import java.io.IOException; import org.apache.commons.io.IOUtils; import org.junit.Assert; import org.junit.Test; +import java.io.IOException; + public class SageMakerRequestObjectTest { private ObjectMapper mapper = new ObjectMapper(); @@ -80,14 +81,14 @@ public void testParseCompleteInputJson() throws IOException { Assert.assertEquals(sro.getSchema().getInput().get(0).getName(), "name_1"); Assert.assertEquals(sro.getSchema().getInput().get(1).getName(), "name_2"); Assert.assertEquals(sro.getSchema().getInput().get(2).getName(), "name_3"); - Assert.assertEquals(sro.getSchema().getInput().get(0).getType(), "int"); + Assert.assertEquals(sro.getSchema().getInput().get(0).getType(), "double"); Assert.assertEquals(sro.getSchema().getInput().get(1).getType(), "string"); Assert.assertEquals(sro.getSchema().getInput().get(2).getType(), "double"); Assert.assertEquals(sro.getSchema().getInput().get(0).getStruct(), "vector"); Assert.assertEquals(sro.getSchema().getInput().get(1).getStruct(), "basic"); Assert.assertEquals(sro.getSchema().getInput().get(2).getStruct(), "array"); Assert.assertEquals(sro.getData(), - Lists.newArrayList(Lists.newArrayList(1, 2, 3), "C", Lists.newArrayList(38.0, 24.0))); + Lists.newArrayList(Lists.newArrayList(1.0, 2.0, 3.0), "C", Lists.newArrayList(38.0, 24.0))); Assert.assertEquals(sro.getSchema().getOutput().getName(), "features"); Assert.assertEquals(sro.getSchema().getOutput().getType(), "double"); Assert.assertEquals(sro.getSchema().getOutput().getStruct(), "vector"); diff --git a/src/test/java/com/amazonaws/sagemaker/helper/DataConversionHelperTest.java b/src/test/java/com/amazonaws/sagemaker/helper/DataConversionHelperTest.java index d070ae3..1ce188c 100644 --- a/src/test/java/com/amazonaws/sagemaker/helper/DataConversionHelperTest.java +++ b/src/test/java/com/amazonaws/sagemaker/helper/DataConversionHelperTest.java @@ -22,8 +22,6 @@ import com.amazonaws.sagemaker.type.DataStructureType; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; -import java.io.IOException; -import java.util.List; import ml.combust.mleap.core.types.ListType; import ml.combust.mleap.core.types.ScalarType; import ml.combust.mleap.core.types.TensorType; @@ -32,9 +30,13 @@ import ml.combust.mleap.runtime.javadsl.LeapFrameBuilder; import ml.combust.mleap.runtime.javadsl.LeapFrameBuilderSupport; import org.apache.commons.io.IOUtils; +import org.apache.spark.ml.linalg.Vectors; import org.junit.Assert; import org.junit.Test; +import java.io.IOException; +import java.util.List; + public class DataConversionHelperTest { private ObjectMapper mapper = new ObjectMapper(); @@ -143,21 +145,11 @@ public void testCastingInputToJavaTypeSingle() { @Test public void testCastingInputToJavaTypeList() { - Assert.assertEquals(Lists.newArrayList(1, 2), dataConversionHelper - .convertInputDataToJavaType(BasicDataType.INTEGER, DataStructureType.VECTOR, - Lists.newArrayList(new Integer("1"), new Integer("2")))); - - Assert.assertEquals(Lists.newArrayList(1.0, 2.0), dataConversionHelper - .convertInputDataToJavaType(BasicDataType.FLOAT, DataStructureType.VECTOR, - Lists.newArrayList(new Double("1.0"), new Double("2.0")))); - Assert.assertEquals(Lists.newArrayList(1.0, 2.0), dataConversionHelper - .convertInputDataToJavaType(BasicDataType.DOUBLE, DataStructureType.VECTOR, - Lists.newArrayList(new Double("1.0"), new Double("2.0")))); - - Assert.assertEquals(Lists.newArrayList(new Byte("1")), dataConversionHelper - .convertInputDataToJavaType(BasicDataType.BYTE, DataStructureType.VECTOR, - Lists.newArrayList(new Byte("1")))); + //Check vector struct and double type returns a Spark vector + Assert.assertEquals(Vectors.dense(new double[]{1.0, 2.0}),dataConversionHelper + .convertInputDataToJavaType(BasicDataType.DOUBLE, DataStructureType.VECTOR, + Lists.newArrayList(new Double("1.0"), new Double("2.0")))); Assert.assertEquals(Lists.newArrayList(1L, 2L), dataConversionHelper .convertInputDataToJavaType(BasicDataType.LONG, DataStructureType.ARRAY, @@ -175,6 +167,12 @@ public void testCastingInputToJavaTypeList() { Lists.newArrayList(Boolean.valueOf("1")))); } + @Test(expected = IllegalArgumentException.class) + public void testConvertInputToJavaTypeNonDoibleVector() { + dataConversionHelper + .convertInputDataToJavaType(BasicDataType.INTEGER, DataStructureType.VECTOR, new Integer("1")); + } + @Test(expected = IllegalArgumentException.class) public void testCastingInputToJavaTypeNonList() { dataConversionHelper diff --git a/src/test/resources/com/amazonaws/sagemaker/dto/complete_input.json b/src/test/resources/com/amazonaws/sagemaker/dto/complete_input.json index 02f909c..9f614cf 100644 --- a/src/test/resources/com/amazonaws/sagemaker/dto/complete_input.json +++ b/src/test/resources/com/amazonaws/sagemaker/dto/complete_input.json @@ -3,7 +3,7 @@ "input": [ { "name": "name_1", - "type": "int", + "type": "double", "struct": "vector" }, { @@ -23,5 +23,5 @@ "struct": "vector" } }, - "data": [[1, 2, 3], "C", [38.0, 24.0]] + "data": [[1.0, 2.0, 3.0], "C", [38.0, 24.0]] } \ No newline at end of file