Skip to content

Commit

Permalink
OVA should respect normalization in underlying learner (#310)
Browse files Browse the repository at this point in the history
* Respect normalization in OVA.

* some cleanup

* fix copypaste issues
  • Loading branch information
Ivanidzo4ka authored Jun 5, 2018
1 parent ab4108d commit 5730685
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,21 @@ public static ModelOperations.PredictorModelOutput CombineOvaModels(IHostEnviron
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
host.CheckNonEmpty(input.ModelArray, nameof(input.ModelArray));

// Something tells me we should put normalization as part of macro expansion, but since i get
// subgraph instead of learner it's a bit tricky to get learner and decide should we add
// normalization node or not, plus everywhere in code we leave that reposnsibility to TransformModel.
var normalizedView = input.ModelArray[0].TransformModel.Apply(host, input.TrainingData);
using (var ch = host.Start("CombineOvaModels"))
{
ISchema schema = input.TrainingData.Schema;
ISchema schema = normalizedView.Schema;
var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.LabelColumn),
input.LabelColumn,
DefaultColumnNames.Label);
var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.FeatureColumn),
input.FeatureColumn, DefaultColumnNames.Features);
var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.WeightColumn),
input.WeightColumn, DefaultColumnNames.Weight);
var data = TrainUtils.CreateExamples(input.TrainingData, label, feature, null, weight);
var data = TrainUtils.CreateExamples(normalizedView, label, feature, null, weight);

return new ModelOperations.PredictorModelOutput
{
Expand Down
59 changes: 59 additions & 0 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -739,5 +739,64 @@ public void TestCrossValidationMacroWithNonDefaultNames()
}
}
}

[Fact]
public void TestOvaMacro()
{
var dataPath = GetDataPath(@"iris.txt");
using (var env = new TlcEnvironment(42))
{
// Specify subgraph for OVA
var subGraph = env.CreateExperiment();
var learnerInput = new Trainers.StochasticDualCoordinateAscentBinaryClassifier { NumThreads = 1 };
var learnerOutput = subGraph.Add(learnerInput);
// Create pipeline with OVA and multiclass scoring.
var experiment = env.CreateExperiment();
var importInput = new ML.Data.TextLoader(dataPath);
importInput.Arguments.Column = new TextLoaderColumn[]
{
new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } }
};
var importOutput = experiment.Add(importInput);
var oneVersusAll = new Models.OneVersusAll
{
TrainingData = importOutput.Data,
Nodes = subGraph,
UseProbabilities = true,
};
var ovaOutput = experiment.Add(oneVersusAll);
var scoreInput = new ML.Transforms.DatasetScorer
{
Data = importOutput.Data,
PredictorModel = ovaOutput.PredictorModel
};
var scoreOutput = experiment.Add(scoreInput);
var evalInput = new ML.Models.ClassificationEvaluator
{
Data = scoreOutput.ScoredData
};
var evalOutput = experiment.Add(evalInput);
experiment.Compile();
experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
experiment.Run();

var data = experiment.GetOutput(evalOutput.OverallMetrics);
var schema = data.Schema;
var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
Assert.True(b);
using (var cursor = data.GetRowCursor(col => col == accCol))
{
var getter = cursor.GetGetter<double>(accCol);
b = cursor.MoveNext();
Assert.True(b);
double acc = 0;
getter(ref acc);
Assert.Equal(0.96, acc, 2);
b = cursor.MoveNext();
Assert.False(b);
}
}
}
}
}

0 comments on commit 5730685

Please sign in to comment.