diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index 63a15c977c..a1373d87ec 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -18,7 +18,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InferenceTesti
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Data", "src\Microsoft.ML.Data\Microsoft.ML.Data.csproj", "{AD92D96B-0E96-4F22-8DCE-892E13B1F282}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.UniversalModelFormat", "src\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Onnx", "src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StandardLearners", "src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj", "{707BB22C-7E5F-497A-8C2F-74578F675705}"
EndProject
diff --git a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj
new file mode 100644
index 0000000000..bcc86939e2
--- /dev/null
+++ b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj
@@ -0,0 +1,13 @@
+
+
+
+ netstandard2.0
+ ML.NET component for exporting ONNX Models
+
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj
new file mode 100644
index 0000000000..07807bb54b
--- /dev/null
+++ b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
index e28b198bcf..2886b39ad3 100644
--- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
+++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
@@ -6,7 +6,6 @@
-
diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
index 1256bf75ba..5a5a67ccda 100644
--- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
+++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
@@ -19,13 +19,13 @@
+
-
diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
index cab6467043..a2ab3a7b16 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
@@ -85,7 +85,7 @@ private static VersionInfo GetVersionInfo()
/// Returns the underlying data view of the composite loader.
/// This can be used to programmatically explore the chain of transforms that's inside the composite loader.
///
- internal IDataView View { get; }
+ public IDataView View { get; }
///
/// Creates a loader according to the specified .
diff --git a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
index 38fb6075ce..8d5b0fd2d0 100644
--- a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
+++ b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
@@ -15,7 +15,6 @@
-
diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
index 0e25840c3b..36d839b93d 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
+++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
@@ -19,7 +19,7 @@ public interface ICanSaveOnnx
}
///
- /// This data model component is savable as Onnx.
+ /// This data model component is savable as ONNX.
///
public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
{
diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
index 46c083173f..bdef784b29 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
+++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
@@ -2,245 +2,101 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Collections.Generic;
-using System.Linq;
-using Microsoft.ML.Runtime.UniversalModelFormat.Onnx;
using Microsoft.ML.Runtime.Data;
namespace Microsoft.ML.Runtime.Model.Onnx
{
///
- /// A context for defining a ONNX output.
+ /// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This
+ /// same context object is iteratively given to exportable components via the interface
+ /// and subinterfaces, that attempt to express their operations as ONNX nodes, if they can. At the point that it is
+ /// given to a component, all other components up to that component have already attempted to express themselves in
+ /// this context, with their outputs possibly available in the ONNX graph.
///
- public sealed class OnnxContext
+ public abstract class OnnxContext
{
- private readonly List _nodes;
- private readonly List _inputs;
- private readonly List _intermediateValues;
- private readonly List _outputs;
- private readonly Dictionary _columnNameMap;
- private readonly HashSet _variableMap;
- private readonly HashSet _nodeNames;
- private readonly string _name;
- private readonly string _producerName;
- private readonly IHost _host;
- private readonly string _domain;
- private readonly string _producerVersion;
- private readonly long _modelVersion;
-
- public OnnxContext(IHostEnvironment env, string name, string producerName,
- string producerVersion, long modelVersion, string domain)
- {
- Contracts.CheckValue(env, nameof(env));
- Contracts.CheckValue(name, nameof(name));
- Contracts.CheckValue(name, nameof(domain));
-
- _host = env.Register(nameof(OnnxContext));
- _nodes = new List();
- _intermediateValues = new List();
- _inputs = new List();
- _outputs = new List();
- _columnNameMap = new Dictionary();
- _variableMap = new HashSet();
- _nodeNames = new HashSet();
- _name = name;
- _producerName = producerName;
- _producerVersion = producerVersion;
- _modelVersion = modelVersion;
- _domain = domain;
- }
-
- public bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
-
- ///
- /// Stops tracking a column. If removeVariable is true then it also removes the
- /// variable associated with it, this is useful in the event where an output variable is
- /// created before realizing the transform cannot actually save as ONNX.
- ///
- /// IDataView column name to stop tracking
- /// Remove associated ONNX variable at the time.
- public void RemoveColumn(string colName, bool removeVariable)
- {
-
- if (removeVariable)
- {
- foreach (var val in _intermediateValues)
- {
- if (val.Name == _columnNameMap[colName])
- {
- _intermediateValues.Remove(val);
- break;
- }
- }
- }
-
- if (_columnNameMap.ContainsKey(colName))
- _columnNameMap.Remove(colName);
- }
-
- ///
- /// Removes an ONNX variable. If removeColumn is true then it also removes the
- /// IDataView column associated with it.
- ///
- /// ONNX variable to remove.
- /// IDataView column to stop tracking
- public void RemoveVariable(string variableName, bool removeColumn)
- {
- _host.Assert(_columnNameMap.ContainsValue(variableName));
- if (removeColumn)
- {
- foreach (var val in _intermediateValues)
- {
- if (val.Name == variableName)
- {
- _intermediateValues.Remove(val);
- break;
- }
- }
- }
-
- string columnName = _columnNameMap.Single(kvp => string.Compare(kvp.Value, variableName) == 0).Key;
-
- Contracts.Assert(_variableMap.Contains(columnName));
-
- _columnNameMap.Remove(columnName);
- _variableMap.Remove(columnName);
- }
-
///
/// Generates a unique name for the node based on a prefix.
///
- public string GetNodeName(string prefix)
- {
- _host.CheckValue(prefix, nameof(prefix));
- return GetUniqueName(prefix, c => _nodeNames.Contains(c));
- }
+ /// The prefix for the node
+ /// A name that has not yet been returned from this function, starting with
+ public abstract string GetNodeName(string prefix);
///
- /// Adds a node to the node list of the graph.
+ /// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can
+ /// safely call .
///
- ///
- public void AddNode(NodeProto node)
- {
- _host.CheckValue(node, nameof(node));
- _host.Assert(!_nodeNames.Contains(node.Name));
-
- _nodeNames.Add(node.Name);
- _nodes.Add(node);
- }
+ /// The data view column name
+ /// Whether the column is mapped in this context
+ public abstract bool ContainsColumn(string colName);
///
- /// Generates a unique name based on a prefix.
+ /// Stops tracking a column.
///
- private string GetUniqueName(string prefix, Func pred)
- {
- _host.CheckValue(prefix, nameof(prefix));
- _host.CheckValue(pred, nameof(pred));
-
- if (!pred(prefix))
- return prefix;
-
- int count = 0;
- while (pred(prefix + count++)) ;
- return prefix + --count;
- }
+ /// Column name to stop tracking
+ /// Remove associated ONNX variable. This is useful in the event where an output
+ /// variable is created through before realizing
+ /// the transform cannot actually save as ONNX.
+ public abstract void RemoveColumn(string colName, bool removeVariable = false);
///
- /// Retrieves the variable name that maps to the IDataView column name at a
- /// given point in the pipeline execution.
+ /// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the column associated with it.
///
- /// Column Name mapping.
- public string GetVariableName(string colName)
- {
- _host.CheckValue(colName, nameof(colName));
- _host.Assert(_columnNameMap.ContainsKey(colName));
-
- return _columnNameMap[colName];
- }
-
- ///
- /// Retrieves the variable name that maps to the IDataView column name at a
- /// given point in the pipeline execution.
- ///
- /// Column Name mapping.
- public string TryGetVariableName(string colName)
- {
- if (_columnNameMap.ContainsKey(colName))
- return GetVariableName(colName);
-
- return null;
- }
-
- ///
- /// Generates a unique column name based on the IDataView column name if
- /// there is a collision between names in the pipeline at any point.
- ///
- /// IDataView column name.
- /// Unique variable name.
- private string AddVariable(string colName)
- {
- _host.CheckValue(colName, nameof(colName));
-
- if (!_columnNameMap.ContainsKey(colName))
- _columnNameMap.Add(colName, colName);
- else
- _columnNameMap[colName] = GetUniqueName(colName, s => _variableMap.Contains(s));
-
- _variableMap.Add(_columnNameMap[colName]);
- return _columnNameMap[colName];
- }
+ /// ONNX variable to remove. Note that this is an ONNX variable name, not an column name
+ /// IDataView column to stop tracking
+ public abstract void RemoveVariable(string variableName, bool removeColumn);
///
- /// Adds an intermediate column to the list.
+ /// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding
+ /// 's column names will map to a variable in the ONNX graph if the intermediate steps
+ /// used to calculate that value are things we knew how to save as ONNX. Retrieves the variable name that maps
+ /// to the column name at a given point in the pipeline execution. Callers should
+ /// probably confirm with whether a mapping for that data view column
+ /// already exists.
///
- public string AddIntermediateVariable(ColumnType type, string colName, bool skip = false)
- {
-
- colName = AddVariable(colName);
-
- //Let the runtime figure the shape.
- if (!skip)
- {
- _host.CheckValue(type, nameof(type));
-
- _intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName));
- }
-
- return colName;
- }
+ /// The data view column name
+ /// The ONNX variable name corresponding to that data view column
+ public abstract string GetVariableName(string colName);
///
- /// Adds an output variable to the list.
+ /// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and
+ /// returns that newly allocated name.
///
- public string AddOutputVariable(ColumnType type, string colName, List dim = null)
- {
- _host.CheckValue(type, nameof(type));
-
- if (!ContainsColumn(colName))
- AddVariable(colName);
-
- colName = GetVariableName(colName);
- _outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim));
- return colName;
- }
+ /// The data view type associated with this column name
+ /// The data view column name
+ /// Whether we should skip the process of establishing the mapping from data view column to
+ /// ONNX variable name.
+ /// The returned value is the name of the variable corresponding
+ public abstract string AddIntermediateVariable(ColumnType type, string colName, bool skip = false);
///
- /// Adds an input variable to the list.
+ /// Creates an ONNX node
///
- public void AddInputVariable(ColumnType type, string colName)
- {
- _host.CheckValue(type, nameof(type));
- _host.CheckValue(colName, nameof(colName));
-
- colName = AddVariable(colName);
- _inputs.Add(OnnxUtils.GetModelArgs(type, colName));
- }
+ /// The name of the ONNX operator to apply
+ /// The names of the variables as inputs
+ /// The names of the variables to create as outputs,
+ /// which ought to have been something returned from
+ /// The name of the operator, which ought to be something returned from
+ /// The domain of the ONNX operator, if non-default
+ /// A node added to the in-progress ONNX graph, that attributes can be set on
+ public abstract OnnxNode CreateNode(string opType, IEnumerable inputs,
+ IEnumerable outputs, string name, string domain = null);
///
- /// Makes the ONNX model based on the context.
+ /// Convenience alternative to
+ /// for the case where there is exactly one input and output.
///
- public ModelProto MakeModel()
- => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
+ /// The name of the ONNX operator to apply
+ /// The name of the variable as input
+ /// The name of the variable as output,
+ /// which ought to have been something returned from
+ /// The name of the operator, which ought to be something returned from
+ /// The domain of the ONNX operator, if non-default
+ /// A node added to the in-progress ONNX graph, that attributes can be set on
+ public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
+ => CreateNode(opType, new[] { input }, new[] { output }, name, domain);
}
}
diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs
new file mode 100644
index 0000000000..259a6d27d4
--- /dev/null
+++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs
@@ -0,0 +1,32 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using Microsoft.ML.Runtime.Data;
+
+namespace Microsoft.ML.Runtime.Model.Onnx
+{
+ ///
+ /// An abstraction for an ONNX node as created by
+ /// .
+ /// That method creates a with inputs and outputs, but this object can modify the node further
+ /// by adding attributes (in ONNX parlance, attributes are more or less constant parameterizations).
+ ///
+ public abstract class OnnxNode
+ {
+ public abstract void AddAttribute(string argName, double value);
+ public abstract void AddAttribute(string argName, long value);
+ public abstract void AddAttribute(string argName, DvText value);
+ public abstract void AddAttribute(string argName, string value);
+ public abstract void AddAttribute(string argName, bool value);
+
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, string[] value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ public abstract void AddAttribute(string argName, IEnumerable value);
+ }
+}
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index 487726572b..835ba5d99a 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -1437,19 +1437,14 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, str
string opType = "Affine";
string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true);
- var node = OnnxUtils.MakeNode(opType, new List { scoreProbablityColumnNames[0] },
- new List { linearOutput }, ctx.GetNodeName(opType), "ai.onnx");
-
- OnnxUtils.NodeAddAttributes(node, "alpha", ParamA * -1);
- OnnxUtils.NodeAddAttributes(node, "beta", -0.0000001);
-
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0] },
+ new[] { linearOutput }, ctx.GetNodeName(opType), "ai.onnx");
+ node.AddAttribute("alpha", ParamA * -1);
+ node.AddAttribute("beta", -0.0000001);
opType = "Sigmoid";
- node = OnnxUtils.MakeNode(opType, new List { linearOutput },
- new List { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx");
-
- ctx.AddNode(node);
+ node = ctx.CreateNode(opType, new[] { linearOutput },
+ new[] { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx");
return true;
}
diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
index 6fde9a5815..91f799b2e7 100644
--- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
@@ -206,11 +206,9 @@ public override void SaveAsOnnx(OnnxContext ctx)
if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2]))
{
string opType = "Binarizer";
- var node = OnnxUtils.MakeNode(opType, new List { ctx.GetVariableName(outColumnNames[2]) },
- new List { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType));
-
- OnnxUtils.NodeAddAttributes(node, "threshold", 0.5);
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, new[] { ctx.GetVariableName(outColumnNames[2]) },
+ new[] { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType));
+ node.AddAttribute("threshold", 0.5);
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
index a6c0e3f490..db572fc93f 100644
--- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
@@ -720,13 +720,11 @@ public void SaveAsOnnx(OnnxContext ctx)
Source.Schema.GetColumnType(srcIndex).ValueCount));
}
- var node = OnnxUtils.MakeNode(opType, new List(inputList.Select(t => t.Key)),
- new List { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType));
+ var node = ctx.CreateNode(opType, inputList.Select(t => t.Key),
+ new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType));
- ctx.AddNode(node);
-
- OnnxUtils.NodeAddAttributes(node, "inputList", inputList.Select(x => x.Key));
- OnnxUtils.NodeAddAttributes(node, "inputdimensions", inputList.Select(x => x.Value));
+ node.AddAttribute("inputList", inputList.Select(x => x.Key));
+ node.AddAttribute("inputdimensions", inputList.Select(x => x.Value));
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
index bffbaa881c..d177d75647 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
@@ -244,10 +244,9 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo
protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
string opType = "OneHotEncoder";
- var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
- OnnxUtils.NodeAddAttributes(node, "cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x));
- OnnxUtils.NodeAddAttributes(node, "zeros", true);
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
+ node.AddAttribute("cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x));
+ node.AddAttribute("zeros", true);
return true;
}
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
index f6e8851f51..6092c2d8b9 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
@@ -434,8 +434,8 @@ private AffineColumnFunction(IHost host)
public abstract void Save(ModelSaveContext ctx);
public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);
-
- public abstract bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount);
+ public bool CanSaveOnnx => true;
+ public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount);
public abstract Delegate GetGetter(IRow input, int icol);
@@ -546,10 +546,10 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return null;
}
- public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
- {
- return false;
- }
+ public bool CanSaveOnnx => false;
+
+ public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
+ => throw Host.ExceptNotSupp();
public abstract Delegate GetGetter(IRow input, int icol);
@@ -671,10 +671,10 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return null;
}
- public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
- {
- return false;
- }
+ public bool CanSaveOnnx => false;
+
+ public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
+ => throw Host.ExceptNotSupp();
public abstract Delegate GetGetter(IRow input, int icol);
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
index 41e55ee338..e577b9370e 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
@@ -577,10 +577,10 @@ public override void Save(ModelSaveContext ctx)
public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
=> PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale);
- public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
+ public override bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
{
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(Offset, featureCount));
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Enumerable.Repeat(Scale, featureCount));
+ nodeProtoWrapper.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount));
+ nodeProtoWrapper.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount));
return true;
}
@@ -648,12 +648,12 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType)));
}
- public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
+ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount)
{
- if (Offset != null)
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Offset);
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Scale);
+ if (Offset != null)
+ node.AddAttribute("offset", Offset);
+ node.AddAttribute("scale", Scale);
return true;
}
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
index ef5eef8551..4c6e1fb011 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
@@ -577,10 +577,10 @@ public override void Save(ModelSaveContext ctx)
public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
=> PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale);
- public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
+ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount)
{
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(Offset, featureCount));
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Enumerable.Repeat(Scale, featureCount));
+ node.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount));
+ node.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount));
return true;
}
@@ -648,14 +648,14 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType)));
}
- public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
+ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount)
{
if (Offset != null)
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Offset);
+ node.AddAttribute("offset", Offset);
else
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(0, featureCount));
+ node.AddAttribute("offset", Enumerable.Repeat(0, featureCount));
- OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Scale);
+ node.AddAttribute("scale", Scale);
return true;
}
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
index 318bbb44f1..22fa9686ed 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
@@ -66,7 +66,9 @@ public interface IColumnFunction : ICanSaveModel
JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);
- bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount);
+ bool CanSaveOnnx { get; }
+
+ bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount);
}
public sealed partial class NormalizeTransform : OneToOneTransformBase
@@ -316,11 +318,11 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info,
if (info.TypeSrc.ValueCount == 0)
return false;
- string opType = "Scaler";
- var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
- if (_functions[iinfo].OnnxInfo(ctx, new OnnxUtils.NodeProtoWrapper(node), info.TypeSrc.ValueCount))
+ if (_functions[iinfo].CanSaveOnnx)
{
- ctx.AddNode(node);
+ string opType = "Scaler";
+ var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
+ _functions[iinfo].OnnxInfo(ctx, node, info.TypeSrc.ValueCount);
return true;
}
diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
index da7442f90e..bb9adf21e1 100644
--- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
@@ -690,11 +690,10 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info,
TermMap map = (TermMap)_termMap[iinfo].Map;
map.GetTerms(ref terms);
string opType = "LabelEncoder";
- var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
- OnnxUtils.NodeAddAttributes(node, "classes_strings", terms.DenseValues());
- OnnxUtils.NodeAddAttributes(node, "default_int64", -1);
- OnnxUtils.NodeAddAttributes(node, "default_string", DvText.Empty);
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
+ node.AddAttribute("classes_strings", terms.DenseValues());
+ node.AddAttribute("default_int64", -1);
+ node.AddAttribute("default_string", DvText.Empty);
return true;
}
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index 633c484716..01668eb3f2 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -3130,26 +3130,24 @@ public virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string fea
}
string opType = "TreeEnsembleRegressor";
- var node = OnnxUtils.MakeNode(opType, new List { featureColumn },
- new List(outputNames), ctx.GetNodeName(opType));
-
- OnnxUtils.NodeAddAttributes(node, "post_transform", PostTransform.None.GetDescription());
- OnnxUtils.NodeAddAttributes(node, "n_targets", 1);
- OnnxUtils.NodeAddAttributes(node, "base_values", new List() { 0 });
- OnnxUtils.NodeAddAttributes(node, "aggregate_function", AggregateFunction.Sum.GetDescription());
- OnnxUtils.NodeAddAttributes(node, "nodes_treeids", nodesTreeids);
- OnnxUtils.NodeAddAttributes(node, "nodes_nodeids", nodesIds);
- OnnxUtils.NodeAddAttributes(node, "nodes_featureids", nodesFeatureIds);
- OnnxUtils.NodeAddAttributes(node, "nodes_modes", nodeModes);
- OnnxUtils.NodeAddAttributes(node, "nodes_values", nodesValues);
- OnnxUtils.NodeAddAttributes(node, "nodes_truenodeids", nodesTrueNodeIds);
- OnnxUtils.NodeAddAttributes(node, "nodes_falsenodeids", nodesFalseNodeIds);
- OnnxUtils.NodeAddAttributes(node, "nodes_missing_value_tracks_true", missingValueTracksTrue);
- OnnxUtils.NodeAddAttributes(node, "target_treeids", classTreeIds);
- OnnxUtils.NodeAddAttributes(node, "target_nodeids", classNodeIds);
- OnnxUtils.NodeAddAttributes(node, "target_ids", classIds);
- OnnxUtils.NodeAddAttributes(node, "target_weights", classWeights);
- ctx.AddNode(node);
+ var node = ctx.CreateNode(opType, new[] { featureColumn }, outputNames, ctx.GetNodeName(opType));
+
+ node.AddAttribute("post_transform", PostTransform.None.GetDescription());
+ node.AddAttribute("n_targets", 1);
+ node.AddAttribute("base_values", new List() { 0 });
+ node.AddAttribute("aggregate_function", AggregateFunction.Sum.GetDescription());
+ node.AddAttribute("nodes_treeids", nodesTreeids);
+ node.AddAttribute("nodes_nodeids", nodesIds);
+ node.AddAttribute("nodes_featureids", nodesFeatureIds);
+ node.AddAttribute("nodes_modes", nodeModes);
+ node.AddAttribute("nodes_values", nodesValues);
+ node.AddAttribute("nodes_truenodeids", nodesTrueNodeIds);
+ node.AddAttribute("nodes_falsenodeids", nodesFalseNodeIds);
+ node.AddAttribute("nodes_missing_value_tracks_true", missingValueTracksTrue);
+ node.AddAttribute("target_treeids", classTreeIds);
+ node.AddAttribute("target_nodeids", classNodeIds);
+ node.AddAttribute("target_ids", classIds);
+ node.AddAttribute("target_weights", classWeights);
return true;
}
diff --git a/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj b/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj
similarity index 52%
rename from src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj
rename to src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj
index 4244681bd6..145dd8be8c 100644
--- a/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj
+++ b/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj
@@ -2,11 +2,16 @@
netstandard2.0
- Microsoft.ML
+ Microsoft.ML.Onnx
+ Microsoft.ML.Runtime.Model.Onnx
+
+
+
+
diff --git a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs
new file mode 100644
index 0000000000..f37a1ea557
--- /dev/null
+++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs
@@ -0,0 +1,255 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Runtime.UniversalModelFormat.Onnx;
+using Microsoft.ML.Runtime.Data;
+
+namespace Microsoft.ML.Runtime.Model.Onnx
+{
+ ///
+ /// A context for defining a ONNX output.
+ ///
+ internal sealed class OnnxContextImpl : OnnxContext
+ {
+ private readonly List _nodes;
+ private readonly List _inputs;
+ // The map from IDataView column names to variable names.
+ private readonly List _intermediateValues;
+ private readonly List _outputs;
+ private readonly Dictionary _columnNameMap;
+ // All existing variable names. New variables must not exist in this set.
+ private readonly HashSet _variableNames;
+ // All existing node names. New node names must not alrady exist in this set.
+ private readonly HashSet _nodeNames;
+ private readonly string _name;
+ private readonly string _producerName;
+ private readonly IHost _host;
+ private readonly string _domain;
+ private readonly string _producerVersion;
+ private readonly long _modelVersion;
+
+ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
+ string producerVersion, long modelVersion, string domain)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(OnnxContext));
+ _host.CheckValue(name, nameof(name));
+ _host.CheckValue(name, nameof(domain));
+
+ _nodes = new List();
+ _intermediateValues = new List();
+ _inputs = new List();
+ _outputs = new List();
+ _columnNameMap = new Dictionary();
+ _variableNames = new HashSet();
+ _nodeNames = new HashSet();
+ _name = name;
+ _producerName = producerName;
+ _producerVersion = producerVersion;
+ _modelVersion = modelVersion;
+ _domain = domain;
+ }
+
+ public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
+
+ ///
+ /// Stops tracking a column. If removeVariable is true then it also removes the
+ /// variable associated with it, this is useful in the event where an output variable is
+ /// created before realizing the transform cannot actually save as ONNX.
+ ///
+ /// IDataView column name to stop tracking
+ /// Remove associated ONNX variable at the time.
+ public override void RemoveColumn(string colName, bool removeVariable)
+ {
+ _host.CheckNonEmpty(colName, nameof(colName));
+
+ if (removeVariable)
+ {
+ foreach (var val in _intermediateValues)
+ {
+ if (val.Name == _columnNameMap[colName])
+ {
+ _intermediateValues.Remove(val);
+ break;
+ }
+ }
+ }
+ _columnNameMap.Remove(colName);
+ }
+
+ ///
+ /// Removes an ONNX variable. If removeColumn is true then it also removes the
+ /// IDataView column associated with it.
+ ///
+ /// ONNX variable to remove.
+ /// IDataView column to stop tracking
+ public override void RemoveVariable(string variableName, bool removeColumn)
+ {
+ _host.CheckNonEmpty(variableName, nameof(variableName));
+ if (!_columnNameMap.ContainsValue(variableName))
+ throw _host.ExceptParam(nameof(variableName), $"Could not find '{variableName}' declared in ONNX graph");
+
+ if (removeColumn)
+ {
+ foreach (var val in _intermediateValues)
+ {
+ if (val.Name == variableName)
+ {
+ _intermediateValues.Remove(val);
+ break;
+ }
+ }
+ }
+
+ string columnName = _columnNameMap.Single(kvp => kvp.Value == variableName).Key;
+
+ Contracts.Assert(_variableNames.Contains(columnName));
+
+ _columnNameMap.Remove(columnName);
+ _variableNames.Remove(columnName);
+ }
+
+ ///
+ /// Generates a unique name for the node based on a prefix.
+ ///
+ public override string GetNodeName(string prefix)
+ {
+ _host.CheckNonEmpty(prefix, nameof(prefix));
+ return GetUniqueName(prefix, _nodeNames.Contains);
+ }
+
+ ///
+ /// Adds a node to the node list of the graph.
+ ///
+ ///
+ private void AddNode(NodeProto node)
+ {
+ _host.CheckValue(node, nameof(node));
+ _host.Assert(!_nodeNames.Contains(node.Name));
+
+ _nodeNames.Add(node.Name);
+ _nodes.Add(node);
+ }
+
+ public override OnnxNode CreateNode(string opType, IEnumerable inputs,
+ IEnumerable outputs, string name, string domain = null)
+ {
+ _host.CheckNonEmpty(opType, nameof(opType));
+ _host.CheckValue(inputs, nameof(inputs));
+ _host.CheckValue(outputs, nameof(outputs));
+ _host.CheckNonEmpty(name, nameof(name));
+
+ var innerNode = OnnxUtils.MakeNode(opType, inputs, outputs, name, domain);
+ AddNode(innerNode);
+ return new OnnxNodeImpl(innerNode);
+ }
+
+ ///
+ /// Generates a unique name based on a prefix.
+ ///
+ private string GetUniqueName(string prefix, Func pred)
+ {
+ _host.CheckNonEmpty(prefix, nameof(prefix));
+ _host.CheckValue(pred, nameof(pred));
+
+ if (!pred(prefix))
+ return prefix;
+
+ int count = 0;
+ while (pred(prefix + count++)) ;
+ return prefix + --count;
+ }
+
+ ///
+ /// Retrieves the variable name that maps to the IDataView column name at a
+ /// given point in the pipeline execution.
+ ///
+ /// Column Name mapping.
+ public override string GetVariableName(string colName)
+ {
+ _host.CheckNonEmpty(colName, nameof(colName));
+ _host.Assert(_columnNameMap.ContainsKey(colName));
+
+ return _columnNameMap[colName];
+ }
+
+ ///
+ /// Retrieves the variable name that maps to the IDataView column name at a
+ /// given point in the pipeline execution.
+ ///
+ /// Column Name mapping.
+ public string TryGetVariableName(string colName)
+ {
+ _host.CheckNonEmpty(colName, nameof(colName));
+ if (_columnNameMap.ContainsKey(colName))
+ return GetVariableName(colName);
+ return null;
+ }
+
+ ///
+ /// Generates a unique column name based on the IDataView column name if
+ /// there is a collision between names in the pipeline at any point.
+ ///
+ /// IDataView column name.
+ /// Unique variable name.
+ private string AddVariable(string colName)
+ {
+ _host.CheckNonEmpty(colName, nameof(colName));
+ _columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains);
+ _variableNames.Add(_columnNameMap[colName]);
+ return _columnNameMap[colName];
+ }
+
+ ///
+ /// Adds an intermediate column to the list.
+ ///
+ public override string AddIntermediateVariable(ColumnType type, string colName, bool skip = false)
+ {
+ colName = AddVariable(colName);
+ // Let the runtime figure the shape.
+ if (!skip)
+ {
+ _host.CheckValue(type, nameof(type));
+ _intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName));
+ }
+ return colName;
+ }
+
+ ///
+ /// Adds an output variable to the list.
+ ///
+ public string AddOutputVariable(ColumnType type, string colName, List dim = null)
+ {
+ _host.CheckValue(type, nameof(type));
+
+ if (!ContainsColumn(colName))
+ AddVariable(colName);
+
+ colName = GetVariableName(colName);
+ _outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim));
+ return colName;
+ }
+
+ ///
+ /// Adds an input variable to the list.
+ ///
+ public void AddInputVariable(ColumnType type, string colName)
+ {
+ _host.CheckValue(type, nameof(type));
+ _host.CheckValue(colName, nameof(colName));
+
+ colName = AddVariable(colName);
+ _inputs.Add(OnnxUtils.GetModelArgs(type, colName));
+ }
+
+ ///
+ /// Makes the ONNX model based on the context.
+ ///
+ public ModelProto MakeModel()
+ => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
+ }
+}
diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs b/src/Microsoft.ML.Onnx/OnnxMl.cs
similarity index 100%
rename from src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs
rename to src/Microsoft.ML.Onnx/OnnxMl.cs
diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md b/src/Microsoft.ML.Onnx/OnnxMl.md
similarity index 100%
rename from src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md
rename to src/Microsoft.ML.Onnx/OnnxMl.md
diff --git a/src/Microsoft.ML.Onnx/OnnxNodeImpl..cs b/src/Microsoft.ML.Onnx/OnnxNodeImpl..cs
new file mode 100644
index 0000000000..9b30fd1d87
--- /dev/null
+++ b/src/Microsoft.ML.Onnx/OnnxNodeImpl..cs
@@ -0,0 +1,46 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.UniversalModelFormat.Onnx;
+
+namespace Microsoft.ML.Runtime.Model.Onnx
+{
+ internal sealed class OnnxNodeImpl : OnnxNode
+ {
+ private readonly NodeProto _node;
+
+ public OnnxNodeImpl(NodeProto node)
+ {
+ Contracts.AssertValue(node);
+ _node = node;
+ }
+
+ public override void AddAttribute(string argName, double value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, IEnumerable value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, IEnumerable value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, IEnumerable value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, long value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, IEnumerable value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, DvText value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, string[] value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, IEnumerable value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, IEnumerable value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, string value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ public override void AddAttribute(string argName, bool value)
+ => OnnxUtils.NodeAddAttributes(_node, argName, value);
+ }
+}
diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs b/src/Microsoft.ML.Onnx/OnnxUtils.cs
similarity index 95%
rename from src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs
rename to src/Microsoft.ML.Onnx/OnnxUtils.cs
index 8ad9f40c20..be5a833643 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs
+++ b/src/Microsoft.ML.Onnx/OnnxUtils.cs
@@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@@ -15,7 +14,7 @@ namespace Microsoft.ML.Runtime.Model.Onnx
///
/// Contains methods to create ONNX models in protocol buffer.
///
- public sealed class OnnxUtils
+ internal static class OnnxUtils
{
private static TypeProto MakeType(TypeProto typeProto, TensorProto.Types.DataType dataType,
List dims, List dimsParam)
@@ -153,7 +152,7 @@ private static AttributeProto MakeAttribute(string key, IEnumerable
private static AttributeProto MakeAttribute(string key, bool value) => MakeAttribute(key, value ? 1 : 0);
- public static NodeProto MakeNode(string opType, List inputs, List outputs, string name, string domain = null)
+ public static NodeProto MakeNode(string opType, IEnumerable inputs, IEnumerable outputs, string name, string domain = null)
{
Contracts.CheckNonEmpty(opType, nameof(opType));
Contracts.CheckValue(inputs, nameof(inputs));
@@ -169,11 +168,6 @@ public static NodeProto MakeNode(string opType, List inputs, List() { inputs }, new List() { outputs }, name);
- }
-
public static void NodeAddAttributes(NodeProto node, string argName, double value)
=> node.Attribute.Add(MakeAttribute(argName, value));
@@ -241,16 +235,6 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List di
}
}
- public sealed class NodeProtoWrapper
- {
- public NodeProto Node;
-
- public NodeProtoWrapper(NodeProto node)
- {
- Node = node;
- }
- }
-
public static ModelProto MakeModel(List nodes, string producerName, string name,
string domain, string producerVersion, long modelVersion, List inputs,
List outputs, List intermediateValues)
diff --git a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
similarity index 97%
rename from src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs
rename to src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
index d2dfc93fde..6b72d64af5 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs
+++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
@@ -12,7 +12,6 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model.Onnx;
-using Microsoft.ML.Runtime.UniversalModelFormat.Onnx;
using Newtonsoft.Json;
[assembly: LoadableClass(SaveOnnxCommand.Summary, typeof(SaveOnnxCommand), typeof(SaveOnnxCommand.Arguments), typeof(SignatureCommand),
@@ -69,7 +68,6 @@ public sealed class Arguments : DataCommand.ArgumentsBase
private readonly HashSet _outputsToDrop;
private readonly ITransformModel _model;
private const string ProducerName = "ML.NET";
- private const string ProducerVersion = "0.2.0.0000";
private const long ModelVersion = 0;
public SaveOnnxCommand(IHostEnvironment env, Arguments args)
@@ -164,7 +162,11 @@ private void Run(IChannel ch)
GetPipe(ch, view, out source, out end, out transforms);
Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);
- var ctx = new OnnxContext(Host, _name, ProducerName, ProducerVersion, ModelVersion, _domain);
+ var assembly = System.Reflection.Assembly.GetExecutingAssembly();
+ var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
+
+ var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion,
+ ModelVersion, _domain);
// If we have a predictor, try to get the scorer for it.
if (rawPred != null)
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
index a3ac7ce72e..3fc63060e8 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
@@ -236,13 +236,12 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
Host.Check(Utils.Size(outputs) == 1);
string opType = "LinearRegressor";
- var node = OnnxUtils.MakeNode(opType, new List { featureColumn }, new List (outputs), ctx.GetNodeName(opType));
+ var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));
// Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
- OnnxUtils.NodeAddAttributes(node, "post_transform", 0);
- OnnxUtils.NodeAddAttributes(node, "targets", 1);
- OnnxUtils.NodeAddAttributes(node, "coefficients", Weight.DenseValues());
- OnnxUtils.NodeAddAttributes(node, "intercepts", Bias);
- ctx.AddNode(node);
+ node.AddAttribute("post_transform", 0);
+ node.AddAttribute("targets", 1);
+ node.AddAttribute("coefficients", Weight.DenseValues());
+ node.AddAttribute("intercepts", Bias);
return true;
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index f2fad63794..5fbc6ab74a 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -844,14 +844,13 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
Host.CheckValue(ctx, nameof(ctx));
string opType = "LinearClassifier";
- var node = OnnxUtils.MakeNode(opType, new List { featureColumn }, new List(outputs), ctx.GetNodeName(opType));
+ var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));
// Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
- OnnxUtils.NodeAddAttributes(node, "post_transform", 0);
- OnnxUtils.NodeAddAttributes(node, "multi_class", true);
- OnnxUtils.NodeAddAttributes(node, "coefficients", _weights.SelectMany(w => w.DenseValues()));
- OnnxUtils.NodeAddAttributes(node, "intercepts", _biases);
- OnnxUtils.NodeAddAttributes(node, "classlabels_strings", _labelNames);
- ctx.AddNode(node);
+ node.AddAttribute("post_transform", 0);
+ node.AddAttribute("multi_class", true);
+ node.AddAttribute("coefficients", _weights.SelectMany(w => w.DenseValues()));
+ node.AddAttribute("intercepts", _biases);
+ node.AddAttribute("classlabels_strings", _labelNames);
return true;
}
diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs
index 52989472fc..44832ee517 100644
--- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs
+++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs
@@ -616,20 +616,19 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info,
return false;
string opType = "Imputer";
- var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
- OnnxUtils.NodeAddAttributes(node, "replaced_value_float", Single.NaN);
+ var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
+ node.AddAttribute("replaced_value_float", Single.NaN);
if (!Infos[iinfo].TypeSrc.IsVector)
- OnnxUtils.NodeAddAttributes(node, "imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1));
+ node.AddAttribute("imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1));
else
{
if (_repIsDefault[iinfo] != null)
- OnnxUtils.NodeAddAttributes(node, "imputed_value_floats", (float[])_repValues[iinfo]);
+ node.AddAttribute("imputed_value_floats", (float[])_repValues[iinfo]);
else
- OnnxUtils.NodeAddAttributes(node, "imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1));
+ node.AddAttribute("imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1));
}
- ctx.AddNode(node);
return true;
}
diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json
index 3afd6dc171..b2d611b784 100644
--- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json
+++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json
@@ -1,7 +1,7 @@
{
"irVersion": "3",
"producerName": "ML.NET",
- "producerVersion": "0.2.0.0000",
+ "producerVersion": "##VERSION##",
"domain": "Onnx",
"graph": {
"node": [
diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj
index bed3dce0eb..ea92975b35 100644
--- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj
+++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj
@@ -11,6 +11,7 @@
+
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index de0db6f3a9..43e208ad3c 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -329,12 +329,12 @@ public void EntryPointInputBuilderOptionals()
ib1.TrySetValue("WeightColumn", "OtherWeight");
Assert.True(instance.WeightColumn.IsExplicit);
- Assert.True(string.Compare(instance.WeightColumn.Value, "OtherWeight") == 0);
+ Assert.Equal("OtherWeight", instance.WeightColumn.Value);
var tok = (JToken)JValue.CreateString("AnotherWeight");
ib1.TrySetValueJson("WeightColumn", tok);
Assert.True(instance.WeightColumn.IsExplicit);
- Assert.True(string.Compare(instance.WeightColumn.Value, "AnotherWeight") == 0);
+ Assert.Equal("AnotherWeight", instance.WeightColumn.Value);
}
[Fact]
diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
index 8635fdf798..2a2ea8bca1 100644
--- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
+++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
@@ -11,6 +11,7 @@
+
diff --git a/test/Microsoft.ML.Tests/OnnxTests.cs b/test/Microsoft.ML.Tests/OnnxTests.cs
index 477a9c6fa6..f4428242be 100644
--- a/test/Microsoft.ML.Tests/OnnxTests.cs
+++ b/test/Microsoft.ML.Tests/OnnxTests.cs
@@ -9,6 +9,7 @@
using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.Trainers;
using System.IO;
+using System.Text.RegularExpressions;
using Xunit;
using Xunit.Abstractions;
@@ -86,6 +87,11 @@ public void BinaryClassificationSaveModelToOnnxTest()
converter.Convert(model);
+ // Strip the version.
+ var fileText = File.ReadAllText(onnxAsJsonPath);
+ fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\"");
+ File.WriteAllText(onnxAsJsonPath, fileText);
+
CheckEquality(subDir, "SaveModelToOnnxTest.json");
Done();
}