Skip to content

Commit

Permalink
exception fixes (dotnet#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
daholste authored and Dmitry-A committed Aug 22, 2019
1 parent 4365d98 commit 87b6766
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 46 deletions.
21 changes: 6 additions & 15 deletions src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
Expand All @@ -16,13 +15,7 @@ public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) Infer
{
var sample = TextFileSample.CreateFromFullFile(path);
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader);

// If label column index > inferred # of columns, throw error
if (labelColumnIndex >= typeInference.Columns.Count())
{
throw new ArgumentOutOfRangeException(nameof(labelColumnIndex), $"Label column index ({labelColumnIndex}) is >= than # of inferred columns ({typeInference.Columns.Count()}).");
}
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader, labelColumnIndex, null);

// if no column is named label,
// rename label column to default ML.NET label column name
Expand All @@ -40,7 +33,7 @@ public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) Infer
{
var sample = TextFileSample.CreateFromFullFile(path);
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
var typeInference = InferColumnTypes(context, sample, splitInference, true);
var typeInference = InferColumnTypes(context, sample, splitInference, true, null, label);
return InferColumns(context, path, label, true, splitInference, typeInference, trimWhitespace, groupColumns);
}

Expand All @@ -49,10 +42,6 @@ public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) Infer
bool trimWhitespace, bool groupColumns)
{
var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns);
if (!loaderColumns.Any(t => label.Equals(t.Name)))
{
throw new InferenceException(InferenceType.Label, $"Specified Label Column '{label}' was not found.");
}
var typedLoaderArgs = new TextLoader.Arguments
{
Column = loaderColumns,
Expand Down Expand Up @@ -121,7 +110,7 @@ private static TextFileContents.ColumnSplitResult InferSplit(TextFileSample samp
}

private static ColumnTypeInference.InferenceResult InferColumnTypes(MLContext context, TextFileSample sample,
TextFileContents.ColumnSplitResult splitInference, bool hasHeader)
TextFileContents.ColumnSplitResult splitInference, bool hasHeader, uint? labelColumnIndex, string label)
{
// infer column types
var typeInferenceResult = ColumnTypeInference.InferTextFileColumnTypes(context, sample,
Expand All @@ -131,7 +120,9 @@ private static ColumnTypeInference.InferenceResult InferColumnTypes(MLContext co
Separator = splitInference.Separator.Value,
AllowSparse = splitInference.AllowSparse,
AllowQuote = splitInference.AllowQuote,
HasHeader = hasHeader
HasHeader = hasHeader,
LabelColumnIndex = labelColumnIndex,
Label = label
});

if (!typeInferenceResult.IsSuccess)
Expand Down
90 changes: 67 additions & 23 deletions src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ internal sealed class Arguments
public int ColumnCount;
public bool HasHeader;
public int MaxRowsToRead;
public uint? LabelColumnIndex;
public string Label;

public Arguments()
{
Expand Down Expand Up @@ -68,13 +70,31 @@ public IntermediateColumn(ReadOnlyMemory<char>[] data, int columnId)
}

public ReadOnlyMemory<char>[] RawData { get { return _data; } }

public string Name { get; set; }

public bool HasAllBooleanValues()
{
if (this.RawData.Skip(1)
.All(x => {
bool value;
// (note: Conversions.TryParse parses an empty string as a Boolean)
return !string.IsNullOrEmpty(x.ToString()) &&
Conversions.TryParse(in x, out value);
}))
{
return true;
}

return false;
}
}

