diff --git a/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs b/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs index 6296c2aa55..6c3c6bebee 100644 --- a/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs +++ b/src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.EntryPoints internal static class SchemaManipulation { [TlcModule.EntryPoint(Name = "Transforms.ColumnConcatenator", Desc = ColumnConcatenatingTransformer.Summary, UserName = ColumnConcatenatingTransformer.UserName, ShortName = ColumnConcatenatingTransformer.LoadName)] - public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ColumnConcatenatingTransformer.Arguments input) + public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ColumnConcatenatingTransformer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("ConcatColumns"); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs b/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs index 3c376d4145..73c9c0882c 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs @@ -313,7 +313,7 @@ protected ColumnBindingsBase(Schema input, bool user, params string[] names) // standard column name. const string standardColumnArgName = "Columns"; Contracts.Assert(nameof(ValueToKeyMappingTransformer.Options.Columns) == standardColumnArgName); - Contracts.Assert(nameof(ColumnConcatenatingTransformer.Arguments.Columns) == standardColumnArgName); + Contracts.Assert(nameof(ColumnConcatenatingTransformer.Options.Columns) == standardColumnArgName); for (int iinfo = 0; iinfo < names.Length; iinfo++) { diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs index e901375f83..175e384a64 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs @@ -11,7 +11,10 @@ namespace Microsoft.ML.Transforms { - public sealed class ColumnConcatenatingEstimator : IEstimator + /// + /// Concatenates columns in an into one single column. Estimator for the . + /// + public sealed class ColumnConcatenatingEstimator : IEstimator { private readonly IHost _host; private readonly string _name; @@ -22,8 +25,8 @@ public sealed class ColumnConcatenatingEstimator : IEstimator /// /// The local instance of . /// The name of the resulting column. - /// The columns to concatenate together. - public ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames) + /// The columns to concatenate into one single column. + internal ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames) { Contracts.CheckValue(env, nameof(env)); _host = env.Register("ColumnConcatenatingEstimator "); @@ -37,7 +40,10 @@ public ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnNam _source = inputColumnNames; } - public ITransformer Fit(IDataView input) + /// + /// Trains and returns a . + /// + public ColumnConcatenatingTransformer Fit(IDataView input) { _host.CheckValue(input, nameof(input)); return new ColumnConcatenatingTransformer(_host, _name, _source); @@ -109,6 +115,10 @@ private SchemaShape.Column CheckInputsAndMakeColumn( return new SchemaShape.Column(name, vecKind, itemType, false, new SchemaShape(meta)); } + /// + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index 1b7912854b..f9b5d6bf8e 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -15,9 +15,10 @@ using Microsoft.ML.Model; using Microsoft.ML.Model.Onnx; using Microsoft.ML.Model.Pfa; +using Microsoft.ML.Transforms; using Newtonsoft.Json.Linq; -[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), typeof(ColumnConcatenatingTransformer.TaggedArguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), typeof(ColumnConcatenatingTransformer.TaggedOptions), typeof(SignatureDataTransform), ColumnConcatenatingTransformer.UserName, ColumnConcatenatingTransformer.LoadName, "ConcatTransform", DocName = "transform/ConcatTransform.md")] [assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), null, typeof(SignatureLoadDataTransform), @@ -33,6 +34,10 @@ namespace Microsoft.ML.Data { using PfaType = PfaUtils.Type; + /// + /// Concatenates columns in an into one single column. Please see for + /// constructing . + /// public sealed class ColumnConcatenatingTransformer : RowToRowTransformerBase { internal const string Summary = "Concatenates one or more columns of the same item type."; @@ -42,7 +47,7 @@ public sealed class ColumnConcatenatingTransformer : RowToRowTransformerBase internal const string LoaderSignature = "ConcatTransform"; internal const string LoaderSignatureOld = "ConcatFunction"; - public sealed class Column : ManyToOneColumn + internal sealed class Column : ManyToOneColumn { internal static Column Parse(string str) { @@ -60,7 +65,8 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class TaggedColumn + [BestFriend] + internal sealed class TaggedColumn { [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the new column", ShortName = "name")] public string Name; @@ -99,13 +105,13 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { - public Arguments() + public Options() { } - public Arguments(string name, params string[] source) + public Options(string name, params string[] source) { Columns = new[] { new Column() { @@ -119,14 +125,16 @@ public Arguments(string name, params string[] source) public Column[] Columns; } - public sealed class TaggedArguments + [BestFriend] + internal sealed class TaggedOptions { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 1)] public TaggedColumn[] Columns; } - public sealed class ColumnInfo + [BestFriend] + internal sealed class ColumnInfo { public readonly string Name; private readonly (string name, string alias)[] _sources; @@ -212,14 +220,18 @@ internal ColumnInfo(ModelLoadContext ctx) private readonly ColumnInfo[] _columns; - public IReadOnlyCollection Columns => _columns.AsReadOnly(); + /// + /// The names of the output and input column pairs for the transformation. + /// + public IReadOnlyCollection<(string outputColumnName, string[] inputColumnNames)> Columns + => _columns.Select(col => (outputColumnName: col.Name, inputColumnNames: col.Sources.Select(source => source.name).ToArray())).ToArray().AsReadOnly(); /// /// Concatename columns in into one column . /// Original columns are also preserved. /// The column types must match, and the output column type is always a vector. /// - public ColumnConcatenatingTransformer(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames) + internal ColumnConcatenatingTransformer(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames) : this(env, new ColumnInfo(outputColumnName, inputColumnNames)) { } @@ -227,7 +239,7 @@ public ColumnConcatenatingTransformer(IHostEnvironment env, string outputColumnN /// /// Concatenates multiple groups of columns, each group is denoted by one of . /// - public ColumnConcatenatingTransformer(IHostEnvironment env, params ColumnInfo[] columns) : + internal ColumnConcatenatingTransformer(IHostEnvironment env, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer))) { Contracts.CheckValue(columns, nameof(columns)); @@ -357,17 +369,17 @@ private ColumnInfo[] LoadLegacy(ModelLoadContext ctx) /// /// Factory method for SignatureDataTransform. /// - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns)); + env.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns)); - for (int i = 0; i < args.Columns.Length; i++) - env.CheckUserArg(Utils.Size(args.Columns[i].Source) > 0, nameof(args.Columns)); + for (int i = 0; i < options.Columns.Length; i++) + env.CheckUserArg(Utils.Size(options.Columns[i].Source) > 0, nameof(options.Columns)); - var cols = args.Columns + var cols = options.Columns .Select(c => new ColumnInfo(c.Name, c.Source)) .ToArray(); var transformer = new ColumnConcatenatingTransformer(env, cols); @@ -377,17 +389,17 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat /// Factory method corresponding to SignatureDataTransform. /// [BestFriend] - internal static IDataTransform Create(IHostEnvironment env, TaggedArguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, TaggedOptions options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns)); + env.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns)); - for (int i = 0; i < args.Columns.Length; i++) - env.CheckUserArg(Utils.Size(args.Columns[i].Source) > 0, nameof(args.Columns)); + for (int i = 0; i < options.Columns.Length; i++) + env.CheckUserArg(Utils.Size(options.Columns[i].Source) > 0, nameof(options.Columns)); - var cols = args.Columns + var cols = options.Columns .Select(c => new ColumnInfo(c.Name, c.Source.Select(kvp => (kvp.Value, kvp.Key != "" ? kvp.Key : null)))) .ToArray(); var transformer = new ColumnConcatenatingTransformer(env, cols); diff --git a/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs index cc556a92cc..80f9a67f10 100644 --- a/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs @@ -45,7 +45,7 @@ public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, => new ColumnCopyingEstimator(CatalogUtils.GetEnvironment(catalog), columns); /// - /// Concatenates two columns together. + /// Concatenates columns together. /// /// The transform's catalog. /// Name of the column resulting from the transformation of . diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 92f4d33fb1..bff9503614 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -16,7 +16,6 @@ using Microsoft.ML.Model.Pfa; using Microsoft.ML.Transforms.Normalizers; using Newtonsoft.Json.Linq; -using static Microsoft.ML.Transforms.Normalizers.NormalizeTransform; [assembly: LoadableClass(typeof(NormalizingTransformer), null, typeof(SignatureLoadModel), "", NormalizingTransformer.LoaderSignature)] @@ -206,7 +205,7 @@ internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, C /// Name of the column to transform. /// If set to , the value of the will be used as source. /// The indicating how to the old values are mapped to the new values. - public NormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, NormalizerMode mode = NormalizerMode.MinMax) + internal NormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, NormalizerMode mode = NormalizerMode.MinMax) : this(env, mode, (outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -217,7 +216,7 @@ public NormalizingEstimator(IHostEnvironment env, string outputColumnName, strin /// The private instance of . /// The indicating how to the old values are mapped to the new values. /// An array of (outputColumnName, inputColumnName) tuples. - public NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (string outputColumnName, string inputColumnName)[] columns) + internal NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (string outputColumnName, string inputColumnName)[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(NormalizingEstimator)); @@ -230,7 +229,7 @@ public NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (s /// /// The private instance of the . /// An array of defining the inputs to the Normalizer, and their settings. - public NormalizingEstimator(IHostEnvironment env, params ColumnBase[] columns) + internal NormalizingEstimator(IHostEnvironment env, params ColumnBase[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(NormalizingEstimator)); @@ -239,12 +238,19 @@ public NormalizingEstimator(IHostEnvironment env, params ColumnBase[] columns) _columns = columns.ToArray(); } + /// + /// Trains and returns a . + /// public NormalizingTransformer Fit(IDataView input) { _host.CheckValue(input, nameof(input)); return NormalizingTransformer.Train(_host, input, _columns); } + /// + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); @@ -275,7 +281,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) public sealed partial class NormalizingTransformer : OneToOneTransformerBase { - public const string LoaderSignature = "Normalizer"; + internal const string LoaderSignature = "Normalizer"; internal const string LoaderSignatureOld = "NormalizeFunction"; @@ -387,7 +393,7 @@ private NormalizingTransformer(IHostEnvironment env, ColumnInfo[] columns) ColumnFunctions = new ColumnFunctionAccessor(Columns); } - public static NormalizingTransformer Train(IHostEnvironment env, IDataView data, NormalizingEstimator.ColumnBase[] columns) + internal static NormalizingTransformer Train(IHostEnvironment env, IDataView data, NormalizingEstimator.ColumnBase[] columns) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(data, nameof(data)); @@ -510,7 +516,7 @@ private NormalizingTransformer(IHost host, ModelLoadContext ctx, IDataView input Columns = ImmutableArray.Create(cols); } - public static NormalizingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + private static NormalizingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs index 1844de7347..fd4dba2aba 100644 --- a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs @@ -74,7 +74,7 @@ public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env // as a group id. That's just one example - you get the idea. string nameFeat = DefaultColumnNames.Features; viewTrain = ColumnConcatenatingTransformer.Create(host, - new ColumnConcatenatingTransformer.TaggedArguments() + new ColumnConcatenatingTransformer.TaggedOptions() { Columns = new[] { new ColumnConcatenatingTransformer.TaggedColumn() { Name = nameFeat, Source = concatNames.ToArray() } } diff --git a/src/Microsoft.ML.StaticPipe/TransformsStatic.cs b/src/Microsoft.ML.StaticPipe/TransformsStatic.cs index b575777ba4..10bee69c74 100644 --- a/src/Microsoft.ML.StaticPipe/TransformsStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TransformsStatic.cs @@ -735,9 +735,9 @@ public static class NAReplacerStaticExtensions private readonly struct Config { public readonly bool ImputeBySlot; - public readonly MissingValueReplacingTransformer.ColumnInfo.ReplacementMode ReplacementMode; + public readonly MissingValueReplacingEstimator.ColumnInfo.ReplacementMode ReplacementMode; - public Config(MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, + public Config(MissingValueReplacingEstimator.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) { ImputeBySlot = imputeBySlot; @@ -803,11 +803,11 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var infos = new MissingValueReplacingTransformer.ColumnInfo[toOutput.Length]; + var infos = new MissingValueReplacingEstimator.ColumnInfo[toOutput.Length]; for (int i = 0; i < toOutput.Length; ++i) { var col = (IColInput)toOutput[i]; - infos[i] = new MissingValueReplacingTransformer.ColumnInfo(outputNames[toOutput[i]], inputNames[col.Input], col.Config.ReplacementMode, col.Config.ImputeBySlot); + infos[i] = new MissingValueReplacingEstimator.ColumnInfo(outputNames[toOutput[i]], inputNames[col.Input], col.Config.ReplacementMode, col.Config.ImputeBySlot); } return new MissingValueReplacingEstimator(env, infos); } @@ -818,7 +818,7 @@ public override IEstimator Reconcile(IHostEnvironment env, /// /// Incoming data. /// How NaN should be replaced - public static Scalar ReplaceNaNValues(this Scalar input, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) + public static Scalar ReplaceNaNValues(this Scalar input, MissingValueReplacingEstimator.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input, new Config(replacementMode, false)); @@ -829,7 +829,7 @@ public static Scalar ReplaceNaNValues(this Scalar input, MissingVa /// /// Incoming data. /// How NaN should be replaced - public static Scalar ReplaceNaNValues(this Scalar input, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) + public static Scalar ReplaceNaNValues(this Scalar input, MissingValueReplacingEstimator.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input, new Config(replacementMode, false)); @@ -842,7 +842,7 @@ public static Scalar ReplaceNaNValues(this Scalar input, Missing /// If true, per-slot imputation of replacement is performed. /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, /// where imputation is always for the entire column. - public static Vector ReplaceNaNValues(this Vector input, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) + public static Vector ReplaceNaNValues(this Vector input, MissingValueReplacingEstimator.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); @@ -856,7 +856,7 @@ public static Vector ReplaceNaNValues(this Vector input, MissingVa /// If true, per-slot imputation of replacement is performed. /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, /// where imputation is always for the entire column. - public static Vector ReplaceNaNValues(this Vector input, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) + public static Vector ReplaceNaNValues(this Vector input, MissingValueReplacingEstimator.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); @@ -867,7 +867,7 @@ public static Vector ReplaceNaNValues(this Vector input, Missing /// /// Incoming data. /// How NaN should be replaced - public static VarVector ReplaceNaNValues(this VarVector input, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) + public static VarVector ReplaceNaNValues(this VarVector input, MissingValueReplacingEstimator.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input, new Config(replacementMode, false)); @@ -877,7 +877,7 @@ public static VarVector ReplaceNaNValues(this VarVector input, Mis /// /// Incoming data. /// How NaN should be replaced - public static VarVector ReplaceNaNValues(this VarVector input, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) + public static VarVector ReplaceNaNValues(this VarVector input, MissingValueReplacingEstimator.ColumnInfo.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input, new Config(replacementMode, false)); diff --git a/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs b/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs index 71f8464892..d2cf692852 100644 --- a/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs @@ -39,26 +39,26 @@ public static MissingValueIndicatorEstimator IndicateMissingValues(this Transfor /// (depending on whether the is given a value, or left to null) /// identical to the input column for everything but the missing values. The missing values of the input column, in this new column are replaced with /// one of the values specifid in the . The default for the is - /// . + /// . /// /// The transform extensions' catalog. /// Name of the column resulting from the transformation of . /// Name of column to transform. If set to , the value of the will be used as source. /// If not provided, the will be replaced with the results of the transforms. - /// The type of replacement to use as specified in + /// The type of replacement to use as specified in public static MissingValueReplacingEstimator ReplaceMissingValues(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null, - MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementKind = MissingValueReplacingEstimator.Defaults.ReplacementMode) + MissingValueReplacingEstimator.ColumnInfo.ReplacementMode replacementKind = MissingValueReplacingEstimator.Defaults.ReplacementMode) => new MissingValueReplacingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, replacementKind); /// /// Creates a new output column, identical to the input column for everything but the missing values. - /// The missing values of the input column, in this new column are replaced with . + /// The missing values of the input column, in this new column are replaced with . /// /// The transform extensions' catalog. /// The name of the columns to use, and per-column transformation configuraiton. - public static MissingValueReplacingEstimator ReplaceMissingValues(this TransformsCatalog catalog, params MissingValueReplacingTransformer.ColumnInfo[] columns) + public static MissingValueReplacingEstimator ReplaceMissingValues(this TransformsCatalog catalog, params MissingValueReplacingEstimator.ColumnInfo[] columns) => new MissingValueReplacingEstimator(CatalogUtils.GetEnvironment(catalog), columns); } } diff --git a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs index 06b37861ec..9f2e780fc0 100644 --- a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs @@ -15,13 +15,13 @@ using Microsoft.ML.Transforms.Conversions; [assembly: LoadableClass(MissingValueHandlingTransformer.Summary, typeof(IDataTransform), typeof(MissingValueHandlingTransformer), - typeof(MissingValueHandlingTransformer.Arguments), typeof(SignatureDataTransform), + typeof(MissingValueHandlingTransformer.Options), typeof(SignatureDataTransform), MissingValueHandlingTransformer.FriendlyName, "NAHandleTransform", MissingValueHandlingTransformer.ShortName, "NA", DocName = "transform/NAHandle.md")] namespace Microsoft.ML.Transforms { /// - public static class MissingValueHandlingTransformer + internal static class MissingValueHandlingTransformer { public enum ReplacementKind : byte { @@ -56,7 +56,7 @@ public enum ReplacementKind : byte Max = Maximum, } - public sealed class Arguments : TransformInputBase + public sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:rep:src)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -117,9 +117,10 @@ internal bool TryUnparse(StringBuilder sb) /// Name of the output column. /// Name of the column to be transformed. If this is null '' will be used. /// The replacement method to utilize. - public static IDataView Create(IHostEnvironment env, IDataView input, string outputColumnName, string inputColumnName = null, ReplacementKind replaceWith = ReplacementKind.DefaultValue) + private static IDataView Create(IHostEnvironment env, IDataView input, string outputColumnName, string inputColumnName = null, + ReplacementKind replaceWith = ReplacementKind.DefaultValue) { - var args = new Arguments() + var args = new Options() { Columns = new[] { @@ -131,7 +132,7 @@ public static IDataView Create(IHostEnvironment env, IDataView input, string out } /// Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var h = env.Register("Categorical"); @@ -139,7 +140,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat h.CheckValue(input, nameof(input)); h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns)); - var replaceCols = new List(); + var replaceCols = new List(); var naIndicatorCols = new List(); var naConvCols = new List(); var concatCols = new List(); @@ -153,7 +154,8 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat var addInd = column.ConcatIndicator ?? args.Concat; if (!addInd) { - replaceCols.Add(new MissingValueReplacingTransformer.ColumnInfo(column.Name, column.Source,(MissingValueReplacingTransformer.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); + replaceCols.Add(new MissingValueReplacingEstimator.ColumnInfo(column.Name, column.Source, + (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); continue; } @@ -187,7 +189,8 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat } // Add the NAReplaceTransform column. - replaceCols.Add(new MissingValueReplacingTransformer.ColumnInfo(tmpReplacementColName, column.Source, (MissingValueReplacingTransformer.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); + replaceCols.Add(new MissingValueReplacingEstimator.ColumnInfo(tmpReplacementColName, column.Source, + (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); // Add the ConcatTransform column. if (replaceType is VectorType) @@ -223,7 +226,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat // Create the indicator columns. if (naIndicatorCols.Count > 0) - output = MissingValueIndicatorTransformer.Create(h, new MissingValueIndicatorTransformer.Arguments() { Columns = naIndicatorCols.ToArray() }, input); + output = MissingValueIndicatorTransformer.Create(h, new MissingValueIndicatorTransformer.Options() { Columns = naIndicatorCols.ToArray() }, input); // Convert the indicator columns to the correct type so that they can be concatenated to the NAReplace outputs. if (naConvCols.Count > 0) @@ -237,7 +240,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat // Concat the NAReplaceTransform output and the NAIndicatorTransform output. if (naIndicatorCols.Count > 0) - output = ColumnConcatenatingTransformer.Create(h, new ColumnConcatenatingTransformer.TaggedArguments() { Columns = concatCols.ToArray() }, output); + output = ColumnConcatenatingTransformer.Create(h, new ColumnConcatenatingTransformer.TaggedOptions() { Columns = concatCols.ToArray() }, output); // Finally, drop the temporary indicator columns. if (dropCols.Count > 0) diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index 5801eda65e..eb91e3eeed 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -16,7 +16,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms; -[assembly: LoadableClass(MissingValueIndicatorTransformer.Summary, typeof(IDataTransform), typeof(MissingValueIndicatorTransformer), typeof(MissingValueIndicatorTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(MissingValueIndicatorTransformer.Summary, typeof(IDataTransform), typeof(MissingValueIndicatorTransformer), typeof(MissingValueIndicatorTransformer.Options), typeof(SignatureDataTransform), MissingValueIndicatorTransformer.FriendlyName, MissingValueIndicatorTransformer.LoadName, "NAIndicator", MissingValueIndicatorTransformer.ShortName, DocName = "transform/NAHandle.md")] [assembly: LoadableClass(MissingValueIndicatorTransformer.Summary, typeof(IDataTransform), typeof(MissingValueIndicatorTransformer), null, typeof(SignatureLoadDataTransform), @@ -33,7 +33,7 @@ namespace Microsoft.ML.Transforms /// public sealed class MissingValueIndicatorTransformer : OneToOneTransformerBase { - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { internal static Column Parse(string str) { @@ -52,7 +52,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -78,6 +78,9 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = nameof(MissingValueIndicatorTransformer); + /// + /// The names of the output and input column pairs for the transformation. + /// public IReadOnlyList<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); /// @@ -85,13 +88,13 @@ private static VersionInfo GetVersionInfo() /// /// The environment to use. /// The names of the input columns of the transformation and the corresponding names for the output columns. - public MissingValueIndicatorTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) + internal MissingValueIndicatorTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), columns) { } - internal MissingValueIndicatorTransformer(IHostEnvironment env, Arguments args) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), GetColumnPairs(args.Columns)) + internal MissingValueIndicatorTransformer(IHostEnvironment env, Options options) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), GetColumnPairs(options.Columns)) { } @@ -114,8 +117,8 @@ internal static MissingValueIndicatorTransformer Create(IHostEnvironment env, Mo } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) - => new MissingValueIndicatorTransformer(env, args).MakeDataTransform(input); + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) + => new MissingValueIndicatorTransformer(env, options).MakeDataTransform(input); // Factory method for SignatureLoadDataTransform. internal static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) @@ -434,7 +437,7 @@ public sealed class MissingValueIndicatorEstimator : TrivialEstimator /// The environment to use. /// The names of the input columns of the transformation and the corresponding names for the output columns. - public MissingValueIndicatorEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) + internal MissingValueIndicatorEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), new MissingValueIndicatorTransformer(env, columns)) { Contracts.CheckValue(env, nameof(env)); @@ -446,13 +449,14 @@ public MissingValueIndicatorEstimator(IHostEnvironment env, params (string outpu /// The environment to use. /// Name of the column resulting from the transformation of . /// Name of the column to transform. If set to , the value of the will be used as source. - public MissingValueIndicatorEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null) + internal MissingValueIndicatorEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null) : this(env, (outputColumnName, inputColumnName ?? outputColumnName)) { } /// - /// Returns the schema that would be produced by the transformation. + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. /// public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 734535e9cc..b99676cbc4 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -21,7 +21,7 @@ using Microsoft.ML.Model.Onnx; using Microsoft.ML.Transforms; -[assembly: LoadableClass(MissingValueReplacingTransformer.Summary, typeof(IDataTransform), typeof(MissingValueReplacingTransformer), typeof(MissingValueReplacingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(MissingValueReplacingTransformer.Summary, typeof(IDataTransform), typeof(MissingValueReplacingTransformer), typeof(MissingValueReplacingTransformer.Options), typeof(SignatureDataTransform), MissingValueReplacingTransformer.FriendlyName, MissingValueReplacingTransformer.LoadName, "NAReplace", MissingValueReplacingTransformer.ShortName, DocName = "transform/NAHandle.md")] [assembly: LoadableClass(MissingValueReplacingTransformer.Summary, typeof(IDataTransform), typeof(MissingValueReplacingTransformer), null, typeof(SignatureLoadDataTransform), @@ -43,7 +43,7 @@ namespace Microsoft.ML.Transforms /// public sealed partial class MissingValueReplacingTransformer : OneToOneTransformerBase { - public enum ReplacementKind : byte + internal enum ReplacementKind : byte { // REVIEW: What should the full list of options for this transform be? DefaultValue = 0, @@ -73,7 +73,7 @@ public enum ReplacementKind : byte // *mean: use domain value closest to the mean // Potentially also min/max; probably will not include median due to its relatively low value and high computational cost. // Note: Will need to support different replacement values for different slots to implement this. - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { // REVIEW: Should flexibility for different replacement values for slots be introduced? [Argument(ArgumentType.AtMostOnce, HelpText = "Replacement value for NAs (uses default value if not given)", ShortName = "rep")] @@ -114,7 +114,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:rep:src)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -127,7 +127,7 @@ public sealed class Arguments : TransformInputBase public bool ImputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot; } - public const string LoadName = "NAReplaceTransform"; + internal const string LoadName = "NAReplaceTransform"; private static VersionInfo GetVersionInfo() { @@ -174,47 +174,7 @@ private static string TestType(ColumnType type) return null; } - /// - /// Describes how the transformer handles one column pair. - /// - public sealed class ColumnInfo - { - public enum ReplacementMode : byte - { - DefaultValue = 0, - Mean = 1, - Minimum = 2, - Maximum = 3, - } - - public readonly string Name; - public readonly string InputColumnName; - public readonly bool ImputeBySlot; - public readonly ReplacementMode Replacement; - - /// - /// Describes how the transformer handles one column pair. - /// - /// Name of the column resulting from the transformation of . - /// Name of column to transform. If set to , the value of the will be used as source. - /// What to replace the missing value with. - /// If true, per-slot imputation of replacement is performed. - /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, - /// where imputation is always for the entire column. - public ColumnInfo(string name, string inputColumnName = null, ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, - bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) - { - Contracts.CheckNonWhiteSpace(name, nameof(name)); - Name = name; - InputColumnName = inputColumnName ?? name; - ImputeBySlot = imputeBySlot; - Replacement = replacementMode; - } - - internal string ReplacementString { get; set; } - } - - private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(MissingValueReplacingEstimator.ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); @@ -243,7 +203,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol throw Host.ExceptParam(nameof(inputSchema), reason); } - public MissingValueReplacingTransformer(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) + internal MissingValueReplacingTransformer(IHostEnvironment env, IDataView input, params MissingValueReplacingEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueReplacingTransformer)), GetColumnPairs(columns)) { // Check that all the input columns are present and correct. @@ -310,7 +270,7 @@ private T[] GetValuesArray(VBuffer src, VectorType srcType, int iinfo) /// Vectors default to by-slot imputation unless otherwise specified, except for unknown sized vectors /// which force across-slot imputation. /// - private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out object[] repValues, out BitArray[] slotIsDefault, out ColumnType[] types) + private void GetReplacementValues(IDataView input, MissingValueReplacingEstimator.ColumnInfo[] columns, out object[] repValues, out BitArray[] slotIsDefault, out ColumnType[] types) { repValues = new object[columns.Length]; slotIsDefault = new BitArray[columns.Length]; @@ -466,32 +426,32 @@ private object GetSpecifiedValue(string srcStr, ColumnType dstType, InPredica } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); - var cols = new ColumnInfo[args.Columns.Length]; + env.CheckValue(options.Columns, nameof(options.Columns)); + var cols = new MissingValueReplacingEstimator.ColumnInfo[options.Columns.Length]; for (int i = 0; i < cols.Length; i++) { - var item = args.Columns[i]; - var kind = item.Kind ?? args.ReplacementKind; + var item = options.Columns[i]; + var kind = item.Kind ?? options.ReplacementKind; if (!Enum.IsDefined(typeof(ReplacementKind), kind)) - throw env.ExceptUserArg(nameof(args.ReplacementKind), "Undefined sorting criteria '{0}' detected for column '{1}'", kind, item.Name); + throw env.ExceptUserArg(nameof(options.ReplacementKind), "Undefined sorting criteria '{0}' detected for column '{1}'", kind, item.Name); - cols[i] = new ColumnInfo( + cols[i] = new MissingValueReplacingEstimator.ColumnInfo( item.Name, item.Source, - (ColumnInfo.ReplacementMode)(item.Kind ?? args.ReplacementKind), - item.Slot ?? args.ImputeBySlot); - cols[i].ReplacementString = item.ReplacementString; + (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(item.Kind ?? options.ReplacementKind), + item.Slot ?? options.ImputeBySlot, + item.ReplacementString); }; return new MissingValueReplacingTransformer(env, input, cols).MakeDataTransform(input); } - internal static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) + internal static IDataTransform Create(IHostEnvironment env, IDataView input, params MissingValueReplacingEstimator.ColumnInfo[] columns) { return new MissingValueReplacingTransformer(env, input, columns).MakeDataTransform(input); } @@ -933,28 +893,107 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src public sealed class MissingValueReplacingEstimator : IEstimator { - public static class Defaults + [BestFriend] + internal static class Defaults { - public const MissingValueReplacingTransformer.ColumnInfo.ReplacementMode ReplacementMode = MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.DefaultValue; + public const ColumnInfo.ReplacementMode ReplacementMode = ColumnInfo.ReplacementMode.DefaultValue; public const bool ImputeBySlot = true; } - private readonly IHost _host; - private readonly MissingValueReplacingTransformer.ColumnInfo[] _columns; - - public MissingValueReplacingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementKind = Defaults.ReplacementMode) - : this(env, new MissingValueReplacingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, replacementKind)) + /// + /// Describes how the transformer handles one column pair. + /// + public sealed class ColumnInfo + { + /// + /// The possible ways to replace missing values. + /// + public enum ReplacementMode : byte { + /// + /// Replace with the default value of the column based on its type. For example, 'zero' for numeric and 'empty' for string/text columns. + /// + DefaultValue = 0, + /// + /// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns. + /// + Mean = 1, + /// + /// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns. + /// + Minimum = 2, + /// + /// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns. + /// + Maximum = 3, + } + /// Name of the column resulting from the transformation of . + public readonly string Name; + /// Name of column to transform. + public readonly string InputColumnName; + /// + /// If true, per-slot imputation of replacement is performed. + /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, + /// where imputation is always for the entire column. + /// + public readonly bool ImputeBySlot; + /// How to replace the missing values. + public readonly ReplacementMode Replacement; + /// Replacement value for missing values (only used in entrypoing and command line API). + internal readonly string ReplacementString; + + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. + /// How to replace the missing values. + /// If true, per-slot imputation of replacement is performed. + /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, + /// where imputation is always for the entire column. + public ColumnInfo(string name, string inputColumnName = null, ReplacementMode replacementMode = Defaults.ReplacementMode, + bool imputeBySlot = Defaults.ImputeBySlot) + { + Contracts.CheckNonWhiteSpace(name, nameof(name)); + Name = name; + InputColumnName = inputColumnName ?? name; + ImputeBySlot = imputeBySlot; + Replacement = replacementMode; } - public MissingValueReplacingEstimator(IHostEnvironment env, params MissingValueReplacingTransformer.ColumnInfo[] columns) + /// + /// This constructor is used internally to convert from to + /// as we support in command line and entrypoint API only. + /// + internal ColumnInfo(string name, string inputColumnName, ReplacementMode replacementMode, bool imputeBySlot, string replacementString) + : this(name, inputColumnName, replacementMode, imputeBySlot) { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(MissingValueReplacingEstimator)); - _columns = columns; + ReplacementString = replacementString; } + } + + private readonly IHost _host; + private readonly ColumnInfo[] _columns; + internal MissingValueReplacingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, ColumnInfo.ReplacementMode replacementKind = Defaults.ReplacementMode) + : this(env, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, replacementKind)) + { + + } + + [BestFriend] + internal MissingValueReplacingEstimator(IHostEnvironment env, params ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(MissingValueReplacingEstimator)); + _columns = columns; + } + + /// + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); @@ -979,7 +1018,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } - public MissingValueReplacingTransformer Fit(IDataView input) => new MissingValueReplacingTransformer(_host, input, _columns); - } - + /// + /// Trains and returns a . + /// + public MissingValueReplacingTransformer Fit(IDataView input) => new MissingValueReplacingTransformer(_host, input, _columns); } +} diff --git a/src/Microsoft.ML.Transforms/NAHandling.cs b/src/Microsoft.ML.Transforms/NAHandling.cs index a964f89eb7..15f65b729e 100644 --- a/src/Microsoft.ML.Transforms/NAHandling.cs +++ b/src/Microsoft.ML.Transforms/NAHandling.cs @@ -48,7 +48,7 @@ public static CommonOutputs.TransformOutput Filter(IHostEnvironment env, NAFilte Desc = MissingValueHandlingTransformer.Summary, UserName = MissingValueHandlingTransformer.FriendlyName, ShortName = MissingValueHandlingTransformer.ShortName)] - public static CommonOutputs.TransformOutput Handle(IHostEnvironment env, MissingValueHandlingTransformer.Arguments input) + public static CommonOutputs.TransformOutput Handle(IHostEnvironment env, MissingValueHandlingTransformer.Options input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NAHandle", input); var xf = MissingValueHandlingTransformer.Create(h, input, input.Data); @@ -63,7 +63,7 @@ public static CommonOutputs.TransformOutput Handle(IHostEnvironment env, Missing Desc = MissingValueIndicatorTransformer.Summary, UserName = MissingValueIndicatorTransformer.FriendlyName, ShortName = MissingValueIndicatorTransformer.ShortName)] - public static CommonOutputs.TransformOutput Indicator(IHostEnvironment env, MissingValueIndicatorTransformer.Arguments input) + public static CommonOutputs.TransformOutput Indicator(IHostEnvironment env, MissingValueIndicatorTransformer.Options input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NAIndicator", input); var xf = new MissingValueIndicatorTransformer(h, input).Transform(input.Data); @@ -78,7 +78,7 @@ public static CommonOutputs.TransformOutput Indicator(IHostEnvironment env, Miss Desc = MissingValueReplacingTransformer.Summary, UserName = MissingValueReplacingTransformer.FriendlyName, ShortName = MissingValueReplacingTransformer.ShortName)] - public static CommonOutputs.TransformOutput Replace(IHostEnvironment env, MissingValueReplacingTransformer.Arguments input) + public static CommonOutputs.TransformOutput Replace(IHostEnvironment env, MissingValueReplacingTransformer.Options input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NAReplace", input); var xf = MissingValueReplacingTransformer.Create(h, input, input.Data); diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index c93bb6157b..1c8470a58d 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -76,7 +76,7 @@ Transforms.BinNormalizer The values are assigned into equidensity bins and a val Transforms.CategoricalHashOneHotVectorizer Converts the categorical value into an indicator array by hashing the value and using the hash as an index in the bag. If the input column is a vector, a single indicator bag is returned for it. Microsoft.ML.Transforms.Categorical.Categorical CatTransformHash Microsoft.ML.Transforms.Categorical.OneHotHashEncodingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.CategoricalOneHotVectorizer Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array. Microsoft.ML.Transforms.Categorical.Categorical CatTransformDict Microsoft.ML.Transforms.Categorical.OneHotEncodingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.CharacterTokenizer Character-oriented tokenizer where text is considered a sequence of characters. Microsoft.ML.Transforms.Text.TextAnalytics CharTokenize Microsoft.ML.Transforms.Text.TokenizingByCharactersTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.ColumnConcatenator Concatenates one or more columns of the same item type. Microsoft.ML.EntryPoints.SchemaManipulation ConcatColumns Microsoft.ML.Data.ColumnConcatenatingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.ColumnConcatenator Concatenates one or more columns of the same item type. Microsoft.ML.EntryPoints.SchemaManipulation ConcatColumns Microsoft.ML.Data.ColumnConcatenatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnCopier Duplicates columns from the dataset Microsoft.ML.EntryPoints.SchemaManipulation CopyColumns Microsoft.ML.Transforms.ColumnCopyingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnSelector Selects a set of columns, dropping all others Microsoft.ML.EntryPoints.SchemaManipulation SelectColumns Microsoft.ML.Transforms.ColumnSelectingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnTypeConverter Converts a column to a different type, using standard conversions. Microsoft.ML.Transforms.Conversions.TypeConversion Convert Microsoft.ML.Transforms.Conversions.TypeConvertingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput @@ -106,11 +106,11 @@ Transforms.LpNormalizer Normalize vectors (rows) individually by rescaling them Transforms.ManyHeterogeneousModelCombiner Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel. Microsoft.ML.EntryPoints.ModelOperations CombineModels Microsoft.ML.EntryPoints.ModelOperations+PredictorModelInput Microsoft.ML.EntryPoints.ModelOperations+PredictorModelOutput Transforms.MeanVarianceNormalizer Normalizes the data based on the computed mean and variance of the data. Microsoft.ML.Data.Normalize MeanVar Microsoft.ML.Transforms.Normalizers.NormalizeTransform+MeanVarArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.MinMaxNormalizer Normalizes the data based on the observed minimum and maximum values of the data. Microsoft.ML.Data.Normalize MinMax Microsoft.ML.Transforms.Normalizers.NormalizeTransform+MinMaxArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.MissingValueHandler Handle missing values by replacing them with either the default value or the mean/min/max value (for non-text columns only). An indicator column can optionally be concatenated, if theinput column type is numeric. Microsoft.ML.Transforms.NAHandling Handle Microsoft.ML.Transforms.MissingValueHandlingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.MissingValueIndicator Create a boolean output column with the same number of slots as the input column, where the output value is true if the value in the input column is missing. Microsoft.ML.Transforms.NAHandling Indicator Microsoft.ML.Transforms.MissingValueIndicatorTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.MissingValueHandler Handle missing values by replacing them with either the default value or the mean/min/max value (for non-text columns only). An indicator column can optionally be concatenated, if theinput column type is numeric. Microsoft.ML.Transforms.NAHandling Handle Microsoft.ML.Transforms.MissingValueHandlingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.MissingValueIndicator Create a boolean output column with the same number of slots as the input column, where the output value is true if the value in the input column is missing. Microsoft.ML.Transforms.NAHandling Indicator Microsoft.ML.Transforms.MissingValueIndicatorTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValuesDropper Removes NAs from vector columns. Microsoft.ML.Transforms.NAHandling Drop Microsoft.ML.Transforms.MissingValueDroppingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValuesRowDropper Filters out rows that contain missing values. Microsoft.ML.Transforms.NAHandling Filter Microsoft.ML.Transforms.NAFilter+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.MissingValueSubstitutor Create an output column of the same type and size of the input column, where missing values are replaced with either the default value or the mean/min/max value (for non-text columns only). Microsoft.ML.Transforms.NAHandling Replace Microsoft.ML.Transforms.MissingValueReplacingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.MissingValueSubstitutor Create an output column of the same type and size of the input column, where missing values are replaced with either the default value or the mean/min/max value (for non-text columns only). Microsoft.ML.Transforms.NAHandling Replace Microsoft.ML.Transforms.MissingValueReplacingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ModelCombiner Combines a sequence of TransformModels into a single model Microsoft.ML.EntryPoints.ModelOperations CombineTransformModels Microsoft.ML.EntryPoints.ModelOperations+CombineTransformModelsInput Microsoft.ML.EntryPoints.ModelOperations+CombineTransformModelsOutput Transforms.NGramTranslator Produces a bag of counts of ngrams (sequences of consecutive values of length 1-n) in a given vector of keys. It does so by building a dictionary of ngrams and using the id in the dictionary as the index in the bag. Microsoft.ML.Transforms.Text.TextAnalytics NGramTransform Microsoft.ML.Transforms.Text.NgramExtractingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.NoOperation Does nothing. Microsoft.ML.Data.NopTransform Nop Microsoft.ML.Data.NopTransform+NopInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs b/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs index 76df72a1f5..7d838bc5ba 100644 --- a/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs +++ b/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs @@ -39,7 +39,7 @@ internal static MLContext CreateRankingEnvironment ( - A: row.ScalarFloat.ReplaceNaNValues(MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Maximum), - B: row.ScalarDouble.ReplaceNaNValues(MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - C: row.VectorFloat.ReplaceNaNValues(MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - D: row.VectorDoulbe.ReplaceNaNValues(MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Minimum) + A: row.ScalarFloat.ReplaceNaNValues(MissingValueReplacingEstimator.ColumnInfo.ReplacementMode.Maximum), + B: row.ScalarDouble.ReplaceNaNValues(MissingValueReplacingEstimator.ColumnInfo.ReplacementMode.Mean), + C: row.VectorFloat.ReplaceNaNValues(MissingValueReplacingEstimator.ColumnInfo.ReplacementMode.Mean), + D: row.VectorDoulbe.ReplaceNaNValues(MissingValueReplacingEstimator.ColumnInfo.ReplacementMode.Minimum) )); TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData); @@ -108,11 +108,11 @@ public void TestOldSavingAndLoading() }; var dataView = ML.Data.ReadFromEnumerable(data); - var pipe = new MissingValueReplacingEstimator(Env, - new MissingValueReplacingTransformer.ColumnInfo("NAA", "A", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - new MissingValueReplacingTransformer.ColumnInfo("NAB", "B", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - new MissingValueReplacingTransformer.ColumnInfo("NAC", "C", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - new MissingValueReplacingTransformer.ColumnInfo("NAD", "D", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean)); + var pipe = ML.Transforms.ReplaceMissingValues( + new MissingValueReplacingEstimator.ColumnInfo("NAA", "A", MissingValueReplacingEstimator.ColumnInfo.ReplacementMode.Mean), + new MissingValueReplacingEstimator.ColumnInfo("NAB", "B", MissingValueReplacingEstimator.ColumnInfo.ReplacementMode.Mean), + new MissingValueReplacingEstimator.ColumnInfo("NAC", "C", MissingValueReplacingEstimator.ColumnInfo.ReplacementMode.Mean), + new MissingValueReplacingEstimator.ColumnInfo("NAD", "D", MissingValueReplacingEstimator.ColumnInfo.ReplacementMode.Mean)); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result);