Skip to content

Check for IEstimator/ITransformer schema consistency, fix bugs uncovered #3408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 18, 2019
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/AnnotationUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
{
var cols = new List<SchemaShape.Column>();
if (labelColumn != null && labelColumn.Value.IsKey)
if (labelColumn.HasValue && labelColumn.Value.IsKey)
{
if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
Expand Down
55 changes: 51 additions & 4 deletions src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
Expand Down Expand Up @@ -76,6 +77,8 @@ private protected CalibratorEstimatorBase(IHostEnvironment env,
/// <summary>
/// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator.
/// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added.
/// The same annotation data that would be produced by <see cref="AnnotationUtils.GetTrainerOutputAnnotation(bool)"/> is marked as
/// being present on the output, if it is present on the input score column.
/// </summary>
/// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param>
SchemaShape IEstimator<CalibratorTransformer<TICalibrator>>.GetOutputSchema(SchemaShape inputSchema)
Expand All @@ -96,13 +99,32 @@ SchemaShape IEstimator<CalibratorTransformer<TICalibrator>>.GetOutputSchema(Sche
checkColumnValid(WeightColumn, "weight");
checkColumnValid(LabelColumn, "label");

bool success = inputSchema.TryFindColumn(ScoreColumn.Name, out var inputScoreCol);
Host.Assert(success);
const SchemaShape.Column.VectorKind scalar = SchemaShape.Column.VectorKind.Scalar;

var annotations = new List<SchemaShape.Column>();
annotations.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized,
SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
// We only propagate this training column metadata if it looks like it's all there, and all correct.
if (inputScoreCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.ScoreColumnSetId, out var setIdCol) &&
setIdCol.Kind == scalar && setIdCol.IsKey && setIdCol.ItemType == NumberDataViewType.UInt32 &&
inputScoreCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.ScoreColumnKind, out var kindCol) &&
kindCol.Kind == scalar && kindCol.ItemType is TextDataViewType &&
inputScoreCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.ScoreValueKind, out var valueKindCol) &&
valueKindCol.Kind == scalar && valueKindCol.ItemType is TextDataViewType)
{
annotations.Add(setIdCol);
annotations.Add(kindCol);
annotations.Add(valueKindCol);
}

// Create the new Probability column.
var outColumns = inputSchema.ToDictionary(x => x.Name);
outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability,
SchemaShape.Column.VectorKind.Scalar,
NumberDataViewType.Single,
false,
new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation(true)));
false, new SchemaShape(annotations));

return new SchemaShape(outColumns.Values);
}
Expand Down Expand Up @@ -182,7 +204,7 @@ private protected override void SaveModel(ModelSaveContext ctx)

// *** Binary format ***
// model: _calibrator
ctx.SaveModel(_calibrator, @"Calibrator");
ctx.SaveModel(_calibrator, "Calibrator");
}

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper<TICalibrator>(this, _calibrator, schema);
Expand Down Expand Up @@ -223,9 +245,34 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var builder = new DataViewSchema.Annotations.Builder();
var annotation = InputSchema[_scoreColIndex].Annotations;
var schema = annotation.Schema;

// We only propagate this training column metadata if it looks like it's all there, and all correct.
if (schema.GetColumnOrNull(AnnotationUtils.Kinds.ScoreColumnSetId) is DataViewSchema.Column setIdCol &&
setIdCol.Type is KeyDataViewType setIdType && setIdType.RawType == typeof(uint) &&
schema.GetColumnOrNull(AnnotationUtils.Kinds.ScoreColumnKind) is DataViewSchema.Column kindCol &&
kindCol.Type is TextDataViewType &&
schema.GetColumnOrNull(AnnotationUtils.Kinds.ScoreValueKind) is DataViewSchema.Column valueKindCol &&
valueKindCol.Type is TextDataViewType)
{
builder.Add(setIdCol.Name, setIdType, annotation.GetGetter<uint>(setIdCol));
// Now, this next one I'm a little less sure about. It is entirely reasonable for someone to, say,
// try to calibrate the result of a regression or ranker training, or something else. But should we
// just pass through this class just like that? Having throught through the alternatives I view this
// as the least harmful thing we could be doing, but it is something to consider I may be wrong
// about if it proves that it ever causes problems to, say, have something identified as a probability
// column but be marked as being a regression task, or what have you.
builder.Add(kindCol.Name, kindCol.Type, annotation.GetGetter<ReadOnlyMemory<char>>(kindCol));
builder.Add(valueKindCol.Name, valueKindCol.Type, annotation.GetGetter<ReadOnlyMemory<char>>(valueKindCol));
}
// Probabilities are always considered normalized.
builder.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, (ref bool value) => value = true);

