Skip to content

Commit

Permalink
ML Context to create them all (#1252)
Browse files Browse the repository at this point in the history
* ML Context and a couple extensions
  • Loading branch information
Zruty0 authored Oct 18, 2018
1 parent c6d4e62 commit 9157cea
Show file tree
Hide file tree
Showing 56 changed files with 1,647 additions and 756 deletions.
212 changes: 93 additions & 119 deletions docs/code/MlNetCookBook.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/samples/Microsoft.ML.Samples/Trainers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// the alignment of the usings with the methods is intentional so they can display on the same level in the docs site.
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Trainers;
using Microsoft.ML.StaticPipe;
using System;

// NOTE: WHEN ADDING TO THE FILE, ALWAYS APPEND TO THE END OF IT.
Expand Down
20 changes: 20 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/DataLoadSaveCatalog.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.ML.Runtime
{
/// <summary>
/// A catalog of operations to load and save data.
/// </summary>
public sealed class DataLoadSaveOperations
{
internal IHostEnvironment Environment { get; }

internal DataLoadSaveOperations(IHostEnvironment env)
{
Contracts.AssertValue(env);
Environment = env;
}
}
}
19 changes: 19 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ public Column() { }
public Column(string name, DataKind? type, int index)
: this(name, type, new[] { new Range(index) }) { }

public Column(string name, DataKind? type, int minIndex, int maxIndex)
: this(name, type, new[] { new Range(minIndex, maxIndex) })
{
}

public Column(string name, DataKind? type, Range[] source, KeyRange keyRange = null)
{
Contracts.CheckValue(name, nameof(name));
Expand Down Expand Up @@ -1003,6 +1008,18 @@ private bool HasHeader
private readonly IHost _host;
private const string RegistrationName = "TextLoader";

public TextLoader(IHostEnvironment env, Column[] columns, Action<Arguments> advancedSettings, IMultiStreamSource dataSample = null)
: this(env, MakeArgs(columns, advancedSettings), dataSample)
{
}

private static Arguments MakeArgs(Column[] columns, Action<Arguments> advancedSettings)
{
var result = new Arguments { Column = columns };
advancedSettings?.Invoke(result);
return result;
}

public TextLoader(IHostEnvironment env, Arguments args, IMultiStreamSource dataSample = null)
{
Contracts.CheckValue(env, nameof(env));
Expand Down Expand Up @@ -1320,6 +1337,8 @@ public void Save(ModelSaveContext ctx)

public IDataView Read(IMultiStreamSource source) => new BoundLoader(this, source);

public IDataView Read(string path) => Read(new MultiFileSource(path));

private sealed class BoundLoader : IDataLoader
{
private readonly TextLoader _reader;
Expand Down
85 changes: 85 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;

namespace Microsoft.ML
{
public static class TextLoaderSaverCatalog
{
/// <summary>
/// Create a text reader.
/// </summary>
/// <param name="catalog">The catalog.</param>
/// <param name="args">The arguments to text reader, describing the data schema.</param>
/// <param name="dataSample">The optional location of a data sample.</param>
public static TextLoader TextReader(this DataLoadSaveOperations catalog,
TextLoader.Arguments args, IMultiStreamSource dataSample = null)
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample);

/// <summary>
/// Create a text reader.
/// </summary>
/// <param name="catalog">The catalog.</param>
/// <param name="columns">The columns of the schema.</param>
/// <param name="advancedSettings">The delegate to set additional settings.</param>
/// <param name="dataSample">The optional location of a data sample.</param>
public static TextLoader TextReader(this DataLoadSaveOperations catalog,
TextLoader.Column[] columns, Action<TextLoader.Arguments> advancedSettings = null, IMultiStreamSource dataSample = null)
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), columns, advancedSettings, dataSample);

/// <summary>
/// Read a data view from a text file using <see cref="TextLoader"/>.
/// </summary>
/// <param name="catalog">The catalog.</param>
/// <param name="columns">The columns of the schema.</param>
/// <param name="advancedSettings">The delegate to set additional settings</param>
/// <param name="path">The path to the file</param>
/// <returns>The data view.</returns>
public static IDataView ReadFromTextFile(this DataLoadSaveOperations catalog,
TextLoader.Column[] columns, string path, Action<TextLoader.Arguments> advancedSettings = null)
{
Contracts.CheckNonEmpty(path, nameof(path));

var env = catalog.GetEnvironment();

// REVIEW: it is almost always a mistake to have a 'trainable' text loader here.
// Therefore, we are going to disallow data sample.
var reader = new TextLoader(env, columns, advancedSettings, dataSample: null);
return reader.Read(new MultiFileSource(path));
}

/// <summary>
/// Save the data view as text.
/// </summary>
/// <param name="catalog">The catalog.</param>
/// <param name="data">The data view to save.</param>
/// <param name="stream">The stream to write to.</param>
/// <param name="separator">The column separator.</param>
/// <param name="headerRow">Whether to write the header row.</param>
/// <param name="schema">Whether to write the header comment with the schema.</param>
/// <param name="keepHidden">Whether to keep hidden columns in the dataset.</param>
public static void SaveAsText(this DataLoadSaveOperations catalog, IDataView data, Stream stream,
char separator = '\t', bool headerRow = true, bool schema = true, bool keepHidden = false)
{
Contracts.CheckValue(catalog, nameof(catalog));
Contracts.CheckValue(data, nameof(data));
Contracts.CheckValue(stream, nameof(stream));

var env = catalog.GetEnvironment();
var saver = new TextSaver(env, new TextSaver.Arguments { Separator = separator.ToString(), OutputHeader = headerRow, OutputSchema = schema });

using (var ch = env.Start("Saving data"))
DataSaverUtils.SaveDataView(ch, saver, data, stream, keepHidden);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public static RegressionEvaluator.Result Evaluate<T>(
/// <param name="score">The index delegate for predicted score column.</param>
/// <returns>The evaluation metrics.</returns>
public static RankerEvaluator.Result Evaluate<T, TVal>(
this RankerContext ctx,
this RankingContext ctx,
DataView<T> data,
Func<T, Scalar<float>> label,
Func<T, Key<uint, TVal>> groupId,
Expand Down
108 changes: 108 additions & 0 deletions src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using System;

namespace Microsoft.ML
{
/// <summary>
/// The <see cref="MLContext"/> is a starting point for all ML.NET operations. It is instantiated by user,
/// provides mechanisms for logging and entry points for training, prediction, model operations etc.
/// </summary>
public sealed class MLContext : IHostEnvironment
{
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation.
private readonly LocalEnvironment _env;

/// <summary>
/// Trainers and tasks specific to binary classification problems.
/// </summary>
public BinaryClassificationContext BinaryClassification { get; }
/// <summary>
/// Trainers and tasks specific to multiclass classification problems.
/// </summary>
public MulticlassClassificationContext MulticlassClassification { get; }
/// <summary>
/// Trainers and tasks specific to regression problems.
/// </summary>
public RegressionContext Regression { get; }
/// <summary>
/// Trainers and tasks specific to clustering problems.
/// </summary>
public ClusteringContext Clustering { get; }
/// <summary>
/// Trainers and tasks specific to ranking problems.
/// </summary>
public RankingContext Ranking { get; }

/// <summary>
/// Data processing operations.
/// </summary>
public TransformsCatalog Transforms { get; }

/// <summary>
/// Operations with trained models.
/// </summary>
public ModelOperationsCatalog Model { get; }

/// <summary>
/// Data loading and saving.
/// </summary>
public DataLoadSaveOperations Data { get; }

// REVIEW: I think it's valuable to have the simplest possible interface for logging interception here,
// and expand if and when necessary. Exposing classes like ChannelMessage, MessageSensitivity and so on
// looks premature at this point.
/// <summary>
/// The handler for the log messages.
/// </summary>
public Action<string> Log { get; set; }

/// <summary>
/// Create the ML context.
/// </summary>
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
public MLContext(int? seed = null, int conc = 0)
{
_env = new LocalEnvironment(seed, conc);
_env.AddListener(ProcessMessage);

BinaryClassification = new BinaryClassificationContext(_env);
MulticlassClassification = new MulticlassClassificationContext(_env);
Regression = new RegressionContext(_env);
Clustering = new ClusteringContext(_env);
Ranking = new RankingContext(_env);
Transforms = new TransformsCatalog(_env);
Model = new ModelOperationsCatalog(_env);
Data = new DataLoadSaveOperations(_env);
}

private void ProcessMessage(IMessageSource source, ChannelMessage message)
{
if (Log == null)
return;

var msg = $"[Source={source.FullName}, Kind={message.Kind}] {message.Message}";
// Log may have been reset from another thread.
// We don't care which logger we send the message to, just making sure we don't crash.
Log?.Invoke(msg);
}

int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor;
bool IHostEnvironment.IsCancelled => _env.IsCancelled;
ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog;
string IExceptionContext.ContextDescription => _env.ContextDescription;
IFileHandle IHostEnvironment.CreateOutputFile(string path) => _env.CreateOutputFile(path);
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix);
IFileHandle IHostEnvironment.OpenInputFile(string path) => _env.OpenInputFile(path);
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex);
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose, int? conc) => _env.Register(name, seed, verbose, conc);
IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
}
}
38 changes: 38 additions & 0 deletions src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using System.IO;

namespace Microsoft.ML.Runtime
{
/// <summary>
/// An object serving as a 'catalog' of available model operations.
/// </summary>
public sealed class ModelOperationsCatalog
{
internal IHostEnvironment Environment { get; }

internal ModelOperationsCatalog(IHostEnvironment env)
{
Contracts.AssertValue(env);
Environment = env;
}

/// <summary>
/// Save the model to the stream.
/// </summary>
/// <param name="model">The trained model to be saved.</param>
/// <param name="stream">A writeable, seekable stream to save to.</param>
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream);

/// <summary>
/// Load the model from the stream.
/// </summary>
/// <param name="stream">A readable, seekable stream to load from.</param>
/// <returns>The loaded model.</returns>
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream);
}
}
Loading

0 comments on commit 9157cea

Please sign in to comment.