Skip to content

Isolate ONNX implementations in separate DLL #462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InferenceTesti
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Data", "src\Microsoft.ML.Data\Microsoft.ML.Data.csproj", "{AD92D96B-0E96-4F22-8DCE-892E13B1F282}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.UniversalModelFormat", "src\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Onnx", "src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StandardLearners", "src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj", "{707BB22C-7E5F-497A-8C2F-74578F675705}"
EndProject
Expand Down
13 changes: 13 additions & 0 deletions pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<PackageDescription>ML.NET component for exporting ONNX Models</PackageDescription>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
</ItemGroup>

</Project>
5 changes: 5 additions & 0 deletions pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<Project DefaultTargets="Pack">

<Import Project="Microsoft.ML.Onnx.nupkgproj" />

</Project>
1 change: 0 additions & 1 deletion pkg/Microsoft.ML/Microsoft.ML.nupkgproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="$(SystemReflectionEmitLightweightPackageVersion)" />
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
<ProjectReference Include="..\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
<ProjectReference Include="..\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
<ProjectReference Include="..\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
<ProjectReference Include="..\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj" />
<ProjectReference Include="..\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
<ProjectReference Include="..\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
<ProjectReference Include="..\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj" />
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
<ProjectReference Include="..\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
<ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
<ProjectReference Include="..\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj" />

<NativeAssemblyReference Include="FastTreeNative" />
<NativeAssemblyReference Include="CpuMathNative" />
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private static VersionInfo GetVersionInfo()
/// Returns the underlying data view of the composite loader.
/// This can be used to programmatically explore the chain of transforms that's inside the composite loader.
/// </summary>
internal IDataView View { get; }
public IDataView View { get; }

/// <summary>
/// Creates a loader according to the specified <paramref name="args"/>.
Expand Down
1 change: 0 additions & 1 deletion src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
<ProjectReference Include="..\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj" />
</ItemGroup>

</Project>
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public interface ICanSaveOnnx
}

