Skip to content

Commit

Permalink
[AutoML] Generated project - FastTree nuget package inclusion dynamic…
Browse files Browse the repository at this point in the history
…ally (#3567)

* added support for fast tree nuget pack inclusion in generated project

* fix testcase

* changed the tool name in telemetry message

* dummy commit

* remove space

* dummy commit to trigger build
  • Loading branch information
srsaggam authored Apr 24, 2019
1 parent 7191ebe commit c9cb3cc
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 27 deletions.
31 changes: 19 additions & 12 deletions src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ internal class CodeGenerator : IProjectGenerator
private readonly ColumnInferenceResults columnInferenceResult;
private readonly HashSet<string> LightGBMTrainers = new HashSet<string>() { TrainerName.LightGbmBinary.ToString(), TrainerName.LightGbmMulti.ToString(), TrainerName.LightGbmRegression.ToString() };
private readonly HashSet<string> mklComponentsTrainers = new HashSet<string>() { TrainerName.OlsRegression.ToString(), TrainerName.SymbolicSgdLogisticRegressionBinary.ToString() };
private readonly HashSet<string> FastTreeTrainers = new HashSet<string>() { TrainerName.FastForestBinary.ToString(), TrainerName.FastForestRegression.ToString(), TrainerName.FastTreeBinary.ToString(), TrainerName.FastTreeRegression.ToString(), TrainerName.FastTreeTweedieRegression.ToString() };


internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResult, CodeGeneratorSettings settings)
{
Expand All @@ -36,15 +38,16 @@ public void GenerateOutput()

bool includeLightGbmPackage = false;
bool includeMklComponentsPackage = false;
SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage);
bool includeFastTreeePackage = false;
SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage, ref includeFastTreeePackage);

// Get Namespace
var namespaceValue = Utils.Normalize(settings.OutputName);
var labelType = columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == columnInferenceResult.ColumnInformation.LabelColumnName).First().DataKind;
Type labelTypeCsharp = Utils.GetCSharpType(labelType);

// Generate Model Project
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage);

// Write files to disk.
var modelprojectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.Model");
Expand All @@ -56,7 +59,7 @@ public void GenerateOutput()
Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir);

// Generate ConsoleApp Project
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage);

// Write files to disk.
var consoleAppProjectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.ConsoleApp");
Expand All @@ -74,7 +77,7 @@ public void GenerateOutput()
Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath);
}

private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage)
private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage, ref bool includeFastTreePackage)
{
foreach (var node in trainerNodes)
{
Expand All @@ -92,15 +95,19 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
{
includeMklComponentsPackage = true;
}
else if (FastTreeTrainers.Contains(currentNode.Name))
{
includeFastTreePackage = true;
}
}
}

internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
{
var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue);
predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent);

var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage);
var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage);

var transformsAndTrainers = GenerateTransformsAndTrainers();
var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name);
Expand All @@ -109,14 +116,14 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent);
}

internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
{
var classLabels = this.GenerateClassLabels();
var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels);
observationCSFileContent = Utils.FormatCode(observationCSFileContent);
var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue);
predictionCSFileContent = Utils.FormatCode(predictionCSFileContent);
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage);
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage);
return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent);
}

Expand Down Expand Up @@ -248,9 +255,9 @@ internal IList<string> GenerateClassLabels()
}

#region Model project
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage)
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
{
ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage };
ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeFastTreePackage = includeFastTreePackage };
return modelProject.TransformText();
}

Expand All @@ -268,9 +275,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList<str
#endregion

#region Predict Project
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage, bool includeMklComponentsPackage)
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
{
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGbmPackage };
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGbmPackage, IncludeFastTreePackage = includeFastTreePackage };
return predictProjectFileContent.TransformText();
}

Expand Down
1 change: 0 additions & 1 deletion src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using Microsoft.ML.Auto;
using static Microsoft.ML.CLI.CodeGenerator.CSharp.TrainerGenerators;

Expand Down
2 changes: 0 additions & 2 deletions src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,6 @@ public override string[] GenerateUsings()
{
return binaryTrainerUsings;
}

}

}
}
6 changes: 3 additions & 3 deletions src/mlnet/Telemetry/MlTelemetry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class MlTelemetry

