Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the creation of ONNX initializers #965

Merged
merged 9 commits into from
Oct 2, 2018
51 changes: 51 additions & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,56 @@ public abstract OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
=> CreateNode(opType, new[] { input }, new[] { output }, name, domain);

/// <summary>
/// Call this function can declare a global float
/// </summary>
/// <param name="value">The float number which is going to be added</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(float value, string name = null);

/// <summary>
/// Call this function can declare a global int64
wschin marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
/// <param name="value">The long number which is going to be added into the ONNX graph</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(long value, string name = null);

/// <summary>
/// Call this function can declare a global string
/// </summary>
/// <param name="value">The string which is going to be added into the ONNX graph</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(string value, string name = null);

/// <summary>
/// Call this function can declare a global float tensor
/// </summary>
/// <param name="values">The floats which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the floats</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null);

/// <summary>
/// Call this function can declare a global long tensor
/// </summary>
/// <param name="values">The longs which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the floats</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null);

/// <summary>
/// Call this function can declare a global long tensor
wschin marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
/// <param name="values">The strings which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the strings</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null);
}
}
87 changes: 85 additions & 2 deletions src/Microsoft.ML.Onnx/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ namespace Microsoft.ML.Runtime.Model.Onnx
/// <summary>
/// A context for defining a ONNX output.
/// </summary>
internal sealed class OnnxContextImpl : OnnxContext
public sealed class OnnxContextImpl : OnnxContext
wschin marked this conversation as resolved.
Show resolved Hide resolved
{
private readonly List<NodeProto> _nodes;
private readonly List<OnnxUtils.ModelArgs> _inputs;
// The map from IDataView column names to variable names.
private readonly List<TensorProto> _initializers;
private readonly List<OnnxUtils.ModelArgs> _intermediateValues;
private readonly List<OnnxUtils.ModelArgs> _outputs;
private readonly Dictionary<string, string> _columnNameMap;
Expand All @@ -43,6 +44,7 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
_nodes = new List<NodeProto>();
_intermediateValues = new List<OnnxUtils.ModelArgs>();
_inputs = new List<OnnxUtils.ModelArgs>();
_initializers = new List<TensorProto>();
_outputs = new List<OnnxUtils.ModelArgs>();
_columnNameMap = new Dictionary<string, string>();
_variableNames = new HashSet<string>();
Expand Down Expand Up @@ -246,10 +248,91 @@ public void AddInputVariable(ColumnType type, string colName)
_inputs.Add(OnnxUtils.GetModelArgs(type, colName));
}

/// <summary>
/// Adds constant tensors into the graph.
/// </summary>
public override string AddInitializer(float value, string name = null)
{
if (name != null)
name = AddVariable(name);
else
name = AddVariable("initializer");
wschin marked this conversation as resolved.
Show resolved Hide resolved

_initializers.Add(OnnxUtils.MakeFloat(name, value));
return name;
}

public override string AddInitializer(string value, string name = null)
{
if (name != null)
name = AddVariable(name);
else
name = AddVariable("initializer");

_initializers.Add(OnnxUtils.MakeString(name, value));
return name;
}

public override string AddInitializer(long value, string name = null)
{
if (name != null)
name = AddVariable(name);
else
name = AddVariable("initializer");

_initializers.Add(OnnxUtils.MakeInt64(name, value));
return name;
}

public override string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

if (name != null)
name = AddVariable(name);
else
name = AddVariable("initializer");

_initializers.Add(OnnxUtils.MakeFloats(name, values, dims));
return name;
}

public override string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

if (name != null)
name = AddVariable(name);
else
name = AddVariable("initializer");

_initializers.Add(OnnxUtils.MakeInt64s(name, values, dims));
return name;
}

public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

if (name != null)
name = AddVariable(name);
else
name = AddVariable("initializer");

_initializers.Add(OnnxUtils.MakeStrings(name, values, dims));
return name;
}

/// <summary>
/// Makes the ONNX model based on the context.
/// </summary>
public ModelProto MakeModel()
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues, _initializers);
}
}
79 changes: 77 additions & 2 deletions src/Microsoft.ML.Onnx/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,13 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List<long> di

