Skip to content

Commit

Permalink
Add cross-validation (CV), and auto-CV for small datasets; push commo…
Browse files Browse the repository at this point in the history
…n API experiment methods into base class (dotnet#287)
  • Loading branch information
daholste authored Apr 3, 2019
1 parent f9d547b commit 43fe8b8
Show file tree
Hide file tree
Showing 57 changed files with 1,413 additions and 763 deletions.
78 changes: 18 additions & 60 deletions src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ public sealed class BinaryExperimentSettings : ExperimentSettings
public BinaryClassificationMetric OptimizingMetric { get; set; } = BinaryClassificationMetric.Accuracy;
public ICollection<BinaryClassificationTrainer> Trainers { get; } =
Enum.GetValues(typeof(BinaryClassificationTrainer)).OfType<BinaryClassificationTrainer>().ToList();
public IProgress<RunResult<BinaryClassificationMetrics>> ProgressHandler { get; set; }
}

public enum BinaryClassificationMetric
Expand Down Expand Up @@ -42,74 +41,33 @@ public enum BinaryClassificationTrainer
SymbolicSgdLogisticRegression,
}

public sealed class BinaryClassificationExperiment
public sealed class BinaryClassificationExperiment : ExperimentBase<BinaryClassificationMetrics>
{
private readonly MLContext _context;
private readonly BinaryExperimentSettings _settings;

internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSettings settings)
: base(context,
new BinaryMetricsAgent(context, settings.OptimizingMetric),
new OptimizingMetricInfo(settings.OptimizingMetric),
settings,
TaskKind.BinaryClassification,
TrainerExtensionUtil.GetTrainerNames(settings.Trainers))
{
_context = context;
_settings = settings;
}

public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null)
{
var columnInformation = new ColumnInformation()
{
LabelColumn = labelColumn,
SamplingKeyColumn = samplingKeyColumn
};
return Execute(_context, trainData, columnInformation, null, preFeaturizers);
}

public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizers = null)
{
return Execute(_context, trainData, columnInformation, null, preFeaturizers);
}

public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, IDataView validationData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizers = null)
{
var columnInformation = new ColumnInformation() { LabelColumn = labelColumn };
return Execute(_context, trainData, columnInformation, validationData, preFeaturizers);
}

public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizers = null)
{
return Execute(_context, trainData, columnInformation, validationData, preFeaturizers);
}

internal IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null)
{
throw new NotImplementedException();
}

internal IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(MLContext context,
IDataView trainData,
ColumnInformation columnInfo,
IDataView validationData = null,
IEstimator<ITransformer> preFeaturizers = null)
{
columnInfo = columnInfo ?? new ColumnInformation();
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);

// run autofit & get all pipelines run in that process
var experiment = new Experiment<BinaryClassificationMetrics>(context, TaskKind.BinaryClassification, trainData, columnInfo,
validationData, preFeaturizers, new OptimizingMetricInfo(_settings.OptimizingMetric), _settings.ProgressHandler,
_settings, new BinaryMetricsAgent(_settings.OptimizingMetric),
TrainerExtensionUtil.GetTrainerNames(_settings.Trainers));

return experiment.Execute();
}
}

public static class BinaryExperimentResultExtensions
{
public static RunResult<BinaryClassificationMetrics> Best(this IEnumerable<RunResult<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
public static RunDetails<BinaryClassificationMetrics> Best(this IEnumerable<RunDetails<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
{
var metricsAgent = new BinaryMetricsAgent(null, metric);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}

public static CrossValidationRunDetails<BinaryClassificationMetrics> Best(this IEnumerable<CrossValidationRunDetails<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
{
var metricsAgent = new BinaryMetricsAgent(metric);
return RunResultUtil.GetBestRunResult(results, metricsAgent);
var metricsAgent = new BinaryMetricsAgent(null, metric);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
}
}
162 changes: 162 additions & 0 deletions src/Microsoft.ML.Auto/API/ExperimentBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;

namespace Microsoft.ML.Auto
{
public abstract class ExperimentBase<TMetrics> where TMetrics : class
{
protected readonly MLContext Context;

private readonly IMetricsAgent<TMetrics> _metricsAgent;
private readonly OptimizingMetricInfo _optimizingMetricInfo;
private readonly ExperimentSettings _settings;
private readonly TaskKind _task;
private readonly IEnumerable<TrainerName> _trainerWhitelist;

internal ExperimentBase(MLContext context,
IMetricsAgent<TMetrics> metricsAgent,
OptimizingMetricInfo optimizingMetricInfo,
ExperimentSettings settings,
TaskKind task,
IEnumerable<TrainerName> trainerWhitelist)
{
Context = context;
_metricsAgent = metricsAgent;
_optimizingMetricInfo = optimizingMetricInfo;
_settings = settings;
_task = task;
_trainerWhitelist = trainerWhitelist;
}

public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
{
var columnInformation = new ColumnInformation()
{
LabelColumn = labelColumn,
SamplingKeyColumn = samplingKeyColumn
};
return Execute(trainData, columnInformation, preFeaturizers, progressHandler);
}

public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, ColumnInformation columnInformation,
IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
{
// Cross val threshold for # of dataset rows --
// If dataset has < threshold # of rows, use cross val.
// Else, use run experiment using train-validate split.
const int crossValRowCountThreshold = 15000;

var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold);

if (rowCount < crossValRowCountThreshold)
{
const int numCrossValFolds = 10;
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, columnInformation?.SamplingKeyColumn);
return ExecuteCrossValSummary(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
}
else
{
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumn);
return ExecuteTrainValidate(splitResult.trainData, columnInformation, splitResult.validationData, preFeaturizer, progressHandler);
}
}

public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, IDataView validationData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
{
var columnInformation = new ColumnInformation() { LabelColumn = labelColumn };
return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler);
}

public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
{
if (validationData == null)
{
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumn);
trainData = splitResult.trainData;
validationData = splitResult.validationData;
}
return ExecuteTrainValidate(trainData, columnInformation, validationData, preFeaturizer, progressHandler);
}

public IEnumerable<CrossValidationRunDetails<TMetrics>> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null, IProgress<CrossValidationRunDetails<TMetrics>> progressHandler = null)
{
UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds);
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumn);
return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizers, progressHandler);
}

