Skip to content
This repository has been archived by the owner on May 3, 2022. It is now read-only.

Commit

Permalink
Closes #321 : add oryx.ml.eval.threshold parameter for rejecting mode…
Browse files Browse the repository at this point in the history
…ls based on evaluation. And some related test refactoring
  • Loading branch information
srowen committed Oct 25, 2016
1 parent 03ebe2e commit 341182b
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.typesafe.config.Config;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.dmg.pmml.PMML;
import org.junit.Test;
import org.slf4j.Logger;
Expand Down Expand Up @@ -171,31 +172,20 @@ public void testRDFSpeedClassification() throws Exception {
Map<String,Integer> countMap = (Map<String,Integer>) fields.get(2);
assertEquals(0, treeID);
assertContains(Arrays.asList("r-", "r+"), nodeID);
int yellowCount = countMap.containsKey(yellow) ? countMap.get(yellow) : 0;
int redCount = countMap.containsKey(red) ? countMap.get(red) : 0;
int yellowCount = countMap.getOrDefault(yellow, 0);
int redCount = countMap.getOrDefault(red, 0);
int count = yellowCount + redCount;
assertGreater(count, 0);
BinomialDistribution dist = new BinomialDistribution(RandomManager.getRandom(), count, 0.9);
IntegerDistribution dist = new BinomialDistribution(RandomManager.getRandom(), count, 0.9);
if ("r+".equals(nodeID)) {
// Should be about 9x more yellow
checkProbability(yellowCount, count, dist);
checkDiscreteProbability(yellowCount, dist);
} else {
// Should be about 9x more red
checkProbability(redCount, count, dist);
checkDiscreteProbability(redCount, dist);
}
}

}