public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name,
string domain, string producerVersion, long modelVersion, List<ModelArgs> inputs,
List<ModelArgs> outputs, List<ModelArgs> intermediateValues)
List<ModelArgs> outputs, List<ModelArgs> intermediateValues, List<TensorProto> initializers)
{
Contracts.CheckValue(nodes, nameof(nodes));
Contracts.CheckValue(inputs, nameof(inputs));
Contracts.CheckValue(outputs, nameof(outputs));
Contracts.CheckValue(outputs, nameof(intermediateValues));
Contracts.CheckValue(intermediateValues, nameof(intermediateValues));
Contracts.CheckValue(initializers, nameof(initializers));
Contracts.CheckNonEmpty(producerName, nameof(producerName));
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckNonEmpty(domain, nameof(domain));
Expand Down Expand Up @@ -282,6 +283,8 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
}

graph.Initializer.AddRange(initializers);

return model;
}

Expand Down Expand Up @@ -349,5 +352,77 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName,

return new ModelArgs(name, dataType, dimsLocal, dimsParamLocal);
}

// Make int64 scalar in ONNX from native C# number
public static TensorProto MakeInt64(string name, long value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(string name, long value) [](start = 43, length = 25)

Can you please add checks? See how above functions do checks. Same for below.

{
var tensor = new TensorProto();
tensor.Name = name;
tensor.DataType = TensorProto.Types.DataType.Int64;
tensor.Int64Data.Add(value);
return tensor;
}

// Make float vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
public static TensorProto MakeInt64s(string name, IEnumerable<long> values, IEnumerable<long> dims = null)
{
var tensor = new TensorProto();
tensor.Name = name;
tensor.DataType = TensorProto.Types.DataType.Int64;
tensor.Int64Data.AddRange(values);
if (dims != null)
tensor.Dims.AddRange(dims);
else
tensor.Dims.Add(values.Count());
return tensor;
}

// Make float scalar in ONNX from native C# number
public static TensorProto MakeFloat(string name, float value)
{
var tensor = new TensorProto();
tensor.Name = name;
tensor.DataType = TensorProto.Types.DataType.Float;
tensor.FloatData.Add(value);
return tensor;
}

// Make float vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
public static TensorProto MakeFloats(string name, IEnumerable<float> values, IEnumerable<long> dims = null)
{
var tensor = new TensorProto();
tensor.Name = name;
tensor.DataType = TensorProto.Types.DataType.Float;
tensor.FloatData.AddRange(values);
if (dims != null)
tensor.Dims.AddRange(dims);
else
tensor.Dims.Add(values.Count());
wschin marked this conversation as resolved.
Show resolved Hide resolved
return tensor;
}

// Make float scalar in ONNX from native C# number
public static TensorProto MakeString(string name, string value)
{
var tensor = new TensorProto();
tensor.Name = name;
tensor.DataType = TensorProto.Types.DataType.String;
tensor.StringData.Add(StringToByteString(value));
return tensor;
}

// Make float vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
wschin marked this conversation as resolved.
Show resolved Hide resolved
wschin marked this conversation as resolved.
Show resolved Hide resolved
public static TensorProto MakeStrings(string name, IEnumerable<string> values, IEnumerable<long> dims = null)
{
var tensor = new TensorProto();
tensor.Name = name;
tensor.DataType = TensorProto.Types.DataType.String;
tensor.StringData.AddRange(StringToByteString(values));
if (dims != null)
tensor.Dims.AddRange(dims);
else
tensor.Dims.Add(values.Count());
return tensor;
}
}
}
68 changes: 68 additions & 0 deletions test/Microsoft.ML.Tests/OnnxTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
using Microsoft.ML.Legacy.Transforms;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Runtime.RunTests;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text.RegularExpressions;
using Xunit;
Expand Down Expand Up @@ -51,6 +53,72 @@ public class BreastCancerMCPrediction
public float[] Scores;
}

