Skip to content

Commit

Permalink
[ML][Inference] Adding classification_weights to ensemble models (#50874
Browse files Browse the repository at this point in the history
) (#50994)

* [ML][Inference] Adding classification_weights to ensemble models

classification_weights are a way to allow models to
prefer specific classification results over others
this might be advantageous if classification value
probabilities are a known quantity and can improve
model error rates.
  • Loading branch information
benwtrent authored Jan 14, 2020
1 parent de5713f commit 72c2709
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
Expand All @@ -41,6 +42,7 @@ public class Ensemble implements TrainedModel {
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights");

private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
NAME,
Expand All @@ -60,6 +62,7 @@ public class Ensemble implements TrainedModel {
AGGREGATE_OUTPUT);
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
PARSER.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS);
}

public static Ensemble fromXContent(XContentParser parser) {
Expand All @@ -71,17 +74,20 @@ public static Ensemble fromXContent(XContentParser parser) {
private final OutputAggregator outputAggregator;
private final TargetType targetType;
private final List<String> classificationLabels;
private final double[] classificationWeights;

Ensemble(List<String> featureNames,
List<TrainedModel> models,
@Nullable OutputAggregator outputAggregator,
TargetType targetType,
@Nullable List<String> classificationLabels) {
@Nullable List<String> classificationLabels,
@Nullable double[] classificationWeights) {
this.featureNames = featureNames;
this.models = models;
this.outputAggregator = outputAggregator;
this.targetType = targetType;
this.classificationLabels = classificationLabels;
this.classificationWeights = classificationWeights;
}

@Override
Expand Down Expand Up @@ -116,6 +122,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (classificationLabels != null) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
}
if (classificationWeights != null) {
builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), classificationWeights);
}
builder.endObject();
return builder;
}
Expand All @@ -129,12 +138,18 @@ public boolean equals(Object o) {
&& Objects.equals(models, that.models)
&& Objects.equals(targetType, that.targetType)
&& Objects.equals(classificationLabels, that.classificationLabels)
&& Arrays.equals(classificationWeights, that.classificationWeights)
&& Objects.equals(outputAggregator, that.outputAggregator);
}

@Override
public int hashCode() {
return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType);
return Objects.hash(featureNames,
models,
outputAggregator,
classificationLabels,
targetType,
Arrays.hashCode(classificationWeights));
}

public static Builder builder() {
Expand All @@ -147,6 +162,7 @@ public static class Builder {
private OutputAggregator outputAggregator;
private TargetType targetType;
private List<String> classificationLabels;
private double[] classificationWeights;

public Builder setFeatureNames(List<String> featureNames) {
this.featureNames = featureNames;
Expand All @@ -173,6 +189,11 @@ public Builder setClassificationLabels(List<String> classificationLabels) {
return this;
}

public Builder setClassificationWeights(List<Double> classificationWeights) {
this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
return this;
}

private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
this.setOutputAggregator(outputAggregators.get(0));
}
Expand All @@ -182,7 +203,7 @@ private void setTargetType(String targetType) {
}

public Ensemble build() {
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels, classificationWeights);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,19 @@ public static Ensemble createRandom(TargetType targetType) {
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
Stream.generate(ESTestCase::randomDouble)
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
.mapToDouble(Double::valueOf)
.toArray() :
null;

return new Ensemble(featureNames,
models,
outputAggregator,
targetType,
categoryLabels);
categoryLabels,
thresholds);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,26 @@ public static class TopClassEntry implements Writeable {

public final ParseField CLASS_NAME = new ParseField("class_name");
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public final ParseField CLASS_SCORE = new ParseField("class_score");

private final String classification;
private final double probability;
private final double score;

public TopClassEntry(String classification, Double probability) {
public TopClassEntry(String classification, double probability) {
this(classification, probability, probability);
}

public TopClassEntry(String classification, double probability, double score) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
this.probability = ExceptionsHelper.requireNonNull(probability, CLASS_PROBABILITY);
this.probability = probability;
this.score = score;
}

public TopClassEntry(StreamInput in) throws IOException {
this.classification = in.readString();
this.probability = in.readDouble();
this.score = in.readDouble();
}

public String getClassification() {
Expand All @@ -134,31 +142,36 @@ public double getProbability() {
return probability;
}

public double getScore() {
return score;
}

public Map<String, Object> asValueMap() {
Map<String, Object> map = new HashMap<>(2);
Map<String, Object> map = new HashMap<>(3, 1.0f);
map.put(CLASS_NAME.getPreferredName(), classification);
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
map.put(CLASS_SCORE.getPreferredName(), score);
return map;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(classification);
out.writeDouble(probability);
out.writeDouble(score);
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) &&
Objects.equals(probability, that.probability);
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
}

@Override
public int hashCode() {
return Objects.hash(classification, probability);
return Objects.hash(classification, probability, score);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

Expand All @@ -20,25 +21,38 @@ public final class InferenceHelpers {

private InferenceHelpers() { }

public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List<Double> probabilities,
List<String> classificationLabels,
int numToInclude) {
if (numToInclude == 0) {
return Collections.emptyList();
}
int[] sortedIndices = IntStream.range(0, probabilities.size())
.boxed()
.sorted(Comparator.comparing(probabilities::get).reversed())
.mapToInt(i -> i)
.toArray();
/**
* @return Tuple of the highest scored index and the top classes
*/
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
List<String> classificationLabels,
@Nullable double[] classificationWeights,
int numToInclude) {

if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
throw ExceptionsHelper
.serverError(
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
null,
probabilities.size(),
classificationLabels);
classificationLabels.size());
}

List<Double> scores = classificationWeights == null ?
probabilities :
IntStream.range(0, probabilities.size())
.mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
.boxed()
.collect(Collectors.toList());

int[] sortedIndices = IntStream.range(0, probabilities.size())
.boxed()
.sorted(Comparator.comparing(scores::get).reversed())
.mapToInt(i -> i)
.toArray();

if (numToInclude == 0) {
return Tuple.tuple(sortedIndices[0], Collections.emptyList());
}

List<String> labels = classificationLabels == null ?
Expand All @@ -50,26 +64,24 @@ public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
for(int i = 0; i < count; i++) {
int idx = sortedIndices[i];
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx)));
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
}

return topClassEntries;
return Tuple.tuple(sortedIndices[0], topClassEntries);
}

public static String classificationLabel(double inferenceValue, @Nullable List<String> classificationLabels) {
assert inferenceValue == Math.rint(inferenceValue);
public static String classificationLabel(Integer inferenceValue, @Nullable List<String> classificationLabels) {
if (classificationLabels == null) {
return String.valueOf(inferenceValue);
}
int label = Double.valueOf(inferenceValue).intValue();
if (label < 0 || label >= classificationLabels.size()) {
if (inferenceValue < 0 || inferenceValue >= classificationLabels.size()) {
throw ExceptionsHelper.serverError(
"model returned classification value of [{}] which is not a valid index in classification labels [{}]",
null,
label,
inferenceValue,
classificationLabels);
}
return classificationLabels.get(label);
return classificationLabels.get(inferenceValue);
}

public static Double toDouble(Object value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,14 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

import java.util.List;
import java.util.Map;

public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {

/**
* @return List of featureNames expected by the model. In the order that they are expected
*/
List<String> getFeatureNames();

/**
* Infer against the provided fields
*
Expand All @@ -36,12 +29,6 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
*/
TargetType targetType();

/**
* @return Ordinal encoded list of classification labels.
*/
@Nullable
List<String> classificationLabels();

/**
* Runs validations against the model.
*
Expand Down
Loading

0 comments on commit 72c2709

Please sign in to comment.