diff --git a/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs index 34c211969a..c9ba930c21 100644 --- a/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs +++ b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs @@ -54,8 +54,8 @@ public void GenerateOutput() var dataModelsDir = Path.Combine(modelprojectDir, "DataModels"); var modelProjectName = $"{settings.OutputName}.Model.csproj"; - Utils.WriteOutputToFiles(modelProjectContents.ObservationCSFileContent, "SampleObservation.cs", dataModelsDir); - Utils.WriteOutputToFiles(modelProjectContents.PredictionCSFileContent, "SamplePrediction.cs", dataModelsDir); + Utils.WriteOutputToFiles(modelProjectContents.ModelInputCSFileContent, "ModelInput.cs", dataModelsDir); + Utils.WriteOutputToFiles(modelProjectContents.ModelOutputCSFileContent, "ModelOutput.cs", dataModelsDir); Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir); // Generate ConsoleApp Project @@ -116,15 +116,15 @@ private void SetRequiredNugetPackages(IEnumerable trainerNodes, re return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent); } - internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage) + internal (string ModelInputCSFileContent, string ModelOutputCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage) { var classLabels = this.GenerateClassLabels(); - var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels); - observationCSFileContent = Utils.FormatCode(observationCSFileContent); - var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue); - predictionCSFileContent = Utils.FormatCode(predictionCSFileContent); + var modelInputCSFileContent = GenerateModelInputCSFileContent(namespaceValue, classLabels); + modelInputCSFileContent = Utils.FormatCode(modelInputCSFileContent); + var modelOutputCSFileContent = GenerateModelOutputCSFileContent(labelTypeCsharp.Name, namespaceValue); + modelOutputCSFileContent = Utils.FormatCode(modelOutputCSFileContent); var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage); - return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent); + return (modelInputCSFileContent, modelOutputCSFileContent, modelProjectFileContent); } internal (string Usings, string TrainerMethod, List PreTrainerTransforms, List PostTrainerTransforms) GenerateTransformsAndTrainers() @@ -261,16 +261,16 @@ private static string GenerateModelProjectFileContent(bool includeLightGbmPackag return modelProject.TransformText(); } - private string GeneratePredictionCSFileContent(string predictionLabelType, string namespaceValue) + private string GenerateModelOutputCSFileContent(string predictionLabelType, string namespaceValue) { - PredictionClass predictionClass = new PredictionClass() { TaskType = settings.MlTask.ToString(), PredictionLabelType = predictionLabelType, Namespace = namespaceValue }; - return predictionClass.TransformText(); + ModelOutputClass modelOutputClass = new ModelOutputClass() { TaskType = settings.MlTask.ToString(), PredictionLabelType = predictionLabelType, Namespace = namespaceValue }; + return modelOutputClass.TransformText(); } - private string GenerateObservationCSFileContent(string namespaceValue, IList classLabels) + private string GenerateModelInputCSFileContent(string namespaceValue, IList classLabels) { - ObservationClass observationClass = new ObservationClass() { Namespace = namespaceValue, ClassLabels = classLabels }; - return observationClass.TransformText(); + ModelInputClass modelInputClass = new ModelInputClass() { Namespace = namespaceValue, ClassLabels = classLabels }; + return modelInputClass.TransformText(); } #endregion diff --git a/src/mlnet/Templates/Console/ModelBuilder.cs b/src/mlnet/Templates/Console/ModelBuilder.cs index 53447e4e22..feb2a993f4 100644 --- a/src/mlnet/Templates/Console/ModelBuilder.cs +++ b/src/mlnet/Templates/Console/ModelBuilder.cs @@ -65,7 +65,7 @@ public virtual string TransformText() public static void CreateModel() { // Load Data - IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( path: TRAIN_DATA_FILEPATH, hasHeader : "); this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant())); @@ -77,9 +77,9 @@ public static void CreateModel() this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant())); this.Write(");\r\n\r\n"); if(!string.IsNullOrEmpty(TestPath)){ - this.Write(" IDataView testDataView = mlContext.Data.LoadFromTextFile(\r\n path: TEST_DATA_FILEPATH,\r\n" + - " hasHeader : "); + this.Write(" IDataView testDataView = mlContext.Data.LoadFromTextFile(" + + "\r\n path: TEST_DATA_FILEPATH,\r\n " + + " hasHeader : "); this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant())); this.Write(",\r\n separatorChar : \'"); this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString()))); diff --git a/src/mlnet/Templates/Console/ModelBuilder.tt b/src/mlnet/Templates/Console/ModelBuilder.tt index cf0c7346fc..cd9e75d39b 100644 --- a/src/mlnet/Templates/Console/ModelBuilder.tt +++ b/src/mlnet/Templates/Console/ModelBuilder.tt @@ -36,7 +36,7 @@ namespace <#= Namespace #>.ConsoleApp public static void CreateModel() { // Load Data - IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( path: TRAIN_DATA_FILEPATH, hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>, separatorChar : '<#= Regex.Escape(Separator.ToString()) #>', @@ -44,7 +44,7 @@ namespace <#= Namespace #>.ConsoleApp allowSparse: <#= AllowSparse.ToString().ToLowerInvariant() #>); <# if(!string.IsNullOrEmpty(TestPath)){ #> - IDataView testDataView = mlContext.Data.LoadFromTextFile( + IDataView testDataView = mlContext.Data.LoadFromTextFile( path: TEST_DATA_FILEPATH, hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>, separatorChar : '<#= Regex.Escape(Separator.ToString()) #>', diff --git a/src/mlnet/Templates/Console/ObservationClass.cs b/src/mlnet/Templates/Console/ModelInputClass.cs similarity index 98% rename from src/mlnet/Templates/Console/ObservationClass.cs rename to src/mlnet/Templates/Console/ModelInputClass.cs index 411f62d016..6205ffc683 100644 --- a/src/mlnet/Templates/Console/ObservationClass.cs +++ b/src/mlnet/Templates/Console/ModelInputClass.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.CLI.Templates.Console /// Class to produce the template output /// [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] - public partial class ObservationClass : ObservationClassBase + public partial class ModelInputClass : ModelInputClassBase { /// /// Create the template output @@ -35,7 +35,7 @@ public virtual string TransformText() namespace "); this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); - this.Write(".Model.DataModels\r\n{\r\n public class SampleObservation\r\n {\r\n"); + this.Write(".Model.DataModels\r\n{\r\n public class ModelInput\r\n {\r\n"); foreach(var label in ClassLabels){ this.Write(" "); this.Write(this.ToStringHelper.ToStringWithCulture(label)); @@ -54,7 +54,7 @@ namespace "); /// Base class for this transformation /// [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] - public class ObservationClassBase + public class ModelInputClassBase { #region Fields private global::System.Text.StringBuilder generationEnvironmentField; diff --git a/src/mlnet/Templates/Console/ObservationClass.tt b/src/mlnet/Templates/Console/ModelInputClass.tt similarity index 96% rename from src/mlnet/Templates/Console/ObservationClass.tt rename to src/mlnet/Templates/Console/ModelInputClass.tt index 296e9ef6c4..94eb6b76f6 100644 --- a/src/mlnet/Templates/Console/ObservationClass.tt +++ b/src/mlnet/Templates/Console/ModelInputClass.tt @@ -13,7 +13,7 @@ using Microsoft.ML.Data; namespace <#= Namespace #>.Model.DataModels { - public class SampleObservation + public class ModelInput { <#foreach(var label in ClassLabels){#> <#=label#> diff --git a/src/mlnet/Templates/Console/PredictionClass.cs b/src/mlnet/Templates/Console/ModelOutputClass.cs similarity index 98% rename from src/mlnet/Templates/Console/PredictionClass.cs rename to src/mlnet/Templates/Console/ModelOutputClass.cs index 8b571f93e5..767ae5d0da 100644 --- a/src/mlnet/Templates/Console/PredictionClass.cs +++ b/src/mlnet/Templates/Console/ModelOutputClass.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.CLI.Templates.Console /// Class to produce the template output /// [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] - public partial class PredictionClass : PredictionClassBase + public partial class ModelOutputClass : ModelOutputClassBase { /// /// Create the template output @@ -36,7 +36,7 @@ public virtual string TransformText() namespace "); this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); - this.Write(".Model.DataModels\r\n{\r\n public class SamplePrediction\r\n {\r\n"); + this.Write(".Model.DataModels\r\n{\r\n public class ModelOutput\r\n {\r\n"); if("BinaryClassification".Equals(TaskType)){ this.Write(" // ColumnName attribute is used to change the column name from\r\n /" + "/ its default value, which is the name of the field.\r\n [ColumnName(\"Predi" + @@ -67,7 +67,7 @@ namespace "); /// Base class for this transformation /// [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] - public class PredictionClassBase + public class ModelOutputClassBase { #region Fields private global::System.Text.StringBuilder generationEnvironmentField; diff --git a/src/mlnet/Templates/Console/PredictionClass.tt b/src/mlnet/Templates/Console/ModelOutputClass.tt similarity index 97% rename from src/mlnet/Templates/Console/PredictionClass.tt rename to src/mlnet/Templates/Console/ModelOutputClass.tt index 5cea1d14fd..e00f998627 100644 --- a/src/mlnet/Templates/Console/PredictionClass.tt +++ b/src/mlnet/Templates/Console/ModelOutputClass.tt @@ -14,7 +14,7 @@ using Microsoft.ML.Data; namespace <#= Namespace #>.Model.DataModels { - public class SamplePrediction + public class ModelOutput { <#if("BinaryClassification".Equals(TaskType)){ #> // ColumnName attribute is used to change the column name from diff --git a/src/mlnet/Templates/Console/PredictProgram.cs b/src/mlnet/Templates/Console/PredictProgram.cs index 40bb69d39b..22d7d0ebee 100644 --- a/src/mlnet/Templates/Console/PredictProgram.cs +++ b/src/mlnet/Templates/Console/PredictProgram.cs @@ -62,13 +62,13 @@ static void Main(string[] args) //ModelBuilder.CreateModel(); ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema); - var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); + var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); // Create sample data to do a single prediction with it - SampleObservation sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH); + ModelInput sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH); // Try a single prediction - SamplePrediction predictionResult = predEngine.Predict(sampleData); + ModelOutput predictionResult = predEngine.Predict(sampleData); "); if("BinaryClassification".Equals(TaskType)){ @@ -92,10 +92,10 @@ static void Main(string[] args) // Method to load single row of data to try a single prediction // You can change this code and create your own sample data here (Hardcoded or from any source) - private static SampleObservation CreateSingleDataSample(MLContext mlContext, string dataFilePath) + private static ModelInput CreateSingleDataSample(MLContext mlContext, string dataFilePath) { // Read dataset to get a single row for trying a prediction - IDataView dataView = mlContext.Data.LoadFromTextFile( + IDataView dataView = mlContext.Data.LoadFromTextFile( path: dataFilePath, hasHeader : "); this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant())); @@ -107,8 +107,8 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant())); this.Write(@"); - // Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file. - SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable(dataView, false) + // Here (ModelInput object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file. + ModelInput sampleForPrediction = mlContext.Data.CreateEnumerable(dataView, false) .First(); return sampleForPrediction; } diff --git a/src/mlnet/Templates/Console/PredictProgram.tt b/src/mlnet/Templates/Console/PredictProgram.tt index 96b07606e4..8a2c3b502b 100644 --- a/src/mlnet/Templates/Console/PredictProgram.tt +++ b/src/mlnet/Templates/Console/PredictProgram.tt @@ -40,13 +40,13 @@ namespace <#= Namespace #>.ConsoleApp //ModelBuilder.CreateModel(); ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema); - var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); + var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); // Create sample data to do a single prediction with it - SampleObservation sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH); + ModelInput sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH); // Try a single prediction - SamplePrediction predictionResult = predEngine.Predict(sampleData); + ModelOutput predictionResult = predEngine.Predict(sampleData); <#if("BinaryClassification".Equals(TaskType)){ #> Console.WriteLine($"Single Prediction --> Actual value: {sampleData.<#= Utils.Normalize(LabelName) #>} | Predicted value: {predictionResult.Prediction}"); @@ -62,18 +62,18 @@ namespace <#= Namespace #>.ConsoleApp // Method to load single row of data to try a single prediction // You can change this code and create your own sample data here (Hardcoded or from any source) - private static SampleObservation CreateSingleDataSample(MLContext mlContext, string dataFilePath) + private static ModelInput CreateSingleDataSample(MLContext mlContext, string dataFilePath) { // Read dataset to get a single row for trying a prediction - IDataView dataView = mlContext.Data.LoadFromTextFile( + IDataView dataView = mlContext.Data.LoadFromTextFile( path: dataFilePath, hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>, separatorChar : '<#= Regex.Escape(Separator.ToString()) #>', allowQuoting : <#= AllowQuoting.ToString().ToLowerInvariant() #>, allowSparse: <#= AllowSparse.ToString().ToLowerInvariant() #>); - // Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file. - SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable(dataView, false) + // Here (ModelInput object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file. + ModelInput sampleForPrediction = mlContext.Data.CreateEnumerable(dataView, false) .First(); return sampleForPrediction; } diff --git a/src/mlnet/mlnet.csproj b/src/mlnet/mlnet.csproj index ce16cea0fd..b84dbefd81 100644 --- a/src/mlnet/mlnet.csproj +++ b/src/mlnet/mlnet.csproj @@ -48,15 +48,15 @@ True ModelProject.tt - + True True - ObservationClass.tt + ModelInputClass.tt - + True True - PredictionClass.tt + ModelOutputClass.tt True @@ -90,13 +90,13 @@ TextTemplatingFilePreprocessor ModelProject.cs - + TextTemplatingFilePreprocessor - ObservationClass.cs + ModelInputClass.cs - + TextTemplatingFilePreprocessor - PredictionClass.cs + ModelOutputClass.cs TextTemplatingFilePreprocessor diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt index 7a4f93beff..353b2f3bf2 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt @@ -27,14 +27,14 @@ namespace TestNamespace.ConsoleApp public static void CreateModel() { // Load Data - IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( path: TRAIN_DATA_FILEPATH, hasHeader: true, separatorChar: ',', allowQuoting: true, allowSparse: true); - IDataView testDataView = mlContext.Data.LoadFromTextFile( + IDataView testDataView = mlContext.Data.LoadFromTextFile( path: TEST_DATA_FILEPATH, hasHeader: true, separatorChar: ',', diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt index 7a8649b242..374a227d7e 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt @@ -27,14 +27,14 @@ namespace TestNamespace.ConsoleApp public static void CreateModel() { // Load Data - IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( path: TRAIN_DATA_FILEPATH, hasHeader: true, separatorChar: ',', allowQuoting: true, allowSparse: true); - IDataView testDataView = mlContext.Data.LoadFromTextFile( + IDataView testDataView = mlContext.Data.LoadFromTextFile( path: TEST_DATA_FILEPATH, hasHeader: true, separatorChar: ',', diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt index 7f71a0b7d5..0e2823da5d 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt @@ -27,14 +27,14 @@ namespace TestNamespace.ConsoleApp public static void CreateModel() { // Load Data - IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( path: TRAIN_DATA_FILEPATH, hasHeader: true, separatorChar: ',', allowQuoting: true, allowSparse: true); - IDataView testDataView = mlContext.Data.LoadFromTextFile( + IDataView testDataView = mlContext.Data.LoadFromTextFile( path: TEST_DATA_FILEPATH, hasHeader: true, separatorChar: ',', diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt index 431dbe2dbd..e7476eced0 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt @@ -29,13 +29,13 @@ namespace TestNamespace.ConsoleApp //ModelBuilder.CreateModel(); ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema); - var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); + var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); // Create sample data to do a single prediction with it - SampleObservation sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH); + ModelInput sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH); // Try a single prediction - SamplePrediction predictionResult = predEngine.Predict(sampleData); + ModelOutput predictionResult = predEngine.Predict(sampleData); Console.WriteLine($"Single Prediction --> Actual value: {sampleData.Label} | Predicted value: {predictionResult.Prediction}"); @@ -45,18 +45,18 @@ namespace TestNamespace.ConsoleApp // Method to load single row of data to try a single prediction // You can change this code and create your own sample data here (Hardcoded or from any source) - private static SampleObservation CreateSingleDataSample(MLContext mlContext, string dataFilePath) + private static ModelInput CreateSingleDataSample(MLContext mlContext, string dataFilePath) { // Read dataset to get a single row for trying a prediction - IDataView dataView = mlContext.Data.LoadFromTextFile( + IDataView dataView = mlContext.Data.LoadFromTextFile( path: dataFilePath, hasHeader: true, separatorChar: ',', allowQuoting: true, allowSparse: true); - // Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file. - SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable(dataView, false) + // Here (ModelInput object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file. + ModelInput sampleForPrediction = mlContext.Data.CreateEnumerable(dataView, false) .First(); return sampleForPrediction; } diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ObservationCSFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ObservationCSFileContentTest.approved.txt index 12f935ee2a..506f739200 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ObservationCSFileContentTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ObservationCSFileContentTest.approved.txt @@ -8,7 +8,7 @@ using Microsoft.ML.Data; namespace TestNamespace.Model.DataModels { - public class SampleObservation + public class ModelInput { [ColumnName("Label"), LoadColumn(0)] public bool Label { get; set; } diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.PredictionCSFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.PredictionCSFileContentTest.approved.txt index 4e0a7e5b9c..367b53fe5b 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.PredictionCSFileContentTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.PredictionCSFileContentTest.approved.txt @@ -9,7 +9,7 @@ using Microsoft.ML.Data; namespace TestNamespace.Model.DataModels { - public class SamplePrediction + public class ModelOutput { // ColumnName attribute is used to change the column name from // its default value, which is the name of the field. diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs index a604c5d086..b6d3c2ddfb 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs @@ -136,7 +136,7 @@ public void ObservationCSFileContentTest() }); var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, false); - Approvals.Verify(result.ObservationCSFileContent); + Approvals.Verify(result.ModelInputCSFileContent); } @@ -160,7 +160,7 @@ public void PredictionCSFileContentTest() }); var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, false); - Approvals.Verify(result.PredictionCSFileContent); + Approvals.Verify(result.ModelOutputCSFileContent); } [TestMethod]