public void SetCommandAndParameters(string command, IEnumerable<string> parameters)
{
if(parameters != null)
if (parameters != null)
{
_parameters.AddRange(parameters);
}
Expand All @@ -28,7 +28,7 @@ public void LogAutoTrainMlCommand(string dataFileName, string task, long dataFil
{
CheckFistTimeUse();

if(!_enabled)
if (!_enabled)
{
return;
}
Expand Down Expand Up @@ -71,7 +71,7 @@ private void CheckFistTimeUse()
@"Welcome to the ML.NET CLI!
--------------------------
Learn more about ML.NET CLI: https://aka.ms/mlnet-cli
Use 'dotnet ml --help' to see available commands or visit: https://aka.ms/mlnet-cli-docs
Use 'mlnet --help' to see available commands or visit: https://aka.ms/mlnet-cli-docs
Telemetry
---------
Expand Down
17 changes: 16 additions & 1 deletion src/mlnet/Templates/Console/ModelProject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ public virtual string TransformText()
#line 18 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
}

#line default
#line hidden

#line 19 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
if(IncludeFastTreePackage){

#line default
#line hidden
this.Write(" <PackageReference Include=\"Microsoft.ML.FastTree\" Version=\"1.0.0-preview\" />\r" +
"\n");

#line 21 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
}

#line default
#line hidden
this.Write(" </ItemGroup>\r\n\r\n <ItemGroup>\r\n <None Update=\"MLModel.zip\">\r\n <CopyToOu" +
Expand All @@ -65,10 +79,11 @@ public virtual string TransformText()
return this.GenerationEnvironment.ToString();
}

#line 28 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
#line 31 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"

public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}


#line default
Expand Down
4 changes: 4 additions & 0 deletions src/mlnet/Templates/Console/ModelProject.tt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
<#}#>
<# if(IncludeMklComponentsPackage){ #>
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
<#}#>
<# if(IncludeFastTreePackage){ #>
<PackageReference Include="Microsoft.ML.FastTree" Version="1.0.0-preview" />
<#}#>
</ItemGroup>

Expand All @@ -28,4 +31,5 @@
<#+
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}
#>
5 changes: 5 additions & 0 deletions src/mlnet/Templates/Console/PredictProject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ public virtual string TransformText()
if(IncludeMklComponentsPackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.Mkl.Components\" Version=\"1.0.0-previe" +
"w\" />\r\n");
}
if(IncludeFastTreePackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.FastTree\" Version=\"1.0.0-preview\" />\r" +
"\n");
}
this.Write(" </ItemGroup>\r\n <ItemGroup>\r\n <ProjectReference Include=\"..\\");
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
Expand All @@ -49,6 +53,7 @@ public virtual string TransformText()
public string Namespace {get;set;}
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}

}
#region Base class
Expand Down
4 changes: 4 additions & 0 deletions src/mlnet/Templates/Console/PredictProject.tt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
<#}#>
<# if(IncludeMklComponentsPackage){ #>
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
<#}#>
<# if(IncludeFastTreePackage){ #>
<PackageReference Include="Microsoft.ML.FastTree" Version="1.0.0-preview" />
<#}#>
</ItemGroup>
<ItemGroup>
Expand All @@ -27,4 +30,5 @@
public string Namespace {get;set;}
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}
#>
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<PackageReference Include="Microsoft.ML" Version="1.0.0-preview" />
<PackageReference Include="Microsoft.ML.LightGBM" Version="1.0.0-preview" />
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
<PackageReference Include="Microsoft.ML.FastTree" Version="1.0.0-preview" />
</ItemGroup>

<ItemGroup>
Expand Down
16 changes: 8 additions & 8 deletions test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void ConsoleAppModelBuilderCSFileContentOvaTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.modelBuilderCSFileContent);
}
Expand All @@ -65,7 +65,7 @@ public void ConsoleAppModelBuilderCSFileContentBinaryTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.modelBuilderCSFileContent);
}
Expand All @@ -88,7 +88,7 @@ public void ConsoleAppModelBuilderCSFileContentRegressionTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.modelBuilderCSFileContent);
}
Expand All @@ -111,7 +111,7 @@ public void ModelProjectFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, true);

Approvals.Verify(result.ModelProjectFileContent);
}
Expand All @@ -134,7 +134,7 @@ public void ObservationCSFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.ObservationCSFileContent);
}
Expand All @@ -158,7 +158,7 @@ public void PredictionCSFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.PredictionCSFileContent);
}
Expand All @@ -181,7 +181,7 @@ public void ConsoleAppProgramCSFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.ConsoleAppProgramCSFileContent);
}
Expand All @@ -204,7 +204,7 @@ public void ConsoleAppProjectFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.ConsoleAppProjectFileContent);
}
Expand Down

0 comments on commit c9cb3cc

Please sign in to comment.