return new[]
{
new DataViewSchema.DetachedColumn(DefaultColumnNames.Probability, NumberDataViewType.Single, null)
new DataViewSchema.DetachedColumn(DefaultColumnNames.Probability, NumberDataViewType.Single, builder.ToAnnotations())
};
}

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private static bool[] GetActive(BindingsBase bindings,
Contracts.Assert(active.Length == bindings.ColumnCount);

var activeInput = bindings.GetActiveInput(columns);
Contracts.Assert(activeInput.Count() == bindings.Input.Count);
Contracts.Assert(activeInput.Length == bindings.Input.Count);

// Get a predicate that determines which Mapper outputs are active.
var predicateMapper = bindings.GetActiveMapperColumns(active);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/Hashing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
metadata.Add(slotMeta);
if (colInfo.MaximumNumberOfInverts != 0)
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, col.ItemType is VectorDataViewType ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, new SchemaShape(metadata));
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, col.Kind, NumberDataViewType.UInt32, true, new SchemaShape(metadata));
}
return new SchemaShape(result.Values);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/KeyToValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
throw Host.ExceptParam(nameof(inputSchema), $"Input column '{colInfo.inputColumnName}' doesn't contain key values metadata");

SchemaShape metadata = null;
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotCol))
if (col.HasSlotNames() && col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotCol))
metadata = new SchemaShape(new[] { slotCol });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why this is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it was just something I noticed, that we were propagating that we were going to do this even if it wasn't the right type, which actually led to a test failure. So I just changed the condition to be a bit tighter is all.


result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName, col.Kind, keyMetaCol.ItemType, keyMetaCol.IsKey, metadata);
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ private void AddMetadata(int iinfo, DataViewSchema.Annotations.Builder builder)
typeNames = null;
}

if (_parent._columns[iinfo].OutputCountVector || srcValueCount == 1)
if (_parent._columns[iinfo].OutputCountVector || srcType is PrimitiveDataViewType)
{
if (typeNames != null)
{
Expand Down Expand Up @@ -336,7 +336,7 @@ private void AddMetadata(int iinfo, DataViewSchema.Annotations.Builder builder)
builder.Add(AnnotationUtils.Kinds.CategoricalSlotRanges, AnnotationUtils.GetCategoricalType(srcValueCount), getter);
}

if (!_parent._columns[iinfo].OutputCountVector || srcValueCount == 1)
if (!_parent._columns[iinfo].OutputCountVector || srcType is PrimitiveDataViewType)
{
ValueGetter<bool> getter = (ref bool dst) =>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,39 +160,25 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)

private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
SchemaShape.Column? labelCol = null;
var predictedLabelAnnotationCols = AnnotationUtils.GetTrainerOutputAnnotation();

if (LabelColumn.IsValid)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var inputLabelCol);
Contracts.Assert(success);

var metadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
.Concat(MetadataForScoreColumn()));
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, metadata)
};
labelCol = inputLabelCol;
predictedLabelAnnotationCols = predictedLabelAnnotationCols.Concat(
inputLabelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues));
}
else
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(MetadataForScoreColumn())),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, new SchemaShape(MetadataForScoreColumn()))
};
}

/// <summary>
/// Normal metadata that we produce for score columns.
/// </summary>
private static IEnumerable<SchemaShape.Column> MetadataForScoreColumn()
{
var cols = new List<SchemaShape.Column>();
cols.Add(new SchemaShape.Column(AnnotationUtils.Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true));
cols.Add(new SchemaShape.Column(AnnotationUtils.Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));
cols.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
cols.Add(new SchemaShape.Column(AnnotationUtils.Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false));

return cols;
var scoreAnnotationCols = AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol);
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single,
false, new SchemaShape(scoreAnnotationCols)),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32,
true, new SchemaShape(predictedLabelAnnotationCols))
};
}

IPredictor ITrainer.Train(TrainContext context) => ((ITrainer<IPredictor>)this).Train(context);
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Transforms/CountFeatureSelection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
metadata.Add(slotMeta);
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.CategoricalSlotRanges, out var categoricalSlotMeta))
metadata.Add(categoricalSlotMeta);
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
if (col.IsNormalized() && col.Annotations.TryFindColumn(AnnotationUtils.Kinds.IsNormalized, out var isNormalizedAnnotation))
metadata.Add(isNormalizedAnnotation);
result[colPair.Name] = new SchemaShape.Column(colPair.Name, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray()));
}
return new SchemaShape(result.Values);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Transforms/KeyToVectorMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ private void AddMetadata(int iinfo, DataViewSchema.Annotations.Builder builder)
typeNames = null;
}

