Skip to content

Commit

Permalink
handle boolean type in construction utils. (#183)
Browse files Browse the repository at this point in the history
 handle boolean type in construction utils.
  • Loading branch information
Ivanidzo4ka authored May 21, 2018
1 parent f16737c commit 3f586bd
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Api/ApiUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ private static OpCode GetAssignmentOpCode(Type t)
// REVIEW: This should be a Dictionary<Type, OpCode> based solution.
// DvTexts, strings, arrays, and VBuffers.
if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) ||
t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray ||
t == typeof(DvBool) || t==typeof(bool?) || t == typeof(DvText) || t == typeof(string) || t.IsArray ||
(t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || t == typeof(DvDateTime) ||
t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128))
{
Expand Down
37 changes: 37 additions & 0 deletions src/Microsoft.ML.Api/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ private Delegate CreateGetter(int index)
Ch.Assert(colType.IsText);
return CreateStringToTextGetter(index);
}
else if (outputType == typeof(bool))
{
// Bool -> DvBool
Ch.Assert(colType.IsBool);
return CreateBooleanToDvBoolGetter(index);
}
else if (outputType == typeof(bool?))
{
// Bool? -> DvBool
Ch.Assert(colType.IsBool);
return CreateNullableBooleanToDvBoolGetter(index);
}

// T -> T
Ch.Assert(colType.RawType == outputType);
del = CreateDirectGetter<int>;
Expand Down Expand Up @@ -197,6 +210,30 @@ private Delegate CreateStringToTextGetter(int index)
});
}

private Delegate CreateBooleanToDvBoolGetter(int index)
{
var peek = DataView._peeks[index] as Peek<TRow, bool>;
Ch.AssertValue(peek);
bool buf = false;
return (ValueGetter<DvBool>)((ref DvBool dst) =>
{
peek(GetCurrentRowObject(), Position, ref buf);
dst = (DvBool)buf;
});
}

private Delegate CreateNullableBooleanToDvBoolGetter(int index)
{
var peek = DataView._peeks[index] as Peek<TRow, bool?>;
Ch.AssertValue(peek);
bool? buf = null;
return (ValueGetter<DvBool>)((ref DvBool dst) =>
{
peek(GetCurrentRowObject(), Position, ref buf);
dst = buf.HasValue ? (DvBool)buf.Value : DvBool.NA;
});
}

private Delegate CreateArrayToVBufferGetter<TDst>(int index)
{
var peek = DataView._peeks[index] as Peek<TRow, TDst[]>;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/DataKind.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ public static bool TryGetDataKind(this Type type, out DataKind kind)
kind = DataKind.R8;
else if (type == typeof(DvText))
kind = DataKind.TX;
else if (type == typeof(DvBool) || type == typeof(bool))
else if (type == typeof(DvBool) || type == typeof(bool) ||type ==typeof(bool?))
kind = DataKind.BL;
else if (type == typeof(DvTimeSpan))
kind = DataKind.TS;
Expand Down
49 changes: 49 additions & 0 deletions test/Microsoft.ML.Tests/LearningPipelineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,54 @@ public void NoTransformPipeline()
pipeline.Add(new FastForestBinaryClassifier());
var model = pipeline.Train<Data, Prediction>();
}

public class BooleanLabelData
{
[ColumnName("Features")]
[VectorType(2)]
public float[] Features;

[ColumnName("Label")]
public bool Label;
}

[Fact]
public void BooleanLabelPipeline()
{
var data = new BooleanLabelData[1];
data[0] = new BooleanLabelData();
data[0].Features = new float[] { 0.0f, 1.0f };
data[0].Label = false;
var pipeline = new LearningPipeline();
pipeline.Add(CollectionDataSource.Create(data));
pipeline.Add(new FastForestBinaryClassifier());
var model = pipeline.Train<Data, Prediction>();
}

public class NullableBooleanLabelData
{
[ColumnName("Features")]
[VectorType(2)]
public float[] Features;

[ColumnName("Label")]
public bool? Label;
}

[Fact]
public void NullableBooleanLabelPipeline()
{
var data = new NullableBooleanLabelData[2];
data[0] = new NullableBooleanLabelData();
data[0].Features = new float[] { 0.0f, 1.0f };
data[0].Label = null;
data[1] = new NullableBooleanLabelData();
data[1].Features = new float[] { 1.0f, 0.0f };
data[1].Label = false;
var pipeline = new LearningPipeline();
pipeline.Add(CollectionDataSource.Create(data));
pipeline.Add(new FastForestBinaryClassifier());
var model = pipeline.Train<Data, Prediction>();
}
}
}

0 comments on commit 3f586bd

Please sign in to comment.