public struct Column
public class Column
{
public readonly int ColumnIndex;
public readonly PrimitiveType ItemType;

public PrimitiveType ItemType;
public string SuggestedName;

public Column(int columnIndex, string suggestedName, PrimitiveType itemType)
Expand Down Expand Up @@ -131,13 +151,10 @@ public void Apply(IntermediateColumn[] columns)
{
foreach (var col in columns)
{
if (!col.RawData.Skip(1)
.All(x =>
{
bool value;
return Conversions.TryParse(in x, out value);
})
)
// skip columns that already have a suggested type,
// or that don't have all Boolean values
if (col.SuggestedType != null ||
!col.HasAllBooleanValues())
{
continue;
}
Expand All @@ -156,12 +173,6 @@ public void Apply(IntermediateColumn[] columns)
{
foreach (var col in columns)
{
// skip columns that already have a suggested type
if(col.SuggestedType != null)
{
continue;
}

if (!col.RawData.Skip(1)
.All(x =>
{
Expand Down Expand Up @@ -215,9 +226,9 @@ public void Apply(IntermediateColumn[] columns)
private static IEnumerable<ITypeInferenceExpert> GetExperts()
{
// Current logic is pretty primitive: if every value (except the first) of a column
// parses as a boolean it's boolean, if it parses as numeric then it's numeric. Otherwise, it is text.
yield return new Experts.BooleanValues();
// parses as numeric then it's numeric. Else if it parses as a Boolean, it's Boolean. Otherwise, it is text.
yield return new Experts.AllNumericValues();
yield return new Experts.BooleanValues();
yield return new Experts.EverythingText();
}

Expand Down Expand Up @@ -329,7 +340,6 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
}

// suggest names
var names = new List<string>();
usedNames.Clear();
foreach (var col in cols)
{
Expand All @@ -338,14 +348,23 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
name0 = name = SuggestName(col, args.HasHeader);
int i = 0;
while (!usedNames.Add(name))
{
name = string.Format("{0}_{1:00}", name0, i++);
names.Add(name);
}
col.Name = name;
}

// validate & retrieve label column
var labelColumn = GetAndValidateLabelColumn(args, cols);

// if label column has all Boolean values, set its type as Boolean
if(labelColumn.HasAllBooleanValues())
{
labelColumn.SuggestedType = BoolType.Instance;
}
var outCols =
cols.Select((x, i) => new Column(x.ColumnId, names[i], x.SuggestedType)).ToArray();

var numerics = outCols.Count(x => x.ItemType.IsNumber());
var outCols = cols.Select(x => new Column(x.ColumnId, x.Name, x.SuggestedType)).ToArray();

return InferenceResult.Success(outCols, args.HasHeader, cols.Select(col => col.RawData).ToArray());
}

Expand All @@ -361,6 +380,31 @@ private static string Sanitize(string header)
return string.Join("", header.Select(x => Char.IsLetterOrDigit(x) ? x : '_'));
}

private static IntermediateColumn GetAndValidateLabelColumn(Arguments args, IntermediateColumn[] cols)
{
IntermediateColumn labelColumn = null;
if (args.LabelColumnIndex != null)
{
// if label column index > inferred # of columns, throw error
if (args.LabelColumnIndex >= cols.Count())
{
throw new ArgumentOutOfRangeException(nameof(args.LabelColumnIndex), $"Label column index ({args.LabelColumnIndex}) is >= than # of inferred columns ({cols.Count()}).");
}

labelColumn = cols[args.LabelColumnIndex.Value];
}
else
{
labelColumn = cols.FirstOrDefault(c => c.Name == args.Label);
if (labelColumn == null)
{
throw new ArgumentException($"Specified label column '{args.Label}' was not found.");
}
}

return labelColumn;
}

public static TextLoader.Column[] GenerateLoaderColumns(Column[] columns)
{
var loaderColumns = new List<TextLoader.Column>();
Expand Down
4 changes: 3 additions & 1 deletion src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ public void Apply(IntermediateColumn[] columns)
Double avgSpaces = 1.0 * sumSpaces / data.Length;
if (cardinalityRatio < 0.7 || seen.Count < 100)
column.SuggestedPurpose = ColumnPurpose.CategoricalFeature;
else if (cardinalityRatio >= 0.85 && (avgLength > 30 || avgSpaces >= 1))
// (note: the columns.Count() == 1 condition below, in case a dataset has only
// a 'name' and a 'label' column, forces what would be a 'name' column to become a text feature)
else if (cardinalityRatio >= 0.85 && (avgLength > 30 || avgSpaces >= 1 || columns.Count() == 1))
column.SuggestedPurpose = ColumnPurpose.TextFeature;
else if (cardinalityRatio >= 0.9)
column.SuggestedPurpose = ColumnPurpose.Name;
Expand Down
7 changes: 1 addition & 6 deletions src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,13 @@ public static void ValidateInferColumnsArgs(string path)
ValidatePath(path);
}

public static void ValidateAutoReadArgs(string path, string label)
{
ValidateLabel(label);
ValidatePath(path);
}

private static void ValidateTrainData(IDataView trainData)
{
if (trainData == null)
{
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
}

var type = trainData.Schema.GetColumnOrNull(DefaultColumnNames.Features)?.Type.GetItemType();
if (type != null && type != NumberType.R4)
{
Expand Down
48 changes: 47 additions & 1 deletion src/Test/ColumnInferenceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void IncorrectLabelColumnTest()
{
var dataPath = DatasetUtil.DownloadUciAdultDataset();
var context = new MLContext();
Assert.ThrowsException<InferenceException>(new System.Action(() => context.Data.InferColumns(dataPath, "Junk", groupColumns: false)));
Assert.ThrowsException<ArgumentException>(new System.Action(() => context.Data.InferColumns(dataPath, "Junk", groupColumns: false)));
}

[TestMethod]
Expand Down Expand Up @@ -62,5 +62,51 @@ public void InferColumnsLabelIndexNoHeaders()
Assert.AreEqual(1, labelPurposes.Count());
Assert.AreEqual(DefaultColumnNames.Label, labelPurposes.First().Name);
}

[TestMethod]
public void InferColumnsWithDatasetWithEmptyColumn()
{
var result = new MLContext().Data.InferColumns(@".\TestData\DatasetWithEmptyColumn.txt", DefaultColumnNames.Label);
var emptyColumn = result.TextLoaderArgs.Column.First(c => c.Name == "Empty");
Assert.AreEqual(DataKind.TX, emptyColumn.Type);
}

[TestMethod]
public void InferColumnsWithDatasetWithBoolColumn()
{
var result = new MLContext().Data.InferColumns(@".\TestData\BinaryDatasetWithBoolColumn.txt", DefaultColumnNames.Label);
Assert.AreEqual(2, result.TextLoaderArgs.Column.Count());
Assert.AreEqual(2, result.ColumnPurpopses.Count());

var boolColumn = result.TextLoaderArgs.Column.First(c => c.Name == "Bool");
var labelColumn = result.TextLoaderArgs.Column.First(c => c.Name == DefaultColumnNames.Label);
// ensure non-label Boolean column is detected as R4
Assert.AreEqual(DataKind.R4, boolColumn.Type);
Assert.AreEqual(DataKind.BL, labelColumn.Type);

// ensure non-label Boolean column is detected as R4
var boolPurpose = result.ColumnPurpopses.First(c => c.Name == "Bool").Purpose;
var labelPurpose = result.ColumnPurpopses.First(c => c.Name == DefaultColumnNames.Label).Purpose;
Assert.AreEqual(ColumnPurpose.NumericFeature, boolPurpose);
Assert.AreEqual(ColumnPurpose.Label, labelPurpose);
}

[TestMethod]
public void InferColumnsWhereNameColumnIsOnlyFeature()
{
var result = new MLContext().Data.InferColumns(@".\TestData\NameColumnIsOnlyFeatureDataset.txt", DefaultColumnNames.Label);
Assert.AreEqual(2, result.TextLoaderArgs.Column.Count());
Assert.AreEqual(2, result.ColumnPurpopses.Count());

var nameColumn = result.TextLoaderArgs.Column.First(c => c.Name == DefaultColumnNames.Name);
var labelColumn = result.TextLoaderArgs.Column.First(c => c.Name == DefaultColumnNames.Label);
Assert.AreEqual(DataKind.TX, nameColumn.Type);
Assert.AreEqual(DataKind.BL, labelColumn.Type);

var namePurpose = result.ColumnPurpopses.First(c => c.Name == DefaultColumnNames.Name).Purpose;
var labelPurpose = result.ColumnPurpopses.First(c => c.Name == DefaultColumnNames.Label).Purpose;
Assert.AreEqual(ColumnPurpose.TextFeature, namePurpose);
Assert.AreEqual(ColumnPurpose.Label, labelPurpose);
}
}
}
12 changes: 12 additions & 0 deletions src/Test/Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,16 @@
<ProjectReference Include="..\Microsoft.ML.Auto\Microsoft.ML.Auto.csproj" />
</ItemGroup>

<ItemGroup>
<None Update="TestData\NameColumnIsOnlyFeatureDataset.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="TestData\BinaryDatasetWithBoolColumn.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="TestData\DatasetWithEmptyColumn.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>
5 changes: 5 additions & 0 deletions src/Test/TestData/BinaryDatasetWithBoolColumn.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Label,Bool
0,1
0,0
1,1
1,0
4 changes: 4 additions & 0 deletions src/Test/TestData/DatasetWithEmptyColumn.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Label,Feature1,Empty
0,2,
0,4,
1,1,
Loading

0 comments on commit 87b6766

Please sign in to comment.