public IEnumerable<CrossValidationRunDetails<TMetrics>> Execute(IDataView trainData,
uint numberOfCVFolds, string labelColumn = DefaultColumnNames.Label,
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null,
Progress<CrossValidationRunDetails<TMetrics>> progressHandler = null)
{
var columnInformation = new ColumnInformation()
{
LabelColumn = labelColumn,
SamplingKeyColumn = samplingKeyColumn
};
return Execute(trainData, numberOfCVFolds, columnInformation, preFeaturizers, progressHandler);
}

private IEnumerable<RunDetails<TMetrics>> ExecuteTrainValidate(IDataView trainData,
ColumnInformation columnInfo,
IDataView validationData,
IEstimator<ITransformer> preFeaturizer,
IProgress<RunDetails<TMetrics>> progressHandler)
{
columnInfo = columnInfo ?? new ColumnInformation();
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.LabelColumn, _metricsAgent,
preFeaturizer, _settings.DebugLogger);
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainData, columnInfo);
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
}

private IEnumerable<CrossValidationRunDetails<TMetrics>> ExecuteCrossVal(IDataView[] trainDatasets,
ColumnInformation columnInfo,
IDataView[] validationDatasets,
IEstimator<ITransformer> preFeaturizer,
IProgress<CrossValidationRunDetails<TMetrics>> progressHandler)
{
columnInfo = columnInfo ?? new ColumnInformation();
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
var runner = new CrossValRunner<TMetrics>(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
columnInfo.LabelColumn, _settings.DebugLogger);
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
}

private IEnumerable<RunDetails<TMetrics>> ExecuteCrossValSummary(IDataView[] trainDatasets,
ColumnInformation columnInfo,
IDataView[] validationDatasets,
IEstimator<ITransformer> preFeaturizer,
IProgress<RunDetails<TMetrics>> progressHandler)
{
columnInfo = columnInfo ?? new ColumnInformation();
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
var runner = new CrossValSummaryRunner<TMetrics>(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
columnInfo.LabelColumn, _optimizingMetricInfo, _settings.DebugLogger);
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
}

private IEnumerable<TRunDetails> Execute<TRunDetails>(ColumnInformation columnInfo,
DatasetColumnInfo[] columns,
IEstimator<ITransformer> preFeaturizer,
IProgress<TRunDetails> progressHandler,
IRunner<TRunDetails> runner)
where TRunDetails : RunDetails
{
// Execute experiment & get all pipelines run
var experiment = new Experiment<TRunDetails, TMetrics>(Context, _task, _optimizingMetricInfo, progressHandler,
_settings, _metricsAgent, _trainerWhitelist, columns, runner);

return experiment.Execute();
}
}
}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Auto/API/ExperimentSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ public class ExperimentSettings
/// (Please note: for an experiment with high runtime operating on a large dataset, opting to keep models in
/// memory could cause a system to run out of memory.)
/// </summary>
public DirectoryInfo ModelDirectory { get; set; } = new DirectoryInfo(Path.Combine(Path.GetTempPath(), "Microsoft.ML.Auto"));
public DirectoryInfo CacheDirectory { get; set; } = new DirectoryInfo(Path.Combine(Path.GetTempPath(), "Microsoft.ML.Auto"));

/// <summary>
/// This setting controls whether or not an AutoML experiment will make use of ML.NET-provided caching.
/// If set to true, caching will be forced on for all pipelines. If set to false, caching will be forced off.
/// If set to null (default value), AutoML will decide whether to enable caching for each model.
/// </summary>
public bool? EnableCaching = null;
public bool? CacheBeforeTrainer = null;

internal int MaxModels = int.MaxValue;
internal IDebugLogger DebugLogger;
Expand Down
Loading

0 comments on commit 43fe8b8

Please sign in to comment.