Skip to content

Commit

Permalink
Creation of components through MLContext and cleanup (Concat, Normal…
Browse files Browse the repository at this point in the history
…izer, NA Indicator/Replace) (#2363)
  • Loading branch information
artidoro authored Feb 5, 2019
1 parent d7eb0a6 commit 58ff9b5
Show file tree
Hide file tree
Showing 20 changed files with 271 additions and 181 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.ML.EntryPoints
internal static class SchemaManipulation
{
[TlcModule.EntryPoint(Name = "Transforms.ColumnConcatenator", Desc = ColumnConcatenatingTransformer.Summary, UserName = ColumnConcatenatingTransformer.UserName, ShortName = ColumnConcatenatingTransformer.LoadName)]
public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ColumnConcatenatingTransformer.Arguments input)
public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ColumnConcatenatingTransformer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("ConcatColumns");
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ protected ColumnBindingsBase(Schema input, bool user, params string[] names)
// standard column name.
const string standardColumnArgName = "Columns";
Contracts.Assert(nameof(ValueToKeyMappingTransformer.Options.Columns) == standardColumnArgName);
Contracts.Assert(nameof(ColumnConcatenatingTransformer.Arguments.Columns) == standardColumnArgName);
Contracts.Assert(nameof(ColumnConcatenatingTransformer.Options.Columns) == standardColumnArgName);

for (int iinfo = 0; iinfo < names.Length; iinfo++)
{
Expand Down
18 changes: 14 additions & 4 deletions src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

namespace Microsoft.ML.Transforms
{
public sealed class ColumnConcatenatingEstimator : IEstimator<ITransformer>
/// <summary>
/// Concatenates columns in an <see cref="IDataView"/> into one single column. Estimator for the <see cref="ColumnConcatenatingTransformer"/>.
/// </summary>
public sealed class ColumnConcatenatingEstimator : IEstimator<ColumnConcatenatingTransformer>
{
private readonly IHost _host;
private readonly string _name;
Expand All @@ -22,8 +25,8 @@ public sealed class ColumnConcatenatingEstimator : IEstimator<ITransformer>
/// </summary>
/// <param name="env">The local instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="outputColumnName">The name of the resulting column.</param>
/// <param name="inputColumnNames">The columns to concatenate together.</param>
public ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames)
/// <param name="inputColumnNames">The columns to concatenate into one single column.</param>
internal ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register("ColumnConcatenatingEstimator ");
Expand All @@ -37,7 +40,10 @@ public ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnNam
_source = inputColumnNames;
}

public ITransformer Fit(IDataView input)
/// <summary>
/// Trains and returns a <see cref="ColumnConcatenatingTransformer"/>.
/// </summary>
public ColumnConcatenatingTransformer Fit(IDataView input)
{
_host.CheckValue(input, nameof(input));
return new ColumnConcatenatingTransformer(_host, _name, _source);
Expand Down Expand Up @@ -109,6 +115,10 @@ private SchemaShape.Column CheckInputsAndMakeColumn(
return new SchemaShape.Column(name, vecKind, itemType, false, new SchemaShape(meta));
}

/// <summary>
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
/// Used for schema propagation and verification in a pipeline.
/// </summary>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
Expand Down
58 changes: 35 additions & 23 deletions src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
using Microsoft.ML.Model;
using Microsoft.ML.Model.Onnx;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Transforms;
using Newtonsoft.Json.Linq;

[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), typeof(ColumnConcatenatingTransformer.TaggedArguments), typeof(SignatureDataTransform),
[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), typeof(ColumnConcatenatingTransformer.TaggedOptions), typeof(SignatureDataTransform),
ColumnConcatenatingTransformer.UserName, ColumnConcatenatingTransformer.LoadName, "ConcatTransform", DocName = "transform/ConcatTransform.md")]

[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), null, typeof(SignatureLoadDataTransform),
Expand All @@ -33,6 +34,10 @@ namespace Microsoft.ML.Data
{
using PfaType = PfaUtils.Type;

/// <summary>
/// Concatenates columns in an <see cref="IDataView"/> into one single column. Please see <see cref="ColumnConcatenatingEstimator"/> for
/// constructing <see cref="ColumnConcatenatingTransformer"/>.
/// </summary>
public sealed class ColumnConcatenatingTransformer : RowToRowTransformerBase
{
internal const string Summary = "Concatenates one or more columns of the same item type.";
Expand All @@ -42,7 +47,7 @@ public sealed class ColumnConcatenatingTransformer : RowToRowTransformerBase
internal const string LoaderSignature = "ConcatTransform";
internal const string LoaderSignatureOld = "ConcatFunction";

public sealed class Column : ManyToOneColumn
internal sealed class Column : ManyToOneColumn
{
internal static Column Parse(string str)
{
Expand All @@ -60,7 +65,8 @@ internal bool TryUnparse(StringBuilder sb)
}
}

public sealed class TaggedColumn
[BestFriend]
internal sealed class TaggedColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the new column", ShortName = "name")]
public string Name;
Expand Down Expand Up @@ -99,13 +105,13 @@ internal bool TryUnparse(StringBuilder sb)
}
}

public sealed class Arguments : TransformInputBase
internal sealed class Options : TransformInputBase
{
public Arguments()
public Options()
{
}

public Arguments(string name, params string[] source)
public Options(string name, params string[] source)
{
Columns = new[] { new Column()
{
Expand All @@ -119,14 +125,16 @@ public Arguments(string name, params string[] source)
public Column[] Columns;
}

public sealed class TaggedArguments
[BestFriend]
internal sealed class TaggedOptions
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)",
Name = "Column", ShortName = "col", SortOrder = 1)]
public TaggedColumn[] Columns;
}

