Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make num_top_classes parameter's default value equal to 2 #48119

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public static Builder builder(String dependentVariable) {
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");

private static final ConstructingObjectParser<Classification, Void> PARSER =
new ConstructingObjectParser<>(
Expand All @@ -61,7 +62,8 @@ public static Builder builder(String dependentVariable) {
(Integer) a[4],
(Double) a[5],
(String) a[6],
(Double) a[7]));
(Double) a[7],
(Integer) a[8]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
Expand All @@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) {
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
}

private final String dependentVariable;
Expand All @@ -82,10 +85,11 @@ public static Builder builder(String dependentVariable) {
private final Double featureBagFraction;
private final String predictionFieldName;
private final Double trainingPercent;
private final Integer numTopClasses;

private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
@Nullable Double trainingPercent, @Nullable Integer numTopClasses) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
Expand All @@ -94,6 +98,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
this.numTopClasses = numTopClasses;
}

@Override
Expand Down Expand Up @@ -133,6 +138,10 @@ public Double getTrainingPercent() {
return trainingPercent;
}

public Integer getNumTopClasses() {
return numTopClasses;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand All @@ -158,14 +167,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (trainingPercent != null) {
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
}
if (numTopClasses != null) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
trainingPercent, numTopClasses);
}

@Override
Expand All @@ -180,7 +192,8 @@ public boolean equals(Object o) {
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent);
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(numTopClasses, that.numTopClasses);
}

@Override
Expand All @@ -197,6 +210,7 @@ public static class Builder {
private Double featureBagFraction;
private String predictionFieldName;
private Double trainingPercent;
private Integer numTopClasses;

private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
Expand Down Expand Up @@ -237,9 +251,14 @@ public Builder setTrainingPercent(Double trainingPercent) {
return this;
}

public Builder setNumTopClasses(Integer numTopClasses) {
this.numTopClasses = numTopClasses;
return this;
}

public Classification build() {
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
trainingPercent, numTopClasses);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1266,8 +1266,7 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
.setDest(DataFrameAnalyticsDest.builder()
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression
.builder("my_dependent_variable")
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
.setTrainingPercent(80.0)
.build())
.setDescription("this is a regression")
Expand Down Expand Up @@ -1301,9 +1300,9 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
.setDest(DataFrameAnalyticsDest.builder()
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification
.builder("my_dependent_variable")
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
.setTrainingPercent(80.0)
.setNumTopClasses(1)
.build())
.setDescription("this is a classification")
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2951,6 +2951,7 @@ public void testPutDataFrameAnalytics() throws Exception {
.setFeatureBagFraction(0.4) // <6>
.setPredictionFieldName("my_prediction_field_name") // <7>
.setTrainingPercent(50.0) // <8>
.setNumTopClasses(1) // <9>
.build();
// end::put-data-frame-analytics-classification

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public static Classification randomClassification() {
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
<7> The name of the prediction field in the results object.
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<9> The number of top classes to be reported in the results. Defaults to 2.

===== Regression

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
Stream.of(Types.categorical(), Types.discreteNumerical(), Types.bool())
.flatMap(Set::stream)
.collect(Collectors.toUnmodifiableSet());
/**
* As long as we only support binary classification it makes sense to always report both classes with their probabilities.
* This way the user can see if the prediction was made with confidence they need.
*/
private static final int DEFAULT_NUM_TOP_CLASSES = 2;

private final String dependentVariable;
private final BoostedTreeParams boostedTreeParams;
Expand All @@ -86,7 +91,7 @@ public Classification(String dependentVariable,
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
this.predictionFieldName = predictionFieldName;
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
}

Expand All @@ -106,6 +111,10 @@ public String getDependentVariable() {
return dependentVariable;
}

public int getNumTopClasses() {
return numTopClasses;
}

public double getTrainingPercent() {
return trainingPercent;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

public class ClassificationTests extends AbstractSerializingTestCase<Classification> {

private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);

@Override
protected Classification doParseInstance(XContentParser parser) throws IOException {
return Classification.fromXContent(parser, false);
Expand All @@ -43,32 +45,68 @@ protected Writeable.Reader<Classification> instanceReader() {
return Classification::new;
}

public void testConstructor_GivenTrainingPercentIsNull() {
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, null);
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}

public void testConstructor_GivenTrainingPercentIsBoundary() {
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 1.0);
assertThat(classification.getTrainingPercent(), equalTo(1.0));
classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0);
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}

public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 0.999));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999));

assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}

public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0001));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001));

assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}

public void testConstructor_GivenNumTopClassesIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0));

assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
}

public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0));

assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
}

public void testGetNumTopClasses() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
assertThat(classification.getNumTopClasses(), equalTo(7));

// Boundary condition: num_top_classes == 0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0);
assertThat(classification.getNumTopClasses(), equalTo(0));

// Boundary condition: num_top_classes == 1000
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0);
assertThat(classification.getNumTopClasses(), equalTo(1000));

// num_top_classes == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0);
assertThat(classification.getNumTopClasses(), equalTo(2));
}

public void testGetTrainingPercent() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
assertThat(classification.getTrainingPercent(), equalTo(50.0));

// Boundary condition: training_percent == 1.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0);
assertThat(classification.getTrainingPercent(), equalTo(1.0));

// Boundary condition: training_percent == 100.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0);
assertThat(classification.getTrainingPercent(), equalTo(100.0));

// training_percent == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null);
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}

public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
assertThat(resultsObject.containsKey("top_classes"), is(false));
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
}

assertProgress(jobId, 100, 100, 100, 100);
Expand Down Expand Up @@ -120,7 +120,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(true));
assertThat(resultsObject.containsKey("top_classes"), is(false));
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
}

assertProgress(jobId, 100, 100, 100, 100);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1810,7 +1810,7 @@ setup:
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"training_percent": 60.3,
"num_top_classes": 0
"num_top_classes": 2
}
}}
- is_true: create_time
Expand Down