Skip to content

Commit

Permalink
more strict check on input parameters by applying non-coerce mode (#173
Browse files Browse the repository at this point in the history
…) (#174)

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
(cherry picked from commit c50e0d6)

Co-authored-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and ylwu-amzn authored Mar 8, 2022
1 parent e895405 commit 14f9bb2
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,22 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
kernelType = ADKernelType.from(parser.text().toUpperCase(Locale.ROOT));
break;
case GAMMA_FIELD:
gamma = parser.doubleValue();
gamma = parser.doubleValue(false);
break;
case NU_FIELD:
nu = parser.doubleValue();
nu = parser.doubleValue(false);
break;
case COST_FIELD:
cost = parser.doubleValue();
cost = parser.doubleValue(false);
break;
case COEFF_FIELD:
coeff = parser.doubleValue();
coeff = parser.doubleValue(false);
break;
case EPSILON_FIELD:
epsilon = parser.doubleValue();
epsilon = parser.doubleValue(false);
break;
case DEGREE_FIELD:
degree = parser.intValue();
degree = parser.intValue(false);
break;
default:
parser.skipChildren();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,22 @@ public static BatchRCFParams parse(XContentParser parser) throws IOException {

switch (fieldName) {
case NUMBER_OF_TREES:
numberOfTrees = parser.intValue();
numberOfTrees = parser.intValue(false);
break;
case SHINGLE_SIZE:
shingleSize = parser.intValue();
shingleSize = parser.intValue(false);
break;
case SAMPLE_SIZE:
sampleSize = parser.intValue();
sampleSize = parser.intValue(false);
break;
case OUTPUT_AFTER:
outputAfter = parser.intValue();
outputAfter = parser.intValue(false);
break;
case TRAINING_DATA_SIZE:
trainingDataSize = parser.intValue();
trainingDataSize = parser.intValue(false);
break;
case ANOMALY_SCORE_THRESHOLD:
anomalyScoreThreshold = parser.doubleValue();
anomalyScoreThreshold = parser.doubleValue(false);
break;
default:
parser.skipChildren();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,22 @@ public static FitRCFParams parse(XContentParser parser) throws IOException {

switch (fieldName) {
case NUMBER_OF_TREES:
numberOfTrees = parser.intValue();
numberOfTrees = parser.intValue(false);
break;
case SHINGLE_SIZE:
shingleSize = parser.intValue();
shingleSize = parser.intValue(false);
break;
case SAMPLE_SIZE:
sampleSize = parser.intValue();
sampleSize = parser.intValue(false);
break;
case OUTPUT_AFTER:
outputAfter = parser.intValue();
outputAfter = parser.intValue(false);
break;
case TIME_DECAY:
timeDecay = parser.doubleValue();
timeDecay = parser.doubleValue(false);
break;
case ANOMALY_RATE:
anomalyRate = parser.doubleValue();
anomalyRate = parser.doubleValue(false);
break;
case TIME_FIELD:
timeField = parser.text();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import org.opensearch.common.ParseField;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -70,10 +69,10 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {

switch (fieldName) {
case CENTROIDS_FIELD:
k = parser.intValue();
k = parser.intValue(false);
break;
case ITERATIONS_FIELD:
iterations = parser.intValue();
iterations = parser.intValue(false);
break;
case DISTANCE_TYPE_FIELD:
distanceType = DistanceType.from(parser.text());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,34 +126,34 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
optimizerType = OptimizerType.valueOf(parser.text().toUpperCase(Locale.ROOT));
break;
case LEARNING_RATE_FIELD:
learningRate = parser.doubleValue();
learningRate = parser.doubleValue(false);
break;
case MOMENTUM_TYPE_FIELD:
momentumType = MomentumType.valueOf(parser.text().toUpperCase(Locale.ROOT));
break;
case MOMENTUM_FACTOR_FIELD:
momentumFactor = parser.doubleValue();
momentumFactor = parser.doubleValue(false);
break;
case EPSILON_FIELD:
epsilon = parser.doubleValue();
epsilon = parser.doubleValue(false);
break;
case BETA1_FIELD:
beta1 = parser.doubleValue();
beta1 = parser.doubleValue(false);
break;
case BETA2_FIELD:
beta2 = parser.doubleValue();
beta2 = parser.doubleValue(false);
break;
case DECAY_RATE_FIELD:
decayRate = parser.doubleValue();
decayRate = parser.doubleValue(false);
break;
case EPOCHS_FIELD:
epochs = parser.intValue();
epochs = parser.intValue(false);
break;
case BATCH_SIZE_FIELD:
batchSize = parser.intValue();
batchSize = parser.intValue(false);
break;
case SEED_FIELD:
seed = parser.longValue();
seed = parser.longValue(false);
break;
case TARGET_FIELD:
target = parser.text();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public static LocalSampleCalculatorInput parse(XContentParser parser) throws IOE
case INPUT_DATA_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
inputData.add(parser.doubleValue());
inputData.add(parser.doubleValue(false));
}
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
content = parser.text();
break;
case MODEL_VERSION:
version = parser.intValue();
version = parser.intValue(false);
break;
case USER:
user = User.parse(parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {

switch (fieldName) {
case SAMPLE_PARAM_FIELD:
sampleParam = parser.intValue();
sampleParam = parser.intValue(false);
break;
default:
parser.skipChildren();
Expand Down
6 changes: 6 additions & 0 deletions common/src/test/java/org/opensearch/ml/common/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ public static <T> void testParseFromString(ToXContentObject obj, String jsonStr,
obj.equals(parsedObj);
}

public static String contentObjectToString(ToXContentObject obj) throws IOException {
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
obj.toXContent(builder, ToXContent.EMPTY_PARAMS);
return xContentBuilderToString(builder);
}

public static String xContentBuilderToString(XContentBuilder builder) {
return BytesReference.bytes(builder).utf8ToString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.ml.common.parameter;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentParser;
Expand All @@ -16,8 +18,12 @@
import java.util.function.Function;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.TestHelper.contentObjectToString;
import static org.opensearch.ml.common.TestHelper.testParseFromString;

public class KMeansParamsTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

KMeansParams params;
private Function<XContentParser, KMeansParams> function = parser -> {
Expand All @@ -42,6 +48,22 @@ public void parse_KMeansParams() throws IOException {
TestHelper.testParse(params, function);
}

@Test
public void parse_KMeansParams_InvalidDoubleValue() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("10.01 cannot be converted to Integer without data loss");
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("\"iterations\":10,", "\"iterations\":10.01,"), function);
}

@Test
public void parse_KMeansParams_InvalidDoubleString() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Integer value passed as String");
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("\"iterations\":10,", "\"iterations\":\"10.01\","), function);
}

@Test
public void parse_EmptyKMeansParams() throws IOException {
TestHelper.testParse(KMeansParams.builder().build(), function);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.ml.common.parameter;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentParser;
Expand All @@ -16,9 +18,14 @@
import java.util.function.Function;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.TestHelper.contentObjectToString;
import static org.opensearch.ml.common.TestHelper.testParseFromString;

public class LinearRegressionParamsTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

private Function<XContentParser, LinearRegressionParams> function = parser -> {
try {
return (LinearRegressionParams) LinearRegressionParams.parse(parser);
Expand Down Expand Up @@ -59,6 +66,36 @@ public void readInputStream_Success() throws IOException {
assertEquals(params, parsedParams);
}

@Test
public void parse_PassIntValueToDoubleField() throws IOException {
LinearRegressionParams params = LinearRegressionParams
.builder()
.objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS)
.optimizerType(LinearRegressionParams.OptimizerType.ADAM)
.learningRate(0.1)
.momentumType(LinearRegressionParams.MomentumType.NESTEROV)
.momentumFactor(0.2)
.epsilon(3.0)
.beta1(0.4)
.beta2(0.5)
.decayRate(0.6)
.epochs(1)
.batchSize(2)
.seed(3L)
.target("test_target")
.build();
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("\"epsilon\":3.0,", "\"epsilon\":3,"), function);
}

@Test
public void parse_InvalidParam_InvalidDoubleValue() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Double value passed as String");
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("\"epsilon\":0.3,", "\"epsilon\":\"0.3\","), function);
}

@Test
public void readInputStream_Success_Empty() throws IOException {
LinearRegressionParams linearRegressionParams = LinearRegressionParams.builder().build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,23 @@ public static AnomalyLocalizationInput parse(XContentParser parser) throws IOExc
break;
case FIELD_START_TIME:
parser.nextToken();
startTime = parser.longValue();
startTime = parser.longValue(false);
break;
case FIELD_END_TIME:
parser.nextToken();
endTime = parser.longValue();
endTime = parser.longValue(false);
break;
case FIELD_MIN_TIME_INTERVAL:
parser.nextToken();
minTimeInterval = parser.longValue();
minTimeInterval = parser.longValue(false);
break;
case FIELD_NUM_OUTPUTS:
parser.nextToken();
numOutputs = parser.intValue();
numOutputs = parser.intValue(false);
break;
case FIELD_ANOMALY_START_TIME:
parser.nextToken();
anomalyStartTime = Optional.of(parser.longValue());
anomalyStartTime = Optional.of(parser.longValue(false));
break;
case FIELD_FILTER_QUERY:
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Expand Down

0 comments on commit 14f9bb2

Please sign in to comment.