/// <summary>
/// This data model component is savable as Onnx.
/// This data model component is savable as ONNX.
/// </summary>
public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
{
Expand Down
272 changes: 64 additions & 208 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,245 +2,101 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Runtime.UniversalModelFormat.Onnx;
using Microsoft.ML.Runtime.Data;

namespace Microsoft.ML.Runtime.Model.Onnx
{
/// <summary>
/// A context for defining a ONNX output.
/// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This
/// same context object is iteratively given to exportable components via the <see cref="ICanSaveOnnx"/> interface
/// and subinterfaces, that attempt to express their operations as ONNX nodes, if they can. At the point that it is
/// given to a component, all other components up to that component have already attempted to express themselves in
/// this context, with their outputs possibly available in the ONNX graph.
/// </summary>
public sealed class OnnxContext
public abstract class OnnxContext
{
private readonly List<NodeProto> _nodes;
private readonly List<OnnxUtils.ModelArgs> _inputs;
private readonly List<OnnxUtils.ModelArgs> _intermediateValues;
private readonly List<OnnxUtils.ModelArgs> _outputs;
private readonly Dictionary<string, string> _columnNameMap;
private readonly HashSet<string> _variableMap;
private readonly HashSet<string> _nodeNames;
private readonly string _name;
private readonly string _producerName;
private readonly IHost _host;
private readonly string _domain;
private readonly string _producerVersion;
private readonly long _modelVersion;

public OnnxContext(IHostEnvironment env, string name, string producerName,
string producerVersion, long modelVersion, string domain)
{
Contracts.CheckValue(env, nameof(env));
Contracts.CheckValue(name, nameof(name));
Contracts.CheckValue(name, nameof(domain));

_host = env.Register(nameof(OnnxContext));
_nodes = new List<NodeProto>();
_intermediateValues = new List<OnnxUtils.ModelArgs>();
_inputs = new List<OnnxUtils.ModelArgs>();
_outputs = new List<OnnxUtils.ModelArgs>();
_columnNameMap = new Dictionary<string, string>();
_variableMap = new HashSet<string>();
_nodeNames = new HashSet<string>();
_name = name;
_producerName = producerName;
_producerVersion = producerVersion;
_modelVersion = modelVersion;
_domain = domain;
}

public bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);

/// <summary>
/// Stops tracking a column. If removeVariable is true then it also removes the
/// variable associated with it, this is useful in the event where an output variable is
/// created before realizing the transform cannot actually save as ONNX.
/// </summary>
/// <param name="colName">IDataView column name to stop tracking</param>
/// <param name="removeVariable">Remove associated ONNX variable at the time.</param>
public void RemoveColumn(string colName, bool removeVariable)
{

if (removeVariable)
{
foreach (var val in _intermediateValues)
{
if (val.Name == _columnNameMap[colName])
{
_intermediateValues.Remove(val);
break;
}
}
}

if (_columnNameMap.ContainsKey(colName))
_columnNameMap.Remove(colName);
}

/// <summary>
/// Removes an ONNX variable. If removeColumn is true then it also removes the
/// IDataView column associated with it.
/// </summary>
/// <param name="variableName">ONNX variable to remove.</param>
/// <param name="removeColumn">IDataView column to stop tracking</param>
public void RemoveVariable(string variableName, bool removeColumn)
{
_host.Assert(_columnNameMap.ContainsValue(variableName));
if (removeColumn)
{
foreach (var val in _intermediateValues)
{
if (val.Name == variableName)
{
_intermediateValues.Remove(val);
break;
}
}
}

string columnName = _columnNameMap.Single(kvp => string.Compare(kvp.Value, variableName) == 0).Key;

Contracts.Assert(_variableMap.Contains(columnName));

_columnNameMap.Remove(columnName);
_variableMap.Remove(columnName);
}

/// <summary>
/// Generates a unique name for the node based on a prefix.
/// </summary>
public string GetNodeName(string prefix)
{
_host.CheckValue(prefix, nameof(prefix));
return GetUniqueName(prefix, c => _nodeNames.Contains(c));
}
/// <param name="prefix">The prefix for the node</param>
/// <returns>A name that has not yet been returned from this function, starting with <paramref name="prefix"/></returns>
public abstract string GetNodeName(string prefix);

/// <summary>
/// Adds a node to the node list of the graph.
/// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can
/// safely call <see cref="GetVariableName(string)"/>.
/// </summary>
/// <param name="node"></param>
public void AddNode(NodeProto node)
{
_host.CheckValue(node, nameof(node));
_host.Assert(!_nodeNames.Contains(node.Name));

_nodeNames.Add(node.Name);
_nodes.Add(node);
}
/// <param name="colName">The data view column name</param>
/// <returns>Whether the column is mapped in this context</returns>
public abstract bool ContainsColumn(string colName);

/// <summary>
/// Generates a unique name based on a prefix.
/// Stops tracking a column.
/// </summary>
private string GetUniqueName(string prefix, Func<string, bool> pred)
{
_host.CheckValue(prefix, nameof(prefix));
_host.CheckValue(pred, nameof(pred));

if (!pred(prefix))
return prefix;

int count = 0;
while (pred(prefix + count++)) ;
return prefix + --count;
}
/// <param name="colName">Column name to stop tracking</param>
/// <param name="removeVariable">Remove associated ONNX variable. This is useful in the event where an output
/// variable is created through <see cref="AddIntermediateVariable(ColumnType, string, bool)"/>before realizing
/// the transform cannot actually save as ONNX.</param>
public abstract void RemoveColumn(string colName, bool removeVariable = false);

/// <summary>
/// Retrieves the variable name that maps to the IDataView column name at a
/// given point in the pipeline execution.
/// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the <see
/// cref="IDataView"/> column associated with it.
/// </summary>
/// <returns>Column Name mapping.</returns>
public string GetVariableName(string colName)
{
_host.CheckValue(colName, nameof(colName));
_host.Assert(_columnNameMap.ContainsKey(colName));

return _columnNameMap[colName];
}

/// <summary>
/// Retrieves the variable name that maps to the IDataView column name at a
/// given point in the pipeline execution.
/// </summary>
/// <returns>Column Name mapping.</returns>
public string TryGetVariableName(string colName)
{
if (_columnNameMap.ContainsKey(colName))
return GetVariableName(colName);

return null;
}

/// <summary>
/// Generates a unique column name based on the IDataView column name if
/// there is a collision between names in the pipeline at any point.
/// </summary>
/// <param name="colName">IDataView column name.</param>
/// <returns>Unique variable name.</returns>
private string AddVariable(string colName)
{
_host.CheckValue(colName, nameof(colName));

if (!_columnNameMap.ContainsKey(colName))
_columnNameMap.Add(colName, colName);
else
_columnNameMap[colName] = GetUniqueName(colName, s => _variableMap.Contains(s));

_variableMap.Add(_columnNameMap[colName]);
return _columnNameMap[colName];
}
/// <param name="variableName">ONNX variable to remove. Note that this is an ONNX variable name, not an <see
/// cref="IDataView"/> column name</param>
/// <param name="removeColumn">IDataView column to stop tracking</param>
public abstract void RemoveVariable(string variableName, bool removeColumn);

/// <summary>
/// Adds an intermediate column to the list.
/// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding
/// <see cref="IDataView"/>'s column names will map to a variable in the ONNX graph if the intermediate steps
/// used to calculate that value are things we knew how to save as ONNX. Retrieves the variable name that maps
/// to the <see cref="IDataView"/> column name at a given point in the pipeline execution. Callers should
/// probably confirm with <see cref="ContainsColumn(string)"/> whether a mapping for that data view column
/// already exists.
/// </summary>
public string AddIntermediateVariable(ColumnType type, string colName, bool skip = false)
{

colName = AddVariable(colName);

//Let the runtime figure the shape.
if (!skip)
{
_host.CheckValue(type, nameof(type));

_intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName));
}

return colName;
}
/// <param name="colName">The data view column name</param>
/// <returns>The ONNX variable name corresponding to that data view column</returns>
public abstract string GetVariableName(string colName);

