diff --git a/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs b/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs new file mode 100644 index 0000000000..63b56e815a --- /dev/null +++ b/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs @@ -0,0 +1,60 @@ +using System; + +namespace Microsoft.ML.AutoML.Utils +{ + internal static class StringEditDistance + { + public static int GetLevenshteinDistance(string first, string second) + { + if (first is null) + { + throw new ArgumentNullException(nameof(first)); + } + + if (second is null) + { + throw new ArgumentNullException(nameof(second)); + } + + if (first.Length == 0 || second.Length == 0) + { + return first.Length + second.Length; + } + + var currentRow = 0; + var nextRow = 1; + var rows = new int[second.Length + 1, second.Length + 1]; + + for (var j = 0; j <= second.Length; ++j) + { + rows[currentRow, j] = j; + } + + for (var i = 1; i <= first.Length; ++i) + { + rows[nextRow, 0] = i; + for (var j = 1; j <= second.Length; ++j) + { + var deletion = rows[currentRow, j] + 1; + var insertion = rows[nextRow, j - 1] + 1; + var substitution = rows[currentRow, j - 1] + (first[i - 1].Equals(second[j - 1]) ? 0 : 1); + + rows[nextRow, j] = Math.Min(deletion, Math.Min(insertion, substitution)); + } + + if (currentRow == 0) + { + currentRow = 1; + nextRow = 0; + } + else + { + currentRow = 0; + nextRow = 1; + } + } + + return rows[currentRow, second.Length]; + } + } +} diff --git a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs index cddbead9d4..6255e526ee 100644 --- a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs +++ b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.ML.AutoML.Utils; using Microsoft.ML.Data; namespace Microsoft.ML.AutoML @@ -248,7 +249,15 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa var nullableColumn = trainData.Schema.GetColumnOrNull(columnName); if (nullableColumn == null) { - throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data."); + var closestNamed = ClosestNamed(trainData, columnName, 7); + + var exceptionMessage = $"Provided {columnPurpose} column '{columnName}' not found in training data."; + if (closestNamed != string.Empty) + { + exceptionMessage += $" Did you mean '{closestNamed}'."; + } + + throw new ArgumentException(exceptionMessage); } if(allowedTypes == null) @@ -272,6 +281,23 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa } } + private static string ClosestNamed(IDataView trainData, string columnName, int maxAllowableEditDistance = int.MaxValue) + { + var minEditDistance = int.MaxValue; + var closestNamed = string.Empty; + foreach (var column in trainData.Schema) + { + var editDistance = StringEditDistance.GetLevenshteinDistance(column.Name, columnName); + if (editDistance < minEditDistance) + { + minEditDistance = editDistance; + closestNamed = column.Name; + } + } + + return minEditDistance <= maxAllowableEditDistance ? closestNamed : string.Empty; + } + private static string FindFirstDuplicate(IEnumerable values) { var groups = values.GroupBy(v => v); diff --git a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs index 6c007c8279..259acede05 100644 --- a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Threading.Tasks; using Microsoft.ML.Data; using Microsoft.ML.TestFramework; @@ -43,10 +44,26 @@ public void ValidateExperimentExecuteLabelNotInTrain() { foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking }) { + const string columnName = "ReallyLongNonExistingColumnName"; var ex = Assert.Throws(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, - new ColumnInformation() { LabelColumnName = "L" }, null, task)); + new ColumnInformation() { LabelColumnName = columnName }, null, task)); - Assert.Equal("Provided label column 'L' not found in training data.", ex.Message); + Assert.Equal($"Provided label column '{columnName}' not found in training data.", ex.Message); + } + } + + [Fact] + public void ValidateExperimentExecuteLabelNotInTrainMistyped() + { + foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking }) + { + var originalColumnName = _data.Schema.First().Name; + var mistypedColumnName = originalColumnName + "a"; + var ex = Assert.Throws(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, + new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task)); + + Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data. Did you mean '{originalColumnName}'.", + ex.Message); } }