diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs index d7462bbf25..804cc362dd 100644 --- a/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs +++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs @@ -62,6 +62,9 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path var textLoader = context.Data.CreateTextLoader(typedLoaderOptions); var dataView = textLoader.Load(path); + // Validate all columns specified in column info exist in inferred data view + ColumnInferenceValidationUtil.ValidateSpecifiedColumnsExist(columnInfo, dataView); + var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, columnInfo); // start building result objects diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceValidationUtil.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceValidationUtil.cs new file mode 100644 index 0000000000..b941779137 --- /dev/null +++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceValidationUtil.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Auto +{ + internal static class ColumnInferenceValidationUtil + { + /// + /// Validate all columns specified in column info exist in inferred data view. + /// + public static void ValidateSpecifiedColumnsExist(ColumnInformation columnInfo, + IDataView dataView) + { + var columnNames = ColumnInformationUtil.GetColumnNames(columnInfo); + foreach (var columnName in columnNames) + { + if (dataView.Schema.GetColumnOrNull(columnName) == null) + { + throw new ArgumentException($"Specified column {columnName} " + + $"is not found in the dataset."); + } + } + } + } +} diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs index c567730f6f..7a9f6c14b6 100644 --- a/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs +++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs @@ -89,5 +89,38 @@ public static ColumnInformation BuildColumnInfo(IEnumerable c { return BuildColumnInfo(columns.Select(c => (c.Name, c.Purpose))); } + + /// + /// Get all column names that are in . + /// + /// Column information. + public static IEnumerable GetColumnNames(ColumnInformation columnInformation) + { + var columnNames = new List(); + AddStringToListIfNotNull(columnNames, columnInformation.LabelColumnName); + AddStringToListIfNotNull(columnNames, columnInformation.ExampleWeightColumnName); + AddStringToListIfNotNull(columnNames, columnInformation.SamplingKeyColumnName); + AddStringsToListIfNotNull(columnNames, columnInformation.CategoricalColumnNames); + AddStringsToListIfNotNull(columnNames, columnInformation.IgnoredColumnNames); + AddStringsToListIfNotNull(columnNames, columnInformation.NumericColumnNames); + AddStringsToListIfNotNull(columnNames, columnInformation.TextColumnNames); + return columnNames; + } + + private static void AddStringsToListIfNotNull(List list, IEnumerable strings) + { + foreach (var str in strings) + { + AddStringToListIfNotNull(list, str); + } + } + + private static void AddStringToListIfNotNull(List list, string str) + { + if (str != null) + { + list.Add(str); + } + } } } diff --git a/test/Microsoft.ML.AutoML.Tests/ColumnInferenceValidationUtilTests.cs b/test/Microsoft.ML.AutoML.Tests/ColumnInferenceValidationUtilTests.cs new file mode 100644 index 0000000000..bae3feb1df --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/ColumnInferenceValidationUtilTests.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class ColumnInferenceValidationUtilTests + { + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateColumnNotContainedInData() + { + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single); + schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single); + var schema = schemaBuilder.ToSchema(); + var dataView = new EmptyDataView(new MLContext(), schema); + var columnInfo = new ColumnInformation(); + columnInfo.CategoricalColumnNames.Add("Categorical"); + ColumnInferenceValidationUtil.ValidateSpecifiedColumnsExist(columnInfo, dataView); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs b/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs index a3631768da..d8b183dbfa 100644 --- a/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Linq; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Microsoft.ML.Auto.Test @@ -32,5 +33,25 @@ public void GetColumnPurpose() Assert.AreEqual(ColumnPurpose.Ignore, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Ignored")); Assert.AreEqual(null, ColumnInformationUtil.GetColumnPurpose(columnInfo, "NonExistent")); } + + [TestMethod] + public void GetColumnNamesTest() + { + var columnInfo = new ColumnInformation() + { + LabelColumnName = "Label", + SamplingKeyColumnName = "SamplingKey", + }; + columnInfo.CategoricalColumnNames.Add("Cat1"); + columnInfo.CategoricalColumnNames.Add("Cat2"); + columnInfo.NumericColumnNames.Add("Num"); + var columnNames = ColumnInformationUtil.GetColumnNames(columnInfo); + Assert.AreEqual(5, columnNames.Count()); + Assert.IsTrue(columnNames.Contains("Label")); + Assert.IsTrue(columnNames.Contains("SamplingKey")); + Assert.IsTrue(columnNames.Contains("Cat1")); + Assert.IsTrue(columnNames.Contains("Cat2")); + Assert.IsTrue(columnNames.Contains("Num")); + } } -} +} \ No newline at end of file