Skip to content

Commit

Permalink
Sanitize the column names in CLI (dotnet#162)
Browse files Browse the repository at this point in the history
* added sanitization layer in CLI

* fix test

* changed exception.StackTrace to exception.ToString()
  • Loading branch information
srsaggam authored and Dmitry-A committed Aug 22, 2019
1 parent b92039a commit dbed126
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
13 changes: 4 additions & 9 deletions src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ public IntermediateColumn(ReadOnlyMemory<char>[] data, int columnId)
public bool HasAllBooleanValues()
{
if (this.RawData.Skip(1)
.All(x => {
.All(x =>
{
bool value;
// (note: Conversions.TryParse parses an empty string as a Boolean)
return !string.IsNullOrEmpty(x.ToString()) &&
Expand Down Expand Up @@ -358,7 +359,7 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
var labelColumn = GetAndValidateLabelColumn(args, cols);

// if label column has all Boolean values, set its type as Boolean
if(labelColumn.HasAllBooleanValues())
if (labelColumn.HasAllBooleanValues())
{
labelColumn.SuggestedType = BoolType.Instance;
}
Expand All @@ -371,13 +372,7 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
private static string SuggestName(IntermediateColumn column, bool hasHeader)
{
var header = column.RawData[0].ToString();
return (hasHeader && !string.IsNullOrWhiteSpace(header)) ? Sanitize(header) : string.Format("col{0}", column.ColumnId);
}

private static string Sanitize(string header)
{
// replace all non-letters and non-digits with '_'.
return string.Join("", header.Select(x => Char.IsLetterOrDigit(x) ? x : '_'));
return (hasHeader && !string.IsNullOrWhiteSpace(header)) ? header : string.Format("col{0}", column.ColumnId);
}

private static IntermediateColumn GetAndValidateLabelColumn(Arguments args, IntermediateColumn[] cols)
Expand Down
4 changes: 2 additions & 2 deletions src/Test/ColumnInferenceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ public void InferColumnsLabelIndex()
var result = new MLContext().Data.InferColumns(DatasetUtil.DownloadUciAdultDataset(), 14, hasHeader: true);
Assert.AreEqual(true, result.TextLoaderArgs.HasHeader);
var labelCol = result.TextLoaderArgs.Column.First(c => c.Source[0].Min == 14 && c.Source[0].Max == 14);
Assert.AreEqual("hours_per_week", labelCol.Name);
Assert.AreEqual("hours-per-week", labelCol.Name);
var labelPurposes = result.ColumnPurpopses.Where(c => c.Purpose == ColumnPurpose.Label);
Assert.AreEqual(1, labelPurposes.Count());
Assert.AreEqual("hours_per_week", labelPurposes.First().Name);
Assert.AreEqual("hours-per-week", labelPurposes.First().Name);
}

[TestMethod]
Expand Down
9 changes: 8 additions & 1 deletion src/mlnet/Commands/New/NewCommandHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Auto;
using Microsoft.ML.CLI.CodeGenerator.CSharp;
Expand Down Expand Up @@ -32,6 +33,8 @@ public void Execute()
// Infer columns
(TextLoader.Arguments TextLoaderArgs, IEnumerable<(string Name, ColumnPurpose Purpose)> ColumnPurpopses) columnInference = InferColumns(context);

Array.ForEach(columnInference.TextLoaderArgs.Column, t => t.Name = Sanitize(t.Name));

// Load data
(IDataView trainData, IDataView validationData) = LoadData(context, columnInference.TextLoaderArgs);

Expand All @@ -45,7 +48,7 @@ public void Execute()
catch (Exception e)
{
logger.Log(LogLevel.Error, $"{Strings.ExplorePipelineException}:");
logger.Log(LogLevel.Error, e.StackTrace);
logger.Log(LogLevel.Error, e.ToString());
logger.Log(LogLevel.Error, Strings.Exiting);
return;
}
Expand Down Expand Up @@ -157,5 +160,9 @@ internal static void SaveModel(ITransformer model, string ModelPath, string mode
model.SaveTo(mlContext, fs);
}

private static string Sanitize(string name)
{
return string.Join("", name.Select(x => Char.IsLetterOrDigit(x) ? x : '_'));
}
}
}

0 comments on commit dbed126

Please sign in to comment.