-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ML Context to create them all (#1252)
* ML Context and a couple extensions
- Loading branch information
Showing
56 changed files
with
1,647 additions
and
756 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
namespace Microsoft.ML.Runtime | ||
{ | ||
/// <summary> | ||
/// A catalog of operations to load and save data. | ||
/// </summary> | ||
public sealed class DataLoadSaveOperations | ||
{ | ||
internal IHostEnvironment Environment { get; } | ||
|
||
internal DataLoadSaveOperations(IHostEnvironment env) | ||
{ | ||
Contracts.AssertValue(env); | ||
Environment = env; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Data.IO; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Linq; | ||
using System.Text; | ||
|
||
namespace Microsoft.ML | ||
{ | ||
public static class TextLoaderSaverCatalog | ||
{ | ||
/// <summary> | ||
/// Create a text reader. | ||
/// </summary> | ||
/// <param name="catalog">The catalog.</param> | ||
/// <param name="args">The arguments to text reader, describing the data schema.</param> | ||
/// <param name="dataSample">The optional location of a data sample.</param> | ||
public static TextLoader TextReader(this DataLoadSaveOperations catalog, | ||
TextLoader.Arguments args, IMultiStreamSource dataSample = null) | ||
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample); | ||
|
||
/// <summary> | ||
/// Create a text reader. | ||
/// </summary> | ||
/// <param name="catalog">The catalog.</param> | ||
/// <param name="columns">The columns of the schema.</param> | ||
/// <param name="advancedSettings">The delegate to set additional settings.</param> | ||
/// <param name="dataSample">The optional location of a data sample.</param> | ||
public static TextLoader TextReader(this DataLoadSaveOperations catalog, | ||
TextLoader.Column[] columns, Action<TextLoader.Arguments> advancedSettings = null, IMultiStreamSource dataSample = null) | ||
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), columns, advancedSettings, dataSample); | ||
|
||
/// <summary> | ||
/// Read a data view from a text file using <see cref="TextLoader"/>. | ||
/// </summary> | ||
/// <param name="catalog">The catalog.</param> | ||
/// <param name="columns">The columns of the schema.</param> | ||
/// <param name="advancedSettings">The delegate to set additional settings</param> | ||
/// <param name="path">The path to the file</param> | ||
/// <returns>The data view.</returns> | ||
public static IDataView ReadFromTextFile(this DataLoadSaveOperations catalog, | ||
TextLoader.Column[] columns, string path, Action<TextLoader.Arguments> advancedSettings = null) | ||
{ | ||
Contracts.CheckNonEmpty(path, nameof(path)); | ||
|
||
var env = catalog.GetEnvironment(); | ||
|
||
// REVIEW: it is almost always a mistake to have a 'trainable' text loader here. | ||
// Therefore, we are going to disallow data sample. | ||
var reader = new TextLoader(env, columns, advancedSettings, dataSample: null); | ||
return reader.Read(new MultiFileSource(path)); | ||
} | ||
|
||
/// <summary> | ||
/// Save the data view as text. | ||
/// </summary> | ||
/// <param name="catalog">The catalog.</param> | ||
/// <param name="data">The data view to save.</param> | ||
/// <param name="stream">The stream to write to.</param> | ||
/// <param name="separator">The column separator.</param> | ||
/// <param name="headerRow">Whether to write the header row.</param> | ||
/// <param name="schema">Whether to write the header comment with the schema.</param> | ||
/// <param name="keepHidden">Whether to keep hidden columns in the dataset.</param> | ||
public static void SaveAsText(this DataLoadSaveOperations catalog, IDataView data, Stream stream, | ||
char separator = '\t', bool headerRow = true, bool schema = true, bool keepHidden = false) | ||
{ | ||
Contracts.CheckValue(catalog, nameof(catalog)); | ||
Contracts.CheckValue(data, nameof(data)); | ||
Contracts.CheckValue(stream, nameof(stream)); | ||
|
||
var env = catalog.GetEnvironment(); | ||
var saver = new TextSaver(env, new TextSaver.Arguments { Separator = separator.ToString(), OutputHeader = headerRow, OutputSchema = schema }); | ||
|
||
using (var ch = env.Start("Saving data")) | ||
DataSaverUtils.SaveDataView(ch, saver, data, stream, keepHidden); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using System; | ||
|
||
namespace Microsoft.ML | ||
{ | ||
/// <summary> | ||
/// The <see cref="MLContext"/> is a starting point for all ML.NET operations. It is instantiated by user, | ||
/// provides mechanisms for logging and entry points for training, prediction, model operations etc. | ||
/// </summary> | ||
public sealed class MLContext : IHostEnvironment | ||
{ | ||
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation. | ||
private readonly LocalEnvironment _env; | ||
|
||
/// <summary> | ||
/// Trainers and tasks specific to binary classification problems. | ||
/// </summary> | ||
public BinaryClassificationContext BinaryClassification { get; } | ||
/// <summary> | ||
/// Trainers and tasks specific to multiclass classification problems. | ||
/// </summary> | ||
public MulticlassClassificationContext MulticlassClassification { get; } | ||
/// <summary> | ||
/// Trainers and tasks specific to regression problems. | ||
/// </summary> | ||
public RegressionContext Regression { get; } | ||
/// <summary> | ||
/// Trainers and tasks specific to clustering problems. | ||
/// </summary> | ||
public ClusteringContext Clustering { get; } | ||
/// <summary> | ||
/// Trainers and tasks specific to ranking problems. | ||
/// </summary> | ||
public RankingContext Ranking { get; } | ||
|
||
/// <summary> | ||
/// Data processing operations. | ||
/// </summary> | ||
public TransformsCatalog Transforms { get; } | ||
|
||
/// <summary> | ||
/// Operations with trained models. | ||
/// </summary> | ||
public ModelOperationsCatalog Model { get; } | ||
|
||
/// <summary> | ||
/// Data loading and saving. | ||
/// </summary> | ||
public DataLoadSaveOperations Data { get; } | ||
|
||
// REVIEW: I think it's valuable to have the simplest possible interface for logging interception here, | ||
// and expand if and when necessary. Exposing classes like ChannelMessage, MessageSensitivity and so on | ||
// looks premature at this point. | ||
/// <summary> | ||
/// The handler for the log messages. | ||
/// </summary> | ||
public Action<string> Log { get; set; } | ||
|
||
/// <summary> | ||
/// Create the ML context. | ||
/// </summary> | ||
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param> | ||
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param> | ||
public MLContext(int? seed = null, int conc = 0) | ||
{ | ||
_env = new LocalEnvironment(seed, conc); | ||
_env.AddListener(ProcessMessage); | ||
|
||
BinaryClassification = new BinaryClassificationContext(_env); | ||
MulticlassClassification = new MulticlassClassificationContext(_env); | ||
Regression = new RegressionContext(_env); | ||
Clustering = new ClusteringContext(_env); | ||
Ranking = new RankingContext(_env); | ||
Transforms = new TransformsCatalog(_env); | ||
Model = new ModelOperationsCatalog(_env); | ||
Data = new DataLoadSaveOperations(_env); | ||
} | ||
|
||
private void ProcessMessage(IMessageSource source, ChannelMessage message) | ||
{ | ||
if (Log == null) | ||
return; | ||
|
||
var msg = $"[Source={source.FullName}, Kind={message.Kind}] {message.Message}"; | ||
// Log may have been reset from another thread. | ||
// We don't care which logger we send the message to, just making sure we don't crash. | ||
Log?.Invoke(msg); | ||
} | ||
|
||
int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor; | ||
bool IHostEnvironment.IsCancelled => _env.IsCancelled; | ||
ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog; | ||
string IExceptionContext.ContextDescription => _env.ContextDescription; | ||
IFileHandle IHostEnvironment.CreateOutputFile(string path) => _env.CreateOutputFile(path); | ||
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix); | ||
IFileHandle IHostEnvironment.OpenInputFile(string path) => _env.OpenInputFile(path); | ||
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex); | ||
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose, int? conc) => _env.Register(name, seed, verbose, conc); | ||
IChannel IChannelProvider.Start(string name) => _env.Start(name); | ||
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name); | ||
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime.Data; | ||
using System.IO; | ||
|
||
namespace Microsoft.ML.Runtime | ||
{ | ||
/// <summary> | ||
/// An object serving as a 'catalog' of available model operations. | ||
/// </summary> | ||
public sealed class ModelOperationsCatalog | ||
{ | ||
internal IHostEnvironment Environment { get; } | ||
|
||
internal ModelOperationsCatalog(IHostEnvironment env) | ||
{ | ||
Contracts.AssertValue(env); | ||
Environment = env; | ||
} | ||
|
||
/// <summary> | ||
/// Save the model to the stream. | ||
/// </summary> | ||
/// <param name="model">The trained model to be saved.</param> | ||
/// <param name="stream">A writeable, seekable stream to save to.</param> | ||
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream); | ||
|
||
/// <summary> | ||
/// Load the model from the stream. | ||
/// </summary> | ||
/// <param name="stream">A readable, seekable stream to load from.</param> | ||
/// <returns>The loaded model.</returns> | ||
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream); | ||
} | ||
} |
Oops, something went wrong.