/// <summary>
/// Adds an output variable to the list.
/// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and
/// returns that newly allocated name.
/// </summary>
public string AddOutputVariable(ColumnType type, string colName, List<long> dim = null)
{
_host.CheckValue(type, nameof(type));

if (!ContainsColumn(colName))
AddVariable(colName);

colName = GetVariableName(colName);
_outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim));
return colName;
}
/// <param name="type">The data view type associated with this column name</param>
/// <param name="colName">The data view column name</param>
/// <param name="skip">Whether we should skip the process of establishing the mapping from data view column to
/// ONNX variable name.</param>
/// <returns>The returned value is the name of the variable corresponding </returns>
public abstract string AddIntermediateVariable(ColumnType type, string colName, bool skip = false);

/// <summary>
/// Adds an input variable to the list.
/// Creates an ONNX node
/// </summary>
public void AddInputVariable(ColumnType type, string colName)
{
_host.CheckValue(type, nameof(type));
_host.CheckValue(colName, nameof(colName));

colName = AddVariable(colName);
_inputs.Add(OnnxUtils.GetModelArgs(type, colName));
}
/// <param name="opType">The name of the ONNX operator to apply</param>
/// <param name="inputs">The names of the variables as inputs</param>
/// <param name="outputs">The names of the variables to create as outputs,
/// which ought to have been something returned from <see cref="AddIntermediateVariable(ColumnType, string, bool)"/></param>
/// <param name="name">The name of the operator, which ought to be something returned from <see cref="GetNodeName(string)"/></param>
/// <param name="domain">The domain of the ONNX operator, if non-default</param>
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
public abstract OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
IEnumerable<string> outputs, string name, string domain = null);

/// <summary>
/// Makes the ONNX model based on the context.
/// Convenience alternative to <see cref="CreateNode(string, IEnumerable{string}, IEnumerable{string}, string, string)"/>
/// for the case where there is exactly one input and output.
/// </summary>
public ModelProto MakeModel()
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
/// <param name="opType">The name of the ONNX operator to apply</param>
/// <param name="input">The name of the variable as input</param>
/// <param name="output">The name of the variable as output,
/// which ought to have been something returned from <see cref="OnnxContext.AddIntermediateVariable(ColumnType, string, bool)"/></param>
/// <param name="name">The name of the operator, which ought to be something returned from <see cref="OnnxContext.GetNodeName(string)"/></param>
/// <param name="domain">The domain of the ONNX operator, if non-default</param>
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
=> CreateNode(opType, new[] { input }, new[] { output }, name, domain);
}
}
Loading