From 26913ac9c057fdd491983b839a79e98cebd4be0e Mon Sep 17 00:00:00 2001 From: Pete Luferenko Date: Fri, 12 Oct 2018 18:50:18 -0700 Subject: [PATCH] ML Context and a couple extensions --- docs/samples/Microsoft.ML.Samples/Trainers.cs | 2 +- .../DataLoadSave/DataLoadSaveCatalog.cs | 20 +++ .../DataLoadSave/Text/TextLoader.cs | 19 ++ .../Text/TextLoaderSaverCatalog.cs | 87 +++++++++ .../Evaluators/EvaluatorStaticExtensions.cs | 2 +- .../Model/ModelOperationsCatalog.cs | 38 ++++ src/Microsoft.ML.Data/Training/MLContext.cs | 107 ++++++++++++ .../Training/TrainContext.cs | 55 ++---- .../Transforms/CatalogUtils.cs | 24 +++ .../Transforms/Normalizer.cs | 12 ++ .../Transforms/NormalizerCatalog.cs | 44 +++++ .../Transforms/TransformsCatalog.cs | 87 +++++++++ src/Microsoft.ML.FastTree/FastTreeCatalog.cs | 102 +++++++++++ src/Microsoft.ML.FastTree/FastTreeStatic.cs | 7 +- .../Microsoft.ML.FastTree.csproj | 1 + .../KMeansCatalog.cs | 36 ++++ .../KMeansStatic.cs | 3 +- src/Microsoft.ML.LightGBM/LightGbmCatalog.cs | 79 +++++++++ src/Microsoft.ML.LightGBM/LightGbmStatic.cs | 3 +- .../FactorizationMachineCatalog.cs | 36 ++++ .../FactorizationMachineStatic.cs | 3 +- .../LogisticRegression/LbfgsCatalog.cs | 122 +++++++++++++ .../LogisticRegression/LbfgsStatic.cs | 3 +- .../Standard/Online/OnlineLearnerCatalog.cs | 98 +++++++++++ .../Standard/Online/OnlineLearnerStatic.cs | 2 +- .../Standard/SdcaCatalog.cs | 110 ++++++++++++ .../Standard/SdcaStatic.cs | 5 +- .../Standard/SgdCatalog.cs | 46 +++++ ...arClassificationStatic.cs => SgdStatic.cs} | 5 +- .../CategoricalCatalog.cs | 61 +++++++ .../CategoricalHashTransform.cs | 11 ++ .../CategoricalTransform.cs | 12 +- .../Text/TextTransformCatalog.cs | 38 ++++ .../Training.cs | 3 +- .../Api/Estimators/CrossValidation.cs | 22 +-- .../Estimators/DecomposableTrainAndPredict.cs | 34 ++-- .../Scenarios/Api/Estimators/Evaluation.cs | 24 ++- .../Scenarios/Api/Estimators/Extensibility.cs | 50 +++--- .../Api/Estimators/FileBasedSavingOfData.cs | 26 +-- .../Api/Estimators/IntrospectiveTraining.cs | 23 ++- .../Api/Estimators/Metacomponents.cs | 22 ++- .../Api/Estimators/MultithreadedPrediction.cs | 40 ++--- .../Estimators/ReconfigurablePrediction.cs | 33 ++-- .../Api/Estimators/SimpleTrainAndPredict.cs | 45 +++-- .../Estimators/TrainSaveModelAndPredict.cs | 68 ++++---- .../Estimators/TrainWithInitialPredictor.cs | 30 ++-- .../Api/Estimators/TrainWithValidationSet.cs | 26 ++- .../Scenarios/Api/Estimators/Visibility.cs | 41 +---- .../Scenarios/Api/Estimators/Wrappers.cs | 165 ------------------ 49 files changed, 1426 insertions(+), 506 deletions(-) create mode 100644 src/Microsoft.ML.Data/DataLoadSave/DataLoadSaveCatalog.cs create mode 100644 src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs create mode 100644 src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs create mode 100644 src/Microsoft.ML.Data/Training/MLContext.cs create mode 100644 src/Microsoft.ML.Data/Transforms/CatalogUtils.cs create mode 100644 src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs create mode 100644 src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs create mode 100644 src/Microsoft.ML.FastTree/FastTreeCatalog.cs create mode 100644 src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs create mode 100644 src/Microsoft.ML.LightGBM/LightGbmCatalog.cs create mode 100644 src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs create mode 100644 src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsCatalog.cs create mode 100644 src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerCatalog.cs create mode 100644 src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs create mode 100644 src/Microsoft.ML.StandardLearners/Standard/SgdCatalog.cs rename src/Microsoft.ML.StandardLearners/Standard/{LinearClassificationStatic.cs => SgdStatic.cs} (98%) create mode 100644 src/Microsoft.ML.Transforms/CategoricalCatalog.cs create mode 100644 src/Microsoft.ML.Transforms/Text/TextTransformCatalog.cs diff --git a/docs/samples/Microsoft.ML.Samples/Trainers.cs b/docs/samples/Microsoft.ML.Samples/Trainers.cs index 9d1001278c..b89361cbd1 100644 --- a/docs/samples/Microsoft.ML.Samples/Trainers.cs +++ b/docs/samples/Microsoft.ML.Samples/Trainers.cs @@ -5,7 +5,7 @@ // the alignment of the usings with the methods is intentional so they can display on the same level in the docs site. using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; - using Microsoft.ML.Trainers; + using Microsoft.ML.StaticPipe; using System; // NOTE: WHEN ADDING TO THE FILE, ALWAYS APPEND TO THE END OF IT. diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataLoadSaveCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataLoadSaveCatalog.cs new file mode 100644 index 0000000000..1938cc80a5 --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/DataLoadSaveCatalog.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.ML.Runtime +{ + /// + /// A catalog of operations to load and save data. + /// + public sealed class DataLoadSaveOperations + { + internal IHostEnvironment Environment { get; } + + internal DataLoadSaveOperations(IHostEnvironment env) + { + Contracts.AssertValue(env); + Environment = env; + } + } +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index f5d8eab550..ce40bb59da 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -46,6 +46,11 @@ public Column() { } public Column(string name, DataKind? type, int index) : this(name, type, new[] { new Range(index) }) { } + public Column(string name, DataKind? type, int minIndex, int maxIndex) + : this(name, type, new[] { new Range(minIndex, maxIndex) }) + { + } + public Column(string name, DataKind? type, Range[] source, KeyRange keyRange = null) { Contracts.CheckValue(name, nameof(name)); @@ -998,6 +1003,18 @@ private bool HasHeader private readonly IHost _host; private const string RegistrationName = "TextLoader"; + public TextLoader(IHostEnvironment env, Column[] columns, Action advancedSettings, IMultiStreamSource dataSample = null) + : this(env, MakeArgs(columns, advancedSettings), dataSample) + { + } + + private static Arguments MakeArgs(Column[] columns, Action advancedSettings) + { + var result = new Arguments { Column = columns }; + advancedSettings?.Invoke(result); + return result; + } + public TextLoader(IHostEnvironment env, Arguments args, IMultiStreamSource dataSample = null) { Contracts.CheckValue(env, nameof(env)); @@ -1315,6 +1332,8 @@ public void Save(ModelSaveContext ctx) public IDataView Read(IMultiStreamSource source) => new BoundLoader(this, source); + public IDataView Read(string path) => Read(new MultiFileSource(path)); + private sealed class BoundLoader : IDataLoader { private readonly TextLoader _reader; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs new file mode 100644 index 0000000000..eb7753883a --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs @@ -0,0 +1,87 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Internal.Utilities; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; + +namespace Microsoft.ML +{ + public static class TextLoaderSaverCatalog + { + /// + /// Create a text reader. + /// + /// The catalog. + /// The arguments to text reader, describing the data schema. + /// The optional data sample + /// + public static TextLoader TextReader(this DataLoadSaveOperations catalog, + TextLoader.Arguments args, IMultiStreamSource dataSample = null) + => new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample); + + /// + /// Create a text reader. + /// + /// The catalog. + /// The columns of the schema. + /// The delegate to set additional settings + /// The optional data sample + /// + public static TextLoader TextReader(this DataLoadSaveOperations catalog, + TextLoader.Column[] columns, Action advancedSettings = null, IMultiStreamSource dataSample = null) + => new TextLoader(CatalogUtils.GetEnvironment(catalog), columns, advancedSettings, dataSample); + + /// + /// Read a data view from a text file using . + /// + /// The catalog. + /// The columns of the schema. + /// The delegate to set additional settings + /// The path to the file + /// The data view. + public static IDataView ReadFromTextFile(this DataLoadSaveOperations catalog, + TextLoader.Column[] columns, string path, Action advancedSettings = null) + { + Contracts.CheckNonEmpty(path, nameof(path)); + + var env = catalog.GetEnvironment(); + + // REVIEW: it is almost always a mistake to have a 'trainable' text loader here. + // Therefore, we are going to disallow data sample. + var reader = new TextLoader(env, columns, advancedSettings, dataSample: null); + return reader.Read(new MultiFileSource(path)); + } + + /// + /// Save the data view as text. + /// + /// The catalog. + /// The data view to save. + /// The stream to write to. + /// The column separator. + /// Whether to write the header row. + /// Whether to write the header comment with the schema. + /// Whether to keep hidden columns in the dataset. + public static void SaveAsText(this DataLoadSaveOperations catalog, IDataView data, Stream stream, + char separator = '\t', bool headerRow = true, bool schema = true, bool keepHidden = false) + { + Contracts.CheckValue(catalog, nameof(catalog)); + Contracts.CheckValue(data, nameof(data)); + Contracts.CheckValue(stream, nameof(stream)); + + var env = catalog.GetEnvironment(); + var saver = new TextSaver(env, new TextSaver.Arguments { Separator = separator.ToString(), OutputHeader = headerRow, OutputSchema = schema }); + + using (var ch = env.Start("Saving data")) + DataSaverUtils.SaveDataView(ch, saver, data, stream, keepHidden); + } + } +} diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs index a8a4b34780..df58e4d46c 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs @@ -213,7 +213,7 @@ public static RegressionEvaluator.Result Evaluate( /// The index delegate for predicted score column. /// The evaluation metrics. public static RankerEvaluator.Result Evaluate( - this RankerContext ctx, + this RankingContext ctx, DataView data, Func> label, Func> groupId, diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs new file mode 100644 index 0000000000..fb1a8d39d5 --- /dev/null +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Data; +using System.IO; + +namespace Microsoft.ML.Runtime +{ + /// + /// An object serving as a 'catalog' of available model operations. + /// + public sealed class ModelOperationsCatalog + { + internal IHostEnvironment Environment { get; } + + internal ModelOperationsCatalog(IHostEnvironment env) + { + Contracts.AssertValue(env); + Environment = env; + } + + /// + /// Save the model to the stream. + /// + /// The trained model to be saved. + /// A writeable, seekable stream to save to. + public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream); + + /// + /// Load the model from the stream. + /// + /// A readable, seekable stream to load from. + /// The loaded model. + public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream); + } +} diff --git a/src/Microsoft.ML.Data/Training/MLContext.cs b/src/Microsoft.ML.Data/Training/MLContext.cs new file mode 100644 index 0000000000..4fed27b12c --- /dev/null +++ b/src/Microsoft.ML.Data/Training/MLContext.cs @@ -0,0 +1,107 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using System; + +namespace Microsoft.ML +{ + /// + /// The is a starting point for all ML.NET operations. It is instantiated by user, + /// provides mechanisms for logging and entry points for training, prediction, model operations etc. + /// + public sealed class MLContext : IHostEnvironment + { + private readonly LocalEnvironment _env; + + /// + /// Trainers and tasks specific to binary classification problems. + /// + public BinaryClassificationContext BinaryClassification { get; } + /// + /// Trainers and tasks specific to multiclass classification problems. + /// + public MulticlassClassificationContext MulticlassClassification { get; } + /// + /// Trainers and tasks specific to regression problems. + /// + public RegressionContext Regression { get; } + /// + /// Trainers and tasks specific to clustering problems. + /// + public ClusteringContext Clustering { get; } + /// + /// Trainers and tasks specific to ranking problems. + /// + public RankingContext Ranking { get; } + + /// + /// Data processing operations. + /// + public TransformsCatalog Transform { get; } + + /// + /// Operations with trained models. + /// + public ModelOperationsCatalog Model { get; } + + /// + /// Data loading and saving. + /// + public DataLoadSaveOperations Data { get; } + + // REVIEW: I think it's valuable to have the simplest possible interface for logging interception here, + // and expand if and when necessary. Exposing classes like ChannelMessage, MessageSensitivity and so on + // looks premature at this point. + /// + /// The handler for the log messages. + /// + public Action Log { get; set; } + + /// + /// Create the ML context. + /// + /// Random seed. Set to null for a non-deterministic environment. + /// Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically. + public MLContext(int? seed = null, int conc = 0) + { + _env = new LocalEnvironment(seed, conc); + _env.AddListener(ProcessMessage); + + BinaryClassification = new BinaryClassificationContext(_env); + MulticlassClassification = new MulticlassClassificationContext(_env); + Regression = new RegressionContext(_env); + Clustering = new ClusteringContext(_env); + Ranking = new RankingContext(_env); + Transform = new TransformsCatalog(_env); + Model = new ModelOperationsCatalog(_env); + Data = new DataLoadSaveOperations(_env); + } + + private void ProcessMessage(IMessageSource source, ChannelMessage message) + { + if (Log == null) + return; + + var msg = $"[Source={source.FullName}, Kind={message.Kind}] {message.Message}"; + // Log may have been reset from another thread. + // We don't care which logger we send the message to, just making sure we don't crash. + Log?.Invoke(msg); + } + + int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor; + bool IHostEnvironment.IsCancelled => _env.IsCancelled; + ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog; + string IExceptionContext.ContextDescription => _env.ContextDescription; + IFileHandle IHostEnvironment.CreateOutputFile(string path) => _env.CreateOutputFile(path); + IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix); + IFileHandle IHostEnvironment.OpenInputFile(string path) => _env.OpenInputFile(path); + TException IExceptionContext.Process(TException ex) => _env.Process(ex); + IHost IHostEnvironment.Register(string name, int? seed, bool? verbose, int? conc) => _env.Register(name, seed, verbose, conc); + IChannel IChannelProvider.Start(string name) => _env.Start(name); + IPipe IChannelProvider.StartPipe(string name) => _env.StartPipe(name); + IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name); + } +} diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index 2e44c1accc..2b385b0393 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -156,7 +156,7 @@ private void EnsureStratificationColumn(ref IDataView data, ref string stratific /// Subclasses of will provide little "extension method" hookable objects /// (for example, something like ). User code will only /// interact with these objects by invoking the extension methods. The actual component code can work - /// through to get more "hidden" information from this object, + /// through to get more "hidden" information from this object, /// for example, the environment. /// public abstract class ContextInstantiatorBase @@ -170,37 +170,6 @@ protected ContextInstantiatorBase(TrainContextBase ctx) } } - /// - /// Utilities for component authors that want to be able to instantiate components using these context - /// objects. These utilities are not hidden from non-component authoring users per see, but are at least - /// registered somewhat less obvious so that they are not confused by the presence. - /// - /// - public static class TrainContextComponentUtils - { - /// - /// Gets the environment hidden within the instantiator's context. - /// - /// The extension method hook object for a context. - /// An environment that can be used when instantiating components. - public static IHostEnvironment GetEnvironment(TrainContextBase.ContextInstantiatorBase obj) - { - Contracts.CheckValue(obj, nameof(obj)); - return obj.Owner.Environment; - } - - /// - /// Gets the environment hidden within the context. - /// - /// The context. - /// An environment that can be used when instantiating components. - public static IHostEnvironment GetEnvironment(TrainContextBase ctx) - { - Contracts.CheckValue(ctx, nameof(ctx)); - return ctx.Environment; - } - } - /// /// The central context for binary classification trainers. /// @@ -234,7 +203,7 @@ internal BinaryClassificationTrainers(BinaryClassificationContext ctx) /// The name of the probability column in , the calibrated version of . /// The name of the predicted label column in . /// The evaluation results for these calibrated outputs. - public BinaryClassifierEvaluator.CalibratedResult Evaluate(IDataView data, string label, string score = DefaultColumnNames.Score, + public BinaryClassifierEvaluator.CalibratedResult Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, string probability = DefaultColumnNames.Probability, string predictedLabel = DefaultColumnNames.PredictedLabel) { Host.CheckValue(data, nameof(data)); @@ -255,7 +224,7 @@ public BinaryClassifierEvaluator.CalibratedResult Evaluate(IDataView data, strin /// The name of the score column in . /// The name of the predicted label column in . /// The evaluation results for these uncalibrated outputs. - public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, string label, string score = DefaultColumnNames.Score, + public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, string predictedLabel = DefaultColumnNames.PredictedLabel) { Host.CheckValue(data, nameof(data)); @@ -403,7 +372,7 @@ internal MulticlassClassificationTrainers(MulticlassClassificationContext ctx) /// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within /// the top-K values as being stored "correctly." /// The evaluation results for these calibrated outputs. - public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string label, string score = DefaultColumnNames.Score, + public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, string predictedLabel = DefaultColumnNames.PredictedLabel, int topK = 0) { Host.CheckValue(data, nameof(data)); @@ -472,7 +441,7 @@ internal RegressionTrainers(RegressionContext ctx) /// The name of the label column in . /// The name of the score column in . /// The evaluation results for these calibrated outputs. - public RegressionEvaluator.Result Evaluate(IDataView data, string label, string score = DefaultColumnNames.Score) + public RegressionEvaluator.Result Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score) { Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(label, nameof(label)); @@ -508,22 +477,22 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string /// /// The central context for regression trainers. /// - public sealed class RankerContext : TrainContextBase + public sealed class RankingContext : TrainContextBase { /// /// For trainers for performing regression. /// - public RankerTrainers Trainers { get; } + public RankingTrainers Trainers { get; } - public RankerContext(IHostEnvironment env) - : base(env, nameof(RankerContext)) + public RankingContext(IHostEnvironment env) + : base(env, nameof(RankingContext)) { - Trainers = new RankerTrainers(this); + Trainers = new RankingTrainers(this); } - public sealed class RankerTrainers : ContextInstantiatorBase + public sealed class RankingTrainers : ContextInstantiatorBase { - internal RankerTrainers(RankerContext ctx) + internal RankingTrainers(RankingContext ctx) : base(ctx) { } diff --git a/src/Microsoft.ML.Data/Transforms/CatalogUtils.cs b/src/Microsoft.ML.Data/Transforms/CatalogUtils.cs new file mode 100644 index 0000000000..98d3a3afc0 --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/CatalogUtils.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// Set of extension methods to extract from various catalog classes. + /// + public static class CatalogUtils + { + public static IHostEnvironment GetEnvironment(this TransformsCatalog catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment; + public static IHostEnvironment GetEnvironment(this TransformsCatalog.SubCatalogBase subCatalog) => Contracts.CheckRef(subCatalog, nameof(subCatalog)).Environment; + public static IHostEnvironment GetEnvironment(this ModelOperationsCatalog catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment; + public static IHostEnvironment GetEnvironment(this DataLoadSaveOperations catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment; + public static IHostEnvironment GetEnvironment(TrainContextBase.ContextInstantiatorBase obj) => Contracts.CheckRef(obj, nameof(obj)).Owner.Environment; + public static IHostEnvironment GetEnvironment(TrainContextBase ctx) => Contracts.CheckRef(ctx, nameof(ctx)).Environment; + + } +} diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 4dfd33c113..e148e4ec8e 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -37,9 +37,21 @@ internal static class Defaults public enum NormalizerMode { + /// + /// Linear rescale such that minimum and maximum values are mapped between -1 and 1. + /// MinMax = 0, + /// + /// Rescale to unit variance and, optionally, zero mean. + /// MeanVariance = 1, + /// + /// Rescale to unit variance on the log scale. + /// LogMeanVariance = 2, + /// + /// Bucketize and then rescale to between -1 and 1. + /// Binning = 3 } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs b/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs new file mode 100644 index 0000000000..c98e94c232 --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.ML.Data.Transforms +{ + /// + /// Extensions for normalizer operations. + /// + public static class NormalizerCatalog + { + /// + /// Normalize (rescale) the column according to the specified . + /// + /// The transform catalog + /// The column name + /// The normalization mode + public static Normalizer Normalize(this TransformsCatalog catalog, string columnName, Normalizer.NormalizerMode mode = Runtime.Data.Normalizer.NormalizerMode.MinMax) + => new Normalizer(CatalogUtils.GetEnvironment(catalog), columnName, mode); + + /// + /// Normalize (rescale) several columns according to the specified . + /// + /// The transform catalog + /// The normalization mode + /// The pairs of input and output columns. + public static Normalizer Normalize(this TransformsCatalog catalog, Normalizer.NormalizerMode mode, params (string input, string output)[] columns) + => new Normalizer(CatalogUtils.GetEnvironment(catalog), mode, columns); + + /// + /// Normalize (rescale) columns according to specified custom parameters. + /// + /// The transform catalog + /// The normalization settings for all the columns + public static Normalizer Normalize(this TransformsCatalog catalog, params Normalizer.ColumnBase[] columns) + => new Normalizer(CatalogUtils.GetEnvironment(catalog), columns); + } +} diff --git a/src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs b/src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs new file mode 100644 index 0000000000..5038591b75 --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs @@ -0,0 +1,87 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.ML.Runtime +{ + /// + /// Similar to training context, a transform context is an object serving as a 'catalog' of available transforms. + /// Individual transforms are exposed as extension methods of this class or its subclasses. + /// + public sealed class TransformsCatalog + { + internal IHostEnvironment Environment { get; } + + public CategoricalTransforms Categorical { get; } + public Conversions Conversion { get; } + public TextTransforms Text { get; } + public ProjectionTransforms Projections { get; } + + internal TransformsCatalog(IHostEnvironment env) + { + Contracts.AssertValue(env); + Environment = env; + + Categorical = new CategoricalTransforms(this); + Conversion = new Conversions(this); + Text = new TextTransforms(this); + Projections = new ProjectionTransforms(this); + } + + public abstract class SubCatalogBase + { + internal IHostEnvironment Environment { get; } + + protected SubCatalogBase(TransformsCatalog owner) + { + Environment = owner.Environment; + } + + } + + /// + /// The catalog of operations over categorical data. + /// + public sealed class CategoricalTransforms : SubCatalogBase + { + internal CategoricalTransforms(TransformsCatalog owner) : base(owner) + { + } + } + + /// + /// The catalog of rescaling operations. + /// + public sealed class Conversions : SubCatalogBase + { + public Conversions(TransformsCatalog owner) : base(owner) + { + } + } + + /// + /// The catalog of text processing operations. + /// + public sealed class TextTransforms : SubCatalogBase + { + public TextTransforms(TransformsCatalog owner) : base(owner) + { + } + } + + /// + /// The catalog of projection operations. + /// + public sealed class ProjectionTransforms : SubCatalogBase + { + public ProjectionTransforms(TransformsCatalog owner) : base(owner) + { + } + } + } +} diff --git a/src/Microsoft.ML.FastTree/FastTreeCatalog.cs b/src/Microsoft.ML.FastTree/FastTreeCatalog.cs new file mode 100644 index 0000000000..730b6816ae --- /dev/null +++ b/src/Microsoft.ML.FastTree/FastTreeCatalog.cs @@ -0,0 +1,102 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.StaticPipe.Runtime; +using System; + +namespace Microsoft.ML +{ + /// + /// FastTree extension methods. + /// + public static partial class RegressionTrainers + { + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features colum. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeRegressionTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + } + + public static partial class BinaryClassificationTrainers + { + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features colum. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeBinaryClassificationTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + } + + public static partial class RankingTrainers + { + + /// + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . + /// + /// The . + /// The label column. + /// The features colum. + /// The groupId column. + /// The optional weights column. + /// Algorithm advanced settings. + public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, + string label = DefaultColumnNames.Label, + string groupId = DefaultColumnNames.GroupId, + string features = DefaultColumnNames.Features, + string weights = null, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings); + } + } +} diff --git a/src/Microsoft.ML.FastTree/FastTreeStatic.cs b/src/Microsoft.ML.FastTree/FastTreeStatic.cs index b2218fc3f3..334f96d6f4 100644 --- a/src/Microsoft.ML.FastTree/FastTreeStatic.cs +++ b/src/Microsoft.ML.FastTree/FastTreeStatic.cs @@ -6,11 +6,10 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// FastTree extension methods. @@ -115,7 +114,7 @@ public static partial class RankingTrainers { /// - /// FastTree . + /// FastTree . /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . /// /// The . @@ -134,7 +133,7 @@ public static partial class RankingTrainers /// the linear model that was trained. Note that this action cannot change the result in any way; /// it is only a way for the caller to be informed about what was learnt. /// The Score output column indicating the predicted value. - public static Scalar FastTree(this RankerContext.RankerTrainers ctx, + public static Scalar FastTree(this RankingContext.RankingTrainers ctx, Scalar label, Vector features, Key groupId, Scalar weights = null, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, diff --git a/src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj b/src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj index ceca51632e..a92b3c26d7 100644 --- a/src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj +++ b/src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj @@ -26,6 +26,7 @@ + diff --git a/src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs b/src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs new file mode 100644 index 0000000000..b7c95227c9 --- /dev/null +++ b/src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.KMeans; +using System; + +namespace Microsoft.ML +{ + /// + /// The trainer context extensions for the . + /// + public static class ClusteringTrainers + { + /// + /// Train a KMeans++ clustering algorithm. + /// + /// The regression context trainer object. + /// The features, or independent variables. + /// The optional example weights. + /// The number of clusters to use for KMeans. + /// Algorithm advanced settings. + public static KMeansPlusPlusTrainer KMeans(this ClusteringContext.ClusteringTrainers ctx, + string features = DefaultColumnNames.Features, + string weights = null, + int clustersCount = KMeansPlusPlusTrainer.Defaults.K, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new KMeansPlusPlusTrainer(env, features, clustersCount, weights, advancedSettings); + } + } +} diff --git a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs index b9f77aabb3..7ce2783efd 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs @@ -5,11 +5,10 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.KMeans; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// The trainer context extensions for the . diff --git a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs new file mode 100644 index 0000000000..aced461f9d --- /dev/null +++ b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.LightGBM; +using System; + +namespace Microsoft.ML +{ + /// + /// Regression trainer estimators. + /// + public static partial class RegressionTrainers + { + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features colum. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static LightGbmRegressorTrainer LightGbm(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LightGbmRegressorTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + } + } + + /// + /// Binary Classification trainer estimators. + /// + public static partial class ClassificationTrainers + { + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features colum. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LightGbmBinaryTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + } + } +} diff --git a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs index a0fbf2410c..c4ea389eff 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs @@ -6,11 +6,10 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.LightGBM; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// Regression trainer estimators. diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs new file mode 100644 index 0000000000..036341e97f --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FactorizationMachine; +using System; + +namespace Microsoft.ML +{ + /// + /// Extension method to create + /// + public static partial class BinaryClassificationTrainers + { + /// + /// Predict a target using a field-aware factorization machine algorithm. + /// + /// The binary classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// A delegate to set more settings. + /// + public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label, string[] features, + string weights = null, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FieldAwareFactorizationMachineTrainer(env, label, features, weights, advancedSettings: advancedSettings); + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs index e6ce5038cf..751a0e70f0 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs @@ -7,13 +7,12 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.FactorizationMachine; using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; using System.Collections.Generic; using System.Linq; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// Extension methods and utilities for instantiating FFM trainer estimators inside statically typed pipelines. diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsCatalog.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsCatalog.cs new file mode 100644 index 0000000000..42bb07fc92 --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsCatalog.cs @@ -0,0 +1,122 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Runtime.Learners; +using System; + +namespace Microsoft.ML +{ + using Arguments = LogisticRegression.Arguments; + + /// + /// Binary Classification trainer estimators. + /// + public static partial class BinaryClassificationTrainers + { + /// + /// Predict a target using a linear binary classification model trained with the trainer. + /// + /// The binary classificaiton context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Enforce non-negative weights. + /// Weight of L1 regularization term. + /// Weight of L2 regularization term. + /// Memory size for . Lower=faster, less accurate. + /// Threshold for optimizer convergence. + /// A delegate to apply all the advanced arguments to the algorithm. + public static LogisticRegression LogisticRegression(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LogisticRegression(env, features, label, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity, advancedSettings); + } + } + + /// + /// Regression trainer estimators. + /// + public static partial class RegressionTrainers + { + + /// + /// Predict a target using a linear regression model trained with the trainer. + /// + /// The regression context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Enforce non-negative weights. + /// Weight of L1 regularization term. + /// Weight of L2 regularization term. + /// Memory size for . Lower=faster, less accurate. + /// Threshold for optimizer convergence. + /// A delegate to apply all the advanced arguments to the algorithm. + public static PoissonRegression PoissonRegression(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new PoissonRegression(env, features, label, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity, advancedSettings); + } + } + + /// + /// MultiClass Classification trainer estimators. + /// + public static partial class MultiClassClassificationTrainers + { + + /// + /// Predict a target using a linear multiclass classification model trained with the trainer. + /// + /// The multiclass classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Enforce non-negative weights. + /// Weight of L1 regularization term. + /// Weight of L2 regularization term. + /// Memory size for . Lower=faster, less accurate. + /// Threshold for optimizer convergence. + /// A delegate to apply all the advanced arguments to the algorithm. + public static MulticlassLogisticRegression LogisticRegression(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + float l1Weight = Arguments.Defaults.L1Weight, + float l2Weight = Arguments.Defaults.L2Weight, + float optimizationTolerance = Arguments.Defaults.OptTol, + int memorySize = Arguments.Defaults.MemorySize, + bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new MulticlassLogisticRegression(env, features, label, weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity, advancedSettings); + } + + } +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs index 360fbca1d8..69f23a6961 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsStatic.cs @@ -6,11 +6,10 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { using Arguments = LogisticRegression.Arguments; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerCatalog.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerCatalog.cs new file mode 100644 index 0000000000..feca9a295a --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerCatalog.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using System; + +namespace Microsoft.ML +{ + /// + /// Binary Classification trainer estimators. + /// + public static partial class BinaryClassificationTrainers + { + /// + /// Predict a target using a linear binary classification model trained with the AveragedPerceptron trainer, and a custom loss. + /// + /// The binary classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The custom loss. + /// The optional example weights. + /// The learning Rate. + /// Decrease learning rate as iterations progress. + /// L2 regularization weight. + /// Number of training iterations through the data. + /// A delegate to supply more advanced arguments to the algorithm. + public static AveragedPerceptronTrainer AveragedPerceptron( + this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + IClassificationLoss lossFunction = null, + float learningRate = AveragedLinearArguments.AveragedDefaultArgs.LearningRate, + bool decreaseLearningRate = AveragedLinearArguments.AveragedDefaultArgs.DecreaseLearningRate, + float l2RegularizerWeight = AveragedLinearArguments.AveragedDefaultArgs.L2RegularizerWeight, + int numIterations = AveragedLinearArguments.AveragedDefaultArgs.NumIterations, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + var loss = new TrivialClassificationLossFactory(lossFunction ?? new LogLoss()); + return new AveragedPerceptronTrainer(env, label, features, weights, loss, learningRate, decreaseLearningRate, l2RegularizerWeight, numIterations, advancedSettings); + } + + private sealed class TrivialClassificationLossFactory : ISupportClassificationLossFactory + { + private readonly IClassificationLoss _loss; + + public TrivialClassificationLossFactory(IClassificationLoss loss) + { + _loss = loss; + } + + public IClassificationLoss CreateComponent(IHostEnvironment env) + { + return _loss; + } + } + } + + /// + /// Regression trainer estimators. + /// + public static partial class RegressionTrainers + { + /// + /// Predict a target using a linear regression model trained with the trainer. + /// + /// The regression context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// The custom loss. Defaults to if not provided. + /// The learning Rate. + /// Decrease learning rate as iterations progress. + /// L2 regularization weight. + /// Number of training iterations through the data. + /// A delegate to supply more advanced arguments to the algorithm. + public static OnlineGradientDescentTrainer OnlineGradientDescent(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + IRegressionLoss lossFunction = null, + float learningRate = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.LearningRate, + bool decreaseLearningRate = OnlineGradientDescentTrainer.Arguments.OgdDefaultArgs.DecreaseLearningRate, + float l2RegularizerWeight = AveragedLinearArguments.AveragedDefaultArgs.L2RegularizerWeight, + int numIterations = OnlineLinearArguments.OnlineDefaultArgs.NumIterations, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new OnlineGradientDescentTrainer(env, label, features, learningRate, decreaseLearningRate, l2RegularizerWeight, numIterations, weights, lossFunction, advancedSettings); + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs index ba1e8c88f6..9a2ca2d537 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLearnerStatic.cs @@ -9,7 +9,7 @@ using Microsoft.ML.StaticPipe.Runtime; using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// Binary Classification trainer estimators. diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs new file mode 100644 index 0000000000..2a05edb685 --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using System; + +namespace Microsoft.ML +{ + /// + /// Extension methods for instantiating SDCA trainer estimators. + /// + public static partial class RegressionTrainers + { + /// + /// Predict a target using a linear regression model trained with the SDCA trainer. + /// + /// The regression context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// The L2 regularization hyperparameter. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The maximum number of passes to perform over the data. + /// The custom loss, if unspecified will be . + /// A delegate to set more settings. + /// + /// + /// + /// + public static SdcaRegressionTrainer StochasticDualCoordinateAscent(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, + ISupportSdcaRegressionLoss loss = null, + float? l2Const = null, + float? l1Threshold = null, + int? maxIterations = null, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new SdcaRegressionTrainer(env, features, label, weights, loss, l2Const, l1Threshold, maxIterations, advancedSettings); + } + } + + public static partial class BinaryClassificationTrainers + { + /// + /// Predict a target using a linear binary classification model trained with the SDCA trainer. + /// + /// The binary classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// The custom loss. Defaults to log-loss if not specified. + /// The L2 regularization hyperparameter. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The maximum number of passes to perform over the data. + /// A delegate to set more settings. + public static LinearClassificationTrainer StochasticDualCoordinateAscent( + this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, + string weights = null, + ISupportSdcaClassificationLoss loss = null, + float? l2Const = null, + float? l1Threshold = null, + int? maxIterations = null, + Action advancedSettings = null + ) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LinearClassificationTrainer(env, features, label, weights, loss, l2Const, l1Threshold, maxIterations, advancedSettings); + } + } + + public static partial class MultiClassClassificationTrainers + { + + /// + /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. + /// + /// The multiclass classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional custom loss. + /// The optional example weights. + /// The L2 regularization hyperparameter. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The maximum number of passes to perform over the data. + /// A delegate to set more settings. + public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + ISupportSdcaClassificationLoss loss = null, + float? l2Const = null, + float? l1Threshold = null, + int? maxIterations = null, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new SdcaMultiClassTrainer(env, features, label, weights, loss, l2Const, l1Threshold, maxIterations, advancedSettings); + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index ac0616d19f..6ddd004255 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -2,15 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; +using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// Extension methods and utilities for instantiating SDCA trainer estimators inside statically typed pipelines. diff --git a/src/Microsoft.ML.StandardLearners/Standard/SgdCatalog.cs b/src/Microsoft.ML.StandardLearners/Standard/SgdCatalog.cs new file mode 100644 index 0000000000..8954110bad --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/Standard/SgdCatalog.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using System; + +namespace Microsoft.ML +{ + using Arguments = StochasticGradientDescentClassificationTrainer.Arguments; + + /// + /// Binary Classification trainer estimators. + /// + public static partial class BinaryClassificationTrainers + { + /// + /// Predict a target using a linear binary classification model trained with the trainer. + /// + /// The binary classificaiton context trainer object. + /// The name of the label column. + /// The name of the feature column. + /// The name for the example weight column. + /// The maximum number of iterations; set to 1 to simulate online learning. + /// The initial learning rate used by SGD. + /// The L2 regularization constant. + /// The loss function to use. + /// A delegate to apply all the advanced arguments to the algorithm. + public static StochasticGradientDescentClassificationTrainer StochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int maxIterations = Arguments.Defaults.MaxIterations, + double initLearningRate = Arguments.Defaults.InitLearningRate, + float l2Weight = Arguments.Defaults.L2Weight, + ISupportClassificationLossFactory loss = null, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new StochasticGradientDescentClassificationTrainer(env, features, label, weights, maxIterations, initLearningRate, l2Weight, loss, advancedSettings); + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SgdStatic.cs similarity index 98% rename from src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatic.cs rename to src/Microsoft.ML.StandardLearners/Standard/SgdStatic.cs index cf6bb9bb13..aec2ee0abf 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SgdStatic.cs @@ -2,15 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; +using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { using Arguments = StochasticGradientDescentClassificationTrainer.Arguments; diff --git a/src/Microsoft.ML.Transforms/CategoricalCatalog.cs b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs new file mode 100644 index 0000000000..bfca9bac66 --- /dev/null +++ b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; + +namespace Microsoft.ML +{ + /// + /// Static extensions for categorical transforms. + /// + public static class CategoricalCatalog + { + /// + /// Convert a text column into one-hot encoded vector. + /// + /// The transform catalog + /// The input column + /// The output column. If null, is used. + /// The conversion mode. + /// + public static CategoricalEstimator OneHotEncoding(this TransformsCatalog.CategoricalTransforms catalog, + string inputColumn, string outputColumn = null, CategoricalTransform.OutputKind outputKind = CategoricalTransform.OutputKind.Ind) + => new CategoricalEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, outputKind); + + /// + /// Convert several text column into one-hot encoded vectors. + /// + /// The transform catalog + /// The column settings. + /// + public static CategoricalEstimator OneHotEncoding(this TransformsCatalog.CategoricalTransforms catalog, + params CategoricalEstimator.ColumnInfo[] columns) + => new CategoricalEstimator(CatalogUtils.GetEnvironment(catalog), columns); + + /// + /// Convert a text column into hash-based one-hot encoded vector. + /// + /// The transform catalog + /// The input column + /// The output column. If null, is used. + /// The conversion mode. + /// + public static CategoricalHashEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog, + string inputColumn, string outputColumn = null, CategoricalTransform.OutputKind outputKind = CategoricalTransform.OutputKind.Ind) + => new CategoricalHashEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, outputKind); + + /// + /// Convert several text column into hash-based one-hot encoded vectors. + /// + /// The transform catalog + /// The column settings. + /// + public static CategoricalHashEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog, + params CategoricalHashEstimator.ColumnInfo[] columns) + => new CategoricalHashEstimator(CatalogUtils.GetEnvironment(catalog), columns); + } +} diff --git a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs index b36fd8e6fc..ea90899d09 100644 --- a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs @@ -208,6 +208,17 @@ public sealed class ColumnInfo { public readonly HashTransformer.ColumnInfo HashInfo; public readonly CategoricalTransform.OutputKind OutputKind; + + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of input column. + /// Name of output column. + /// Kind of output: bag, indicator vector etc. + /// Number of bits to hash into. Must be between 1 and 31, inclusive. + /// Hashing seed. + /// Whether the position of each term should be included in the hash. + /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. public ColumnInfo(string input, string output, CategoricalTransform.OutputKind outputKind = Defaults.OutputKind, int hashBits = Defaults.HashBits, diff --git a/src/Microsoft.ML.Transforms/CategoricalTransform.cs b/src/Microsoft.ML.Transforms/CategoricalTransform.cs index f0646103ee..4aaeab41b1 100644 --- a/src/Microsoft.ML.Transforms/CategoricalTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalTransform.cs @@ -128,7 +128,7 @@ public Arguments() public static IDataView Create(IHostEnvironment env, IDataView input, string name, string source = null, OutputKind outputKind = CategoricalEstimator.Defaults.OutKind) { - return new CategoricalEstimator(env, name, source, outputKind).Fit(input).Transform(input) as IDataView; + return new CategoricalEstimator(env, source, name, outputKind).Fit(input).Transform(input) as IDataView; } public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) @@ -207,12 +207,12 @@ internal void SetTerms(string terms) /// A helper method to create for public facing API. /// Host Environment. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. + /// Name of the column to be transformed. + /// Name of the output column. If this is null, is used. /// The type of output expected. - public CategoricalEstimator(IHostEnvironment env, string name, - string source = null, CategoricalTransform.OutputKind outputKind = Defaults.OutKind) - : this(env, new ColumnInfo(source ?? name, name, outputKind)) + public CategoricalEstimator(IHostEnvironment env, string input, + string output = null, CategoricalTransform.OutputKind outputKind = Defaults.OutKind) + : this(env, new ColumnInfo(input ?? output, output, outputKind)) { } diff --git a/src/Microsoft.ML.Transforms/Text/TextTransformCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextTransformCatalog.cs new file mode 100644 index 0000000000..11557c82c3 --- /dev/null +++ b/src/Microsoft.ML.Transforms/Text/TextTransformCatalog.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; + +namespace Microsoft.ML +{ + public static class TextTransformCatalog + { + /// + /// Transform a text column into featurized float array that represents counts of ngrams and char-grams. + /// + /// The transform catalog + /// The input column + /// The output column + /// Advanced transform settings + public static TextTransform FeaturizeText(this TransformsCatalog.TextTransforms catalog, + string inputColumn, string outputColumn = null, + Action advancedSettings = null) + => new TextTransform(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), inputColumn, outputColumn, advancedSettings); + + /// + /// Transform several text columns into featurized float array that represents counts of ngrams and char-grams. + /// + /// The transform catalog + /// The input columns + /// The output column + /// Advanced transform settings + public static TextTransform FeaturizeText(this TransformsCatalog.TextTransforms catalog, + IEnumerable inputColumns, string outputColumn, + Action advancedSettings = null) + => new TextTransform(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), inputColumns, outputColumn, advancedSettings); + } +} diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index b148622611..be284fa437 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.StaticPipe; using Microsoft.ML.Trainers; using System; using System.Linq; @@ -719,7 +720,7 @@ public void FastTreeRanking() var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename); var dataSource = new MultiFileSource(dataPath); - var ctx = new RankerContext(env); + var ctx = new RankingContext(env); var reader = TextLoader.CreateReader(env, c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs index 6d9fd5cb81..cd621e9d37 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.RunTests; using Xunit; @@ -22,23 +21,14 @@ public partial class ApiScenariosTests [Fact] void New_CrossValidation() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { + var ml = new MLContext(seed: 1, conc: 1); - var data = new TextLoader(env, MakeSentimentTextLoaderArgs()) - .Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - // Pipeline. - var pipeline = new TextTransform(env, "SentimentText", "Features") - .Append(new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => { s.ConvergenceTolerance = 1f; s.NumThreads = 1; })); + var data = ml.Data.TextReader(MakeSentimentTextLoaderArgs()).Read(GetDataPath(TestDatasets.Sentiment.trainFilename)); + // Pipeline. + var pipeline = ml.Transform.Text.FeaturizeText("SentimentText", "Features") + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: (s) => { s.ConvergenceTolerance = 1f; s.NumThreads = 1; })); - - var cv = new MyCrossValidation.BinaryCrossValidator(env) - { - NumFolds = 2 - }; - - var cvResult = cv.CrossValidate(data, pipeline); - } + var cvResult = ml.BinaryClassification.CrossValidate(data, pipeline); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs index 1a11bc6478..4a1d8e126e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs @@ -26,27 +26,25 @@ public partial class ApiScenariosTests void New_DecomposableTrainAndPredict() { var dataPath = GetDataPath(TestDatasets.irisData.trainFilename); - using (var env = new LocalEnvironment() - .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog - { - var data = new TextLoader(env, MakeIrisTextLoaderArgs()) - .Read(new MultiFileSource(dataPath)); + var ml = new MLContext(); + + var data = ml.Data.TextReader(MakeIrisTextLoaderArgs()) + .Read(dataPath); - var pipeline = new ConcatEstimator(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") - .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) - .Append(new SdcaMultiClassTrainer(env, "Features", "Label", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; })) - .Append(new KeyToValueEstimator(env, "PredictedLabel")); + var pipeline = new ConcatEstimator(ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + .Append(new TermEstimator(ml, "Label"), TransformerScope.TrainTest) + .Append(ml.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: s => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; })) + .Append(new KeyToValueEstimator(ml, "PredictedLabel")); - var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); - var engine = model.MakePredictionFunction(env); + var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); + var engine = model.MakePredictionFunction(ml); - var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); - var testData = testLoader.AsEnumerable(env, false); - foreach (var input in testData.Take(20)) - { - var prediction = engine.Predict(input); - Assert.True(prediction.PredictedLabel == input.Label); - } + var testLoader = TextLoader.ReadFile(ml, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); + var testData = testLoader.AsEnumerable(ml, false); + foreach (var input in testData.Take(20)) + { + var prediction = engine.Predict(input); + Assert.True(prediction.PredictedLabel == input.Label); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs index be7e413ba4..cbd05e23d4 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs @@ -20,23 +20,19 @@ public partial class ApiScenariosTests [Fact] public void New_Evaluation() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); + var ml = new MLContext(seed: 1, conc: 1); - // Pipeline. - var pipeline = new TextLoader(env, MakeSentimentTextLoaderArgs()) - .Append(new TextTransform(env, "SentimentText", "Features")) - .Append(new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => s.NumThreads = 1)); + // Pipeline. + var pipeline = ml.Data.TextReader(MakeSentimentTextLoaderArgs()) + .Append(ml.Transform.Text.FeaturizeText("SentimentText", "Features")) + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: s => s.NumThreads = 1)); - // Train. - var readerModel = pipeline.Fit(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); + // Train. + var readerModel = pipeline.Fit(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - // Evaluate on the test set. - var dataEval = readerModel.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.validFilename))); - var evaluator = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { }); - var metrics = evaluator.Evaluate(dataEval); - } + // Evaluate on the test set. + var dataEval = readerModel.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.validFilename))); + var metrics = ml.BinaryClassification.Evaluate(dataEval); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs index 2e019d327d..620bb0aeb4 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs @@ -25,35 +25,33 @@ void New_Extensibility() { var dataPath = GetDataPath(TestDatasets.irisData.trainFilename); - using (var env = new LocalEnvironment()) - { - var data = new TextLoader(env, MakeIrisTextLoaderArgs()) - .Read(new MultiFileSource(dataPath)); + var ml = new MLContext(); + var data = ml.Data.TextReader(MakeIrisTextLoaderArgs()) + .Read(dataPath); - Action action = (i, j) => - { - j.Label = i.Label; - j.PetalLength = i.SepalLength > 3 ? i.PetalLength : i.SepalLength; - j.PetalWidth = i.PetalWidth; - j.SepalLength = i.SepalLength; - j.SepalWidth = i.SepalWidth; - }; - var pipeline = new ConcatEstimator(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") - .Append(new MyLambdaTransform(env, action), TransformerScope.TrainTest) - .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) - .Append(new SdcaMultiClassTrainer(env, "Features", "Label", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; })) - .Append(new KeyToValueEstimator(env, "PredictedLabel")); + Action action = (i, j) => + { + j.Label = i.Label; + j.PetalLength = i.SepalLength > 3 ? i.PetalLength : i.SepalLength; + j.PetalWidth = i.PetalWidth; + j.SepalLength = i.SepalLength; + j.SepalWidth = i.SepalWidth; + }; + var pipeline = new ConcatEstimator(ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + .Append(new MyLambdaTransform(ml, action), TransformerScope.TrainTest) + .Append(new TermEstimator(ml, "Label"), TransformerScope.TrainTest) + .Append(ml.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; })) + .Append(new KeyToValueEstimator(ml, "PredictedLabel")); - var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); - var engine = model.MakePredictionFunction(env); + var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); + var engine = model.MakePredictionFunction(ml); - var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); - var testData = testLoader.AsEnumerable(env, false); - foreach (var input in testData.Take(20)) - { - var prediction = engine.Predict(input); - Assert.True(prediction.PredictedLabel == input.Label); - } + var testLoader = TextLoader.ReadFile(ml, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); + var testData = testLoader.AsEnumerable(ml, false); + foreach (var input in testData.Take(20)) + { + var prediction = engine.Predict(input); + Assert.True(prediction.PredictedLabel == input.Label); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs index 880f8d5c96..063ff0a9ff 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.RunTests; +using System.IO; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -23,22 +24,21 @@ public partial class ApiScenariosTests [Fact] void New_FileBasedSavingOfData() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - var trainData = new TextLoader(env, MakeSentimentTextLoaderArgs()) - .Append(new TextTransform(env, "SentimentText", "Features")) - .FitAndRead(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - using (var file = env.CreateOutputFile("i.idv")) - trainData.SaveAsBinary(env, file.CreateWriteStream()); + var ml = new MLContext(seed: 1, conc: 1); + var trainData = ml.Data.TextReader(MakeSentimentTextLoaderArgs()) + .Append(ml.Transform.Text.FeaturizeText("SentimentText", "Features")) + .FitAndRead(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - var trainer = new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => s.NumThreads = 1); - var loadedTrainData = new BinaryLoader(env, new BinaryLoader.Arguments(), new MultiFileSource("i.idv")); + using (var file = File.Create(GetOutputPath("i.idv"))) + trainData.SaveAsBinary(ml, file); - // Train. - var model = trainer.Train(new RoleMappedData(loadedTrainData, DefaultColumnNames.Label, DefaultColumnNames.Features)); - DeleteOutputPath("i.idv"); - } + var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: s => s.NumThreads = 1); + var loadedTrainData = new BinaryLoader(ml, new BinaryLoader.Arguments(), new MultiFileSource("i.idv")); + + // Train. + var model = trainer.Fit(loadedTrainData); + DeleteOutputPath("i.idv"); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs index a0e696343e..0f6bb09def 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs @@ -33,21 +33,20 @@ public partial class ApiScenariosTests [Fact] public void New_IntrospectiveTraining() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - var data = new TextLoader(env, MakeSentimentTextLoaderArgs()) - .Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); + var ml = new MLContext(seed: 1, conc: 1); + var data = ml.Data.TextReader(MakeSentimentTextLoaderArgs()) + .Read(GetDataPath(TestDatasets.Sentiment.trainFilename)); - var pipeline = new TextTransform(env, "SentimentText", "Features") - .Append(new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => s.NumThreads = 1)); + var pipeline = ml.Transform.Text.FeaturizeText("SentimentText", "Features") + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: s => s.NumThreads = 1)); - // Train. - var model = pipeline.Fit(data); + // Train. + var model = pipeline.Fit(data); + + // Get feature weights. + VBuffer weights = default; + model.LastTransformer.Model.GetFeatureWeights(ref weights); - // Get feature weights. - VBuffer weights = default; - model.LastTransformer.Model.GetFeatureWeights(ref weights); - } } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs index fd1a3d137f..bb3cb001ba 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.RunTests; using System.Linq; @@ -21,19 +20,18 @@ public partial class ApiScenariosTests [Fact] public void New_Metacomponents() { - using (var env = new LocalEnvironment()) - { - var data = new TextLoader(env, MakeIrisTextLoaderArgs()) - .Read(new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename))); + var ml = new MLContext(); + var data = ml.Data.TextReader(MakeIrisTextLoaderArgs()) + .Read(GetDataPath(TestDatasets.irisData.trainFilename)); - var sdcaTrainer = new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; }); - var pipeline = new ConcatEstimator(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") - .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) - .Append(new Ova(env, sdcaTrainer)) - .Append(new KeyToValueEstimator(env, "PredictedLabel")); + var sdcaTrainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; }); - var model = pipeline.Fit(data); - } + var pipeline = new ConcatEstimator(ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + .Append(new TermEstimator(ml, "Label"), TransformerScope.TrainTest) + .Append(new Ova(ml, sdcaTrainer)) + .Append(new KeyToValueEstimator(ml, "PredictedLabel")); + + var model = pipeline.Fit(data); } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs index e7b3cbe40b..ef6b246fe8 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs @@ -25,33 +25,31 @@ public partial class ApiScenariosTests [Fact] void New_MultithreadedPrediction() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); - var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); + var ml = new MLContext(seed: 1, conc: 1); + var reader = ml.Data.TextReader(MakeSentimentTextLoaderArgs()); + var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - // Pipeline. - var pipeline = new TextTransform(env, "SentimentText", "Features") - .Append(new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => s.NumThreads = 1)); + // Pipeline. + var pipeline = ml.Transform.Text.FeaturizeText("SentimentText", "Features") + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: s => s.NumThreads = 1)); - // Train. - var model = pipeline.Fit(data); + // Train. + var model = pipeline.Fit(data); - // Create prediction engine and test predictions. - var engine = model.MakePredictionFunction(env); + // Create prediction engine and test predictions. + var engine = model.MakePredictionFunction(ml); - // Take a couple examples out of the test data and run predictions on top. - var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))) - .AsEnumerable(env, false); + // Take a couple examples out of the test data and run predictions on top. + var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))) + .AsEnumerable(ml, false); - Parallel.ForEach(testData, (input) => + Parallel.ForEach(testData, (input) => + { + lock (engine) { - lock (engine) - { - var prediction = engine.Predict(input); - } - }); - } + var prediction = engine.Predict(input); + } + }); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs index 1387f07b7c..b06610488b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs @@ -21,29 +21,26 @@ public partial class ApiScenariosTests [Fact] public void New_ReconfigurablePrediction() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - var dataReader = new TextLoader(env, MakeSentimentTextLoaderArgs()); + var ml = new MLContext(seed: 1, conc: 1); + var dataReader = ml.Data.TextReader(MakeSentimentTextLoaderArgs()); - var data = dataReader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - var testData = dataReader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))); + var data = dataReader.Read(GetDataPath(TestDatasets.Sentiment.trainFilename)); + var testData = dataReader.Read(GetDataPath(TestDatasets.Sentiment.testFilename)); - // Pipeline. - var pipeline = new TextTransform(env, "SentimentText", "Features") - .Fit(data); + // Pipeline. + var pipeline = ml.Transform.Text.FeaturizeText("SentimentText", "Features") + .Fit(data); - var trainer = new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => s.NumThreads = 1); - var trainData = pipeline.Transform(data); - var model = trainer.Fit(trainData); + var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: (s) => s.NumThreads = 1); + var trainData = pipeline.Transform(data); + var model = trainer.Fit(trainData); - var scoredTest = model.Transform(pipeline.Transform(testData)); - var metrics = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments()).Evaluate(scoredTest, "Label", "Probability"); - - var newModel = new BinaryPredictionTransformer>(env, model.Model, trainData.Schema, model.FeatureColumn, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability); - var newScoredTest = newModel.Transform(pipeline.Transform(testData)); - var newMetrics = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments { Threshold = 0.01f, UseRawScoreThreshold = false }).Evaluate(newScoredTest, "Label", "Probability"); - } + var scoredTest = model.Transform(pipeline.Transform(testData)); + var metrics = ml.BinaryClassification.Evaluate(scoredTest); + var newModel = new BinaryPredictionTransformer>(ml, model.Model, trainData.Schema, model.FeatureColumn, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability); + var newScoredTest = newModel.Transform(pipeline.Transform(testData)); + var newMetrics = ml.BinaryClassification.Evaluate(scoredTest); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs index 041ae9b2f4..09ceb08d6e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs @@ -4,10 +4,9 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Learners; -using Xunit; -using System.Linq; using Microsoft.ML.Runtime.RunTests; +using System.Linq; +using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api { @@ -22,30 +21,28 @@ public partial class ApiScenariosTests [Fact] public void New_SimpleTrainAndPredict() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); - var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - // Pipeline. - var pipeline = new TextTransform(env, "SentimentText", "Features") - .Append(new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => s.NumThreads = 1)); + var ml = new MLContext(seed: 1, conc: 1); + var reader = ml.Data.TextReader(MakeSentimentTextLoaderArgs()); + var data = reader.Read(GetDataPath(TestDatasets.Sentiment.trainFilename)); + // Pipeline. + var pipeline = ml.Transform.Text.FeaturizeText("SentimentText", "Features") + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: s => s.NumThreads = 1)); - // Train. - var model = pipeline.Fit(data); + // Train. + var model = pipeline.Fit(data); - // Create prediction engine and test predictions. - var engine = model.MakePredictionFunction(env); + // Create prediction engine and test predictions. + var engine = model.MakePredictionFunction(ml); - // Take a couple examples out of the test data and run predictions on top. - var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))) - .AsEnumerable(env, false); - foreach (var input in testData.Take(5)) - { - var prediction = engine.Predict(input); - // Verify that predictions match and scores are separated from zero. - Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); - } + // Take a couple examples out of the test data and run predictions on top. + var testData = reader.Read(GetDataPath(TestDatasets.Sentiment.testFilename)) + .AsEnumerable(ml, false); + foreach (var input in testData.Take(5)) + { + var prediction = engine.Predict(input); + // Verify that predictions match and scores are separated from zero. + Assert.Equal(input.Sentiment, prediction.Sentiment); + Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs index 40ee0de532..8b97259716 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.RunTests; +using System.IO; using System.Linq; using Xunit; @@ -23,42 +24,39 @@ public partial class ApiScenariosTests [Fact] public void New_TrainSaveModelAndPredict() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) + var ml = new MLContext(seed: 1, conc: 1); + var reader = ml.Data.TextReader(MakeSentimentTextLoaderArgs()); + var data = reader.Read(GetDataPath(TestDatasets.Sentiment.trainFilename)); + + // Pipeline. + var pipeline = ml.Transform.Text.FeaturizeText("SentimentText", "Features") + .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: s => s.NumThreads = 1)); + + // Train. + var model = pipeline.Fit(data); + + var modelPath = GetOutputPath("temp.zip"); + // Save model. + using (var file = File.Create(modelPath)) + model.SaveTo(ml, file); + + // Load model. + ITransformer loadedModel; + using (var file = File.OpenRead(modelPath)) + loadedModel = TransformerChain.LoadFrom(ml, file); + + // Create prediction engine and test predictions. + var engine = loadedModel.MakePredictionFunction(ml); + + // Take a couple examples out of the test data and run predictions on top. + var testData = reader.Read(GetDataPath(TestDatasets.Sentiment.testFilename)) + .AsEnumerable(ml, false); + foreach (var input in testData.Take(5)) { - var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); - var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - - // Pipeline. - var pipeline = new TextTransform(env, "SentimentText", "Features") - .Append(new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => s.NumThreads = 1)); - - // Train. - var model = pipeline.Fit(data); - - ITransformer loadedModel; - using (var file = env.CreateTempFile()) - { - // Save model. - using (var fs = file.CreateWriteStream()) - model.SaveTo(env, fs); - - // Load model. - loadedModel = TransformerChain.LoadFrom(env, file.OpenReadStream()); - } - - // Create prediction engine and test predictions. - var engine = loadedModel.MakePredictionFunction(env); - - // Take a couple examples out of the test data and run predictions on top. - var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))) - .AsEnumerable(env, false); - foreach (var input in testData.Take(5)) - { - var prediction = engine.Predict(input); - // Verify that predictions match and scores are separated from zero. - Assert.Equal(input.Sentiment, prediction.Sentiment); - Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); - } + var prediction = engine.Predict(input); + // Verify that predictions match and scores are separated from zero. + Assert.Equal(input.Sentiment, prediction.Sentiment); + Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs index 7d1d30aff2..045936e738 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs @@ -20,26 +20,26 @@ public partial class ApiScenariosTests public void New_TrainWithInitialPredictor() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - var data = new TextLoader(env, MakeSentimentTextLoaderArgs()).Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); + var ml = new MLContext(seed: 1, conc: 1); - // Pipeline. - var pipeline = new TextTransform(env, "SentimentText", "Features"); + var data = ml.Data.TextReader(MakeSentimentTextLoaderArgs()).Read(GetDataPath(TestDatasets.Sentiment.trainFilename)); - // Train the pipeline, prepare train set. - var trainData = pipeline.FitAndTransform(data); + // Pipeline. + var pipeline = ml.Transform.Text.FeaturizeText("SentimentText", "Features"); - // Train the first predictor. - var trainer = new LinearClassificationTrainer(env, "Features", "Label", advancedSettings: (s) => s.NumThreads = 1); - var firstModel = trainer.Fit(trainData); + // Train the pipeline, prepare train set. + var trainData = pipeline.FitAndTransform(data); - // Train the second predictor on the same data. - var secondTrainer = new AveragedPerceptronTrainer(env, "Label", "Features"); + // Train the first predictor. + var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: s => s.NumThreads = 1); + var firstModel = trainer.Fit(trainData); + + // Train the second predictor on the same data. + var secondTrainer = ml.BinaryClassification.Trainers.AveragedPerceptron(); + + var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); + var finalModel = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: firstModel.Model)); - var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); - var finalModel = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: firstModel.Model)); - } } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs index 98ee8fadd8..2fa236437b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs @@ -18,22 +18,20 @@ public partial class ApiScenariosTests [Fact] public void New_TrainWithValidationSet() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - // Pipeline. - var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); - var pipeline = new TextTransform(env, "SentimentText", "Features"); + var ml = new MLContext(seed: 1, conc: 1); + // Pipeline. + var reader = ml.Data.TextReader(MakeSentimentTextLoaderArgs()); + var pipeline = ml.Transform.Text.FeaturizeText("SentimentText", "Features"); - // Train the pipeline, prepare train and validation set. - var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - var preprocess = pipeline.Fit(data); - var trainData = preprocess.Transform(data); - var validData = preprocess.Transform(reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename)))); + // Train the pipeline, prepare train and validation set. + var data = reader.Read(GetDataPath(TestDatasets.Sentiment.trainFilename)); + var preprocess = pipeline.Fit(data); + var trainData = preprocess.Transform(data); + var validData = preprocess.Transform(reader.Read(GetDataPath(TestDatasets.Sentiment.testFilename))); - // Train model with validation set. - var trainer = new LinearClassificationTrainer(env, "Features", "Label"); - var model = trainer.Train(trainData, validData); - } + // Train model with validation set. + var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(); + var model = trainer.Train(trainData, validData); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs index 4840f0300c..7c932fc844 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Linq; +using Microsoft.ML.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.RunTests; using Xunit; @@ -22,40 +24,15 @@ public partial class ApiScenariosTests [Fact] void New_Visibility() { - using (var env = new LocalEnvironment(seed: 1, conc: 1)) - { - var pipeline = new TextLoader(env, MakeSentimentTextLoaderArgs()) - .Append(new TextTransform(env, "SentimentText", "Features", s => s.OutputTokens = true)); + var ml = new MLContext(seed: 1, conc: 1); + var pipeline = ml.Data.TextReader(MakeSentimentTextLoaderArgs()) + .Append(ml.Transform.Text.FeaturizeText("SentimentText", "Features", s => s.OutputTokens = true)); - var data = pipeline.FitAndRead(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - // In order to find out available column names, you can go through schema and check - // column names and appropriate type for getter. - for (int i = 0; i < data.Schema.ColumnCount; i++) - { - var columnName = data.Schema.GetColumnName(i); - var columnType = data.Schema.GetColumnType(i).RawType; - } + var data = pipeline.FitAndRead(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); - using (var cursor = data.GetRowCursor(x => true)) - { - Assert.True(cursor.Schema.TryGetColumnIndex("SentimentText", out int textColumn)); - Assert.True(cursor.Schema.TryGetColumnIndex("Features_TransformedText", out int transformedTextColumn)); - Assert.True(cursor.Schema.TryGetColumnIndex("Features", out int featureColumn)); - - var originalTextGettter = cursor.GetGetter>(textColumn); - var transformedTextGettter = cursor.GetGetter>>(transformedTextColumn); - var featureGettter = cursor.GetGetter>(featureColumn); - ReadOnlyMemory text = default; - VBuffer> transformedText = default; - VBuffer features = default; - while (cursor.MoveNext()) - { - originalTextGettter(ref text); - transformedTextGettter(ref transformedText); - featureGettter(ref features); - } - } - } + var textColumn = data.GetColumn(ml, "SentimentText").Take(20); + var transformedTextColumn = data.GetColumn(ml, "Features_TransformedText").Take(20); + var features = data.GetColumn(ml, "Features").Take(20); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index aa425e3f4e..fa5b6e1808 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -16,171 +16,6 @@ namespace Microsoft.ML.Tests.Scenarios.Api { - public sealed class LoaderWrapper : IDataReader, ICanSaveModel - { - public const string LoaderSignature = "LoaderWrapper"; - - private readonly IHostEnvironment _env; - private readonly Func _loaderFactory; - - public LoaderWrapper(IHostEnvironment env, Func loaderFactory) - { - _env = env; - _loaderFactory = loaderFactory; - } - - public ISchema GetOutputSchema() - { - var emptyData = Read(new MultiFileSource(null)); - return emptyData.Schema; - } - - public IDataView Read(IMultiStreamSource input) => _loaderFactory(input); - - public void Save(ModelSaveContext ctx) - { - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - var ldr = Read(new MultiFileSource(null)); - ctx.SaveModel(ldr, "Loader"); - } - - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "LDR WRPR", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(LoaderWrapper).Assembly.FullName); - } - - public LoaderWrapper(IHostEnvironment env, ModelLoadContext ctx) - { - ctx.CheckAtModel(GetVersionInfo()); - ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null)); - - var loaderStream = new MemoryStream(); - using (var rep = RepositoryWriter.CreateNew(loaderStream)) - { - ModelSaveContext.SaveModel(rep, loader, "Loader"); - rep.Commit(); - } - - _env = env; - _loaderFactory = (IMultiStreamSource source) => - { - using (var rep = RepositoryReader.Open(loaderStream)) - { - ModelLoadContext.LoadModel(env, out var ldr, rep, "Loader", source); - return ldr; - } - }; - - } - } - - public sealed class MyBinaryClassifierEvaluator - { - private readonly IHostEnvironment _env; - private readonly BinaryClassifierEvaluator _evaluator; - - public MyBinaryClassifierEvaluator(IHostEnvironment env, BinaryClassifierEvaluator.Arguments args) - { - _env = env; - _evaluator = new BinaryClassifierEvaluator(env, args); - } - - public BinaryClassificationMetrics Evaluate(IDataView data, string labelColumn = DefaultColumnNames.Label, - string probabilityColumn = DefaultColumnNames.Probability) - { - var ci = EvaluateUtils.GetScoreColumnInfo(_env, data.Schema, null, DefaultColumnNames.Score, MetadataUtils.Const.ScoreColumnKind.BinaryClassification); - var map = new KeyValuePair[] - { - RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probabilityColumn), - RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, ci.Name) - }; - var rmd = new RoleMappedData(data, labelColumn, DefaultColumnNames.Features, opt: true, custom: map); - - var metricsDict = _evaluator.Evaluate(rmd); - return BinaryClassificationMetrics.FromMetrics(_env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"]).Single(); - } - } - - public static class MyCrossValidation - { - public sealed class BinaryCrossValidationMetrics - { - public readonly ITransformer[] FoldModels; - public readonly BinaryClassificationMetrics[] FoldMetrics; - - public BinaryCrossValidationMetrics(ITransformer[] models, BinaryClassificationMetrics[] metrics) - { - FoldModels = models; - FoldMetrics = metrics; - } - } - - public sealed class BinaryCrossValidator - { - private readonly IHostEnvironment _env; - - public int NumFolds { get; set; } = 2; - - public string StratificationColumn { get; set; } - - public string LabelColumn { get; set; } = DefaultColumnNames.Label; - - public BinaryCrossValidator(IHostEnvironment env) - { - _env = env; - } - - public BinaryCrossValidationMetrics CrossValidate(IDataView trainData, IEstimator estimator) - { - var models = new ITransformer[NumFolds]; - var metrics = new BinaryClassificationMetrics[NumFolds]; - - if (StratificationColumn == null) - { - StratificationColumn = "StratificationColumn"; - var random = new GenerateNumberTransform(_env, trainData, StratificationColumn); - trainData = random; - } - else - throw new NotImplementedException(); - - var evaluator = new MyBinaryClassifierEvaluator(_env, new BinaryClassifierEvaluator.Arguments() { }); - - for (int fold = 0; fold < NumFolds; fold++) - { - var trainFilter = new RangeFilter(_env, new RangeFilter.Arguments() - { - Column = StratificationColumn, - Min = (Double)fold / NumFolds, - Max = (Double)(fold + 1) / NumFolds, - Complement = true - }, trainData); - var testFilter = new RangeFilter(_env, new RangeFilter.Arguments() - { - Column = StratificationColumn, - Min = (Double)fold / NumFolds, - Max = (Double)(fold + 1) / NumFolds, - Complement = false - }, trainData); - - models[fold] = estimator.Fit(trainFilter); - var scoredTest = models[fold].Transform(testFilter); - metrics[fold] = evaluator.Evaluate(scoredTest, labelColumn: LabelColumn, probabilityColumn: "Probability"); - } - - return new BinaryCrossValidationMetrics(models, metrics); - - } - } - } - public class MyLambdaTransform : IEstimator where TSrc : class, new() where TDst : class, new()