[Fact]
public void InitializerCreationTest()
{
using (var env = new ConsoleEnvironment())
{
// Create the actual implementation
var ctxImpl = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test");

// Use implementation as in the actual conversion code
var ctx = ctxImpl as OnnxContext;
ctx.AddInitializer(9.4f, "float");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.AddInitializer(9.4f, "float"); [](start = 16, length = 34)

var floatScalar = ctx.AddInitializer(9.4f, "float");

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better than do
var floatScalar = model.Graph.Initializer[0];

If someone decide to change this test, and add one more initializer in the beginning, it will break tests.


In reply to: 220259227 [](ancestors = 220259227)

Copy link
Member Author

@wschin wschin Sep 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.AddInitializer returns the ONNX name of a tensor, not the actual ONNX object.
In addition, the test aims at examining the generated ONNX objects under the real scenario. Getting information before finishing the conversion is not ideal because it means we are checking an intermediate state of a conversion instead of the final state we care the most.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use that name of tensor to access initializer instead of model.Graph.Initalizer[0]?
If yes, i would suggest to use names instead of index.
If no, I find it strange, but i'm ok with current code.


In reply to: 220418126 [](ancestors = 220418126)

ctx.AddInitializer(17L, "int64");
ctx.AddInitializer("36", "string");
ctx.AddInitializer(new List<float> { 9.4f, 1.7f, 3.6f }, new List<long> { 1, 3 }, "floats");
ctx.AddInitializer(new List<long> { 94L, 17L, 36L }, new List<long> { 1, 3 }, "int64s");
ctx.AddInitializer(new List<string> { "94" , "17", "36" }, new List<long> { 1, 3 }, "strings");

var model = ctxImpl.MakeModel();

var floatScalar = model.Graph.Initializer[0];
Assert.True(floatScalar.Name == "float");
Assert.True(floatScalar.Dims.Count == 0);
Assert.True(floatScalar.FloatData.Count == 1);
Assert.True(floatScalar.FloatData[0] == 9.4f);

var int64Scalar = model.Graph.Initializer[1];
Assert.True(int64Scalar.Name == "int64");
Assert.True(int64Scalar.Dims.Count == 0);
Assert.True(int64Scalar.Int64Data.Count == 1);
Assert.True(int64Scalar.Int64Data[0] == 17L);

var stringScalar = model.Graph.Initializer[2];
Assert.True(stringScalar.Name == "string");
Assert.True(stringScalar.Dims.Count == 0);
Assert.True(stringScalar.StringData.Count == 1);
Assert.True(stringScalar.StringData[0].ToStringUtf8() == "36");

var floatsTensor = model.Graph.Initializer[3];
Assert.True(floatsTensor.Dims.Count == 2);
Assert.True(floatsTensor.Dims[0] == 1);
Assert.True(floatsTensor.Dims[1] == 3);
Assert.True(floatsTensor.FloatData.Count == 3);
Assert.True(floatsTensor.FloatData[0] == 9.4f);
Assert.True(floatsTensor.FloatData[1] == 1.7f);
Assert.True(floatsTensor.FloatData[2] == 3.6f);

var int64sTensor = model.Graph.Initializer[4];
Assert.True(int64sTensor.Dims.Count == 2);
Assert.True(int64sTensor.Dims[0] == 1);
Assert.True(int64sTensor.Dims[1] == 3);
Assert.True(int64sTensor.Int64Data.Count == 3);
Assert.True(int64sTensor.Int64Data[0] == 94L);
Assert.True(int64sTensor.Int64Data[1] == 17L);
Assert.True(int64sTensor.Int64Data[2] == 36L);

var stringsTensor = model.Graph.Initializer[5];
Assert.True(stringsTensor.Dims.Count == 2);
Assert.True(stringsTensor.Dims[0] == 1);
Assert.True(stringsTensor.Dims[1] == 3);
Assert.True(stringsTensor.StringData.Count == 3);
Assert.True(stringsTensor.StringData[0].ToStringUtf8() == "94");
Assert.True(stringsTensor.StringData[1].ToStringUtf8() == "17");
Assert.True(stringsTensor.StringData[2].ToStringUtf8() == "36");
}
}

[Fact]
public void BinaryClassificationFastTreeSaveModelToOnnxTest()
{
Expand Down