diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs index ccb65d43ab..ccc73265ec 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs @@ -18,13 +18,20 @@ public interface ITransformModel /// Note that the schema may have columns that aren't needed by this transform model. /// If an IDataView exists with this schema, then applying this transform model to it /// shouldn't fail because of column type issues. - /// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note - /// however that doing so may cause issues for composing transform models. For example, - /// if transform model A needs column X and model B needs Y, that is NOT produced by A, - /// then trimming A's input schema would cause composition to fail. /// + // REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note + // however that doing so may cause issues for composing transform models. For example, + // if transform model A needs column X and model B needs Y, that is NOT produced by A, + // then trimming A's input schema would cause composition to fail. ISchema InputSchema { get; } + /// + /// The output schema that this transform model was originally instantiated on. The schema resulting + /// from may differ from this, similarly to how + /// may differ from the schema of dataviews we apply this transform model to. + /// + ISchema OutputSchema { get; } + /// /// Apply the transform(s) in the model to the given input data. /// diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index b840529e77..9edc87df6d 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -39,10 +39,14 @@ public sealed class TransformModel : ITransformModel /// if transform model A needs column X and model B needs Y, that is NOT produced by A, /// then trimming A's input schema would cause composition to fail. /// - public ISchema InputSchema - { - get { return _schemaRoot; } - } + public ISchema InputSchema => _schemaRoot; + + /// + /// The resulting schema once applied to this model. The might have + /// columns that are not needed by this transform and these columns will be seen in the + /// produced by this transform. + /// + public ISchema OutputSchema => _chain.Schema; /// /// Create a TransformModel containing the transforms from "result" back to "input". diff --git a/src/Microsoft.ML/PredictionModel.cs b/src/Microsoft.ML/PredictionModel.cs index d074d64441..6eb1b3c6f5 100644 --- a/src/Microsoft.ML/PredictionModel.cs +++ b/src/Microsoft.ML/PredictionModel.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.IO; @@ -29,6 +30,39 @@ internal Runtime.EntryPoints.TransformModel PredictorModel get { return _predictorModel; } } + /// + /// Returns labels that correspond to indices of the score array in the case of + /// multi-class classification problem. + /// + /// Label to score mapping + /// Name of the score column + /// + public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = DefaultColumnNames.Score) + { + names = null; + ISchema schema = _predictorModel.OutputSchema; + int colIndex = -1; + if (!schema.TryGetColumnIndex(scoreColumnName, out colIndex)) + return false; + + int expectedLabelCount = schema.GetColumnType(colIndex).ValueCount; + if (!schema.HasSlotNames(colIndex, expectedLabelCount)) + return false; + + VBuffer labels = default; + schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colIndex, ref labels); + + if (labels.Length != expectedLabelCount) + return false; + + names = new string[expectedLabelCount]; + int index = 0; + foreach(var label in labels.DenseValues()) + names[index++] = label.ToString(); + + return true; + } + /// /// Read model from file asynchronously. /// diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index ebddc33b03..348a851020 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -30,6 +30,14 @@ public void TrainAndPredictIrisModelWithStringLabelTest() pipeline.Add(new StochasticDualCoordinateAscentClassifier()); PredictionModel model = pipeline.Train(); + string[] scoreLabels; + model.TryGetScoreLabelNames(out scoreLabels); + + Assert.NotNull(scoreLabels); + Assert.Equal(3, scoreLabels.Length); + Assert.Equal("Iris-setosa", scoreLabels[0]); + Assert.Equal("Iris-versicolor", scoreLabels[1]); + Assert.Equal("Iris-virginica", scoreLabels[2]); IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel() {