Skip to content

Commit

Permalink
Export to ONNX and cross-platform command-line tool to script ML.NET …
Browse files Browse the repository at this point in the history
…training and inference (#248)

* Export to ONNX and Maml cross-platform executable.
  • Loading branch information
codemzs authored Jun 6, 2018
1 parent 5730685 commit 1bb1249
Show file tree
Hide file tree
Showing 20 changed files with 5,446 additions and 3,347 deletions.
24 changes: 17 additions & 7 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Core", "src\Microsoft.ML.Core\Microsoft.ML.Core.csproj", "{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{09EADF06-BE25-4228-AB53-95AE3E15B530}"
ProjectSection(SolutionItems) = preProject
src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj = src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{AED9C836-31E3-4F3F-8ABC-929555D3F3C4}"
EndProject
Expand All @@ -30,8 +33,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.KMeansClusteri
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.PCA", "src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj", "{58E06735-1129-4DD5-86E0-6BBFF049AAD9}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj", "{D956E291-F6E5-4474-9023-91793F45ABEB}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Api", "src\Microsoft.ML.Api\Microsoft.ML.Api.csproj", "{2F636A2C-062C-49F4-85F3-60DCADAB6A43}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Tests", "test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj", "{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}"
Expand Down Expand Up @@ -104,6 +105,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Parquet", "Mic
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Benchmarks", "test\Microsoft.ML.Benchmarks\Microsoft.ML.Benchmarks.csproj", "{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj", "{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Console", "src\Microsoft.ML.Console\Microsoft.ML.Console.csproj", "{362A98CF-FBF7-4EBB-A11B-990BBF845B15}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -158,10 +163,6 @@ Global
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.Build.0 = Release|Any CPU
{D956E291-F6E5-4474-9023-91793F45ABEB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{D956E291-F6E5-4474-9023-91793F45ABEB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{D956E291-F6E5-4474-9023-91793F45ABEB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{D956E291-F6E5-4474-9023-91793F45ABEB}.Release|Any CPU.Build.0 = Release|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.Build.0 = Debug|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.ActiveCfg = Release|Any CPU
Expand Down Expand Up @@ -202,6 +203,14 @@ Global
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.Build.0 = Release|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.Build.0 = Release|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.Build.0 = Debug|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.ActiveCfg = Release|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand All @@ -219,7 +228,6 @@ Global
{7288C084-11C0-43BE-AC7F-45DCFEAEEBF6} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{58E06735-1129-4DD5-86E0-6BBFF049AAD9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{D956E291-F6E5-4474-9023-91793F45ABEB} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{2F636A2C-062C-49F4-85F3-60DCADAB6A43} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{FDA2FD2C-A708-43AC-A941-4D941B0853BF} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
Expand All @@ -236,6 +244,8 @@ Global
{DEC8F776-49F7-4D87-836C-FE4DC057D08C} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{6C95FC87-F5F2-4EEF-BB97-567F2F5DD141} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{362A98CF-FBF7-4EBB-A11B-990BBF845B15} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.ML.Console/Console.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.ML.Runtime.Tools.Console
{
public static class Console
{
public static int Main(string[] args) => Maml.Main(args);
}
}
20 changes: 20 additions & 0 deletions src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>CORECLR</DefineConstants>
<IncludeInPackage>Microsoft.ML</IncludeInPackage>
<TargetFramework>netcoreapp2.0</TargetFramework>
<OutputType>Exe</OutputType>
<AssemblyName>MML</AssemblyName>
<StartupObject>Microsoft.ML.Runtime.Tools.Console.Console</StartupObject>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<ProjectReference Include="..\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
<ProjectReference Include="..\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
</ItemGroup>

</Project>
18 changes: 9 additions & 9 deletions src/Microsoft.ML.Data/Commands/DataCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,38 @@ public static class DataCommand
{
public abstract class ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>")]
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>")]
public SubComponent<IDataLoader, SignatureDataLoader> Loader;

[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)]
public string DataFile;

[Argument(ArgumentType.AtMostOnce, HelpText = "Model file to save", ShortName = "out")]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Model file to save", ShortName = "out")]
public string OutputModelFile;

[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)]
public string InputModelFile;

[Argument(ArgumentType.Multiple, HelpText = "Load transforms from model file?", ShortName = "loadTrans", SortOrder = 91)]
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Load transforms from model file?", ShortName = "loadTrans", SortOrder = 91)]
public bool? LoadTransforms;

[Argument(ArgumentType.AtMostOnce, HelpText = "Random seed", ShortName = "seed", SortOrder = 101)]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Random seed", ShortName = "seed", SortOrder = 101)]
public int? RandomSeed;

[Argument(ArgumentType.AtMostOnce, HelpText = "Verbose?", ShortName = "v", Hide = true)]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Verbose?", ShortName = "v", Hide = true)]
public bool? Verbose;

[Argument(ArgumentType.AtMostOnce, HelpText = "The web server to publish the RESTful API", Hide = true)]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The web server to publish the RESTful API", Hide = true)]
public ServerChannel.IServerFactory Server;

// This is actually an advisory value. The implementations themselves are responsible for
// determining what they consider appropriate, and the actual heuristics is a bit more
// complex than just this.
[Argument(ArgumentType.LastOccurenceWins,
[Argument(ArgumentType.LastOccurenceWins, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly,
HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")]
public int? Parallel;

[Argument(ArgumentType.Multiple, HelpText = "Transform", ShortName = "xf")]
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf")]
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] Transform;
}

Expand Down
11 changes: 9 additions & 2 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ public sealed class OnnxContext
private readonly HashSet<string> _variableMap;
private readonly HashSet<string> _nodeNames;
private readonly string _name;
private readonly string _producerName;
private readonly IHost _host;
private readonly string _domain;
private readonly string _producerVersion;
private readonly long _modelVersion;

public OnnxContext(IHostEnvironment env, string name, string domain)
public OnnxContext(IHostEnvironment env, string name, string producerName,
string producerVersion, long modelVersion, string domain)
{
Contracts.CheckValue(env, nameof(env));
Contracts.CheckValue(name, nameof(name));
Expand All @@ -41,6 +45,9 @@ public OnnxContext(IHostEnvironment env, string name, string domain)
_variableMap = new HashSet<string>();
_nodeNames = new HashSet<string>();
_name = name;
_producerName = producerName;
_producerVersion = producerVersion;
_modelVersion = modelVersion;
_domain = domain;
}

Expand Down Expand Up @@ -234,6 +241,6 @@ public void AddInputVariable(ColumnType type, string colName)
/// Makes the ONNX model based on the context.
/// </summary>
public ModelProto MakeModel()
=> OnnxUtils.MakeModel(_nodes, _name, _name, _domain, _inputs, _outputs, _intermediateValues);
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
}
}
13 changes: 10 additions & 3 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private static AttributeProto MakeAttribute(string key, IEnumerable<GraphProto>

private static AttributeProto MakeAttribute(string key, bool value) => MakeAttribute(key, value ? 1 : 0);

public static NodeProto MakeNode(string opType, List<string> inputs, List<string> outputs, string name)
public static NodeProto MakeNode(string opType, List<string> inputs, List<string> outputs, string name, string domain = null)
{
Contracts.CheckNonEmpty(opType, nameof(opType));
Contracts.CheckValue(inputs, nameof(inputs));
Expand All @@ -165,7 +165,7 @@ public static NodeProto MakeNode(string opType, List<string> inputs, List<string
node.Input.Add(inputs);
node.Output.Add(outputs);
node.Name = name;
node.Domain = "ai.onnx.ml";
node.Domain = domain ?? "ai.onnx.ml";
return node;
}

Expand Down Expand Up @@ -251,7 +251,8 @@ public NodeProtoWrapper(NodeProto node)
}
}

public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name, string domain, List<ModelArgs> inputs,
public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name,
string domain, string producerVersion, long modelVersion, List<ModelArgs> inputs,
List<ModelArgs> outputs, List<ModelArgs> intermediateValues)
{
Contracts.CheckValue(nodes, nameof(nodes));
Expand All @@ -261,10 +262,16 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
Contracts.CheckNonEmpty(producerName, nameof(producerName));
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckNonEmpty(domain, nameof(domain));
Contracts.CheckNonEmpty(producerVersion, nameof(producerVersion));

var model = new ModelProto();
model.Domain = domain;
model.ProducerName = producerName;
model.ProducerVersion = producerVersion;
model.IrVersion = (long)UniversalModelFormat.Onnx.Version.IrVersion;
model.ModelVersion = modelVersion;
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx", Version = 6 });
model.Graph = new GraphProto();
var graph = model.Graph;
graph.Node.Add(nodes);
Expand Down
Loading

0 comments on commit 1bb1249

Please sign in to comment.