diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs
index d7462bbf25..804cc362dd 100644
--- a/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs
+++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs
@@ -62,6 +62,9 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
var textLoader = context.Data.CreateTextLoader(typedLoaderOptions);
var dataView = textLoader.Load(path);
+ // Validate all columns specified in column info exist in inferred data view
+ ColumnInferenceValidationUtil.ValidateSpecifiedColumnsExist(columnInfo, dataView);
+
var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, columnInfo);
// start building result objects
diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceValidationUtil.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceValidationUtil.cs
new file mode 100644
index 0000000000..b941779137
--- /dev/null
+++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceValidationUtil.cs
@@ -0,0 +1,28 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class ColumnInferenceValidationUtil
+ {
+ ///
+ /// Validate all columns specified in column info exist in inferred data view.
+ ///
+ public static void ValidateSpecifiedColumnsExist(ColumnInformation columnInfo,
+ IDataView dataView)
+ {
+ var columnNames = ColumnInformationUtil.GetColumnNames(columnInfo);
+ foreach (var columnName in columnNames)
+ {
+ if (dataView.Schema.GetColumnOrNull(columnName) == null)
+ {
+ throw new ArgumentException($"Specified column {columnName} " +
+ $"is not found in the dataset.");
+ }
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs
index c567730f6f..7a9f6c14b6 100644
--- a/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs
+++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs
@@ -89,5 +89,38 @@ public static ColumnInformation BuildColumnInfo(IEnumerable c
{
return BuildColumnInfo(columns.Select(c => (c.Name, c.Purpose)));
}
+
+ ///
+ /// Get all column names that are in .
+ ///
+ /// Column information.
+ public static IEnumerable GetColumnNames(ColumnInformation columnInformation)
+ {
+ var columnNames = new List();
+ AddStringToListIfNotNull(columnNames, columnInformation.LabelColumnName);
+ AddStringToListIfNotNull(columnNames, columnInformation.ExampleWeightColumnName);
+ AddStringToListIfNotNull(columnNames, columnInformation.SamplingKeyColumnName);
+ AddStringsToListIfNotNull(columnNames, columnInformation.CategoricalColumnNames);
+ AddStringsToListIfNotNull(columnNames, columnInformation.IgnoredColumnNames);
+ AddStringsToListIfNotNull(columnNames, columnInformation.NumericColumnNames);
+ AddStringsToListIfNotNull(columnNames, columnInformation.TextColumnNames);
+ return columnNames;
+ }
+
+ private static void AddStringsToListIfNotNull(List list, IEnumerable strings)
+ {
+ foreach (var str in strings)
+ {
+ AddStringToListIfNotNull(list, str);
+ }
+ }
+
+ private static void AddStringToListIfNotNull(List list, string str)
+ {
+ if (str != null)
+ {
+ list.Add(str);
+ }
+ }
}
}
diff --git a/test/Microsoft.ML.AutoML.Tests/ColumnInferenceValidationUtilTests.cs b/test/Microsoft.ML.AutoML.Tests/ColumnInferenceValidationUtilTests.cs
new file mode 100644
index 0000000000..bae3feb1df
--- /dev/null
+++ b/test/Microsoft.ML.AutoML.Tests/ColumnInferenceValidationUtilTests.cs
@@ -0,0 +1,29 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.IO;
+using Microsoft.ML.Data;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Microsoft.ML.Auto.Test
+{
+ [TestClass]
+ public class ColumnInferenceValidationUtilTests
+ {
+ [TestMethod]
+ [ExpectedException(typeof(ArgumentException))]
+ public void ValidateColumnNotContainedInData()
+ {
+ var schemaBuilder = new DataViewSchema.Builder();
+ schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
+ schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
+ var schema = schemaBuilder.ToSchema();
+ var dataView = new EmptyDataView(new MLContext(), schema);
+ var columnInfo = new ColumnInformation();
+ columnInfo.CategoricalColumnNames.Add("Categorical");
+ ColumnInferenceValidationUtil.ValidateSpecifiedColumnsExist(columnInfo, dataView);
+ }
+ }
+}
diff --git a/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs b/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs
index a3631768da..d8b183dbfa 100644
--- a/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs
+++ b/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System.Linq;
using Microsoft.VisualStudio.TestTools.UnitTesting;
namespace Microsoft.ML.Auto.Test
@@ -32,5 +33,25 @@ public void GetColumnPurpose()
Assert.AreEqual(ColumnPurpose.Ignore, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Ignored"));
Assert.AreEqual(null, ColumnInformationUtil.GetColumnPurpose(columnInfo, "NonExistent"));
}
+
+ [TestMethod]
+ public void GetColumnNamesTest()
+ {
+ var columnInfo = new ColumnInformation()
+ {
+ LabelColumnName = "Label",
+ SamplingKeyColumnName = "SamplingKey",
+ };
+ columnInfo.CategoricalColumnNames.Add("Cat1");
+ columnInfo.CategoricalColumnNames.Add("Cat2");
+ columnInfo.NumericColumnNames.Add("Num");
+ var columnNames = ColumnInformationUtil.GetColumnNames(columnInfo);
+ Assert.AreEqual(5, columnNames.Count());
+ Assert.IsTrue(columnNames.Contains("Label"));
+ Assert.IsTrue(columnNames.Contains("SamplingKey"));
+ Assert.IsTrue(columnNames.Contains("Cat1"));
+ Assert.IsTrue(columnNames.Contains("Cat2"));
+ Assert.IsTrue(columnNames.Contains("Num"));
+ }
}
-}
+}
\ No newline at end of file