diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 63a15c977c..a1373d87ec 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -18,7 +18,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InferenceTesti EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Data", "src\Microsoft.ML.Data\Microsoft.ML.Data.csproj", "{AD92D96B-0E96-4F22-8DCE-892E13B1F282}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.UniversalModelFormat", "src\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Onnx", "src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StandardLearners", "src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj", "{707BB22C-7E5F-497A-8C2F-74578F675705}" EndProject diff --git a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj new file mode 100644 index 0000000000..bcc86939e2 --- /dev/null +++ b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj @@ -0,0 +1,13 @@ + + + + netstandard2.0 + ML.NET component for exporting ONNX Models + + + + + + + + diff --git a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj new file mode 100644 index 0000000000..07807bb54b --- /dev/null +++ b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj @@ -0,0 +1,5 @@ + + + + + diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj index e28b198bcf..2886b39ad3 100644 --- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj +++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj @@ -6,7 +6,6 @@ - diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj index 1256bf75ba..5a5a67ccda 100644 --- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj @@ -19,13 +19,13 @@ + - diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index cab6467043..a2ab3a7b16 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -85,7 +85,7 @@ private static VersionInfo GetVersionInfo() /// Returns the underlying data view of the composite loader. /// This can be used to programmatically explore the chain of transforms that's inside the composite loader. /// - internal IDataView View { get; } + public IDataView View { get; } /// /// Creates a loader according to the specified . diff --git a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj index 38fb6075ce..8d5b0fd2d0 100644 --- a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj +++ b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj @@ -15,7 +15,6 @@ - diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs index 0e25840c3b..36d839b93d 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs @@ -19,7 +19,7 @@ public interface ICanSaveOnnx } /// - /// This data model component is savable as Onnx. + /// This data model component is savable as ONNX. /// public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform { diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 46c083173f..bdef784b29 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -2,245 +2,101 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; -using System.Linq; -using Microsoft.ML.Runtime.UniversalModelFormat.Onnx; using Microsoft.ML.Runtime.Data; namespace Microsoft.ML.Runtime.Model.Onnx { /// - /// A context for defining a ONNX output. + /// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This + /// same context object is iteratively given to exportable components via the interface + /// and subinterfaces, that attempt to express their operations as ONNX nodes, if they can. At the point that it is + /// given to a component, all other components up to that component have already attempted to express themselves in + /// this context, with their outputs possibly available in the ONNX graph. /// - public sealed class OnnxContext + public abstract class OnnxContext { - private readonly List _nodes; - private readonly List _inputs; - private readonly List _intermediateValues; - private readonly List _outputs; - private readonly Dictionary _columnNameMap; - private readonly HashSet _variableMap; - private readonly HashSet _nodeNames; - private readonly string _name; - private readonly string _producerName; - private readonly IHost _host; - private readonly string _domain; - private readonly string _producerVersion; - private readonly long _modelVersion; - - public OnnxContext(IHostEnvironment env, string name, string producerName, - string producerVersion, long modelVersion, string domain) - { - Contracts.CheckValue(env, nameof(env)); - Contracts.CheckValue(name, nameof(name)); - Contracts.CheckValue(name, nameof(domain)); - - _host = env.Register(nameof(OnnxContext)); - _nodes = new List(); - _intermediateValues = new List(); - _inputs = new List(); - _outputs = new List(); - _columnNameMap = new Dictionary(); - _variableMap = new HashSet(); - _nodeNames = new HashSet(); - _name = name; - _producerName = producerName; - _producerVersion = producerVersion; - _modelVersion = modelVersion; - _domain = domain; - } - - public bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName); - - /// - /// Stops tracking a column. If removeVariable is true then it also removes the - /// variable associated with it, this is useful in the event where an output variable is - /// created before realizing the transform cannot actually save as ONNX. - /// - /// IDataView column name to stop tracking - /// Remove associated ONNX variable at the time. - public void RemoveColumn(string colName, bool removeVariable) - { - - if (removeVariable) - { - foreach (var val in _intermediateValues) - { - if (val.Name == _columnNameMap[colName]) - { - _intermediateValues.Remove(val); - break; - } - } - } - - if (_columnNameMap.ContainsKey(colName)) - _columnNameMap.Remove(colName); - } - - /// - /// Removes an ONNX variable. If removeColumn is true then it also removes the - /// IDataView column associated with it. - /// - /// ONNX variable to remove. - /// IDataView column to stop tracking - public void RemoveVariable(string variableName, bool removeColumn) - { - _host.Assert(_columnNameMap.ContainsValue(variableName)); - if (removeColumn) - { - foreach (var val in _intermediateValues) - { - if (val.Name == variableName) - { - _intermediateValues.Remove(val); - break; - } - } - } - - string columnName = _columnNameMap.Single(kvp => string.Compare(kvp.Value, variableName) == 0).Key; - - Contracts.Assert(_variableMap.Contains(columnName)); - - _columnNameMap.Remove(columnName); - _variableMap.Remove(columnName); - } - /// /// Generates a unique name for the node based on a prefix. /// - public string GetNodeName(string prefix) - { - _host.CheckValue(prefix, nameof(prefix)); - return GetUniqueName(prefix, c => _nodeNames.Contains(c)); - } + /// The prefix for the node + /// A name that has not yet been returned from this function, starting with + public abstract string GetNodeName(string prefix); /// - /// Adds a node to the node list of the graph. + /// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can + /// safely call . /// - /// - public void AddNode(NodeProto node) - { - _host.CheckValue(node, nameof(node)); - _host.Assert(!_nodeNames.Contains(node.Name)); - - _nodeNames.Add(node.Name); - _nodes.Add(node); - } + /// The data view column name + /// Whether the column is mapped in this context + public abstract bool ContainsColumn(string colName); /// - /// Generates a unique name based on a prefix. + /// Stops tracking a column. /// - private string GetUniqueName(string prefix, Func pred) - { - _host.CheckValue(prefix, nameof(prefix)); - _host.CheckValue(pred, nameof(pred)); - - if (!pred(prefix)) - return prefix; - - int count = 0; - while (pred(prefix + count++)) ; - return prefix + --count; - } + /// Column name to stop tracking + /// Remove associated ONNX variable. This is useful in the event where an output + /// variable is created through before realizing + /// the transform cannot actually save as ONNX. + public abstract void RemoveColumn(string colName, bool removeVariable = false); /// - /// Retrieves the variable name that maps to the IDataView column name at a - /// given point in the pipeline execution. + /// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the column associated with it. /// - /// Column Name mapping. - public string GetVariableName(string colName) - { - _host.CheckValue(colName, nameof(colName)); - _host.Assert(_columnNameMap.ContainsKey(colName)); - - return _columnNameMap[colName]; - } - - /// - /// Retrieves the variable name that maps to the IDataView column name at a - /// given point in the pipeline execution. - /// - /// Column Name mapping. - public string TryGetVariableName(string colName) - { - if (_columnNameMap.ContainsKey(colName)) - return GetVariableName(colName); - - return null; - } - - /// - /// Generates a unique column name based on the IDataView column name if - /// there is a collision between names in the pipeline at any point. - /// - /// IDataView column name. - /// Unique variable name. - private string AddVariable(string colName) - { - _host.CheckValue(colName, nameof(colName)); - - if (!_columnNameMap.ContainsKey(colName)) - _columnNameMap.Add(colName, colName); - else - _columnNameMap[colName] = GetUniqueName(colName, s => _variableMap.Contains(s)); - - _variableMap.Add(_columnNameMap[colName]); - return _columnNameMap[colName]; - } + /// ONNX variable to remove. Note that this is an ONNX variable name, not an column name + /// IDataView column to stop tracking + public abstract void RemoveVariable(string variableName, bool removeColumn); /// - /// Adds an intermediate column to the list. + /// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding + /// 's column names will map to a variable in the ONNX graph if the intermediate steps + /// used to calculate that value are things we knew how to save as ONNX. Retrieves the variable name that maps + /// to the column name at a given point in the pipeline execution. Callers should + /// probably confirm with whether a mapping for that data view column + /// already exists. /// - public string AddIntermediateVariable(ColumnType type, string colName, bool skip = false) - { - - colName = AddVariable(colName); - - //Let the runtime figure the shape. - if (!skip) - { - _host.CheckValue(type, nameof(type)); - - _intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName)); - } - - return colName; - } + /// The data view column name + /// The ONNX variable name corresponding to that data view column + public abstract string GetVariableName(string colName); /// - /// Adds an output variable to the list. + /// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and + /// returns that newly allocated name. /// - public string AddOutputVariable(ColumnType type, string colName, List dim = null) - { - _host.CheckValue(type, nameof(type)); - - if (!ContainsColumn(colName)) - AddVariable(colName); - - colName = GetVariableName(colName); - _outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim)); - return colName; - } + /// The data view type associated with this column name + /// The data view column name + /// Whether we should skip the process of establishing the mapping from data view column to + /// ONNX variable name. + /// The returned value is the name of the variable corresponding + public abstract string AddIntermediateVariable(ColumnType type, string colName, bool skip = false); /// - /// Adds an input variable to the list. + /// Creates an ONNX node /// - public void AddInputVariable(ColumnType type, string colName) - { - _host.CheckValue(type, nameof(type)); - _host.CheckValue(colName, nameof(colName)); - - colName = AddVariable(colName); - _inputs.Add(OnnxUtils.GetModelArgs(type, colName)); - } + /// The name of the ONNX operator to apply + /// The names of the variables as inputs + /// The names of the variables to create as outputs, + /// which ought to have been something returned from + /// The name of the operator, which ought to be something returned from + /// The domain of the ONNX operator, if non-default + /// A node added to the in-progress ONNX graph, that attributes can be set on + public abstract OnnxNode CreateNode(string opType, IEnumerable inputs, + IEnumerable outputs, string name, string domain = null); /// - /// Makes the ONNX model based on the context. + /// Convenience alternative to + /// for the case where there is exactly one input and output. /// - public ModelProto MakeModel() - => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues); + /// The name of the ONNX operator to apply + /// The name of the variable as input + /// The name of the variable as output, + /// which ought to have been something returned from + /// The name of the operator, which ought to be something returned from + /// The domain of the ONNX operator, if non-default + /// A node added to the in-progress ONNX graph, that attributes can be set on + public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null) + => CreateNode(opType, new[] { input }, new[] { output }, name, domain); } } diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs new file mode 100644 index 0000000000..259a6d27d4 --- /dev/null +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Runtime.Model.Onnx +{ + /// + /// An abstraction for an ONNX node as created by + /// . + /// That method creates a with inputs and outputs, but this object can modify the node further + /// by adding attributes (in ONNX parlance, attributes are more or less constant parameterizations). + /// + public abstract class OnnxNode + { + public abstract void AddAttribute(string argName, double value); + public abstract void AddAttribute(string argName, long value); + public abstract void AddAttribute(string argName, DvText value); + public abstract void AddAttribute(string argName, string value); + public abstract void AddAttribute(string argName, bool value); + + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, string[] value); + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, IEnumerable value); + } +} diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 487726572b..835ba5d99a 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -1437,19 +1437,14 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, str string opType = "Affine"; string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true); - var node = OnnxUtils.MakeNode(opType, new List { scoreProbablityColumnNames[0] }, - new List { linearOutput }, ctx.GetNodeName(opType), "ai.onnx"); - - OnnxUtils.NodeAddAttributes(node, "alpha", ParamA * -1); - OnnxUtils.NodeAddAttributes(node, "beta", -0.0000001); - - ctx.AddNode(node); + var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0] }, + new[] { linearOutput }, ctx.GetNodeName(opType), "ai.onnx"); + node.AddAttribute("alpha", ParamA * -1); + node.AddAttribute("beta", -0.0000001); opType = "Sigmoid"; - node = OnnxUtils.MakeNode(opType, new List { linearOutput }, - new List { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx"); - - ctx.AddNode(node); + node = ctx.CreateNode(opType, new[] { linearOutput }, + new[] { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx"); return true; } diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 6fde9a5815..91f799b2e7 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -206,11 +206,9 @@ public override void SaveAsOnnx(OnnxContext ctx) if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2])) { string opType = "Binarizer"; - var node = OnnxUtils.MakeNode(opType, new List { ctx.GetVariableName(outColumnNames[2]) }, - new List { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType)); - - OnnxUtils.NodeAddAttributes(node, "threshold", 0.5); - ctx.AddNode(node); + var node = ctx.CreateNode(opType, new[] { ctx.GetVariableName(outColumnNames[2]) }, + new[] { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType)); + node.AddAttribute("threshold", 0.5); } } diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index a6c0e3f490..db572fc93f 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -720,13 +720,11 @@ public void SaveAsOnnx(OnnxContext ctx) Source.Schema.GetColumnType(srcIndex).ValueCount)); } - var node = OnnxUtils.MakeNode(opType, new List(inputList.Select(t => t.Key)), - new List { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, inputList.Select(t => t.Key), + new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType)); - ctx.AddNode(node); - - OnnxUtils.NodeAddAttributes(node, "inputList", inputList.Select(x => x.Key)); - OnnxUtils.NodeAddAttributes(node, "inputdimensions", inputList.Select(x => x.Value)); + node.AddAttribute("inputList", inputList.Select(x => x.Key)); + node.AddAttribute("inputdimensions", inputList.Select(x => x.Value)); } } diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index bffbaa881c..d177d75647 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -244,10 +244,9 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { string opType = "OneHotEncoder"; - var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - OnnxUtils.NodeAddAttributes(node, "cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x)); - OnnxUtils.NodeAddAttributes(node, "zeros", true); - ctx.AddNode(node); + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x)); + node.AddAttribute("zeros", true); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index f6e8851f51..6092c2d8b9 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -434,8 +434,8 @@ private AffineColumnFunction(IHost host) public abstract void Save(ModelSaveContext ctx); public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken); - - public abstract bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount); + public bool CanSaveOnnx => true; + public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount); public abstract Delegate GetGetter(IRow input, int icol); @@ -546,10 +546,10 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return null; } - public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) - { - return false; - } + public bool CanSaveOnnx => false; + + public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) + => throw Host.ExceptNotSupp(); public abstract Delegate GetGetter(IRow input, int icol); @@ -671,10 +671,10 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return null; } - public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) - { - return false; - } + public bool CanSaveOnnx => false; + + public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) + => throw Host.ExceptNotSupp(); public abstract Delegate GetGetter(IRow input, int icol); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs index 41e55ee338..e577b9370e 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs @@ -577,10 +577,10 @@ public override void Save(ModelSaveContext ctx) public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale); - public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) + public override bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) { - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(Offset, featureCount)); - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Enumerable.Repeat(Scale, featureCount)); + nodeProtoWrapper.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount)); + nodeProtoWrapper.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount)); return true; } @@ -648,12 +648,12 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType))); } - public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) + public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) { - if (Offset != null) - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Offset); - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Scale); + if (Offset != null) + node.AddAttribute("offset", Offset); + node.AddAttribute("scale", Scale); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs index ef5eef8551..4c6e1fb011 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs @@ -577,10 +577,10 @@ public override void Save(ModelSaveContext ctx) public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale); - public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) + public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) { - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(Offset, featureCount)); - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Enumerable.Repeat(Scale, featureCount)); + node.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount)); + node.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount)); return true; } @@ -648,14 +648,14 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType))); } - public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) + public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) { if (Offset != null) - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Offset); + node.AddAttribute("offset", Offset); else - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(0, featureCount)); + node.AddAttribute("offset", Enumerable.Repeat(0, featureCount)); - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Scale); + node.AddAttribute("scale", Scale); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs index 318bbb44f1..22fa9686ed 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs @@ -66,7 +66,9 @@ public interface IColumnFunction : ICanSaveModel JToken PfaInfo(BoundPfaContext ctx, JToken srcToken); - bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount); + bool CanSaveOnnx { get; } + + bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount); } public sealed partial class NormalizeTransform : OneToOneTransformBase @@ -316,11 +318,11 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, if (info.TypeSrc.ValueCount == 0) return false; - string opType = "Scaler"; - var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - if (_functions[iinfo].OnnxInfo(ctx, new OnnxUtils.NodeProtoWrapper(node), info.TypeSrc.ValueCount)) + if (_functions[iinfo].CanSaveOnnx) { - ctx.AddNode(node); + string opType = "Scaler"; + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + _functions[iinfo].OnnxInfo(ctx, node, info.TypeSrc.ValueCount); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index da7442f90e..bb9adf21e1 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -690,11 +690,10 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, TermMap map = (TermMap)_termMap[iinfo].Map; map.GetTerms(ref terms); string opType = "LabelEncoder"; - var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - OnnxUtils.NodeAddAttributes(node, "classes_strings", terms.DenseValues()); - OnnxUtils.NodeAddAttributes(node, "default_int64", -1); - OnnxUtils.NodeAddAttributes(node, "default_string", DvText.Empty); - ctx.AddNode(node); + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("classes_strings", terms.DenseValues()); + node.AddAttribute("default_int64", -1); + node.AddAttribute("default_string", DvText.Empty); return true; } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 633c484716..01668eb3f2 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -3130,26 +3130,24 @@ public virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string fea } string opType = "TreeEnsembleRegressor"; - var node = OnnxUtils.MakeNode(opType, new List { featureColumn }, - new List(outputNames), ctx.GetNodeName(opType)); - - OnnxUtils.NodeAddAttributes(node, "post_transform", PostTransform.None.GetDescription()); - OnnxUtils.NodeAddAttributes(node, "n_targets", 1); - OnnxUtils.NodeAddAttributes(node, "base_values", new List() { 0 }); - OnnxUtils.NodeAddAttributes(node, "aggregate_function", AggregateFunction.Sum.GetDescription()); - OnnxUtils.NodeAddAttributes(node, "nodes_treeids", nodesTreeids); - OnnxUtils.NodeAddAttributes(node, "nodes_nodeids", nodesIds); - OnnxUtils.NodeAddAttributes(node, "nodes_featureids", nodesFeatureIds); - OnnxUtils.NodeAddAttributes(node, "nodes_modes", nodeModes); - OnnxUtils.NodeAddAttributes(node, "nodes_values", nodesValues); - OnnxUtils.NodeAddAttributes(node, "nodes_truenodeids", nodesTrueNodeIds); - OnnxUtils.NodeAddAttributes(node, "nodes_falsenodeids", nodesFalseNodeIds); - OnnxUtils.NodeAddAttributes(node, "nodes_missing_value_tracks_true", missingValueTracksTrue); - OnnxUtils.NodeAddAttributes(node, "target_treeids", classTreeIds); - OnnxUtils.NodeAddAttributes(node, "target_nodeids", classNodeIds); - OnnxUtils.NodeAddAttributes(node, "target_ids", classIds); - OnnxUtils.NodeAddAttributes(node, "target_weights", classWeights); - ctx.AddNode(node); + var node = ctx.CreateNode(opType, new[] { featureColumn }, outputNames, ctx.GetNodeName(opType)); + + node.AddAttribute("post_transform", PostTransform.None.GetDescription()); + node.AddAttribute("n_targets", 1); + node.AddAttribute("base_values", new List() { 0 }); + node.AddAttribute("aggregate_function", AggregateFunction.Sum.GetDescription()); + node.AddAttribute("nodes_treeids", nodesTreeids); + node.AddAttribute("nodes_nodeids", nodesIds); + node.AddAttribute("nodes_featureids", nodesFeatureIds); + node.AddAttribute("nodes_modes", nodeModes); + node.AddAttribute("nodes_values", nodesValues); + node.AddAttribute("nodes_truenodeids", nodesTrueNodeIds); + node.AddAttribute("nodes_falsenodeids", nodesFalseNodeIds); + node.AddAttribute("nodes_missing_value_tracks_true", missingValueTracksTrue); + node.AddAttribute("target_treeids", classTreeIds); + node.AddAttribute("target_nodeids", classNodeIds); + node.AddAttribute("target_ids", classIds); + node.AddAttribute("target_weights", classWeights); return true; } diff --git a/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj b/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj similarity index 52% rename from src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj rename to src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj index 4244681bd6..145dd8be8c 100644 --- a/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj +++ b/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj @@ -2,11 +2,16 @@ netstandard2.0 - Microsoft.ML + Microsoft.ML.Onnx + Microsoft.ML.Runtime.Model.Onnx + + + + diff --git a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs new file mode 100644 index 0000000000..f37a1ea557 --- /dev/null +++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs @@ -0,0 +1,255 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Runtime.UniversalModelFormat.Onnx; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Runtime.Model.Onnx +{ + /// + /// A context for defining a ONNX output. + /// + internal sealed class OnnxContextImpl : OnnxContext + { + private readonly List _nodes; + private readonly List _inputs; + // The map from IDataView column names to variable names. + private readonly List _intermediateValues; + private readonly List _outputs; + private readonly Dictionary _columnNameMap; + // All existing variable names. New variables must not exist in this set. + private readonly HashSet _variableNames; + // All existing node names. New node names must not alrady exist in this set. + private readonly HashSet _nodeNames; + private readonly string _name; + private readonly string _producerName; + private readonly IHost _host; + private readonly string _domain; + private readonly string _producerVersion; + private readonly long _modelVersion; + + public OnnxContextImpl(IHostEnvironment env, string name, string producerName, + string producerVersion, long modelVersion, string domain) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(OnnxContext)); + _host.CheckValue(name, nameof(name)); + _host.CheckValue(name, nameof(domain)); + + _nodes = new List(); + _intermediateValues = new List(); + _inputs = new List(); + _outputs = new List(); + _columnNameMap = new Dictionary(); + _variableNames = new HashSet(); + _nodeNames = new HashSet(); + _name = name; + _producerName = producerName; + _producerVersion = producerVersion; + _modelVersion = modelVersion; + _domain = domain; + } + + public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName); + + /// + /// Stops tracking a column. If removeVariable is true then it also removes the + /// variable associated with it, this is useful in the event where an output variable is + /// created before realizing the transform cannot actually save as ONNX. + /// + /// IDataView column name to stop tracking + /// Remove associated ONNX variable at the time. + public override void RemoveColumn(string colName, bool removeVariable) + { + _host.CheckNonEmpty(colName, nameof(colName)); + + if (removeVariable) + { + foreach (var val in _intermediateValues) + { + if (val.Name == _columnNameMap[colName]) + { + _intermediateValues.Remove(val); + break; + } + } + } + _columnNameMap.Remove(colName); + } + + /// + /// Removes an ONNX variable. If removeColumn is true then it also removes the + /// IDataView column associated with it. + /// + /// ONNX variable to remove. + /// IDataView column to stop tracking + public override void RemoveVariable(string variableName, bool removeColumn) + { + _host.CheckNonEmpty(variableName, nameof(variableName)); + if (!_columnNameMap.ContainsValue(variableName)) + throw _host.ExceptParam(nameof(variableName), $"Could not find '{variableName}' declared in ONNX graph"); + + if (removeColumn) + { + foreach (var val in _intermediateValues) + { + if (val.Name == variableName) + { + _intermediateValues.Remove(val); + break; + } + } + } + + string columnName = _columnNameMap.Single(kvp => kvp.Value == variableName).Key; + + Contracts.Assert(_variableNames.Contains(columnName)); + + _columnNameMap.Remove(columnName); + _variableNames.Remove(columnName); + } + + /// + /// Generates a unique name for the node based on a prefix. + /// + public override string GetNodeName(string prefix) + { + _host.CheckNonEmpty(prefix, nameof(prefix)); + return GetUniqueName(prefix, _nodeNames.Contains); + } + + /// + /// Adds a node to the node list of the graph. + /// + /// + private void AddNode(NodeProto node) + { + _host.CheckValue(node, nameof(node)); + _host.Assert(!_nodeNames.Contains(node.Name)); + + _nodeNames.Add(node.Name); + _nodes.Add(node); + } + + public override OnnxNode CreateNode(string opType, IEnumerable inputs, + IEnumerable outputs, string name, string domain = null) + { + _host.CheckNonEmpty(opType, nameof(opType)); + _host.CheckValue(inputs, nameof(inputs)); + _host.CheckValue(outputs, nameof(outputs)); + _host.CheckNonEmpty(name, nameof(name)); + + var innerNode = OnnxUtils.MakeNode(opType, inputs, outputs, name, domain); + AddNode(innerNode); + return new OnnxNodeImpl(innerNode); + } + + /// + /// Generates a unique name based on a prefix. + /// + private string GetUniqueName(string prefix, Func pred) + { + _host.CheckNonEmpty(prefix, nameof(prefix)); + _host.CheckValue(pred, nameof(pred)); + + if (!pred(prefix)) + return prefix; + + int count = 0; + while (pred(prefix + count++)) ; + return prefix + --count; + } + + /// + /// Retrieves the variable name that maps to the IDataView column name at a + /// given point in the pipeline execution. + /// + /// Column Name mapping. + public override string GetVariableName(string colName) + { + _host.CheckNonEmpty(colName, nameof(colName)); + _host.Assert(_columnNameMap.ContainsKey(colName)); + + return _columnNameMap[colName]; + } + + /// + /// Retrieves the variable name that maps to the IDataView column name at a + /// given point in the pipeline execution. + /// + /// Column Name mapping. + public string TryGetVariableName(string colName) + { + _host.CheckNonEmpty(colName, nameof(colName)); + if (_columnNameMap.ContainsKey(colName)) + return GetVariableName(colName); + return null; + } + + /// + /// Generates a unique column name based on the IDataView column name if + /// there is a collision between names in the pipeline at any point. + /// + /// IDataView column name. + /// Unique variable name. + private string AddVariable(string colName) + { + _host.CheckNonEmpty(colName, nameof(colName)); + _columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains); + _variableNames.Add(_columnNameMap[colName]); + return _columnNameMap[colName]; + } + + /// + /// Adds an intermediate column to the list. + /// + public override string AddIntermediateVariable(ColumnType type, string colName, bool skip = false) + { + colName = AddVariable(colName); + // Let the runtime figure the shape. + if (!skip) + { + _host.CheckValue(type, nameof(type)); + _intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName)); + } + return colName; + } + + /// + /// Adds an output variable to the list. + /// + public string AddOutputVariable(ColumnType type, string colName, List dim = null) + { + _host.CheckValue(type, nameof(type)); + + if (!ContainsColumn(colName)) + AddVariable(colName); + + colName = GetVariableName(colName); + _outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim)); + return colName; + } + + /// + /// Adds an input variable to the list. + /// + public void AddInputVariable(ColumnType type, string colName) + { + _host.CheckValue(type, nameof(type)); + _host.CheckValue(colName, nameof(colName)); + + colName = AddVariable(colName); + _inputs.Add(OnnxUtils.GetModelArgs(type, colName)); + } + + /// + /// Makes the ONNX model based on the context. + /// + public ModelProto MakeModel() + => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues); + } +} diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs b/src/Microsoft.ML.Onnx/OnnxMl.cs similarity index 100% rename from src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs rename to src/Microsoft.ML.Onnx/OnnxMl.cs diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md b/src/Microsoft.ML.Onnx/OnnxMl.md similarity index 100% rename from src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md rename to src/Microsoft.ML.Onnx/OnnxMl.md diff --git a/src/Microsoft.ML.Onnx/OnnxNodeImpl..cs b/src/Microsoft.ML.Onnx/OnnxNodeImpl..cs new file mode 100644 index 0000000000..9b30fd1d87 --- /dev/null +++ b/src/Microsoft.ML.Onnx/OnnxNodeImpl..cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.UniversalModelFormat.Onnx; + +namespace Microsoft.ML.Runtime.Model.Onnx +{ + internal sealed class OnnxNodeImpl : OnnxNode + { + private readonly NodeProto _node; + + public OnnxNodeImpl(NodeProto node) + { + Contracts.AssertValue(node); + _node = node; + } + + public override void AddAttribute(string argName, double value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, long value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, DvText value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, string[] value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, string value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, bool value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + } +} diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs b/src/Microsoft.ML.Onnx/OnnxUtils.cs similarity index 95% rename from src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs rename to src/Microsoft.ML.Onnx/OnnxUtils.cs index 8ad9f40c20..be5a833643 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs +++ b/src/Microsoft.ML.Onnx/OnnxUtils.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -15,7 +14,7 @@ namespace Microsoft.ML.Runtime.Model.Onnx /// /// Contains methods to create ONNX models in protocol buffer. /// - public sealed class OnnxUtils + internal static class OnnxUtils { private static TypeProto MakeType(TypeProto typeProto, TensorProto.Types.DataType dataType, List dims, List dimsParam) @@ -153,7 +152,7 @@ private static AttributeProto MakeAttribute(string key, IEnumerable private static AttributeProto MakeAttribute(string key, bool value) => MakeAttribute(key, value ? 1 : 0); - public static NodeProto MakeNode(string opType, List inputs, List outputs, string name, string domain = null) + public static NodeProto MakeNode(string opType, IEnumerable inputs, IEnumerable outputs, string name, string domain = null) { Contracts.CheckNonEmpty(opType, nameof(opType)); Contracts.CheckValue(inputs, nameof(inputs)); @@ -169,11 +168,6 @@ public static NodeProto MakeNode(string opType, List inputs, List() { inputs }, new List() { outputs }, name); - } - public static void NodeAddAttributes(NodeProto node, string argName, double value) => node.Attribute.Add(MakeAttribute(argName, value)); @@ -241,16 +235,6 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List di } } - public sealed class NodeProtoWrapper - { - public NodeProto Node; - - public NodeProtoWrapper(NodeProto node) - { - Node = node; - } - } - public static ModelProto MakeModel(List nodes, string producerName, string name, string domain, string producerVersion, long modelVersion, List inputs, List outputs, List intermediateValues) diff --git a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs similarity index 97% rename from src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs rename to src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index d2dfc93fde..6b72d64af5 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -12,7 +12,6 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.UniversalModelFormat.Onnx; using Newtonsoft.Json; [assembly: LoadableClass(SaveOnnxCommand.Summary, typeof(SaveOnnxCommand), typeof(SaveOnnxCommand.Arguments), typeof(SignatureCommand), @@ -69,7 +68,6 @@ public sealed class Arguments : DataCommand.ArgumentsBase private readonly HashSet _outputsToDrop; private readonly ITransformModel _model; private const string ProducerName = "ML.NET"; - private const string ProducerVersion = "0.2.0.0000"; private const long ModelVersion = 0; public SaveOnnxCommand(IHostEnvironment env, Arguments args) @@ -164,7 +162,11 @@ private void Run(IChannel ch) GetPipe(ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); - var ctx = new OnnxContext(Host, _name, ProducerName, ProducerVersion, ModelVersion, _domain); + var assembly = System.Reflection.Assembly.GetExecutingAssembly(); + var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); + + var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, + ModelVersion, _domain); // If we have a predictor, try to get the scorer for it. if (rawPred != null) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index a3ac7ce72e..3fc63060e8 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -236,13 +236,12 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) Host.Check(Utils.Size(outputs) == 1); string opType = "LinearRegressor"; - var node = OnnxUtils.MakeNode(opType, new List { featureColumn }, new List (outputs), ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType)); // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT} - OnnxUtils.NodeAddAttributes(node, "post_transform", 0); - OnnxUtils.NodeAddAttributes(node, "targets", 1); - OnnxUtils.NodeAddAttributes(node, "coefficients", Weight.DenseValues()); - OnnxUtils.NodeAddAttributes(node, "intercepts", Bias); - ctx.AddNode(node); + node.AddAttribute("post_transform", 0); + node.AddAttribute("targets", 1); + node.AddAttribute("coefficients", Weight.DenseValues()); + node.AddAttribute("intercepts", Bias); return true; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index f2fad63794..5fbc6ab74a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -844,14 +844,13 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) Host.CheckValue(ctx, nameof(ctx)); string opType = "LinearClassifier"; - var node = OnnxUtils.MakeNode(opType, new List { featureColumn }, new List(outputs), ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType)); // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT} - OnnxUtils.NodeAddAttributes(node, "post_transform", 0); - OnnxUtils.NodeAddAttributes(node, "multi_class", true); - OnnxUtils.NodeAddAttributes(node, "coefficients", _weights.SelectMany(w => w.DenseValues())); - OnnxUtils.NodeAddAttributes(node, "intercepts", _biases); - OnnxUtils.NodeAddAttributes(node, "classlabels_strings", _labelNames); - ctx.AddNode(node); + node.AddAttribute("post_transform", 0); + node.AddAttribute("multi_class", true); + node.AddAttribute("coefficients", _weights.SelectMany(w => w.DenseValues())); + node.AddAttribute("intercepts", _biases); + node.AddAttribute("classlabels_strings", _labelNames); return true; } diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index 52989472fc..44832ee517 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -616,20 +616,19 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, return false; string opType = "Imputer"; - var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - OnnxUtils.NodeAddAttributes(node, "replaced_value_float", Single.NaN); + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("replaced_value_float", Single.NaN); if (!Infos[iinfo].TypeSrc.IsVector) - OnnxUtils.NodeAddAttributes(node, "imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1)); + node.AddAttribute("imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1)); else { if (_repIsDefault[iinfo] != null) - OnnxUtils.NodeAddAttributes(node, "imputed_value_floats", (float[])_repValues[iinfo]); + node.AddAttribute("imputed_value_floats", (float[])_repValues[iinfo]); else - OnnxUtils.NodeAddAttributes(node, "imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1)); + node.AddAttribute("imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1)); } - ctx.AddNode(node); return true; } diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json index 3afd6dc171..b2d611b784 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json @@ -1,7 +1,7 @@ { "irVersion": "3", "producerName": "ML.NET", - "producerVersion": "0.2.0.0000", + "producerVersion": "##VERSION##", "domain": "Onnx", "graph": { "node": [ diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj index bed3dce0eb..ea92975b35 100644 --- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj +++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj @@ -11,6 +11,7 @@ + diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index de0db6f3a9..43e208ad3c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -329,12 +329,12 @@ public void EntryPointInputBuilderOptionals() ib1.TrySetValue("WeightColumn", "OtherWeight"); Assert.True(instance.WeightColumn.IsExplicit); - Assert.True(string.Compare(instance.WeightColumn.Value, "OtherWeight") == 0); + Assert.Equal("OtherWeight", instance.WeightColumn.Value); var tok = (JToken)JValue.CreateString("AnotherWeight"); ib1.TrySetValueJson("WeightColumn", tok); Assert.True(instance.WeightColumn.IsExplicit); - Assert.True(string.Compare(instance.WeightColumn.Value, "AnotherWeight") == 0); + Assert.Equal("AnotherWeight", instance.WeightColumn.Value); } [Fact] diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 8635fdf798..2a2ea8bca1 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -11,6 +11,7 @@ + diff --git a/test/Microsoft.ML.Tests/OnnxTests.cs b/test/Microsoft.ML.Tests/OnnxTests.cs index 477a9c6fa6..f4428242be 100644 --- a/test/Microsoft.ML.Tests/OnnxTests.cs +++ b/test/Microsoft.ML.Tests/OnnxTests.cs @@ -9,6 +9,7 @@ using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Trainers; using System.IO; +using System.Text.RegularExpressions; using Xunit; using Xunit.Abstractions; @@ -86,6 +87,11 @@ public void BinaryClassificationSaveModelToOnnxTest() converter.Convert(model); + // Strip the version. + var fileText = File.ReadAllText(onnxAsJsonPath); + fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); + File.WriteAllText(onnxAsJsonPath, fileText); + CheckEquality(subDir, "SaveModelToOnnxTest.json"); Done(); }