From 14f9bb22715660809bd96146ec850232681fa97e Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 8 Mar 2022 15:35:22 -0800 Subject: [PATCH] more strict check on input parameters by applying non-coerce mode (#173) (#174) Signed-off-by: Yaliang Wu (cherry picked from commit c50e0d6aa6ab63723a64af94ccccefb54ef50a22) Co-authored-by: Yaliang Wu --- .../parameter/AnomalyDetectionParams.java | 12 +++--- .../ml/common/parameter/BatchRCFParams.java | 12 +++--- .../ml/common/parameter/FitRCFParams.java | 12 +++--- .../ml/common/parameter/KMeansParams.java | 5 +-- .../parameter/LinearRegressionParams.java | 18 ++++----- .../parameter/LocalSampleCalculatorInput.java | 2 +- .../ml/common/parameter/MLModel.java | 2 +- .../ml/common/parameter/SampleAlgoParams.java | 2 +- .../org/opensearch/ml/common/TestHelper.java | 6 +++ .../ml/common/parameter/KMeansParamsTest.java | 22 +++++++++++ .../parameter/LinearRegressionParamsTest.java | 37 +++++++++++++++++++ .../AnomalyLocalizationInput.java | 10 ++--- 12 files changed, 102 insertions(+), 38 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/AnomalyDetectionParams.java b/common/src/main/java/org/opensearch/ml/common/parameter/AnomalyDetectionParams.java index 8c5793ff9a..667f64eae6 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/AnomalyDetectionParams.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/AnomalyDetectionParams.java @@ -88,22 +88,22 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { kernelType = ADKernelType.from(parser.text().toUpperCase(Locale.ROOT)); break; case GAMMA_FIELD: - gamma = parser.doubleValue(); + gamma = parser.doubleValue(false); break; case NU_FIELD: - nu = parser.doubleValue(); + nu = parser.doubleValue(false); break; case COST_FIELD: - cost = parser.doubleValue(); + cost = parser.doubleValue(false); break; case COEFF_FIELD: - coeff = parser.doubleValue(); + coeff = parser.doubleValue(false); break; case EPSILON_FIELD: - epsilon = parser.doubleValue(); + epsilon = parser.doubleValue(false); break; case DEGREE_FIELD: - degree = parser.intValue(); + degree = parser.intValue(false); break; default: parser.skipChildren(); diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/BatchRCFParams.java b/common/src/main/java/org/opensearch/ml/common/parameter/BatchRCFParams.java index 8da997bf2f..488d792f8e 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/BatchRCFParams.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/BatchRCFParams.java @@ -91,22 +91,22 @@ public static BatchRCFParams parse(XContentParser parser) throws IOException { switch (fieldName) { case NUMBER_OF_TREES: - numberOfTrees = parser.intValue(); + numberOfTrees = parser.intValue(false); break; case SHINGLE_SIZE: - shingleSize = parser.intValue(); + shingleSize = parser.intValue(false); break; case SAMPLE_SIZE: - sampleSize = parser.intValue(); + sampleSize = parser.intValue(false); break; case OUTPUT_AFTER: - outputAfter = parser.intValue(); + outputAfter = parser.intValue(false); break; case TRAINING_DATA_SIZE: - trainingDataSize = parser.intValue(); + trainingDataSize = parser.intValue(false); break; case ANOMALY_SCORE_THRESHOLD: - anomalyScoreThreshold = parser.doubleValue(); + anomalyScoreThreshold = parser.doubleValue(false); break; default: parser.skipChildren(); diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/FitRCFParams.java b/common/src/main/java/org/opensearch/ml/common/parameter/FitRCFParams.java index 51e55af094..6214ac7874 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/FitRCFParams.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/FitRCFParams.java @@ -112,22 +112,22 @@ public static FitRCFParams parse(XContentParser parser) throws IOException { switch (fieldName) { case NUMBER_OF_TREES: - numberOfTrees = parser.intValue(); + numberOfTrees = parser.intValue(false); break; case SHINGLE_SIZE: - shingleSize = parser.intValue(); + shingleSize = parser.intValue(false); break; case SAMPLE_SIZE: - sampleSize = parser.intValue(); + sampleSize = parser.intValue(false); break; case OUTPUT_AFTER: - outputAfter = parser.intValue(); + outputAfter = parser.intValue(false); break; case TIME_DECAY: - timeDecay = parser.doubleValue(); + timeDecay = parser.doubleValue(false); break; case ANOMALY_RATE: - anomalyRate = parser.doubleValue(); + anomalyRate = parser.doubleValue(false); break; case TIME_FIELD: timeField = parser.text(); diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java b/common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java index 7ed6a02963..93291438e9 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java @@ -7,7 +7,6 @@ import lombok.Builder; import lombok.Data; -import lombok.Getter; import org.opensearch.common.ParseField; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -70,10 +69,10 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { switch (fieldName) { case CENTROIDS_FIELD: - k = parser.intValue(); + k = parser.intValue(false); break; case ITERATIONS_FIELD: - iterations = parser.intValue(); + iterations = parser.intValue(false); break; case DISTANCE_TYPE_FIELD: distanceType = DistanceType.from(parser.text()); diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java b/common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java index 2772bd8cd6..68bd7830e7 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java @@ -126,34 +126,34 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { optimizerType = OptimizerType.valueOf(parser.text().toUpperCase(Locale.ROOT)); break; case LEARNING_RATE_FIELD: - learningRate = parser.doubleValue(); + learningRate = parser.doubleValue(false); break; case MOMENTUM_TYPE_FIELD: momentumType = MomentumType.valueOf(parser.text().toUpperCase(Locale.ROOT)); break; case MOMENTUM_FACTOR_FIELD: - momentumFactor = parser.doubleValue(); + momentumFactor = parser.doubleValue(false); break; case EPSILON_FIELD: - epsilon = parser.doubleValue(); + epsilon = parser.doubleValue(false); break; case BETA1_FIELD: - beta1 = parser.doubleValue(); + beta1 = parser.doubleValue(false); break; case BETA2_FIELD: - beta2 = parser.doubleValue(); + beta2 = parser.doubleValue(false); break; case DECAY_RATE_FIELD: - decayRate = parser.doubleValue(); + decayRate = parser.doubleValue(false); break; case EPOCHS_FIELD: - epochs = parser.intValue(); + epochs = parser.intValue(false); break; case BATCH_SIZE_FIELD: - batchSize = parser.intValue(); + batchSize = parser.intValue(false); break; case SEED_FIELD: - seed = parser.longValue(); + seed = parser.longValue(false); break; case TARGET_FIELD: target = parser.text(); diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java b/common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java index 2553e331b6..dd6e71b572 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java @@ -48,7 +48,7 @@ public static LocalSampleCalculatorInput parse(XContentParser parser) throws IOE case INPUT_DATA_FIELD: ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - inputData.add(parser.doubleValue()); + inputData.add(parser.doubleValue(false)); } break; default: diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java b/common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java index 86e9b05760..1433586547 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java @@ -113,7 +113,7 @@ public static MLModel parse(XContentParser parser) throws IOException { content = parser.text(); break; case MODEL_VERSION: - version = parser.intValue(); + version = parser.intValue(false); break; case USER: user = User.parse(parser); diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/SampleAlgoParams.java b/common/src/main/java/org/opensearch/ml/common/parameter/SampleAlgoParams.java index 1efe6f72d1..e7aa948567 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/SampleAlgoParams.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/SampleAlgoParams.java @@ -51,7 +51,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { switch (fieldName) { case SAMPLE_PARAM_FIELD: - sampleParam = parser.intValue(); + sampleParam = parser.intValue(false); break; default: parser.skipChildren(); diff --git a/common/src/test/java/org/opensearch/ml/common/TestHelper.java b/common/src/test/java/org/opensearch/ml/common/TestHelper.java index 3a843e7cdb..3934417514 100644 --- a/common/src/test/java/org/opensearch/ml/common/TestHelper.java +++ b/common/src/test/java/org/opensearch/ml/common/TestHelper.java @@ -45,6 +45,12 @@ public static void testParseFromString(ToXContentObject obj, String jsonStr, obj.equals(parsedObj); } + public static String contentObjectToString(ToXContentObject obj) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + obj.toXContent(builder, ToXContent.EMPTY_PARAMS); + return xContentBuilderToString(builder); + } + public static String xContentBuilderToString(XContentBuilder builder) { return BytesReference.bytes(builder).utf8ToString(); } diff --git a/common/src/test/java/org/opensearch/ml/common/parameter/KMeansParamsTest.java b/common/src/test/java/org/opensearch/ml/common/parameter/KMeansParamsTest.java index 64ef443443..d595890ab1 100644 --- a/common/src/test/java/org/opensearch/ml/common/parameter/KMeansParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/parameter/KMeansParamsTest.java @@ -6,7 +6,9 @@ package org.opensearch.ml.common.parameter; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentParser; @@ -16,8 +18,12 @@ import java.util.function.Function; import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; public class KMeansParamsTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); KMeansParams params; private Function function = parser -> { @@ -42,6 +48,22 @@ public void parse_KMeansParams() throws IOException { TestHelper.testParse(params, function); } + @Test + public void parse_KMeansParams_InvalidDoubleValue() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("10.01 cannot be converted to Integer without data loss"); + String paramsStr = contentObjectToString(params); + testParseFromString(params, paramsStr.replace("\"iterations\":10,", "\"iterations\":10.01,"), function); + } + + @Test + public void parse_KMeansParams_InvalidDoubleString() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Integer value passed as String"); + String paramsStr = contentObjectToString(params); + testParseFromString(params, paramsStr.replace("\"iterations\":10,", "\"iterations\":\"10.01\","), function); + } + @Test public void parse_EmptyKMeansParams() throws IOException { TestHelper.testParse(KMeansParams.builder().build(), function); diff --git a/common/src/test/java/org/opensearch/ml/common/parameter/LinearRegressionParamsTest.java b/common/src/test/java/org/opensearch/ml/common/parameter/LinearRegressionParamsTest.java index 66ecaca80a..1bac41e4d6 100644 --- a/common/src/test/java/org/opensearch/ml/common/parameter/LinearRegressionParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/parameter/LinearRegressionParamsTest.java @@ -6,7 +6,9 @@ package org.opensearch.ml.common.parameter; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentParser; @@ -16,9 +18,14 @@ import java.util.function.Function; import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; public class LinearRegressionParamsTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + private Function function = parser -> { try { return (LinearRegressionParams) LinearRegressionParams.parse(parser); @@ -59,6 +66,36 @@ public void readInputStream_Success() throws IOException { assertEquals(params, parsedParams); } + @Test + public void parse_PassIntValueToDoubleField() throws IOException { + LinearRegressionParams params = LinearRegressionParams + .builder() + .objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS) + .optimizerType(LinearRegressionParams.OptimizerType.ADAM) + .learningRate(0.1) + .momentumType(LinearRegressionParams.MomentumType.NESTEROV) + .momentumFactor(0.2) + .epsilon(3.0) + .beta1(0.4) + .beta2(0.5) + .decayRate(0.6) + .epochs(1) + .batchSize(2) + .seed(3L) + .target("test_target") + .build(); + String paramsStr = contentObjectToString(params); + testParseFromString(params, paramsStr.replace("\"epsilon\":3.0,", "\"epsilon\":3,"), function); + } + + @Test + public void parse_InvalidParam_InvalidDoubleValue() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Double value passed as String"); + String paramsStr = contentObjectToString(params); + testParseFromString(params, paramsStr.replace("\"epsilon\":0.3,", "\"epsilon\":\"0.3\","), function); + } + @Test public void readInputStream_Success_Empty() throws IOException { LinearRegressionParams linearRegressionParams = LinearRegressionParams.builder().build(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationInput.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationInput.java index d22c6cc9f4..1f730c4b28 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationInput.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationInput.java @@ -93,23 +93,23 @@ public static AnomalyLocalizationInput parse(XContentParser parser) throws IOExc break; case FIELD_START_TIME: parser.nextToken(); - startTime = parser.longValue(); + startTime = parser.longValue(false); break; case FIELD_END_TIME: parser.nextToken(); - endTime = parser.longValue(); + endTime = parser.longValue(false); break; case FIELD_MIN_TIME_INTERVAL: parser.nextToken(); - minTimeInterval = parser.longValue(); + minTimeInterval = parser.longValue(false); break; case FIELD_NUM_OUTPUTS: parser.nextToken(); - numOutputs = parser.intValue(); + numOutputs = parser.intValue(false); break; case FIELD_ANOMALY_START_TIME: parser.nextToken(); - anomalyStartTime = Optional.of(parser.longValue()); + anomalyStartTime = Optional.of(parser.longValue(false)); break; case FIELD_FILTER_QUERY: ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);