public sealed class ColumnInfo
[BestFriend]
internal sealed class ColumnInfo
{
public readonly string Name;
private readonly (string name, string alias)[] _sources;
Expand Down Expand Up @@ -212,22 +220,26 @@ internal ColumnInfo(ModelLoadContext ctx)

private readonly ColumnInfo[] _columns;

public IReadOnlyCollection<ColumnInfo> Columns => _columns.AsReadOnly();
/// <summary>
/// The names of the output and input column pairs for the transformation.
/// </summary>
public IReadOnlyCollection<(string outputColumnName, string[] inputColumnNames)> Columns
=> _columns.Select(col => (outputColumnName: col.Name, inputColumnNames: col.Sources.Select(source => source.name).ToArray())).ToArray().AsReadOnly();

/// <summary>
/// Concatename columns in <paramref name="inputColumnNames"/> into one column <paramref name="outputColumnName"/>.
/// Original columns are also preserved.
/// The column types must match, and the output column type is always a vector.
/// </summary>
public ColumnConcatenatingTransformer(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames)
internal ColumnConcatenatingTransformer(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames)
: this(env, new ColumnInfo(outputColumnName, inputColumnNames))
{
}

/// <summary>
/// Concatenates multiple groups of columns, each group is denoted by one of <paramref name="columns"/>.
/// </summary>
public ColumnConcatenatingTransformer(IHostEnvironment env, params ColumnInfo[] columns) :
internal ColumnConcatenatingTransformer(IHostEnvironment env, params ColumnInfo[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer)))
{
Contracts.CheckValue(columns, nameof(columns));
Expand Down Expand Up @@ -357,17 +369,17 @@ private ColumnInfo[] LoadLegacy(ModelLoadContext ctx)
///<summary>
/// Factory method for SignatureDataTransform.
/// </summary>
internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
env.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns));
env.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns));

for (int i = 0; i < args.Columns.Length; i++)
env.CheckUserArg(Utils.Size(args.Columns[i].Source) > 0, nameof(args.Columns));
for (int i = 0; i < options.Columns.Length; i++)
env.CheckUserArg(Utils.Size(options.Columns[i].Source) > 0, nameof(options.Columns));

var cols = args.Columns
var cols = options.Columns
.Select(c => new ColumnInfo(c.Name, c.Source))
.ToArray();
var transformer = new ColumnConcatenatingTransformer(env, cols);
Expand All @@ -377,17 +389,17 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
/// Factory method corresponding to SignatureDataTransform.
/// </summary>
[BestFriend]
internal static IDataTransform Create(IHostEnvironment env, TaggedArguments args, IDataView input)
internal static IDataTransform Create(IHostEnvironment env, TaggedOptions options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
env.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns));
env.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns));

for (int i = 0; i < args.Columns.Length; i++)
env.CheckUserArg(Utils.Size(args.Columns[i].Source) > 0, nameof(args.Columns));
for (int i = 0; i < options.Columns.Length; i++)
env.CheckUserArg(Utils.Size(options.Columns[i].Source) > 0, nameof(options.Columns));