if (_infos[iinfo].TypeSrc.GetValueCount() == 1)
if (_infos[iinfo].TypeSrc is PrimitiveDataViewType)
{
if (typeNames != null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
metadata.Add(slotMeta);
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.CategoricalSlotRanges, out var categoricalSlotMeta))
metadata.Add(categoricalSlotMeta);
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
if (col.IsNormalized() && col.Annotations.TryFindColumn(AnnotationUtils.Kinds.IsNormalized, out var isNormalizedAnnotation))
metadata.Add(isNormalizedAnnotation);
result[colPair.outputColumnName] = new SchemaShape.Column(colPair.outputColumnName, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray()));
}
return new SchemaShape(result.Values);
Expand All @@ -198,7 +199,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
var host = env.Register(RegistrationName);
host.CheckValue(options, nameof(options));
host.CheckValue(input, nameof(input));
host.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns));
host.CheckNonEmpty(options.Columns, nameof(options.Columns));
host.CheckUserArg(options.SlotsInOutput > 0, nameof(options.SlotsInOutput));
host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn));
host.Check(options.NumBins > 1, "numBins must be greater than 1.");
Expand Down
4 changes: 3 additions & 1 deletion src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, ExpectedColumnType, col.GetTypeString());
}
var metadata = new List<SchemaShape.Column>();
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
// If this is non-zero, we will be doing invert hashing at some level and so have slot names.
if (colInfo.MaximumNumberOfInverts != 0)
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(metadata));
}
return new SchemaShape(result.Values);
Expand Down
14 changes: 9 additions & 5 deletions test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,18 @@ protected void TestEstimatorCore(IEstimator<ITransformer> estimator,
private void CheckSameSchemaShape(SchemaShape promised, SchemaShape delivered)
{
Assert.True(promised.Count == delivered.Count);
var sortedCols1 = promised.OrderBy(x => x.Name);
var sortedCols2 = delivered.OrderBy(x => x.Name);
var promisedCols = promised.OrderBy(x => x.Name);
var deliveredCols = delivered.OrderBy(x => x.Name);

foreach (var (x, y) in sortedCols1.Zip(sortedCols2, (x, y) => (x, y)))
foreach (var (p, d) in promisedCols.Zip(deliveredCols, (p, d) => (p, d)))
{
Assert.Equal(x.Name, y.Name);
Assert.Equal(p.Name, d.Name);
// We want the 'promised' metadata to be a superset of 'delivered'.
Assert.True(y.IsCompatibleWith(x), $"Mismatch on {x.Name}");
Assert.True(d.IsCompatibleWith(p), $"Mismatch on {p.Name}, there was a mismatch, or some unexpected annotations was present.");
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Apr 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert.True(d.IsCompatibleWith(p), $"Mismatch on {p.Name}, there was a mismatch, or some unexpected annotations was present."); [](start = 16, length = 127)

I'm trying to come up with some example where this wouldn't be true....
I think previous statement was made based on fact what what SchemaShape has less information than Schema and we don't have actual data, so in some cases we will add annotations which was impossible to predict during GetOutputSchema time.

But nothing comes to my mind, and since tests are not failing, I would guess this is our current state of annotation business. #ByDesign

Copy link
Contributor Author

@TomFinley TomFinley Apr 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so in some cases we will add annotations which was impossible to predict during GetOutputSchema time.

Maybe. The way the test worked, it made sure the promise was a superset of what was delivered -- that is, the estimator's annotation list would be a sort of "semi-promise" like, "ok, this annotation may be there or not, but if it is there, it will have this type," and so on. At least, so I must interpret this test as previously written.

There is one case in which this could potentially have been valuable, the key-to-vector annotation, as discussed in depth in issue #3380 (the special case where it was a known size vector that was of length precisely one). But that was such an obscure situation and such a micro-optimization that I hardly felt it worthwhile.

And anyway, the IEstimator did things in the wrong way, it only said the slot names would be there if it was scalar, whereas if it actually wanted to capture this "possibility" it should have said that the slot names could be there if scalar or vector (but not var-vector!). But if you're going to go that far, I'd argue that the usefulness of SchemaShape annotations data is utterly useless.

I could have imagined a "richer" schema shape where there were differing degrees of certainty -- "this will definitely be here, this might be here," etc., but in the end I decided with saying, "you know what, let's just make it a simple promise."

One nice thing is that IEstimator is still not possible to fully implement, so hypothetical future authors might be able to refine this. Or come up with a future successor concept -- certainly not every ITransformer is the result of fitting an IEstimator.

// We also want the 'delivered' to be a superset of 'promised'. Since the above
// test must have worked if we got this far, I believe the only plausible reason
// this could happen is if there was something promised but not delivered.
Assert.True(p.IsCompatibleWith(d), $"Mismatch on {p.Name}, something was promised in the annotations but not delivered.");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void FixedPlattCalibratorEstimator()
CheckValidCalibratedData(calibratorTestData.ScoredData, fixedPlattCalibratorTransformer);

//test estimator
TestEstimatorCore(calibratorTestData.Pipeline, calibratorTestData.Data);
TestEstimatorCore(fixedPlattCalibratorEstimator, calibratorTestData.ScoredData);

Done();
}
Expand All @@ -68,7 +68,7 @@ public void NaiveCalibratorEstimator()
CheckValidCalibratedData(calibratorTestData.ScoredData, naiveCalibratorTransformer);

//test estimator
TestEstimatorCore(calibratorTestData.Pipeline, calibratorTestData.Data);
TestEstimatorCore(naiveCalibratorEstimator, calibratorTestData.ScoredData);

Done();
}
Expand All @@ -88,7 +88,7 @@ public void PavCalibratorEstimator()
CheckValidCalibratedData(calibratorTestData.ScoredData, pavCalibratorTransformer);

//test estimator
TestEstimatorCore(calibratorTestData.Pipeline, calibratorTestData.Data);
TestEstimatorCore(pavCalibratorEstimator, calibratorTestData.ScoredData);

Done();
}
Expand Down
Loading