private static void checkProbability(int majorityCount,
int count,
BinomialDistribution dist) {
double expected = 0.9 * count;
double probAsExtreme = majorityCount <= expected ?
dist.cumulativeProbability(majorityCount) :
(1.0 - dist.cumulativeProbability(majorityCount)) + dist.probability(majorityCount);
assertTrue(majorityCount + " should be about " + expected + " (~90% of " + count + ")",
probAsExtreme >= 0.001);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public static Config overlayOn(Map<String,?> overlay, Config underlying) {
* @return value for given key, or {@code null} if none exists
*/
public static String getOptionalString(Config config, String key) {
return config.hasPath(key) ? config.getString(key) : null;
return config.getIsNull(key) ? null : config.getString(key);
}

/**
Expand All @@ -89,7 +89,16 @@ public static String getOptionalString(Config config, String key) {
* @return value for given key, or {@code null} if none exists
*/
public static List<String> getOptionalStringList(Config config, String key) {
return config.hasPath(key) ? config.getStringList(key) : null;
return config.getIsNull(key) ? null : config.getStringList(key);
}

/**
* @param config configuration to query for value
* @param key configuration path key
* @return value for given key, or {@code null} if none exists
*/
public static Double getOptionalDouble(Config config, String key) {
return config.getIsNull(key) ? null : config.getDouble(key);
}

/**
Expand Down
6 changes: 6 additions & 0 deletions framework/oryx-common/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,12 @@ oryx = {
candidates = 1
# Number of models to build in parallel
parallelism = 1
# Minimum evaluation for a model to be accepted as valid, if specified.
# Note that some evaluation metrics, like squared error for k-means clustering,
# are better when smaller, and the framework ranks models based on the negative of
# such values. To specify, for example, a *maximum* squared error of 100 for k-means
# clustering, specify a threshold of -100.
threshold = null
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import org.apache.commons.math3.distribution.IntegerDistribution;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
Expand Down Expand Up @@ -187,4 +188,20 @@ public static void sleepSeconds(int seconds) {
}
}

/**
* Asserts that the probability of sampling a value as or more extreme than the given value,
* from the given discrete distribution, is at least 0.001.
*
* @param value sample value
* @param dist discrete distribution
*/
public static void checkDiscreteProbability(int value, IntegerDistribution dist) {
double probAsExtreme = value <= dist.getNumericalMean() ?
dist.cumulativeProbability(value) :
(1.0 - dist.cumulativeProbability(value - 1));
assertTrue(value + " is not likely (" + probAsExtreme + " ) to differ from expected value " +
dist.getNumericalMean() + " by chance",
probAsExtreme >= 0.001);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
Expand Down Expand Up @@ -50,12 +51,23 @@ public void testSerialize() {

@Test
public void testOptionalString() {
assertNull(ConfigUtils.getOptionalString(ConfigUtils.getDefault(), "nonexistent"));
Config config = ConfigUtils.overlayOn(Collections.singletonMap("nonexistent", "null"),
ConfigUtils.getDefault());
assertNull(ConfigUtils.getOptionalString(config, "nonexistent"));
}

@Test
public void testOptionalStringList() {
assertNull(ConfigUtils.getOptionalStringList(ConfigUtils.getDefault(), "nonexistent"));
Config config = ConfigUtils.overlayOn(Collections.singletonMap("nonexistent", "null"),
ConfigUtils.getDefault());
assertNull(ConfigUtils.getOptionalStringList(config, "nonexistent"));
}

@Test
public void testOptionalDouble() {
Config config = ConfigUtils.overlayOn(Collections.singletonMap("nonexistent", "null"),
ConfigUtils.getDefault());
assertNull(ConfigUtils.getOptionalDouble(config, "nonexistent"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.cloudera.oryx.common.lang.ExecUtils;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.common.random.RandomManager;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.ml.param.HyperParamValues;
import com.cloudera.oryx.ml.param.HyperParams;

Expand All @@ -65,12 +66,14 @@ public abstract class MLUpdate<M> implements BatchLayerUpdate<Object,M,String> {
private final double testFraction;
private final int candidates;
private final int evalParallelism;
private final Double threshold;
private final int maxMessageSize;

protected MLUpdate(Config config) {
this.testFraction = config.getDouble("oryx.ml.eval.test-fraction");
int candidates = config.getInt("oryx.ml.eval.candidates");
this.evalParallelism = config.getInt("oryx.ml.eval.parallelism");
this.threshold = ConfigUtils.getOptionalDouble(config, "oryx.ml.eval.threshold");
this.maxMessageSize = config.getInt("oryx.update-topic.message.max-size");
Preconditions.checkArgument(testFraction >= 0.0 && testFraction <= 1.0);
Preconditions.checkArgument(candidates > 0);
Expand Down Expand Up @@ -284,6 +287,11 @@ private Path findBestCandidatePath(JavaSparkContext sparkContext,
}
} // else can't do anything; no model at all
}
if (threshold != null && bestEval < threshold) {
log.info("Best model at {} had eval {}, but did not exceed threshold {}; discarding model",
bestCandidatePath, bestEval, threshold);
bestCandidatePath = null;
}
return bestCandidatePath;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package com.cloudera.oryx.ml;

import java.util.ArrayList;
import java.util.List;

import com.typesafe.config.Config;
Expand All @@ -35,12 +36,17 @@ public final class MockMLUpdate extends MLUpdate<String> {

private static final Logger log = LoggerFactory.getLogger(MockMLUpdate.class);

private static List<Integer> trainCounts;
private static List<Integer> testCounts;
private static final List<Integer> trainCounts = new ArrayList<>();
private static final List<Integer> testCounts = new ArrayList<>();

static void setCountHolders(List<Integer> trainCounts, List<Integer> testCounts) {
MockMLUpdate.trainCounts = trainCounts;
MockMLUpdate.testCounts = testCounts;
static List<Integer> getResetTrainCounts() {
trainCounts.clear();
return trainCounts;
}

static List<Integer> getResetTestCounts() {
testCounts.clear();
return testCounts;
}

public MockMLUpdate(Config config) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package com.cloudera.oryx.ml;

import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -49,21 +48,19 @@ public final class SimpleMLUpdateIT extends AbstractBatchIT {
public void testMLUpdate() throws Exception {
Path tempDir = getTempDir();
Path dataDir = tempDir.resolve("data");
Path modelDir = tempDir.resolve("model");
Map<String,Object> overlayConfig = new HashMap<>();
overlayConfig.put("oryx.batch.update-class", MockMLUpdate.class.getName());
ConfigUtils.set(overlayConfig, "oryx.batch.storage.data-dir", dataDir);
ConfigUtils.set(overlayConfig, "oryx.batch.storage.model-dir", modelDir);
ConfigUtils.set(overlayConfig, "oryx.batch.storage.model-dir", tempDir.resolve("model"));
overlayConfig.put("oryx.batch.streaming.generation-interval-sec", GEN_INTERVAL_SEC);
overlayConfig.put("oryx.ml.eval.test-fraction", TEST_FRACTION);
overlayConfig.put("oryx.ml.eval.threshold", DATA_TO_WRITE / 2); // Should easily pass threshold
Config config = ConfigUtils.overlayOn(overlayConfig, getConfig());

startMessaging();

List<Integer> trainCounts = new ArrayList<>();
List<Integer> testCounts = new ArrayList<>();

MockMLUpdate.setCountHolders(trainCounts, testCounts);
List<Integer> trainCounts = MockMLUpdate.getResetTrainCounts();
List<Integer> testCounts = MockMLUpdate.getResetTestCounts();

startServerProduceConsumeTopics(config, DATA_TO_WRITE, WRITE_INTERVAL_MSEC);

Expand Down Expand Up @@ -96,15 +93,7 @@ public void testMLUpdate() throws Exception {
int totalNew = testCount + newTrainInGen;

IntegerDistribution dist = new BinomialDistribution(random, totalNew, TEST_FRACTION);
double probability;
if (testCount < dist.getNumericalMean()) {
probability = dist.cumulativeProbability(testCount);
} else {
probability = 1.0 - dist.cumulativeProbability(testCount);
}
log.info("Probability of observing {} as {} sample of {}: {}",
testCount, TEST_FRACTION, totalNew, probability);
assertGreaterOrEqual(probability, 0.001);
checkDiscreteProbability(testCount, dist);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
*
* Cloudera, Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"). You may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the
* License.
*/

package com.cloudera.oryx.ml;

import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;

import com.typesafe.config.Config;
import org.junit.Test;

import com.cloudera.oryx.common.io.IOUtils;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.lambda.batch.AbstractBatchIT;

/**
* Tests {@link MLUpdate} threshold, where model doesn't pass threshold.
*/
public final class ThresholdIT extends AbstractBatchIT {

private static final int DATA_TO_WRITE = 50;
private static final int WRITE_INTERVAL_MSEC = 25;
private static final int GEN_INTERVAL_SEC = 2;

@Test
public void testMLUpdate() throws Exception {
Path tempDir = getTempDir();
Path modelDir = tempDir.resolve("model");
Map<String,Object> overlayConfig = new HashMap<>();
overlayConfig.put("oryx.batch.update-class", MockMLUpdate.class.getName());
ConfigUtils.set(overlayConfig, "oryx.batch.storage.data-dir", tempDir.resolve("data"));
ConfigUtils.set(overlayConfig, "oryx.batch.storage.model-dir", modelDir);
overlayConfig.put("oryx.batch.streaming.generation-interval-sec", GEN_INTERVAL_SEC);
overlayConfig.put("oryx.ml.eval.threshold", DATA_TO_WRITE * 2); // Won't pass
Config config = ConfigUtils.overlayOn(overlayConfig, getConfig());

startMessaging();
startServerProduceConsumeTopics(config, DATA_TO_WRITE, WRITE_INTERVAL_MSEC);
assertTrue(IOUtils.listFiles(modelDir, "*").isEmpty());
}

}

0 comments on commit 341182b

Please sign in to comment.