var cols = args.Columns
var cols = options.Columns
.Select(c => new ColumnInfo(c.Name, c.Source.Select(kvp => (kvp.Value, kvp.Key != "" ? kvp.Key : null))))
.ToArray();
var transformer = new ColumnConcatenatingTransformer(env, cols);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog,
=> new ColumnCopyingEstimator(CatalogUtils.GetEnvironment(catalog), columns);

/// <summary>
/// Concatenates two columns together.
/// Concatenates columns together.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnNames"/>.</param>
Expand Down
20 changes: 13 additions & 7 deletions src/Microsoft.ML.Data/Transforms/Normalizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Transforms.Normalizers;
using Newtonsoft.Json.Linq;
using static Microsoft.ML.Transforms.Normalizers.NormalizeTransform;

[assembly: LoadableClass(typeof(NormalizingTransformer), null, typeof(SignatureLoadModel),
"", NormalizingTransformer.LoaderSignature)]
Expand Down Expand Up @@ -206,7 +205,7 @@ internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, C
/// <param name="inputColumnName">Name of the column to transform.
/// If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
/// <param name="mode">The <see cref="NormalizerMode"/> indicating how to the old values are mapped to the new values.</param>
public NormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, NormalizerMode mode = NormalizerMode.MinMax)
internal NormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, NormalizerMode mode = NormalizerMode.MinMax)
: this(env, mode, (outputColumnName, inputColumnName ?? outputColumnName))
{
}
Expand All @@ -217,7 +216,7 @@ public NormalizingEstimator(IHostEnvironment env, string outputColumnName, strin
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="mode">The <see cref="NormalizerMode"/> indicating how to the old values are mapped to the new values.</param>
/// <param name="columns">An array of (outputColumnName, inputColumnName) tuples.</param>
public NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (string outputColumnName, string inputColumnName)[] columns)
internal NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (string outputColumnName, string inputColumnName)[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(NormalizingEstimator));
Expand All @@ -230,7 +229,7 @@ public NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (s
/// </summary>
/// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param>
/// <param name="columns">An array of <see cref="ColumnBase"/> defining the inputs to the Normalizer, and their settings.</param>
public NormalizingEstimator(IHostEnvironment env, params ColumnBase[] columns)
internal NormalizingEstimator(IHostEnvironment env, params ColumnBase[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(NormalizingEstimator));
Expand All @@ -239,12 +238,19 @@ public NormalizingEstimator(IHostEnvironment env, params ColumnBase[] columns)
_columns = columns.ToArray();
}

/// <summary>
/// Trains and returns a <see cref="NormalizingTransformer"/>.
/// </summary>
public NormalizingTransformer Fit(IDataView input)
{
_host.CheckValue(input, nameof(input));
return NormalizingTransformer.Train(_host, input, _columns);
}

/// <summary>
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
/// Used for schema propagation and verification in a pipeline.
/// </summary>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
Expand Down Expand Up @@ -275,7 +281,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)

public sealed partial class NormalizingTransformer : OneToOneTransformerBase
{
public const string LoaderSignature = "Normalizer";
internal const string LoaderSignature = "Normalizer";

internal const string LoaderSignatureOld = "NormalizeFunction";

Expand Down Expand Up @@ -387,7 +393,7 @@ private NormalizingTransformer(IHostEnvironment env, ColumnInfo[] columns)
ColumnFunctions = new ColumnFunctionAccessor(Columns);
}

public static NormalizingTransformer Train(IHostEnvironment env, IDataView data, NormalizingEstimator.ColumnBase[] columns)
internal static NormalizingTransformer Train(IHostEnvironment env, IDataView data, NormalizingEstimator.ColumnBase[] columns)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(data, nameof(data));
Expand Down Expand Up @@ -510,7 +516,7 @@ private NormalizingTransformer(IHost host, ModelLoadContext ctx, IDataView input
Columns = ImmutableArray.Create(cols);
}

public static NormalizingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
private static NormalizingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.EntryPoints/FeatureCombiner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env
// as a group id. That's just one example - you get the idea.
string nameFeat = DefaultColumnNames.Features;
viewTrain = ColumnConcatenatingTransformer.Create(host,
new ColumnConcatenatingTransformer.TaggedArguments()
new ColumnConcatenatingTransformer.TaggedOptions()
{
Columns =
new[] { new ColumnConcatenatingTransformer.TaggedColumn() { Name = nameFeat, Source = concatNames.ToArray() } }
Expand Down
Loading

0 comments on commit 58ff9b5

Please sign in to comment.