From 0de0311e86c7c3d8f10e1d436d33815cf2e373a8 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 29 Jun 2018 11:13:33 -0700 Subject: [PATCH 1/9] Abstraction of ONNX exporting to interfaces, and isolation of actual implementation to separate DLL. --- .../DataLoadSave/CompositeDataLoader.cs | 2 +- .../Microsoft.ML.Data.csproj | 1 - .../Model/Onnx/ICanSaveOnnx.cs | 8 +- .../Model/Onnx/IOnnxContext.cs | 82 +++++++++++++++++++ src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs | 33 ++++++++ .../Prediction/Calibrator.cs | 19 ++--- .../Scorers/BinaryClassifierScorer.cs | 8 +- .../Scorers/GenericScorer.cs | 2 +- .../Scorers/MultiClassClassifierScorer.cs | 2 +- .../Scorers/PredictedLabelScorerBase.cs | 2 +- .../Scorers/SchemaBindablePredictorWrapper.cs | 6 +- .../Transforms/ConcatTransform.cs | 10 +-- .../Transforms/KeyToVectorTransform.cs | 9 +- .../Transforms/NormalizeColumn.cs | 20 ++--- .../Transforms/NormalizeColumnDbl.cs | 14 ++-- .../Transforms/NormalizeColumnSng.cs | 14 ++-- .../Transforms/NormalizeTransform.cs | 14 ++-- .../Transforms/TermTransform.cs | 11 ++- .../Transforms/TransformBase.cs | 4 +- src/Microsoft.ML.FastTree/FastTree.cs | 37 ++++----- .../Standard/LinearPredictor.cs | 13 ++- .../MulticlassLogisticRegression.cs | 15 ++-- .../NAReplaceTransform.cs | 13 ++- .../Microsoft.ML.UniversalModelFormat.csproj | 4 + .../Onnx/OnnxContext.cs | 12 ++- .../Onnx/OnnxNode..cs | 46 +++++++++++ .../Onnx/OnnxUtils.cs | 18 +--- .../Onnx/SaveOnnxCommand.cs | 1 - .../Microsoft.ML.Tests.csproj | 1 + 29 files changed, 282 insertions(+), 139 deletions(-) create mode 100644 src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs create mode 100644 src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs rename src/{Microsoft.ML.Data/Model => Microsoft.ML.UniversalModelFormat}/Onnx/OnnxContext.cs (95%) create mode 100644 src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxNode..cs rename src/{Microsoft.ML.Data/Model => Microsoft.ML.UniversalModelFormat}/Onnx/OnnxUtils.cs (96%) rename src/{Microsoft.ML.Data/Model => Microsoft.ML.UniversalModelFormat}/Onnx/SaveOnnxCommand.cs (99%) diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index cab6467043..a2ab3a7b16 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -85,7 +85,7 @@ private static VersionInfo GetVersionInfo() /// Returns the underlying data view of the composite loader. /// This can be used to programmatically explore the chain of transforms that's inside the composite loader. /// - internal IDataView View { get; } + public IDataView View { get; } /// /// Creates a loader according to the specified . diff --git a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj index 38fb6075ce..8d5b0fd2d0 100644 --- a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj +++ b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj @@ -15,7 +15,6 @@ - diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs index 0e25840c3b..2ae075ba85 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs @@ -27,7 +27,7 @@ public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform /// Save as ONNX. /// /// The ONNX program being built - void SaveAsOnnx(OnnxContext ctx); + void SaveAsOnnx(IOnnxContext ctx); } /// @@ -52,7 +52,7 @@ public interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper /// the outputs produced by this bindable mapper. This is the array that holds /// those names, so that implementors of this method know what to produce in /// . - bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames); + bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames); } /// @@ -61,7 +61,7 @@ public interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper /// public interface ISingleCanSaveOnnx : ICanSaveOnnx { - bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn); + bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumn); } /// @@ -70,6 +70,6 @@ public interface ISingleCanSaveOnnx : ICanSaveOnnx /// public interface IDistCanSaveOnnx : ISingleCanSaveOnnx, IValueMapperDist { - new bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn); + new bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumn); } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs new file mode 100644 index 0000000000..932fedbd1f --- /dev/null +++ b/src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Runtime.Model.Onnx +{ + /// + /// A context for defining a ONNX output. This is iteratively given to exportable components + /// via the interface and subinterfaces, that attempt to express + /// their operations as ONNX nodes, if they can. At the point that it is given to a component, + /// all other components up to that component have already attempted to express themselves in + /// this context, with their outputs possibly available in the ONNX graph. + /// + public interface IOnnxContext + { + bool ContainsColumn(string colName); + + /// + /// Stops tracking a column. If removeVariable is true then it also removes the + /// variable associated with it, this is useful in the event where an output variable is + /// created before realizing the transform cannot actually save as ONNX. + /// + /// IDataView column name to stop tracking + /// Remove associated ONNX variable at the time. + void RemoveColumn(string colName, bool removeVariable); + + /// + /// Removes an ONNX variable. If removeColumn is true then it also removes the + /// tracking for the column associated with it. + /// + /// ONNX variable to remove. + /// IDataView column to stop tracking + void RemoveVariable(string variableName, bool removeColumn); + + /// + /// Generates a unique name for the node based on a prefix. + /// + string GetNodeName(string prefix); + + /// + /// Retrieves the variable name that maps to the IDataView column name at a + /// given point in the pipeline execution. + /// + /// Column Name mapping. + string GetVariableName(string colName); + + /// + /// Retrieves the variable name that maps to the IDataView column name at a + /// given point in the pipeline execution. + /// + /// Column Name mapping. + string TryGetVariableName(string colName); + + /// + /// Adds an intermediate column to the list. + /// + string AddIntermediateVariable(ColumnType type, string colName, bool skip = false); + + /// + /// Creates an ONNX node of + /// + /// The name of the ONNX operator to apply + /// The names of the variables as inputs + /// The names of the variables to create as outputs, + /// which ought to have been something + /// + /// + /// + IOnnxNode CreateNode(string opType, List inputs, + List outputs, string name, string domain = null); + } + + public static class OnnxContextExtensions + { + public static IOnnxNode CreateNode(this IOnnxContext ctx, + string opType, string inputs, string outputs, string name) + => ctx.CreateNode(opType, new List() { inputs }, new List() { outputs }, name); + } +} diff --git a/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs new file mode 100644 index 0000000000..f442a72c25 --- /dev/null +++ b/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.ML.Runtime.Model.Onnx +{ + /// + /// An abstraction for an ONNX node as created by + /// . + /// That method creates a with inputs and outputs, but this object can modify the node further + /// by adding attributes (in ONNX parlance, attributes are more or less constant parameterizations). + /// + public interface IOnnxNode + { + void AddAttribute(string argName, double value); + void AddAttribute(string argName, IEnumerable value); + void AddAttribute(string argName, IEnumerable value); + void AddAttribute(string argName, IEnumerable value); + void AddAttribute(string argName, long value); + void AddAttribute(string argName, IEnumerable value); + void AddAttribute(string argName, DvText value); + void AddAttribute(string argName, string[] value); + void AddAttribute(string argName, IEnumerable value); + void AddAttribute(string argName, IEnumerable value); + void AddAttribute(string argName, string value); + void AddAttribute(string argName, bool value); + } +} diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 487726572b..4250666701 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -296,7 +296,7 @@ public void SaveAsPfa(BoundPfaContext ctx, JToken input, probToken = ctx.DeclareVar(prob, probExpression); } - public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumnName) + public bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumnName) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(outputNames, nameof(outputNames)); @@ -658,7 +658,7 @@ public void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] out ctx.Hide(outputs); } - public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputs) + public bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputs) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckParam(Utils.Size(outputs) == 2, nameof(outputs), "Expected this to have two outputs"); @@ -1429,7 +1429,7 @@ public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) PfaUtils.Call("+", -ParamB, PfaUtils.Call("*", -ParamA, input))); } - public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, string featureColumnName) + public bool SaveAsOnnx(IOnnxContext ctx, string[] scoreProbablityColumnNames, string featureColumnName) { _host.CheckValue(ctx, nameof(ctx)); _host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames)); @@ -1437,20 +1437,15 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, str string opType = "Affine"; string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true); - var node = OnnxUtils.MakeNode(opType, new List { scoreProbablityColumnNames[0] }, + var node = ctx.CreateNode(opType, new List { scoreProbablityColumnNames[0] }, new List { linearOutput }, ctx.GetNodeName(opType), "ai.onnx"); - - OnnxUtils.NodeAddAttributes(node, "alpha", ParamA * -1); - OnnxUtils.NodeAddAttributes(node, "beta", -0.0000001); - - ctx.AddNode(node); + node.AddAttribute("alpha", ParamA * -1); + node.AddAttribute("beta", -0.0000001); opType = "Sigmoid"; - node = OnnxUtils.MakeNode(opType, new List { linearOutput }, + node = ctx.CreateNode(opType, new List { linearOutput }, new List { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx"); - ctx.AddNode(node); - return true; } diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 6fde9a5815..13020f0c94 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -184,7 +184,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write(_threshold); } - public override void SaveAsOnnx(OnnxContext ctx) + public override void SaveAsOnnx(IOnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); @@ -206,11 +206,9 @@ public override void SaveAsOnnx(OnnxContext ctx) if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2])) { string opType = "Binarizer"; - var node = OnnxUtils.MakeNode(opType, new List { ctx.GetVariableName(outColumnNames[2]) }, + var node = ctx.CreateNode(opType, new List { ctx.GetVariableName(outColumnNames[2]) }, new List { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType)); - - OnnxUtils.NodeAddAttributes(node, "threshold", 0.5); - ctx.AddNode(node); + node.AddAttribute("threshold", 0.5); } } diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index 91c84a0734..4363b9a435 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -214,7 +214,7 @@ public void SaveAsPfa(BoundPfaContext ctx) pfaBindable.SaveAsPfa(ctx, schema, outColNames); } - public void SaveAsOnnx(OnnxContext ctx) + public void SaveAsOnnx(IOnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 4832f92cd4..953cebba05 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -203,7 +203,7 @@ public void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] out ((IBindableCanSavePfa)_bindable).SaveAsPfa(ctx, schema, outputNames); } - public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + public bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index fe69585b78..d80ff3bc17 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -365,7 +365,7 @@ public void SaveAsPfa(BoundPfaContext ctx) protected abstract JToken PredictedLabelPfa(string[] mapperOutputs); - public virtual void SaveAsOnnx(OnnxContext ctx) + public virtual void SaveAsOnnx(IOnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 1e07f587f7..2f7e887f10 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -99,7 +99,7 @@ public virtual void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, stri ctx.Hide(outputNames); } - public virtual bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false; + public virtual bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false; public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) { @@ -289,7 +289,7 @@ public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, str ctx.DeclareVar(outputNames[0], scoreToken); } - public override bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + public override bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); @@ -403,7 +403,7 @@ public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, str Contracts.Assert(ctx.TokenOrNullForName(outputNames[1]) == probToken.ToString()); } - public override bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + public override bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index a6c0e3f490..a8fc11b130 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -687,7 +687,7 @@ public void SaveAsPfa(BoundPfaContext ctx) ctx.DeclareVar(toDeclare.ToArray()); } - public void SaveAsOnnx(OnnxContext ctx) + public void SaveAsOnnx(IOnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(CanSaveOnnx); @@ -720,13 +720,11 @@ public void SaveAsOnnx(OnnxContext ctx) Source.Schema.GetColumnType(srcIndex).ValueCount)); } - var node = OnnxUtils.MakeNode(opType, new List(inputList.Select(t => t.Key)), + var node = ctx.CreateNode(opType, new List(inputList.Select(t => t.Key)), new List { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType)); - ctx.AddNode(node); - - OnnxUtils.NodeAddAttributes(node, "inputList", inputList.Select(x => x.Key)); - OnnxUtils.NodeAddAttributes(node, "inputdimensions", inputList.Select(x => x.Value)); + node.AddAttribute("inputList", inputList.Select(x => x.Key)); + node.AddAttribute("inputdimensions", inputList.Select(x => x.Value)); } } diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index bffbaa881c..5b2b2dc6c6 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -241,13 +241,12 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo PfaUtils.Call("cast.fanoutDouble", -1, 0, keyCount, false), PfaUtils.FuncRef("u." + funcName)); } - protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + protected override bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { string opType = "OneHotEncoder"; - var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - OnnxUtils.NodeAddAttributes(node, "cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x)); - OnnxUtils.NodeAddAttributes(node, "zeros", true); - ctx.AddNode(node); + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x)); + node.AddAttribute("zeros", true); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index f6e8851f51..307236673b 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -434,8 +434,8 @@ private AffineColumnFunction(IHost host) public abstract void Save(ModelSaveContext ctx); public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken); - - public abstract bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount); + public bool CanSaveOnnx => true; + public abstract bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount); public abstract Delegate GetGetter(IRow input, int icol); @@ -546,10 +546,10 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return null; } - public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) - { - return false; - } + public bool CanSaveOnnx => false; + + public bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount) + => throw Host.ExceptNotSupp(); public abstract Delegate GetGetter(IRow input, int icol); @@ -671,10 +671,10 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return null; } - public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) - { - return false; - } + public bool CanSaveOnnx => false; + + public bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount) + => throw Host.ExceptNotSupp(); public abstract Delegate GetGetter(IRow input, int icol); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs index 41e55ee338..70a660e353 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs @@ -577,10 +577,10 @@ public override void Save(ModelSaveContext ctx) public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale); - public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) + public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount) { - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(Offset, featureCount)); - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Enumerable.Repeat(Scale, featureCount)); + nodeProtoWrapper.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount)); + nodeProtoWrapper.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount)); return true; } @@ -648,12 +648,12 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType))); } - public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) + public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode node, int featureCount) { - if (Offset != null) - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Offset); - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Scale); + if (Offset != null) + node.AddAttribute("offset", Offset); + node.AddAttribute("scale", Scale); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs index ef5eef8551..0cdd8807ce 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs @@ -577,10 +577,10 @@ public override void Save(ModelSaveContext ctx) public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale); - public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) + public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode node, int featureCount) { - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(Offset, featureCount)); - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Enumerable.Repeat(Scale, featureCount)); + node.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount)); + node.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount)); return true; } @@ -648,14 +648,14 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType))); } - public override bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) + public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode node, int featureCount) { if (Offset != null) - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Offset); + node.AddAttribute("offset", Offset); else - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "offset", Enumerable.Repeat(0, featureCount)); + node.AddAttribute("offset", Enumerable.Repeat(0, featureCount)); - OnnxUtils.NodeAddAttributes(nodeProtoWrapper.Node, "scale", Scale); + node.AddAttribute("scale", Scale); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs index 318bbb44f1..a857d3f80f 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs @@ -66,7 +66,9 @@ public interface IColumnFunction : ICanSaveModel JToken PfaInfo(BoundPfaContext ctx, JToken srcToken); - bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount); + bool CanSaveOnnx { get; } + + bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount); } public sealed partial class NormalizeTransform : OneToOneTransformBase @@ -306,7 +308,7 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo return _functions[iinfo].PfaInfo(ctx, srcToken); } - protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + protected override bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { Contracts.AssertValue(ctx); Contracts.Assert(0 <= iinfo && iinfo < Infos.Length); @@ -316,11 +318,11 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, if (info.TypeSrc.ValueCount == 0) return false; - string opType = "Scaler"; - var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - if (_functions[iinfo].OnnxInfo(ctx, new OnnxUtils.NodeProtoWrapper(node), info.TypeSrc.ValueCount)) + if (_functions[iinfo].CanSaveOnnx) { - ctx.AddNode(node); + string opType = "Scaler"; + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + _functions[iinfo].OnnxInfo(ctx, node, info.TypeSrc.ValueCount); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index da7442f90e..57ab877eab 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -681,7 +681,7 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo return PfaUtils.If(PfaUtils.Call("map.containsKey", cellRef, srcToken), PfaUtils.Index(cellRef, srcToken), -1); } - protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + protected override bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { if (!info.TypeSrc.ItemType.IsText) return false; @@ -690,11 +690,10 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, TermMap map = (TermMap)_termMap[iinfo].Map; map.GetTerms(ref terms); string opType = "LabelEncoder"; - var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - OnnxUtils.NodeAddAttributes(node, "classes_strings", terms.DenseValues()); - OnnxUtils.NodeAddAttributes(node, "default_int64", -1); - OnnxUtils.NodeAddAttributes(node, "default_string", DvText.Empty); - ctx.AddNode(node); + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("classes_strings", terms.DenseValues()); + node.AddAttribute("default_int64", -1); + node.AddAttribute("default_string", DvText.Empty); return true; } diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index 263a3cf4ca..42d5ff7a2a 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -571,7 +571,7 @@ public void SaveAsPfa(BoundPfaContext ctx) ctx.DeclareVar(toDeclare.ToArray()); } - public void SaveAsOnnx(OnnxContext ctx) + public void SaveAsOnnx(IOnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(CanSaveOnnx); @@ -616,7 +616,7 @@ protected virtual JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo i return null; } - protected virtual bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, + protected virtual bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) => false; public sealed override ISchema Schema => _bindings; diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 633c484716..06d00324be 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -3065,7 +3065,7 @@ private enum AggregateFunction Max } - public virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) + public virtual bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); @@ -3130,26 +3130,25 @@ public virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string fea } string opType = "TreeEnsembleRegressor"; - var node = OnnxUtils.MakeNode(opType, new List { featureColumn }, + var node = ctx.CreateNode(opType, new List { featureColumn }, new List(outputNames), ctx.GetNodeName(opType)); - OnnxUtils.NodeAddAttributes(node, "post_transform", PostTransform.None.GetDescription()); - OnnxUtils.NodeAddAttributes(node, "n_targets", 1); - OnnxUtils.NodeAddAttributes(node, "base_values", new List() { 0 }); - OnnxUtils.NodeAddAttributes(node, "aggregate_function", AggregateFunction.Sum.GetDescription()); - OnnxUtils.NodeAddAttributes(node, "nodes_treeids", nodesTreeids); - OnnxUtils.NodeAddAttributes(node, "nodes_nodeids", nodesIds); - OnnxUtils.NodeAddAttributes(node, "nodes_featureids", nodesFeatureIds); - OnnxUtils.NodeAddAttributes(node, "nodes_modes", nodeModes); - OnnxUtils.NodeAddAttributes(node, "nodes_values", nodesValues); - OnnxUtils.NodeAddAttributes(node, "nodes_truenodeids", nodesTrueNodeIds); - OnnxUtils.NodeAddAttributes(node, "nodes_falsenodeids", nodesFalseNodeIds); - OnnxUtils.NodeAddAttributes(node, "nodes_missing_value_tracks_true", missingValueTracksTrue); - OnnxUtils.NodeAddAttributes(node, "target_treeids", classTreeIds); - OnnxUtils.NodeAddAttributes(node, "target_nodeids", classNodeIds); - OnnxUtils.NodeAddAttributes(node, "target_ids", classIds); - OnnxUtils.NodeAddAttributes(node, "target_weights", classWeights); - ctx.AddNode(node); + node.AddAttribute("post_transform", PostTransform.None.GetDescription()); + node.AddAttribute("n_targets", 1); + node.AddAttribute("base_values", new List() { 0 }); + node.AddAttribute("aggregate_function", AggregateFunction.Sum.GetDescription()); + node.AddAttribute("nodes_treeids", nodesTreeids); + node.AddAttribute("nodes_nodeids", nodesIds); + node.AddAttribute("nodes_featureids", nodesFeatureIds); + node.AddAttribute("nodes_modes", nodeModes); + node.AddAttribute("nodes_values", nodesValues); + node.AddAttribute("nodes_truenodeids", nodesTrueNodeIds); + node.AddAttribute("nodes_falsenodeids", nodesFalseNodeIds); + node.AddAttribute("nodes_missing_value_tracks_true", missingValueTracksTrue); + node.AddAttribute("target_treeids", classTreeIds); + node.AddAttribute("target_nodeids", classNodeIds); + node.AddAttribute("target_ids", classIds); + node.AddAttribute("target_weights", classWeights); return true; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index a3ac7ce72e..61e322838c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -230,19 +230,18 @@ public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) return PfaUtils.Call("model.reg.linear", input, cellRef); } - public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) + public bool SaveAsOnnx(IOnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); Host.Check(Utils.Size(outputs) == 1); string opType = "LinearRegressor"; - var node = OnnxUtils.MakeNode(opType, new List { featureColumn }, new List (outputs), ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new List { featureColumn }, new List (outputs), ctx.GetNodeName(opType)); // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT} - OnnxUtils.NodeAddAttributes(node, "post_transform", 0); - OnnxUtils.NodeAddAttributes(node, "targets", 1); - OnnxUtils.NodeAddAttributes(node, "coefficients", Weight.DenseValues()); - OnnxUtils.NodeAddAttributes(node, "intercepts", Bias); - ctx.AddNode(node); + node.AddAttribute("post_transform", 0); + node.AddAttribute("targets", 1); + node.AddAttribute("coefficients", Weight.DenseValues()); + node.AddAttribute("intercepts", Bias); return true; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index f2fad63794..5b4854906b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -839,19 +839,18 @@ public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) return PfaUtils.Call("m.link.softmax", PfaUtils.Call("model.reg.linear", input, cellRef)); } - public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) + public bool SaveAsOnnx(IOnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); string opType = "LinearClassifier"; - var node = OnnxUtils.MakeNode(opType, new List { featureColumn }, new List(outputs), ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new List { featureColumn }, new List(outputs), ctx.GetNodeName(opType)); // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT} - OnnxUtils.NodeAddAttributes(node, "post_transform", 0); - OnnxUtils.NodeAddAttributes(node, "multi_class", true); - OnnxUtils.NodeAddAttributes(node, "coefficients", _weights.SelectMany(w => w.DenseValues())); - OnnxUtils.NodeAddAttributes(node, "intercepts", _biases); - OnnxUtils.NodeAddAttributes(node, "classlabels_strings", _labelNames); - ctx.AddNode(node); + node.AddAttribute("post_transform", 0); + node.AddAttribute("multi_class", true); + node.AddAttribute("coefficients", _weights.SelectMany(w => w.DenseValues())); + node.AddAttribute("intercepts", _biases); + node.AddAttribute("classlabels_strings", _labelNames); return true; } diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index 52989472fc..156d94f52e 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -601,7 +601,7 @@ private Delegate ComposeGetterVec(IRow input, int iinfo) }; } - protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + protected override bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { DataKind rawKind; var type = Infos[iinfo].TypeSrc; @@ -616,20 +616,19 @@ protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, return false; string opType = "Imputer"; - var node = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - OnnxUtils.NodeAddAttributes(node, "replaced_value_float", Single.NaN); + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("replaced_value_float", Single.NaN); if (!Infos[iinfo].TypeSrc.IsVector) - OnnxUtils.NodeAddAttributes(node, "imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1)); + node.AddAttribute("imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1)); else { if (_repIsDefault[iinfo] != null) - OnnxUtils.NodeAddAttributes(node, "imputed_value_floats", (float[])_repValues[iinfo]); + node.AddAttribute("imputed_value_floats", (float[])_repValues[iinfo]); else - OnnxUtils.NodeAddAttributes(node, "imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1)); + node.AddAttribute("imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1)); } - ctx.AddNode(node); return true; } diff --git a/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj b/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj index 4244681bd6..800525aca2 100644 --- a/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj +++ b/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj @@ -9,4 +9,8 @@ + + + + diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxContext.cs similarity index 95% rename from src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs rename to src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxContext.cs index 46c083173f..1340d9778f 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxContext.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML.Runtime.Model.Onnx /// /// A context for defining a ONNX output. /// - public sealed class OnnxContext + internal sealed class OnnxContext : IOnnxContext { private readonly List _nodes; private readonly List _inputs; @@ -121,7 +121,7 @@ public string GetNodeName(string prefix) /// Adds a node to the node list of the graph. /// /// - public void AddNode(NodeProto node) + private void AddNode(NodeProto node) { _host.CheckValue(node, nameof(node)); _host.Assert(!_nodeNames.Contains(node.Name)); @@ -130,6 +130,14 @@ public void AddNode(NodeProto node) _nodes.Add(node); } + public IOnnxNode CreateNode(string opType, List inputs, + List outputs, string name, string domain = null) + { + var innerNode = OnnxUtils.MakeNode(opType, inputs, outputs, name, domain); + AddNode(innerNode); + return new OnnxNode(innerNode); + } + /// /// Generates a unique name based on a prefix. /// diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxNode..cs b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxNode..cs new file mode 100644 index 0000000000..efbd7f676d --- /dev/null +++ b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxNode..cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.UniversalModelFormat.Onnx; + +namespace Microsoft.ML.Runtime.Model.Onnx +{ + internal sealed class OnnxNode : IOnnxNode + { + private readonly NodeProto _node; + + public OnnxNode(NodeProto node) + { + Contracts.AssertValue(node); + _node = node; + } + + public void AddAttribute(string argName, double value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, long value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, DvText value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, string[] value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, IEnumerable value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, string value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + public void AddAttribute(string argName, bool value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); + } +} diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxUtils.cs similarity index 96% rename from src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs rename to src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxUtils.cs index 8ad9f40c20..cf61b42445 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs +++ b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxUtils.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -15,7 +14,7 @@ namespace Microsoft.ML.Runtime.Model.Onnx /// /// Contains methods to create ONNX models in protocol buffer. /// - public sealed class OnnxUtils + internal static class OnnxUtils { private static TypeProto MakeType(TypeProto typeProto, TensorProto.Types.DataType dataType, List dims, List dimsParam) @@ -169,11 +168,6 @@ public static NodeProto MakeNode(string opType, List inputs, List() { inputs }, new List() { outputs }, name); - } - public static void NodeAddAttributes(NodeProto node, string argName, double value) => node.Attribute.Add(MakeAttribute(argName, value)); @@ -241,16 +235,6 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List di } } - public sealed class NodeProtoWrapper - { - public NodeProto Node; - - public NodeProtoWrapper(NodeProto node) - { - Node = node; - } - } - public static ModelProto MakeModel(List nodes, string producerName, string name, string domain, string producerVersion, long modelVersion, List inputs, List outputs, List intermediateValues) diff --git a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.UniversalModelFormat/Onnx/SaveOnnxCommand.cs similarity index 99% rename from src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs rename to src/Microsoft.ML.UniversalModelFormat/Onnx/SaveOnnxCommand.cs index d2dfc93fde..0b2b8cb6ad 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.UniversalModelFormat/Onnx/SaveOnnxCommand.cs @@ -12,7 +12,6 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.UniversalModelFormat.Onnx; using Newtonsoft.Json; [assembly: LoadableClass(SaveOnnxCommand.Summary, typeof(SaveOnnxCommand), typeof(SaveOnnxCommand.Arguments), typeof(SignatureCommand), diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 8635fdf798..d12f875d01 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -11,6 +11,7 @@ + From 671084e78f484fc1d80ccfd84b5e75c5e6170a88 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 29 Jun 2018 22:44:55 -0700 Subject: [PATCH 2/9] Documentation changes, minor bug fixed Remove needless delegate --- .../Model/Onnx/IOnnxContext.cs | 101 +++++++++++------- src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs | 2 +- .../Prediction/Calibrator.cs | 8 +- .../Scorers/BinaryClassifierScorer.cs | 4 +- .../Transforms/ConcatTransform.cs | 4 +- src/Microsoft.ML.FastTree/FastTree.cs | 3 +- .../Standard/LinearPredictor.cs | 2 +- .../MulticlassLogisticRegression.cs | 2 +- .../Onnx/OnnxContext.cs | 33 +++--- .../Onnx/OnnxUtils.cs | 2 +- 10 files changed, 90 insertions(+), 71 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs index 932fedbd1f..dc81cf5419 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs @@ -8,75 +8,102 @@ namespace Microsoft.ML.Runtime.Model.Onnx { /// - /// A context for defining a ONNX output. This is iteratively given to exportable components - /// via the interface and subinterfaces, that attempt to express - /// their operations as ONNX nodes, if they can. At the point that it is given to a component, - /// all other components up to that component have already attempted to express themselves in + /// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This + /// same context object is iteratively given to exportable components via the interface + /// and subinterfaces, that attempt to express their operations as ONNX nodes, if they can. At the point that it is + /// given to a component, all other components up to that component have already attempted to express themselves in /// this context, with their outputs possibly available in the ONNX graph. + /// + /// /// public interface IOnnxContext { - bool ContainsColumn(string colName); - /// - /// Stops tracking a column. If removeVariable is true then it also removes the - /// variable associated with it, this is useful in the event where an output variable is - /// created before realizing the transform cannot actually save as ONNX. + /// Generates a unique name for the node based on a prefix. /// - /// IDataView column name to stop tracking - /// Remove associated ONNX variable at the time. - void RemoveColumn(string colName, bool removeVariable); + /// The prefix for the node + /// A name that has not yet been returned from this function, starting with + string GetNodeName(string prefix); /// - /// Removes an ONNX variable. If removeColumn is true then it also removes the - /// tracking for the column associated with it. + /// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can + /// safely call . /// - /// ONNX variable to remove. - /// IDataView column to stop tracking - void RemoveVariable(string variableName, bool removeColumn); + /// The data view column name + /// Whether the column is mapped in this context + bool ContainsColumn(string colName); /// - /// Generates a unique name for the node based on a prefix. + /// Stops tracking a column. /// - string GetNodeName(string prefix); + /// Column name to stop tracking + /// Remove associated ONNX variable. This is useful in the event where an output + /// variable is created through before realizing + /// the transform cannot actually save as ONNX. + void RemoveColumn(string colName, bool removeVariable = false); /// - /// Retrieves the variable name that maps to the IDataView column name at a - /// given point in the pipeline execution. + /// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the column associated with it. /// - /// Column Name mapping. - string GetVariableName(string colName); + /// ONNX variable to remove. Note that this is an ONNX variable name, not an column name + /// IDataView column to stop tracking + void RemoveVariable(string variableName, bool removeColumn); /// - /// Retrieves the variable name that maps to the IDataView column name at a - /// given point in the pipeline execution. + /// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding + /// 's column names will map to a variable in the ONNX graph if the intermediate steps + /// used to calculate that value are things we knew how to save as ONNX. Retrieves the variable name that maps + /// to the column name at a given point in the pipeline execution. Callers should + /// probably confirm with whether a mapping for that data view column + /// already exists. /// - /// Column Name mapping. - string TryGetVariableName(string colName); + /// The data view column name + /// The ONNX variable name corresponding to that data view column + string GetVariableName(string colName); /// - /// Adds an intermediate column to the list. + /// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and + /// returns that newly allocated name. /// + /// The data view type associated with this column name + /// The data view column name + /// Whether we should skip the process of establishing the mapping from data view column to + /// ONNX variable name. + /// The returned value is the name of the variable corresponding string AddIntermediateVariable(ColumnType type, string colName, bool skip = false); /// - /// Creates an ONNX node of + /// Creates an ONNX node /// /// The name of the ONNX operator to apply /// The names of the variables as inputs /// The names of the variables to create as outputs, - /// which ought to have been something - /// - /// - /// - IOnnxNode CreateNode(string opType, List inputs, - List outputs, string name, string domain = null); + /// which ought to have been something returned from + /// The name of the operator, which ought to be something returned from + /// The domain of the ONNX operator, if non-default + /// A node added to the in-progress ONNX graph, that attributes can be set on + IOnnxNode CreateNode(string opType, IEnumerable inputs, + IEnumerable outputs, string name, string domain = null); } public static class OnnxContextExtensions { + /// + /// Convenience alternative to + /// for the case where there is exactly one input and output. + /// + /// The ONNX save context + /// The name of the ONNX operator to apply + /// The name of the variable as input + /// The name of the variable as output, + /// which ought to have been something returned from + /// The name of the operator, which ought to be something returned from + /// The domain of the ONNX operator, if non-default + /// A node added to the in-progress ONNX graph, that attributes can be set on public static IOnnxNode CreateNode(this IOnnxContext ctx, - string opType, string inputs, string outputs, string name) - => ctx.CreateNode(opType, new List() { inputs }, new List() { outputs }, name); + string opType, string input, string output, string name, string domain = null) + => ctx.CreateNode(opType, new[] { input }, new[] { output }, name, domain); } } diff --git a/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs index f442a72c25..0e8dfb7ede 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Runtime.Model.Onnx { /// /// An abstraction for an ONNX node as created by - /// . + /// . /// That method creates a with inputs and outputs, but this object can modify the node further /// by adding attributes (in ONNX parlance, attributes are more or less constant parameterizations). /// diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 4250666701..c8cbecfb5b 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -1437,14 +1437,14 @@ public bool SaveAsOnnx(IOnnxContext ctx, string[] scoreProbablityColumnNames, st string opType = "Affine"; string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true); - var node = ctx.CreateNode(opType, new List { scoreProbablityColumnNames[0] }, - new List { linearOutput }, ctx.GetNodeName(opType), "ai.onnx"); + var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0] }, + new[] { linearOutput }, ctx.GetNodeName(opType), "ai.onnx"); node.AddAttribute("alpha", ParamA * -1); node.AddAttribute("beta", -0.0000001); opType = "Sigmoid"; - node = ctx.CreateNode(opType, new List { linearOutput }, - new List { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx"); + node = ctx.CreateNode(opType, new[] { linearOutput }, + new[] { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx"); return true; } diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 13020f0c94..bb97c0b425 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -206,8 +206,8 @@ public override void SaveAsOnnx(IOnnxContext ctx) if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2])) { string opType = "Binarizer"; - var node = ctx.CreateNode(opType, new List { ctx.GetVariableName(outColumnNames[2]) }, - new List { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new[] { ctx.GetVariableName(outColumnNames[2]) }, + new[] { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType)); node.AddAttribute("threshold", 0.5); } } diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index a8fc11b130..537a412597 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -720,8 +720,8 @@ public void SaveAsOnnx(IOnnxContext ctx) Source.Schema.GetColumnType(srcIndex).ValueCount)); } - var node = ctx.CreateNode(opType, new List(inputList.Select(t => t.Key)), - new List { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, inputList.Select(t => t.Key), + new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType)); node.AddAttribute("inputList", inputList.Select(x => x.Key)); node.AddAttribute("inputdimensions", inputList.Select(x => x.Value)); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 06d00324be..29ff38b3aa 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -3130,8 +3130,7 @@ public virtual bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string fe } string opType = "TreeEnsembleRegressor"; - var node = ctx.CreateNode(opType, new List { featureColumn }, - new List(outputNames), ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new[] { featureColumn }, outputNames, ctx.GetNodeName(opType)); node.AddAttribute("post_transform", PostTransform.None.GetDescription()); node.AddAttribute("n_targets", 1); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 61e322838c..fb1b2be1fb 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -236,7 +236,7 @@ public bool SaveAsOnnx(IOnnxContext ctx, string[] outputs, string featureColumn) Host.Check(Utils.Size(outputs) == 1); string opType = "LinearRegressor"; - var node = ctx.CreateNode(opType, new List { featureColumn }, new List (outputs), ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType)); // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT} node.AddAttribute("post_transform", 0); node.AddAttribute("targets", 1); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 5b4854906b..f4ed5f8fe8 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -844,7 +844,7 @@ public bool SaveAsOnnx(IOnnxContext ctx, string[] outputs, string featureColumn) Host.CheckValue(ctx, nameof(ctx)); string opType = "LinearClassifier"; - var node = ctx.CreateNode(opType, new List { featureColumn }, new List(outputs), ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType)); // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT} node.AddAttribute("post_transform", 0); node.AddAttribute("multi_class", true); diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxContext.cs b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxContext.cs index 1340d9778f..83903704eb 100644 --- a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxContext.cs @@ -17,10 +17,13 @@ internal sealed class OnnxContext : IOnnxContext { private readonly List _nodes; private readonly List _inputs; + // The map from IDataView column names to variable names. private readonly List _intermediateValues; private readonly List _outputs; private readonly Dictionary _columnNameMap; - private readonly HashSet _variableMap; + // All existing variable names. New variables must not exist in this set. + private readonly HashSet _variableNames; + // All existing node names. New node names must not alrady exist in this set. private readonly HashSet _nodeNames; private readonly string _name; private readonly string _producerName; @@ -42,7 +45,7 @@ public OnnxContext(IHostEnvironment env, string name, string producerName, _inputs = new List(); _outputs = new List(); _columnNameMap = new Dictionary(); - _variableMap = new HashSet(); + _variableNames = new HashSet(); _nodeNames = new HashSet(); _name = name; _producerName = producerName; @@ -102,10 +105,10 @@ public void RemoveVariable(string variableName, bool removeColumn) string columnName = _columnNameMap.Single(kvp => string.Compare(kvp.Value, variableName) == 0).Key; - Contracts.Assert(_variableMap.Contains(columnName)); + Contracts.Assert(_variableNames.Contains(columnName)); _columnNameMap.Remove(columnName); - _variableMap.Remove(columnName); + _variableNames.Remove(columnName); } /// @@ -114,7 +117,7 @@ public void RemoveVariable(string variableName, bool removeColumn) public string GetNodeName(string prefix) { _host.CheckValue(prefix, nameof(prefix)); - return GetUniqueName(prefix, c => _nodeNames.Contains(c)); + return GetUniqueName(prefix, _nodeNames.Contains); } /// @@ -130,8 +133,8 @@ private void AddNode(NodeProto node) _nodes.Add(node); } - public IOnnxNode CreateNode(string opType, List inputs, - List outputs, string name, string domain = null) + public IOnnxNode CreateNode(string opType, IEnumerable inputs, + IEnumerable outputs, string name, string domain = null) { var innerNode = OnnxUtils.MakeNode(opType, inputs, outputs, name, domain); AddNode(innerNode); @@ -176,7 +179,6 @@ public string TryGetVariableName(string colName) { if (_columnNameMap.ContainsKey(colName)) return GetVariableName(colName); - return null; } @@ -189,13 +191,8 @@ public string TryGetVariableName(string colName) private string AddVariable(string colName) { _host.CheckValue(colName, nameof(colName)); - - if (!_columnNameMap.ContainsKey(colName)) - _columnNameMap.Add(colName, colName); - else - _columnNameMap[colName] = GetUniqueName(colName, s => _variableMap.Contains(s)); - - _variableMap.Add(_columnNameMap[colName]); + _columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains); + _variableNames.Add(_columnNameMap[colName]); return _columnNameMap[colName]; } @@ -204,17 +201,13 @@ private string AddVariable(string colName) /// public string AddIntermediateVariable(ColumnType type, string colName, bool skip = false) { - colName = AddVariable(colName); - - //Let the runtime figure the shape. + // Let the runtime figure the shape. if (!skip) { _host.CheckValue(type, nameof(type)); - _intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName)); } - return colName; } diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxUtils.cs b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxUtils.cs index cf61b42445..be5a833643 100644 --- a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxUtils.cs +++ b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxUtils.cs @@ -152,7 +152,7 @@ private static AttributeProto MakeAttribute(string key, IEnumerable private static AttributeProto MakeAttribute(string key, bool value) => MakeAttribute(key, value ? 1 : 0); - public static NodeProto MakeNode(string opType, List inputs, List outputs, string name, string domain = null) + public static NodeProto MakeNode(string opType, IEnumerable inputs, IEnumerable outputs, string name, string domain = null) { Contracts.CheckNonEmpty(opType, nameof(opType)); Contracts.CheckValue(inputs, nameof(inputs)); From 971eda67c747bcbb376fb26f724e0547404785ec Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 29 Jun 2018 23:06:08 -0700 Subject: [PATCH 3/9] Change "universal model format" to "ONNX". Forgot one dependency in Tests -> Onnx --- Microsoft.ML.sln | 2 +- .../Microsoft.ML.Onnx.csproj} | 1 + .../Onnx => Microsoft.ML.Onnx}/OnnxContext.cs | 0 .../Onnx => Microsoft.ML.Onnx}/OnnxMl.cs | 0 .../Onnx => Microsoft.ML.Onnx}/OnnxMl.md | 0 .../Onnx => Microsoft.ML.Onnx}/OnnxNode..cs | 0 .../Onnx => Microsoft.ML.Onnx}/OnnxUtils.cs | 0 .../Onnx => Microsoft.ML.Onnx}/SaveOnnxCommand.cs | 0 test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj | 2 +- 9 files changed, 3 insertions(+), 2 deletions(-) rename src/{Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj => Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj} (86%) rename src/{Microsoft.ML.UniversalModelFormat/Onnx => Microsoft.ML.Onnx}/OnnxContext.cs (100%) rename src/{Microsoft.ML.UniversalModelFormat/Onnx => Microsoft.ML.Onnx}/OnnxMl.cs (100%) rename src/{Microsoft.ML.UniversalModelFormat/Onnx => Microsoft.ML.Onnx}/OnnxMl.md (100%) rename src/{Microsoft.ML.UniversalModelFormat/Onnx => Microsoft.ML.Onnx}/OnnxNode..cs (100%) rename src/{Microsoft.ML.UniversalModelFormat/Onnx => Microsoft.ML.Onnx}/OnnxUtils.cs (100%) rename src/{Microsoft.ML.UniversalModelFormat/Onnx => Microsoft.ML.Onnx}/SaveOnnxCommand.cs (100%) diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 63a15c977c..a1373d87ec 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -18,7 +18,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InferenceTesti EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Data", "src\Microsoft.ML.Data\Microsoft.ML.Data.csproj", "{AD92D96B-0E96-4F22-8DCE-892E13B1F282}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.UniversalModelFormat", "src\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Onnx", "src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StandardLearners", "src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj", "{707BB22C-7E5F-497A-8C2F-74578F675705}" EndProject diff --git a/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj b/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj similarity index 86% rename from src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj rename to src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj index 800525aca2..4a7540dc38 100644 --- a/src/Microsoft.ML.UniversalModelFormat/Microsoft.ML.UniversalModelFormat.csproj +++ b/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj @@ -3,6 +3,7 @@ netstandard2.0 Microsoft.ML + Microsoft.ML.Runtime.Model.Onnx diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxContext.cs b/src/Microsoft.ML.Onnx/OnnxContext.cs similarity index 100% rename from src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxContext.cs rename to src/Microsoft.ML.Onnx/OnnxContext.cs diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs b/src/Microsoft.ML.Onnx/OnnxMl.cs similarity index 100% rename from src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs rename to src/Microsoft.ML.Onnx/OnnxMl.cs diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md b/src/Microsoft.ML.Onnx/OnnxMl.md similarity index 100% rename from src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md rename to src/Microsoft.ML.Onnx/OnnxMl.md diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxNode..cs b/src/Microsoft.ML.Onnx/OnnxNode..cs similarity index 100% rename from src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxNode..cs rename to src/Microsoft.ML.Onnx/OnnxNode..cs diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxUtils.cs b/src/Microsoft.ML.Onnx/OnnxUtils.cs similarity index 100% rename from src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxUtils.cs rename to src/Microsoft.ML.Onnx/OnnxUtils.cs diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs similarity index 100% rename from src/Microsoft.ML.UniversalModelFormat/Onnx/SaveOnnxCommand.cs rename to src/Microsoft.ML.Onnx/SaveOnnxCommand.cs diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index d12f875d01..2a2ea8bca1 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -11,7 +11,7 @@ - + From 4bed4169d97159e86d678460c736ed7563dab5b6 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Sat, 30 Jun 2018 08:00:10 -0700 Subject: [PATCH 4/9] Add reference to Core.Tests to ONNX project for entry-point catalog --- test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj index bed3dce0eb..ea92975b35 100644 --- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj +++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj @@ -11,6 +11,7 @@ + From 1ba5cbd39472c9060b58a9fb429f6fc39d23d62f Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Sat, 30 Jun 2018 08:05:18 -0700 Subject: [PATCH 5/9] Onnx -> ONNX in prose text --- src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs index 2ae075ba85..7c81b4d7a3 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs @@ -19,7 +19,7 @@ public interface ICanSaveOnnx } /// - /// This data model component is savable as Onnx. + /// This data model component is savable as ONNX. /// public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform { From d32598c358f1e58ce6ff5c1507ad4e3dc03e0399 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Mon, 2 Jul 2018 15:46:48 -0700 Subject: [PATCH 6/9] Update console project reference to use ONNX project. Add new nuget. Review comments. --- pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj | 13 +++++++++++++ .../Microsoft.ML.Onnx.symbols.nupkgproj | 5 +++++ pkg/Microsoft.ML/Microsoft.ML.nupkgproj | 1 - .../Microsoft.ML.Console.csproj | 2 +- src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs | 13 +++++++------ src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj | 2 +- 6 files changed, 27 insertions(+), 9 deletions(-) create mode 100644 pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj create mode 100644 pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj diff --git a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj new file mode 100644 index 0000000000..b068970336 --- /dev/null +++ b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj @@ -0,0 +1,13 @@ + + + + netstandard2.0 + ML.NET component for Exporting ONNX Models + + + + + + + + diff --git a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj new file mode 100644 index 0000000000..07807bb54b --- /dev/null +++ b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj @@ -0,0 +1,5 @@ + + + + + diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj index e28b198bcf..2886b39ad3 100644 --- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj +++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj @@ -6,7 +6,6 @@ - diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj index 1256bf75ba..5a5a67ccda 100644 --- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj @@ -19,13 +19,13 @@ + - diff --git a/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs index 0e8dfb7ede..81fb8e4dd9 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs @@ -18,16 +18,17 @@ namespace Microsoft.ML.Runtime.Model.Onnx public interface IOnnxNode { void AddAttribute(string argName, double value); + void AddAttribute(string argName, long value); + void AddAttribute(string argName, DvText value); + void AddAttribute(string argName, string value); + void AddAttribute(string argName, bool value); + void AddAttribute(string argName, IEnumerable value); void AddAttribute(string argName, IEnumerable value); - void AddAttribute(string argName, IEnumerable value); - void AddAttribute(string argName, long value); void AddAttribute(string argName, IEnumerable value); - void AddAttribute(string argName, DvText value); - void AddAttribute(string argName, string[] value); void AddAttribute(string argName, IEnumerable value); + void AddAttribute(string argName, string[] value); void AddAttribute(string argName, IEnumerable value); - void AddAttribute(string argName, string value); - void AddAttribute(string argName, bool value); + void AddAttribute(string argName, IEnumerable value); } } diff --git a/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj b/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj index 4a7540dc38..145dd8be8c 100644 --- a/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj +++ b/src/Microsoft.ML.Onnx/Microsoft.ML.Onnx.csproj @@ -2,7 +2,7 @@ netstandard2.0 - Microsoft.ML + Microsoft.ML.Onnx Microsoft.ML.Runtime.Model.Onnx From 8d1c469b71454df9bc630d72c5edf71374fa0a85 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Mon, 2 Jul 2018 17:27:48 -0700 Subject: [PATCH 7/9] * Make extension points abstract classes. * Improve ML.NET version handling to actually rely on the assembly vs. a constant string we have to change. * Tighten checks of arguments to the context implementation. --- .../Model/Onnx/ICanSaveOnnx.cs | 8 ++-- src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs | 34 -------------- .../Onnx/{IOnnxContext.cs => OnnxContext.cs} | 24 +++++----- src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs | 32 +++++++++++++ .../Prediction/Calibrator.cs | 6 +-- .../Scorers/BinaryClassifierScorer.cs | 2 +- .../Scorers/GenericScorer.cs | 2 +- .../Scorers/MultiClassClassifierScorer.cs | 2 +- .../Scorers/PredictedLabelScorerBase.cs | 2 +- .../Scorers/SchemaBindablePredictorWrapper.cs | 6 +-- .../Transforms/ConcatTransform.cs | 2 +- .../Transforms/KeyToVectorTransform.cs | 2 +- .../Transforms/NormalizeColumn.cs | 6 +-- .../Transforms/NormalizeColumnDbl.cs | 4 +- .../Transforms/NormalizeColumnSng.cs | 4 +- .../Transforms/NormalizeTransform.cs | 4 +- .../Transforms/TermTransform.cs | 2 +- .../Transforms/TransformBase.cs | 4 +- src/Microsoft.ML.FastTree/FastTree.cs | 2 +- .../{OnnxContext.cs => OnnxContextImpl.cs} | 46 +++++++++++-------- .../{OnnxNode..cs => OnnxNodeImpl..cs} | 28 +++++------ src/Microsoft.ML.Onnx/SaveOnnxCommand.cs | 7 ++- .../Standard/LinearPredictor.cs | 2 +- .../MulticlassLogisticRegression.cs | 2 +- .../NAReplaceTransform.cs | 2 +- .../BreastCancer/SaveModelToOnnxTest.json | 2 +- test/Microsoft.ML.Tests/OnnxTests.cs | 6 +++ 27 files changed, 130 insertions(+), 113 deletions(-) delete mode 100644 src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs rename src/Microsoft.ML.Data/Model/Onnx/{IOnnxContext.cs => OnnxContext.cs} (85%) create mode 100644 src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs rename src/Microsoft.ML.Onnx/{OnnxContext.cs => OnnxContextImpl.cs} (83%) rename src/Microsoft.ML.Onnx/{OnnxNode..cs => OnnxNodeImpl..cs} (56%) diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs index 7c81b4d7a3..36d839b93d 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs @@ -27,7 +27,7 @@ public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform /// Save as ONNX. /// /// The ONNX program being built - void SaveAsOnnx(IOnnxContext ctx); + void SaveAsOnnx(OnnxContext ctx); } /// @@ -52,7 +52,7 @@ public interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper /// the outputs produced by this bindable mapper. This is the array that holds /// those names, so that implementors of this method know what to produce in /// . - bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames); + bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames); } /// @@ -61,7 +61,7 @@ public interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper /// public interface ISingleCanSaveOnnx : ICanSaveOnnx { - bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumn); + bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn); } /// @@ -70,6 +70,6 @@ public interface ISingleCanSaveOnnx : ICanSaveOnnx /// public interface IDistCanSaveOnnx : ISingleCanSaveOnnx, IValueMapperDist { - new bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumn); + new bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn); } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs deleted file mode 100644 index 81fb8e4dd9..0000000000 --- a/src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs +++ /dev/null @@ -1,34 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime.Data; -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.ML.Runtime.Model.Onnx -{ - /// - /// An abstraction for an ONNX node as created by - /// . - /// That method creates a with inputs and outputs, but this object can modify the node further - /// by adding attributes (in ONNX parlance, attributes are more or less constant parameterizations). - /// - public interface IOnnxNode - { - void AddAttribute(string argName, double value); - void AddAttribute(string argName, long value); - void AddAttribute(string argName, DvText value); - void AddAttribute(string argName, string value); - void AddAttribute(string argName, bool value); - - void AddAttribute(string argName, IEnumerable value); - void AddAttribute(string argName, IEnumerable value); - void AddAttribute(string argName, IEnumerable value); - void AddAttribute(string argName, IEnumerable value); - void AddAttribute(string argName, string[] value); - void AddAttribute(string argName, IEnumerable value); - void AddAttribute(string argName, IEnumerable value); - } -} diff --git a/src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs similarity index 85% rename from src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs rename to src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index dc81cf5419..8ea8446ea6 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/IOnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -16,14 +16,14 @@ namespace Microsoft.ML.Runtime.Model.Onnx /// /// /// - public interface IOnnxContext + public abstract class OnnxContext { /// /// Generates a unique name for the node based on a prefix. /// /// The prefix for the node /// A name that has not yet been returned from this function, starting with - string GetNodeName(string prefix); + public abstract string GetNodeName(string prefix); /// /// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can @@ -31,7 +31,7 @@ public interface IOnnxContext /// /// The data view column name /// Whether the column is mapped in this context - bool ContainsColumn(string colName); + public abstract bool ContainsColumn(string colName); /// /// Stops tracking a column. @@ -40,7 +40,7 @@ public interface IOnnxContext /// Remove associated ONNX variable. This is useful in the event where an output /// variable is created through before realizing /// the transform cannot actually save as ONNX. - void RemoveColumn(string colName, bool removeVariable = false); + public abstract void RemoveColumn(string colName, bool removeVariable = false); /// /// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the ONNX variable to remove. Note that this is an ONNX variable name, not an column name /// IDataView column to stop tracking - void RemoveVariable(string variableName, bool removeColumn); + public abstract void RemoveVariable(string variableName, bool removeColumn); /// /// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding @@ -61,7 +61,7 @@ public interface IOnnxContext /// /// The data view column name /// The ONNX variable name corresponding to that data view column - string GetVariableName(string colName); + public abstract string GetVariableName(string colName); /// /// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and @@ -72,7 +72,7 @@ public interface IOnnxContext /// Whether we should skip the process of establishing the mapping from data view column to /// ONNX variable name. /// The returned value is the name of the variable corresponding - string AddIntermediateVariable(ColumnType type, string colName, bool skip = false); + public abstract string AddIntermediateVariable(ColumnType type, string colName, bool skip = false); /// /// Creates an ONNX node @@ -84,25 +84,25 @@ public interface IOnnxContext /// The name of the operator, which ought to be something returned from /// The domain of the ONNX operator, if non-default /// A node added to the in-progress ONNX graph, that attributes can be set on - IOnnxNode CreateNode(string opType, IEnumerable inputs, + public abstract OnnxNode CreateNode(string opType, IEnumerable inputs, IEnumerable outputs, string name, string domain = null); } public static class OnnxContextExtensions { /// - /// Convenience alternative to + /// Convenience alternative to /// for the case where there is exactly one input and output. /// /// The ONNX save context /// The name of the ONNX operator to apply /// The name of the variable as input /// The name of the variable as output, - /// which ought to have been something returned from - /// The name of the operator, which ought to be something returned from + /// which ought to have been something returned from + /// The name of the operator, which ought to be something returned from /// The domain of the ONNX operator, if non-default /// A node added to the in-progress ONNX graph, that attributes can be set on - public static IOnnxNode CreateNode(this IOnnxContext ctx, + public static OnnxNode CreateNode(this OnnxContext ctx, string opType, string input, string output, string name, string domain = null) => ctx.CreateNode(opType, new[] { input }, new[] { output }, name, domain); } diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs new file mode 100644 index 0000000000..259a6d27d4 --- /dev/null +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Runtime.Model.Onnx +{ + /// + /// An abstraction for an ONNX node as created by + /// . + /// That method creates a with inputs and outputs, but this object can modify the node further + /// by adding attributes (in ONNX parlance, attributes are more or less constant parameterizations). + /// + public abstract class OnnxNode + { + public abstract void AddAttribute(string argName, double value); + public abstract void AddAttribute(string argName, long value); + public abstract void AddAttribute(string argName, DvText value); + public abstract void AddAttribute(string argName, string value); + public abstract void AddAttribute(string argName, bool value); + + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, string[] value); + public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, IEnumerable value); + } +} diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index c8cbecfb5b..835ba5d99a 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -296,7 +296,7 @@ public void SaveAsPfa(BoundPfaContext ctx, JToken input, probToken = ctx.DeclareVar(prob, probExpression); } - public bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumnName) + public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumnName) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(outputNames, nameof(outputNames)); @@ -658,7 +658,7 @@ public void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] out ctx.Hide(outputs); } - public bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputs) + public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputs) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckParam(Utils.Size(outputs) == 2, nameof(outputs), "Expected this to have two outputs"); @@ -1429,7 +1429,7 @@ public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) PfaUtils.Call("+", -ParamB, PfaUtils.Call("*", -ParamA, input))); } - public bool SaveAsOnnx(IOnnxContext ctx, string[] scoreProbablityColumnNames, string featureColumnName) + public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, string featureColumnName) { _host.CheckValue(ctx, nameof(ctx)); _host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames)); diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index bb97c0b425..91f799b2e7 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -184,7 +184,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write(_threshold); } - public override void SaveAsOnnx(IOnnxContext ctx) + public override void SaveAsOnnx(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index 4363b9a435..91c84a0734 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -214,7 +214,7 @@ public void SaveAsPfa(BoundPfaContext ctx) pfaBindable.SaveAsPfa(ctx, schema, outColNames); } - public void SaveAsOnnx(IOnnxContext ctx) + public void SaveAsOnnx(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 953cebba05..4832f92cd4 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -203,7 +203,7 @@ public void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] out ((IBindableCanSavePfa)_bindable).SaveAsPfa(ctx, schema, outputNames); } - public bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index d80ff3bc17..fe69585b78 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -365,7 +365,7 @@ public void SaveAsPfa(BoundPfaContext ctx) protected abstract JToken PredictedLabelPfa(string[] mapperOutputs); - public virtual void SaveAsOnnx(IOnnxContext ctx) + public virtual void SaveAsOnnx(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 2f7e887f10..1e07f587f7 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -99,7 +99,7 @@ public virtual void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, stri ctx.Hide(outputNames); } - public virtual bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false; + public virtual bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false; public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) { @@ -289,7 +289,7 @@ public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, str ctx.DeclareVar(outputNames[0], scoreToken); } - public override bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + public override bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); @@ -403,7 +403,7 @@ public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, str Contracts.Assert(ctx.TokenOrNullForName(outputNames[1]) == probToken.ToString()); } - public override bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + public override bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index 537a412597..db572fc93f 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -687,7 +687,7 @@ public void SaveAsPfa(BoundPfaContext ctx) ctx.DeclareVar(toDeclare.ToArray()); } - public void SaveAsOnnx(IOnnxContext ctx) + public void SaveAsOnnx(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(CanSaveOnnx); diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 5b2b2dc6c6..d177d75647 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -241,7 +241,7 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo PfaUtils.Call("cast.fanoutDouble", -1, 0, keyCount, false), PfaUtils.FuncRef("u." + funcName)); } - protected override bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { string opType = "OneHotEncoder"; var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index 307236673b..6092c2d8b9 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -435,7 +435,7 @@ private AffineColumnFunction(IHost host) public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken); public bool CanSaveOnnx => true; - public abstract bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount); + public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount); public abstract Delegate GetGetter(IRow input, int icol); @@ -548,7 +548,7 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) public bool CanSaveOnnx => false; - public bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount) + public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) => throw Host.ExceptNotSupp(); public abstract Delegate GetGetter(IRow input, int icol); @@ -673,7 +673,7 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) public bool CanSaveOnnx => false; - public bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount) + public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) => throw Host.ExceptNotSupp(); public abstract Delegate GetGetter(IRow input, int icol); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs index 70a660e353..e577b9370e 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs @@ -577,7 +577,7 @@ public override void Save(ModelSaveContext ctx) public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale); - public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount) + public override bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) { nodeProtoWrapper.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount)); nodeProtoWrapper.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount)); @@ -648,7 +648,7 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType))); } - public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode node, int featureCount) + public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) { if (Offset != null) diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs index 0cdd8807ce..4c6e1fb011 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs @@ -577,7 +577,7 @@ public override void Save(ModelSaveContext ctx) public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale); - public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode node, int featureCount) + public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) { node.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount)); node.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount)); @@ -648,7 +648,7 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType))); } - public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode node, int featureCount) + public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) { if (Offset != null) node.AddAttribute("offset", Offset); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs index a857d3f80f..22fa9686ed 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs @@ -68,7 +68,7 @@ public interface IColumnFunction : ICanSaveModel bool CanSaveOnnx { get; } - bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount); + bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount); } public sealed partial class NormalizeTransform : OneToOneTransformBase @@ -308,7 +308,7 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo return _functions[iinfo].PfaInfo(ctx, srcToken); } - protected override bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { Contracts.AssertValue(ctx); Contracts.Assert(0 <= iinfo && iinfo < Infos.Length); diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index 57ab877eab..bb9adf21e1 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -681,7 +681,7 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo return PfaUtils.If(PfaUtils.Call("map.containsKey", cellRef, srcToken), PfaUtils.Index(cellRef, srcToken), -1); } - protected override bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { if (!info.TypeSrc.ItemType.IsText) return false; diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index 42d5ff7a2a..263a3cf4ca 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -571,7 +571,7 @@ public void SaveAsPfa(BoundPfaContext ctx) ctx.DeclareVar(toDeclare.ToArray()); } - public void SaveAsOnnx(IOnnxContext ctx) + public void SaveAsOnnx(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(CanSaveOnnx); @@ -616,7 +616,7 @@ protected virtual JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo i return null; } - protected virtual bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, + protected virtual bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) => false; public sealed override ISchema Schema => _bindings; diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 29ff38b3aa..01668eb3f2 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -3065,7 +3065,7 @@ private enum AggregateFunction Max } - public virtual bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumn) + public virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Onnx/OnnxContext.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs similarity index 83% rename from src/Microsoft.ML.Onnx/OnnxContext.cs rename to src/Microsoft.ML.Onnx/OnnxContextImpl.cs index 83903704eb..9cee8c8332 100644 --- a/src/Microsoft.ML.Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML.Runtime.Model.Onnx /// /// A context for defining a ONNX output. /// - internal sealed class OnnxContext : IOnnxContext + internal sealed class OnnxContextImpl : OnnxContext { private readonly List _nodes; private readonly List _inputs; @@ -32,14 +32,14 @@ internal sealed class OnnxContext : IOnnxContext private readonly string _producerVersion; private readonly long _modelVersion; - public OnnxContext(IHostEnvironment env, string name, string producerName, + public OnnxContextImpl(IHostEnvironment env, string name, string producerName, string producerVersion, long modelVersion, string domain) { Contracts.CheckValue(env, nameof(env)); - Contracts.CheckValue(name, nameof(name)); - Contracts.CheckValue(name, nameof(domain)); - _host = env.Register(nameof(OnnxContext)); + _host.CheckValue(name, nameof(name)); + _host.CheckValue(name, nameof(domain)); + _nodes = new List(); _intermediateValues = new List(); _inputs = new List(); @@ -54,7 +54,7 @@ public OnnxContext(IHostEnvironment env, string name, string producerName, _domain = domain; } - public bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName); + public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName); /// /// Stops tracking a column. If removeVariable is true then it also removes the @@ -63,8 +63,9 @@ public OnnxContext(IHostEnvironment env, string name, string producerName, /// /// IDataView column name to stop tracking /// Remove associated ONNX variable at the time. - public void RemoveColumn(string colName, bool removeVariable) + public override void RemoveColumn(string colName, bool removeVariable) { + _host.CheckNonEmpty(colName, nameof(colName)); if (removeVariable) { @@ -88,9 +89,12 @@ public void RemoveColumn(string colName, bool removeVariable) /// /// ONNX variable to remove. /// IDataView column to stop tracking - public void RemoveVariable(string variableName, bool removeColumn) + public override void RemoveVariable(string variableName, bool removeColumn) { - _host.Assert(_columnNameMap.ContainsValue(variableName)); + _host.CheckNonEmpty(variableName, nameof(variableName)); + if (!_columnNameMap.ContainsValue(variableName)) + throw _host.ExceptParam(nameof(variableName), $"Could not find '{variableName}' declared in ONNX graph"); + if (removeColumn) { foreach (var val in _intermediateValues) @@ -114,9 +118,9 @@ public void RemoveVariable(string variableName, bool removeColumn) /// /// Generates a unique name for the node based on a prefix. /// - public string GetNodeName(string prefix) + public override string GetNodeName(string prefix) { - _host.CheckValue(prefix, nameof(prefix)); + _host.CheckNonEmpty(prefix, nameof(prefix)); return GetUniqueName(prefix, _nodeNames.Contains); } @@ -133,12 +137,17 @@ private void AddNode(NodeProto node) _nodes.Add(node); } - public IOnnxNode CreateNode(string opType, IEnumerable inputs, + public override OnnxNode CreateNode(string opType, IEnumerable inputs, IEnumerable outputs, string name, string domain = null) { + _host.CheckNonEmpty(opType, nameof(opType)); + _host.CheckValue(inputs, nameof(inputs)); + _host.CheckValue(outputs, nameof(outputs)); + _host.CheckNonEmpty(name, nameof(name)); + var innerNode = OnnxUtils.MakeNode(opType, inputs, outputs, name, domain); AddNode(innerNode); - return new OnnxNode(innerNode); + return new OnnxNodeImpl(innerNode); } /// @@ -146,7 +155,7 @@ public IOnnxNode CreateNode(string opType, IEnumerable inputs, /// private string GetUniqueName(string prefix, Func pred) { - _host.CheckValue(prefix, nameof(prefix)); + _host.CheckNonEmpty(prefix, nameof(prefix)); _host.CheckValue(pred, nameof(pred)); if (!pred(prefix)) @@ -162,9 +171,9 @@ private string GetUniqueName(string prefix, Func pred) /// given point in the pipeline execution. /// /// Column Name mapping. - public string GetVariableName(string colName) + public override string GetVariableName(string colName) { - _host.CheckValue(colName, nameof(colName)); + _host.CheckNonEmpty(colName, nameof(colName)); _host.Assert(_columnNameMap.ContainsKey(colName)); return _columnNameMap[colName]; @@ -177,6 +186,7 @@ public string GetVariableName(string colName) /// Column Name mapping. public string TryGetVariableName(string colName) { + _host.CheckNonEmpty(colName, nameof(colName)); if (_columnNameMap.ContainsKey(colName)) return GetVariableName(colName); return null; @@ -190,7 +200,7 @@ public string TryGetVariableName(string colName) /// Unique variable name. private string AddVariable(string colName) { - _host.CheckValue(colName, nameof(colName)); + _host.CheckNonEmpty(colName, nameof(colName)); _columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains); _variableNames.Add(_columnNameMap[colName]); return _columnNameMap[colName]; @@ -199,7 +209,7 @@ private string AddVariable(string colName) /// /// Adds an intermediate column to the list. /// - public string AddIntermediateVariable(ColumnType type, string colName, bool skip = false) + public override string AddIntermediateVariable(ColumnType type, string colName, bool skip = false) { colName = AddVariable(colName); // Let the runtime figure the shape. diff --git a/src/Microsoft.ML.Onnx/OnnxNode..cs b/src/Microsoft.ML.Onnx/OnnxNodeImpl..cs similarity index 56% rename from src/Microsoft.ML.Onnx/OnnxNode..cs rename to src/Microsoft.ML.Onnx/OnnxNodeImpl..cs index efbd7f676d..9b30fd1d87 100644 --- a/src/Microsoft.ML.Onnx/OnnxNode..cs +++ b/src/Microsoft.ML.Onnx/OnnxNodeImpl..cs @@ -8,39 +8,39 @@ namespace Microsoft.ML.Runtime.Model.Onnx { - internal sealed class OnnxNode : IOnnxNode + internal sealed class OnnxNodeImpl : OnnxNode { private readonly NodeProto _node; - public OnnxNode(NodeProto node) + public OnnxNodeImpl(NodeProto node) { Contracts.AssertValue(node); _node = node; } - public void AddAttribute(string argName, double value) + public override void AddAttribute(string argName, double value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, IEnumerable value) + public override void AddAttribute(string argName, IEnumerable value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, IEnumerable value) + public override void AddAttribute(string argName, IEnumerable value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, IEnumerable value) + public override void AddAttribute(string argName, IEnumerable value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, long value) + public override void AddAttribute(string argName, long value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, IEnumerable value) + public override void AddAttribute(string argName, IEnumerable value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, DvText value) + public override void AddAttribute(string argName, DvText value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, string[] value) + public override void AddAttribute(string argName, string[] value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, IEnumerable value) + public override void AddAttribute(string argName, IEnumerable value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, IEnumerable value) + public override void AddAttribute(string argName, IEnumerable value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, string value) + public override void AddAttribute(string argName, string value) => OnnxUtils.NodeAddAttributes(_node, argName, value); - public void AddAttribute(string argName, bool value) + public override void AddAttribute(string argName, bool value) => OnnxUtils.NodeAddAttributes(_node, argName, value); } } diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index 0b2b8cb6ad..6b72d64af5 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -68,7 +68,6 @@ public sealed class Arguments : DataCommand.ArgumentsBase private readonly HashSet _outputsToDrop; private readonly ITransformModel _model; private const string ProducerName = "ML.NET"; - private const string ProducerVersion = "0.2.0.0000"; private const long ModelVersion = 0; public SaveOnnxCommand(IHostEnvironment env, Arguments args) @@ -163,7 +162,11 @@ private void Run(IChannel ch) GetPipe(ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); - var ctx = new OnnxContext(Host, _name, ProducerName, ProducerVersion, ModelVersion, _domain); + var assembly = System.Reflection.Assembly.GetExecutingAssembly(); + var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); + + var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, + ModelVersion, _domain); // If we have a predictor, try to get the scorer for it. if (rawPred != null) { diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index fb1b2be1fb..3fc63060e8 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -230,7 +230,7 @@ public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) return PfaUtils.Call("model.reg.linear", input, cellRef); } - public bool SaveAsOnnx(IOnnxContext ctx, string[] outputs, string featureColumn) + public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); Host.Check(Utils.Size(outputs) == 1); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index f4ed5f8fe8..5fbc6ab74a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -839,7 +839,7 @@ public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) return PfaUtils.Call("m.link.softmax", PfaUtils.Call("model.reg.linear", input, cellRef)); } - public bool SaveAsOnnx(IOnnxContext ctx, string[] outputs, string featureColumn) + public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index 156d94f52e..44832ee517 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -601,7 +601,7 @@ private Delegate ComposeGetterVec(IRow input, int iinfo) }; } - protected override bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { DataKind rawKind; var type = Infos[iinfo].TypeSrc; diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json index 3afd6dc171..b2d611b784 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json @@ -1,7 +1,7 @@ { "irVersion": "3", "producerName": "ML.NET", - "producerVersion": "0.2.0.0000", + "producerVersion": "##VERSION##", "domain": "Onnx", "graph": { "node": [ diff --git a/test/Microsoft.ML.Tests/OnnxTests.cs b/test/Microsoft.ML.Tests/OnnxTests.cs index 477a9c6fa6..f4428242be 100644 --- a/test/Microsoft.ML.Tests/OnnxTests.cs +++ b/test/Microsoft.ML.Tests/OnnxTests.cs @@ -9,6 +9,7 @@ using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Trainers; using System.IO; +using System.Text.RegularExpressions; using Xunit; using Xunit.Abstractions; @@ -86,6 +87,11 @@ public void BinaryClassificationSaveModelToOnnxTest() converter.Convert(model); + // Strip the version. + var fileText = File.ReadAllText(onnxAsJsonPath); + fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); + File.WriteAllText(onnxAsJsonPath, fileText); + CheckEquality(subDir, "SaveModelToOnnxTest.json"); Done(); } From 9429d88d7645c1120d8b03148dc50b094625f7d7 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Tue, 3 Jul 2018 09:34:26 -0700 Subject: [PATCH 8/9] CreateNode now regular method. Exporting => exporting. --- pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj | 2 +- src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj index b068970336..bcc86939e2 100644 --- a/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj +++ b/pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj @@ -2,7 +2,7 @@ netstandard2.0 - ML.NET component for Exporting ONNX Models + ML.NET component for exporting ONNX Models diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 8ea8446ea6..96fc1046ce 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -86,15 +86,11 @@ public abstract class OnnxContext /// A node added to the in-progress ONNX graph, that attributes can be set on public abstract OnnxNode CreateNode(string opType, IEnumerable inputs, IEnumerable outputs, string name, string domain = null); - } - public static class OnnxContextExtensions - { /// - /// Convenience alternative to + /// Convenience alternative to /// for the case where there is exactly one input and output. /// - /// The ONNX save context /// The name of the ONNX operator to apply /// The name of the variable as input /// The name of the variable as output, @@ -102,8 +98,7 @@ public static class OnnxContextExtensions /// The name of the operator, which ought to be something returned from /// The domain of the ONNX operator, if non-default /// A node added to the in-progress ONNX graph, that attributes can be set on - public static OnnxNode CreateNode(this OnnxContext ctx, - string opType, string input, string output, string name, string domain = null) - => ctx.CreateNode(opType, new[] { input }, new[] { output }, name, domain); + public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null) + => CreateNode(opType, new[] { input }, new[] { output }, name, domain); } } From 689f1d925a16b2bd59da50907d113ffb110cd686 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Tue, 3 Jul 2018 16:56:41 -0700 Subject: [PATCH 9/9] Remove extra lines, no string.Compare silliness --- src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs | 2 -- src/Microsoft.ML.Onnx/OnnxContextImpl.cs | 6 ++---- test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs | 4 ++-- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 96fc1046ce..bdef784b29 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -13,8 +13,6 @@ namespace Microsoft.ML.Runtime.Model.Onnx /// and subinterfaces, that attempt to express their operations as ONNX nodes, if they can. At the point that it is /// given to a component, all other components up to that component have already attempted to express themselves in /// this context, with their outputs possibly available in the ONNX graph. - /// - /// /// public abstract class OnnxContext { diff --git a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs index 9cee8c8332..f37a1ea557 100644 --- a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs +++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs @@ -78,9 +78,7 @@ public override void RemoveColumn(string colName, bool removeVariable) } } } - - if (_columnNameMap.ContainsKey(colName)) - _columnNameMap.Remove(colName); + _columnNameMap.Remove(colName); } /// @@ -107,7 +105,7 @@ public override void RemoveVariable(string variableName, bool removeColumn) } } - string columnName = _columnNameMap.Single(kvp => string.Compare(kvp.Value, variableName) == 0).Key; + string columnName = _columnNameMap.Single(kvp => kvp.Value == variableName).Key; Contracts.Assert(_variableNames.Contains(columnName)); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index de0db6f3a9..43e208ad3c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -329,12 +329,12 @@ public void EntryPointInputBuilderOptionals() ib1.TrySetValue("WeightColumn", "OtherWeight"); Assert.True(instance.WeightColumn.IsExplicit); - Assert.True(string.Compare(instance.WeightColumn.Value, "OtherWeight") == 0); + Assert.Equal("OtherWeight", instance.WeightColumn.Value); var tok = (JToken)JValue.CreateString("AnotherWeight"); ib1.TrySetValueJson("WeightColumn", tok); Assert.True(instance.WeightColumn.IsExplicit); - Assert.True(string.Compare(instance.WeightColumn.Value, "AnotherWeight") == 0); + Assert.Equal("AnotherWeight", instance.WeightColumn.Value); } [Fact]