From 094df6a6761d43867f0e9fc61bbacd6e9398d859 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 24 May 2018 15:35:03 -0700 Subject: [PATCH 1/5] Scores to label mapping for multi-class classification problem. --- src/Microsoft.ML.Core/Data/ITransformModel.cs | 4 +++ .../EntryPoints/TransformModel.cs | 8 ++++++ src/Microsoft.ML/PredictionModel.cs | 27 +++++++++++++++++++ ...PlantClassificationWithStringLabelTests.cs | 5 ++++ 4 files changed, 44 insertions(+) diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs index ccb65d43ab..a225f6ed21 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs @@ -25,6 +25,10 @@ public interface ITransformModel /// ISchema InputSchema { get; } + /// + /// Schema of the transform model. + /// + IDataView Schema { get; } /// /// Apply the transform(s) in the model to the given input data. /// diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index b840529e77..dd16decc92 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -44,6 +44,14 @@ public ISchema InputSchema get { return _schemaRoot; } } + /// + /// Schema of the transform model. + /// + public IDataView Schema + { + get { return _chain; } + } + /// /// Create a TransformModel containing the transforms from "result" back to "input". /// diff --git a/src/Microsoft.ML/PredictionModel.cs b/src/Microsoft.ML/PredictionModel.cs index d074d64441..de7484de49 100644 --- a/src/Microsoft.ML/PredictionModel.cs +++ b/src/Microsoft.ML/PredictionModel.cs @@ -29,6 +29,33 @@ internal Runtime.EntryPoints.TransformModel PredictorModel get { return _predictorModel; } } + /// + /// Returns labels that correspond to indices of the score array in the case of + /// multi-class classification problem. + /// + /// Label to score mapping + /// Name of the score column + /// + public bool TryGetScoreLabelMapping(out string[] mapping, string scoreColumnName = "Score") + { + mapping = null; + IDataView idv = _predictorModel.Schema; + int colIndex = -1; + if (!idv.Schema.TryGetColumnIndex(scoreColumnName, out colIndex)) + return false; + + VBuffer labels = default(VBuffer); + idv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colIndex, ref labels); + + Contracts.Assert(labels.IsDense); + + mapping = new string[labels.Count]; + for (int index = 0; index < labels.Count; index++) + mapping[index] = labels.Values[index].ToString(); + + return true; + } + /// /// Read model from file asynchronously. /// diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index ebddc33b03..ee3301ff9f 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -30,6 +30,11 @@ public void TrainAndPredictIrisModelWithStringLabelTest() pipeline.Add(new StochasticDualCoordinateAscentClassifier()); PredictionModel model = pipeline.Train(); + string[] scoreLabels; + model.TryGetScoreLabelMapping(out scoreLabels); + + Assert.NotNull(scoreLabels); + Assert.Equal(3, scoreLabels.Length); IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel() { From f067e0f6503a42c0eb72bf42b2e3499633180a16 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 24 May 2018 15:42:24 -0700 Subject: [PATCH 2/5] update test. --- .../Scenarios/IrisPlantClassificationWithStringLabelTests.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index ee3301ff9f..cfebbbeb74 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -35,6 +35,9 @@ public void TrainAndPredictIrisModelWithStringLabelTest() Assert.NotNull(scoreLabels); Assert.Equal(3, scoreLabels.Length); + Assert.True(scoreLabels[0] == "Iris-setosa"); + Assert.True(scoreLabels[1] == "Iris-versicolor"); + Assert.True(scoreLabels[2] == "Iris-virginica"); IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel() { From 431117d919965703a5bdea0c5137e7b9c7e77e6f Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 25 May 2018 14:55:27 -0700 Subject: [PATCH 3/5] PR feedback. --- src/Microsoft.ML.Core/Data/ITransformModel.cs | 7 +++++-- .../EntryPoints/TransformModel.cs | 14 +++++-------- src/Microsoft.ML/PredictionModel.cs | 21 ++++++++++++------- ...PlantClassificationWithStringLabelTests.cs | 6 +++--- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs index a225f6ed21..2fffc1930a 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs @@ -26,9 +26,12 @@ public interface ITransformModel ISchema InputSchema { get; } /// - /// Schema of the transform model. + /// The resulting schema once applied to this model. The might have + /// columns that are not needed by this transform and these columns will be seen in the + /// produced by this transform. /// - IDataView Schema { get; } + ISchema OutputSchema { get; } + /// /// Apply the transform(s) in the model to the given input data. /// diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index dd16decc92..9edc87df6d 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -39,18 +39,14 @@ public sealed class TransformModel : ITransformModel /// if transform model A needs column X and model B needs Y, that is NOT produced by A, /// then trimming A's input schema would cause composition to fail. /// - public ISchema InputSchema - { - get { return _schemaRoot; } - } + public ISchema InputSchema => _schemaRoot; /// - /// Schema of the transform model. + /// The resulting schema once applied to this model. The might have + /// columns that are not needed by this transform and these columns will be seen in the + /// produced by this transform. /// - public IDataView Schema - { - get { return _chain; } - } + public ISchema OutputSchema => _chain.Schema; /// /// Create a TransformModel containing the transforms from "result" back to "input". diff --git a/src/Microsoft.ML/PredictionModel.cs b/src/Microsoft.ML/PredictionModel.cs index de7484de49..d4e7a62d21 100644 --- a/src/Microsoft.ML/PredictionModel.cs +++ b/src/Microsoft.ML/PredictionModel.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.IO; @@ -36,20 +37,26 @@ internal Runtime.EntryPoints.TransformModel PredictorModel /// Label to score mapping /// Name of the score column /// - public bool TryGetScoreLabelMapping(out string[] mapping, string scoreColumnName = "Score") + public bool TryGetScoreLabelMapping(out string[] mapping, string scoreColumnName = DefaultColumnNames.Score) { mapping = null; - IDataView idv = _predictorModel.Schema; + ISchema schema = _predictorModel.OutputSchema; int colIndex = -1; - if (!idv.Schema.TryGetColumnIndex(scoreColumnName, out colIndex)) + if (!schema.TryGetColumnIndex(scoreColumnName, out colIndex)) + return false; + + int expectedLabelCount = schema.GetColumnType(colIndex).AsVector.ValueCount; + if (!schema.HasSlotNames(colIndex, expectedLabelCount)) return false; - VBuffer labels = default(VBuffer); - idv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colIndex, ref labels); + VBuffer labels = default; + schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colIndex, ref labels); + VBufferUtils.Densify(ref labels); - Contracts.Assert(labels.IsDense); + if (labels.Length != expectedLabelCount) + return false; - mapping = new string[labels.Count]; + mapping = new string[labels.Length]; for (int index = 0; index < labels.Count; index++) mapping[index] = labels.Values[index].ToString(); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index cfebbbeb74..3643a920dd 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -35,9 +35,9 @@ public void TrainAndPredictIrisModelWithStringLabelTest() Assert.NotNull(scoreLabels); Assert.Equal(3, scoreLabels.Length); - Assert.True(scoreLabels[0] == "Iris-setosa"); - Assert.True(scoreLabels[1] == "Iris-versicolor"); - Assert.True(scoreLabels[2] == "Iris-virginica"); + Assert.Equal("Iris-setosa", scoreLabels[0]); + Assert.Equal("Iris-versicolor", scoreLabels[1]); + Assert.Equal("Iris-virginica", scoreLabels[2]); IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel() { From ef46c6cd1fdef24e33f2d14614b79cf303343744 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 25 May 2018 15:25:09 -0700 Subject: [PATCH 4/5] PR feedback. --- src/Microsoft.ML.Core/Data/ITransformModel.cs | 14 +++++++------- src/Microsoft.ML/PredictionModel.cs | 17 +++++++++-------- ...isPlantClassificationWithStringLabelTests.cs | 2 +- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs index 2fffc1930a..ccc73265ec 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs @@ -18,17 +18,17 @@ public interface ITransformModel /// Note that the schema may have columns that aren't needed by this transform model. /// If an IDataView exists with this schema, then applying this transform model to it /// shouldn't fail because of column type issues. - /// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note - /// however that doing so may cause issues for composing transform models. For example, - /// if transform model A needs column X and model B needs Y, that is NOT produced by A, - /// then trimming A's input schema would cause composition to fail. /// + // REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note + // however that doing so may cause issues for composing transform models. For example, + // if transform model A needs column X and model B needs Y, that is NOT produced by A, + // then trimming A's input schema would cause composition to fail. ISchema InputSchema { get; } /// - /// The resulting schema once applied to this model. The might have - /// columns that are not needed by this transform and these columns will be seen in the - /// produced by this transform. + /// The output schema that this transform model was originally instantiated on. The schema resulting + /// from may differ from this, similarly to how + /// may differ from the schema of dataviews we apply this transform model to. /// ISchema OutputSchema { get; } diff --git a/src/Microsoft.ML/PredictionModel.cs b/src/Microsoft.ML/PredictionModel.cs index d4e7a62d21..5a884eb558 100644 --- a/src/Microsoft.ML/PredictionModel.cs +++ b/src/Microsoft.ML/PredictionModel.cs @@ -34,31 +34,32 @@ internal Runtime.EntryPoints.TransformModel PredictorModel /// Returns labels that correspond to indices of the score array in the case of /// multi-class classification problem. /// - /// Label to score mapping + /// Label to score mapping /// Name of the score column /// - public bool TryGetScoreLabelMapping(out string[] mapping, string scoreColumnName = DefaultColumnNames.Score) + public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = DefaultColumnNames.Score) { - mapping = null; + names = null; ISchema schema = _predictorModel.OutputSchema; int colIndex = -1; if (!schema.TryGetColumnIndex(scoreColumnName, out colIndex)) return false; - int expectedLabelCount = schema.GetColumnType(colIndex).AsVector.ValueCount; + int expectedLabelCount = schema.GetColumnType(colIndex).ValueCount; if (!schema.HasSlotNames(colIndex, expectedLabelCount)) return false; VBuffer labels = default; schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colIndex, ref labels); - VBufferUtils.Densify(ref labels); if (labels.Length != expectedLabelCount) return false; - mapping = new string[labels.Length]; - for (int index = 0; index < labels.Count; index++) - mapping[index] = labels.Values[index].ToString(); + names = new string[expectedLabelCount]; + int index = 0; + foreach(var label in labels.DenseValues()) + names[index++] = label.ToString(); + return true; } diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index 3643a920dd..348a851020 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -31,7 +31,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest() PredictionModel model = pipeline.Train(); string[] scoreLabels; - model.TryGetScoreLabelMapping(out scoreLabels); + model.TryGetScoreLabelNames(out scoreLabels); Assert.NotNull(scoreLabels); Assert.Equal(3, scoreLabels.Length); From 6eae0b67e5df4c40033c4a988bf645be7530ae9f Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 25 May 2018 15:49:19 -0700 Subject: [PATCH 5/5] cleanup. --- src/Microsoft.ML/PredictionModel.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Microsoft.ML/PredictionModel.cs b/src/Microsoft.ML/PredictionModel.cs index 5a884eb558..6eb1b3c6f5 100644 --- a/src/Microsoft.ML/PredictionModel.cs +++ b/src/Microsoft.ML/PredictionModel.cs @@ -60,7 +60,6 @@ public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = D foreach(var label in labels.DenseValues()) names[index++] = label.ToString(); - return true; }