diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs
index ccb65d43ab..ccc73265ec 100644
--- a/src/Microsoft.ML.Core/Data/ITransformModel.cs
+++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs
@@ -18,13 +18,20 @@ public interface ITransformModel
/// Note that the schema may have columns that aren't needed by this transform model.
/// If an IDataView exists with this schema, then applying this transform model to it
/// shouldn't fail because of column type issues.
- /// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
- /// however that doing so may cause issues for composing transform models. For example,
- /// if transform model A needs column X and model B needs Y, that is NOT produced by A,
- /// then trimming A's input schema would cause composition to fail.
///
+ // REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
+ // however that doing so may cause issues for composing transform models. For example,
+ // if transform model A needs column X and model B needs Y, that is NOT produced by A,
+ // then trimming A's input schema would cause composition to fail.
ISchema InputSchema { get; }
+ ///
+ /// The output schema that this transform model was originally instantiated on. The schema resulting
+ /// from may differ from this, similarly to how
+ /// may differ from the schema of dataviews we apply this transform model to.
+ ///
+ ISchema OutputSchema { get; }
+
///
/// Apply the transform(s) in the model to the given input data.
///
diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
index b840529e77..9edc87df6d 100644
--- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
@@ -39,10 +39,14 @@ public sealed class TransformModel : ITransformModel
/// if transform model A needs column X and model B needs Y, that is NOT produced by A,
/// then trimming A's input schema would cause composition to fail.
///
- public ISchema InputSchema
- {
- get { return _schemaRoot; }
- }
+ public ISchema InputSchema => _schemaRoot;
+
+ ///
+ /// The resulting schema once applied to this model. The might have
+ /// columns that are not needed by this transform and these columns will be seen in the
+ /// produced by this transform.
+ ///
+ public ISchema OutputSchema => _chain.Schema;
///
/// Create a TransformModel containing the transforms from "result" back to "input".
diff --git a/src/Microsoft.ML/PredictionModel.cs b/src/Microsoft.ML/PredictionModel.cs
index d074d64441..6eb1b3c6f5 100644
--- a/src/Microsoft.ML/PredictionModel.cs
+++ b/src/Microsoft.ML/PredictionModel.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
using System.IO;
@@ -29,6 +30,39 @@ internal Runtime.EntryPoints.TransformModel PredictorModel
get { return _predictorModel; }
}
+ ///
+ /// Returns labels that correspond to indices of the score array in the case of
+ /// multi-class classification problem.
+ ///
+ /// Label to score mapping
+ /// Name of the score column
+ ///
+ public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = DefaultColumnNames.Score)
+ {
+ names = null;
+ ISchema schema = _predictorModel.OutputSchema;
+ int colIndex = -1;
+ if (!schema.TryGetColumnIndex(scoreColumnName, out colIndex))
+ return false;
+
+ int expectedLabelCount = schema.GetColumnType(colIndex).ValueCount;
+ if (!schema.HasSlotNames(colIndex, expectedLabelCount))
+ return false;
+
+ VBuffer labels = default;
+ schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colIndex, ref labels);
+
+ if (labels.Length != expectedLabelCount)
+ return false;
+
+ names = new string[expectedLabelCount];
+ int index = 0;
+ foreach(var label in labels.DenseValues())
+ names[index++] = label.ToString();
+
+ return true;
+ }
+
///
/// Read model from file asynchronously.
///
diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs
index ebddc33b03..348a851020 100644
--- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs
@@ -30,6 +30,14 @@ public void TrainAndPredictIrisModelWithStringLabelTest()
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
PredictionModel model = pipeline.Train();
+ string[] scoreLabels;
+ model.TryGetScoreLabelNames(out scoreLabels);
+
+ Assert.NotNull(scoreLabels);
+ Assert.Equal(3, scoreLabels.Length);
+ Assert.Equal("Iris-setosa", scoreLabels[0]);
+ Assert.Equal("Iris-versicolor", scoreLabels[1]);
+ Assert.Equal("Iris-virginica", scoreLabels[2]);
IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel()
{