From d3a0aae9034d5ae8821b8f23e1ae24a916141ec2 Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Thu, 6 Jul 2023 12:49:56 +0300 Subject: [PATCH 1/8] Provide ability to filter by null value # Conflicts: # src/Microsoft.Data.Analysis/DataFrameColumn.BinaryOperations.tt # src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.cs # src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.tt --- .../DataFrameColumn.BinaryOperations.cs | 9 +++ ...imitiveDataFrameColumn.BinaryOperations.cs | 74 +++++++++++++++++++ .../PrimitiveDataFrameColumn.cs | 24 ++++++ .../StringDataFrameColumn.BinaryOperations.cs | 26 +++++++ .../DataFrameTests.cs | 68 +++++++++++++++++ 5 files changed, 201 insertions(+) diff --git a/src/Microsoft.Data.Analysis/DataFrameColumn.BinaryOperations.cs b/src/Microsoft.Data.Analysis/DataFrameColumn.BinaryOperations.cs index 8ecd052486..4a3bac6988 100644 --- a/src/Microsoft.Data.Analysis/DataFrameColumn.BinaryOperations.cs +++ b/src/Microsoft.Data.Analysis/DataFrameColumn.BinaryOperations.cs @@ -316,5 +316,14 @@ public virtual PrimitiveDataFrameColumn ElementwiseLessThan(T value) throw new NotImplementedException(); } + public virtual PrimitiveDataFrameColumn ElementwiseIsNull() + { + throw new NotImplementedException(); + } + + public virtual PrimitiveDataFrameColumn ElementwiseIsNotNull() + { + throw new NotImplementedException(); + } } } diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.cs index d05af4d699..fc75bda8fe 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.cs @@ -47,10 +47,12 @@ public override DataFrameColumn Add(DataFrameColumn column, bool inPlace = false return AddImplementation(ushortColumn, inPlace); case PrimitiveDataFrameColumn DateTimeColumn: return AddImplementation(DateTimeColumn, inPlace); + default: throw new NotSupportedException(); } } + /// public override DataFrameColumn Add(U value, bool inPlace = false) { @@ -61,6 +63,7 @@ public override DataFrameColumn Add(U value, bool inPlace = false) } return AddImplementation(value, inPlace); } + /// public override DataFrameColumn Subtract(DataFrameColumn column, bool inPlace = false) { @@ -94,10 +97,12 @@ public override DataFrameColumn Subtract(DataFrameColumn column, bool inPlace = return SubtractImplementation(ushortColumn, inPlace); case PrimitiveDataFrameColumn DateTimeColumn: return SubtractImplementation(DateTimeColumn, inPlace); + default: throw new NotSupportedException(); } } + /// public override DataFrameColumn Subtract(U value, bool inPlace = false) { @@ -108,6 +113,7 @@ public override DataFrameColumn Subtract(U value, bool inPlace = false) } return SubtractImplementation(value, inPlace); } + /// public override DataFrameColumn Multiply(DataFrameColumn column, bool inPlace = false) { @@ -141,10 +147,12 @@ public override DataFrameColumn Multiply(DataFrameColumn column, bool inPlace = return MultiplyImplementation(ushortColumn, inPlace); case PrimitiveDataFrameColumn DateTimeColumn: return MultiplyImplementation(DateTimeColumn, inPlace); + default: throw new NotSupportedException(); } } + /// public override DataFrameColumn Multiply(U value, bool inPlace = false) { @@ -155,6 +163,7 @@ public override DataFrameColumn Multiply(U value, bool inPlace = false) } return MultiplyImplementation(value, inPlace); } + /// public override DataFrameColumn Divide(DataFrameColumn column, bool inPlace = false) { @@ -188,10 +197,12 @@ public override DataFrameColumn Divide(DataFrameColumn column, bool inPlace = fa return DivideImplementation(ushortColumn, inPlace); case PrimitiveDataFrameColumn DateTimeColumn: return DivideImplementation(DateTimeColumn, inPlace); + default: throw new NotSupportedException(); } } + /// public override DataFrameColumn Divide(U value, bool inPlace = false) { @@ -202,6 +213,7 @@ public override DataFrameColumn Divide(U value, bool inPlace = false) } return DivideImplementation(value, inPlace); } + /// public override DataFrameColumn Modulo(DataFrameColumn column, bool inPlace = false) { @@ -235,10 +247,12 @@ public override DataFrameColumn Modulo(DataFrameColumn column, bool inPlace = fa return ModuloImplementation(ushortColumn, inPlace); case PrimitiveDataFrameColumn DateTimeColumn: return ModuloImplementation(DateTimeColumn, inPlace); + default: throw new NotSupportedException(); } } + /// public override DataFrameColumn Modulo(U value, bool inPlace = false) { @@ -249,6 +263,7 @@ public override DataFrameColumn Modulo(U value, bool inPlace = false) } return ModuloImplementation(value, inPlace); } + /// public override DataFrameColumn And(DataFrameColumn column, bool inPlace = false) { @@ -282,15 +297,18 @@ public override DataFrameColumn And(DataFrameColumn column, bool inPlace = false return AndImplementation(ushortColumn, inPlace); case PrimitiveDataFrameColumn DateTimeColumn: return AndImplementation(DateTimeColumn, inPlace); + default: throw new NotSupportedException(); } } + /// public override PrimitiveDataFrameColumn And(bool value, bool inPlace = false) { return AndImplementation(value, inPlace); } + /// public override DataFrameColumn Or(DataFrameColumn column, bool inPlace = false) { @@ -324,15 +342,18 @@ public override DataFrameColumn Or(DataFrameColumn column, bool inPlace = false) return OrImplementation(ushortColumn, inPlace); case PrimitiveDataFrameColumn DateTimeColumn: return OrImplementation(DateTimeColumn, inPlace); + default: throw new NotSupportedException(); } } + /// public override PrimitiveDataFrameColumn Or(bool value, bool inPlace = false) { return OrImplementation(value, inPlace); } + /// public override DataFrameColumn Xor(DataFrameColumn column, bool inPlace = false) { @@ -366,15 +387,18 @@ public override DataFrameColumn Xor(DataFrameColumn column, bool inPlace = false return XorImplementation(ushortColumn, inPlace); case PrimitiveDataFrameColumn DateTimeColumn: return XorImplementation(DateTimeColumn, inPlace); + default: throw new NotSupportedException(); } } + /// public override PrimitiveDataFrameColumn Xor(bool value, bool inPlace = false) { return XorImplementation(value, inPlace); } + /// public override DataFrameColumn LeftShift(int value, bool inPlace = false) { @@ -418,10 +442,14 @@ public override PrimitiveDataFrameColumn ElementwiseEquals(DataFrameColumn return ElementwiseEqualsImplementation(ushortColumn); case PrimitiveDataFrameColumn DateTimeColumn: return ElementwiseEqualsImplementation(DateTimeColumn); + case null: + return ElementwiseIsNull(); + default: throw new NotSupportedException(); } } + /// public override PrimitiveDataFrameColumn ElementwiseEquals(U value) { @@ -432,6 +460,7 @@ public override PrimitiveDataFrameColumn ElementwiseEquals(U value) } return ElementwiseEqualsImplementation(value); } + /// public override PrimitiveDataFrameColumn ElementwiseNotEquals(DataFrameColumn column) { @@ -465,10 +494,14 @@ public override PrimitiveDataFrameColumn ElementwiseNotEquals(DataFrameCol return ElementwiseNotEqualsImplementation(ushortColumn); case PrimitiveDataFrameColumn DateTimeColumn: return ElementwiseNotEqualsImplementation(DateTimeColumn); + case null: + return ElementwiseIsNotNull(); + default: throw new NotSupportedException(); } } + /// public override PrimitiveDataFrameColumn ElementwiseNotEquals(U value) { @@ -479,6 +512,7 @@ public override PrimitiveDataFrameColumn ElementwiseNotEquals(U value) } return ElementwiseNotEqualsImplementation(value); } + /// public override PrimitiveDataFrameColumn ElementwiseGreaterThanOrEqual(DataFrameColumn column) { @@ -512,10 +546,12 @@ public override PrimitiveDataFrameColumn ElementwiseGreaterThanOrEqual(Dat return ElementwiseGreaterThanOrEqualImplementation(ushortColumn); case PrimitiveDataFrameColumn DateTimeColumn: return ElementwiseGreaterThanOrEqualImplementation(DateTimeColumn); + default: throw new NotSupportedException(); } } + /// public override PrimitiveDataFrameColumn ElementwiseGreaterThanOrEqual(U value) { @@ -526,6 +562,7 @@ public override PrimitiveDataFrameColumn ElementwiseGreaterThanOrEqual( } return ElementwiseGreaterThanOrEqualImplementation(value); } + /// public override PrimitiveDataFrameColumn ElementwiseLessThanOrEqual(DataFrameColumn column) { @@ -559,10 +596,12 @@ public override PrimitiveDataFrameColumn ElementwiseLessThanOrEqual(DataFr return ElementwiseLessThanOrEqualImplementation(ushortColumn); case PrimitiveDataFrameColumn DateTimeColumn: return ElementwiseLessThanOrEqualImplementation(DateTimeColumn); + default: throw new NotSupportedException(); } } + /// public override PrimitiveDataFrameColumn ElementwiseLessThanOrEqual(U value) { @@ -573,6 +612,7 @@ public override PrimitiveDataFrameColumn ElementwiseLessThanOrEqual(U v } return ElementwiseLessThanOrEqualImplementation(value); } + /// public override PrimitiveDataFrameColumn ElementwiseGreaterThan(DataFrameColumn column) { @@ -606,10 +646,12 @@ public override PrimitiveDataFrameColumn ElementwiseGreaterThan(DataFrameC return ElementwiseGreaterThanImplementation(ushortColumn); case PrimitiveDataFrameColumn DateTimeColumn: return ElementwiseGreaterThanImplementation(DateTimeColumn); + default: throw new NotSupportedException(); } } + /// public override PrimitiveDataFrameColumn ElementwiseGreaterThan(U value) { @@ -620,6 +662,7 @@ public override PrimitiveDataFrameColumn ElementwiseGreaterThan(U value } return ElementwiseGreaterThanImplementation(value); } + /// public override PrimitiveDataFrameColumn ElementwiseLessThan(DataFrameColumn column) { @@ -653,10 +696,12 @@ public override PrimitiveDataFrameColumn ElementwiseLessThan(DataFrameColu return ElementwiseLessThanImplementation(ushortColumn); case PrimitiveDataFrameColumn DateTimeColumn: return ElementwiseLessThanImplementation(DateTimeColumn); + default: throw new NotSupportedException(); } } + /// public override PrimitiveDataFrameColumn ElementwiseLessThan(U value) { @@ -668,6 +713,7 @@ public override PrimitiveDataFrameColumn ElementwiseLessThan(U value) return ElementwiseLessThanImplementation(value); } + internal DataFrameColumn AddImplementation(PrimitiveDataFrameColumn column, bool inPlace) where U : unmanaged { @@ -750,6 +796,7 @@ internal DataFrameColumn AddImplementation(PrimitiveDataFrameColumn column throw new NotSupportedException(); } } + internal DataFrameColumn AddImplementation(U value, bool inPlace) { switch (typeof(T)) @@ -1035,6 +1082,7 @@ internal DataFrameColumn SubtractImplementation(PrimitiveDataFrameColumn c throw new NotSupportedException(); } } + internal DataFrameColumn SubtractImplementation(U value, bool inPlace) { switch (typeof(T)) @@ -1139,6 +1187,7 @@ internal DataFrameColumn SubtractImplementation(U value, bool inPlace) throw new NotSupportedException(); } } + internal DataFrameColumn MultiplyImplementation(PrimitiveDataFrameColumn column, bool inPlace) where U : unmanaged { @@ -1221,6 +1270,7 @@ internal DataFrameColumn MultiplyImplementation(PrimitiveDataFrameColumn c throw new NotSupportedException(); } } + internal DataFrameColumn MultiplyImplementation(U value, bool inPlace) { switch (typeof(T)) @@ -1298,6 +1348,7 @@ internal DataFrameColumn MultiplyImplementation(U value, bool inPlace) throw new NotSupportedException(); } } + internal DataFrameColumn DivideImplementation(PrimitiveDataFrameColumn column, bool inPlace) where U : unmanaged { @@ -1380,6 +1431,7 @@ internal DataFrameColumn DivideImplementation(PrimitiveDataFrameColumn col throw new NotSupportedException(); } } + internal DataFrameColumn DivideImplementation(U value, bool inPlace) { switch (typeof(T)) @@ -1484,6 +1536,7 @@ internal DataFrameColumn DivideImplementation(U value, bool inPlace) throw new NotSupportedException(); } } + internal DataFrameColumn ModuloImplementation(PrimitiveDataFrameColumn column, bool inPlace) where U : unmanaged { @@ -1566,6 +1619,7 @@ internal DataFrameColumn ModuloImplementation(PrimitiveDataFrameColumn col throw new NotSupportedException(); } } + internal DataFrameColumn ModuloImplementation(U value, bool inPlace) { switch (typeof(T)) @@ -1643,6 +1697,7 @@ internal DataFrameColumn ModuloImplementation(U value, bool inPlace) throw new NotSupportedException(); } } + internal DataFrameColumn AndImplementation(PrimitiveDataFrameColumn column, bool inPlace) where U : unmanaged { @@ -1678,6 +1733,7 @@ internal DataFrameColumn AndImplementation(PrimitiveDataFrameColumn column throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn AndImplementation(U value, bool inPlace) { switch (typeof(T)) @@ -1708,6 +1764,7 @@ internal PrimitiveDataFrameColumn AndImplementation(U value, bool inPla throw new NotSupportedException(); } } + internal DataFrameColumn OrImplementation(PrimitiveDataFrameColumn column, bool inPlace) where U : unmanaged { @@ -1743,6 +1800,7 @@ internal DataFrameColumn OrImplementation(PrimitiveDataFrameColumn column, throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn OrImplementation(U value, bool inPlace) { switch (typeof(T)) @@ -1773,6 +1831,7 @@ internal PrimitiveDataFrameColumn OrImplementation(U value, bool inPlac throw new NotSupportedException(); } } + internal DataFrameColumn XorImplementation(PrimitiveDataFrameColumn column, bool inPlace) where U : unmanaged { @@ -1808,6 +1867,7 @@ internal DataFrameColumn XorImplementation(PrimitiveDataFrameColumn column throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn XorImplementation(U value, bool inPlace) { switch (typeof(T)) @@ -1838,6 +1898,7 @@ internal PrimitiveDataFrameColumn XorImplementation(U value, bool inPla throw new NotSupportedException(); } } + internal DataFrameColumn LeftShiftImplementation(int value, bool inPlace) { switch (typeof(T)) @@ -1901,6 +1962,7 @@ internal DataFrameColumn LeftShiftImplementation(int value, bool inPlace) throw new NotSupportedException(); } } + internal DataFrameColumn RightShiftImplementation(int value, bool inPlace) { switch (typeof(T)) @@ -1964,6 +2026,7 @@ internal DataFrameColumn RightShiftImplementation(int value, bool inPlace) throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseEqualsImplementation(PrimitiveDataFrameColumn column) where U : unmanaged { @@ -2053,6 +2116,7 @@ internal PrimitiveDataFrameColumn ElementwiseEqualsImplementation(Primi throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseEqualsImplementation(U value) { switch (typeof(T)) @@ -2137,6 +2201,7 @@ internal PrimitiveDataFrameColumn ElementwiseEqualsImplementation(U val throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseNotEqualsImplementation(PrimitiveDataFrameColumn column) where U : unmanaged { @@ -2226,6 +2291,7 @@ internal PrimitiveDataFrameColumn ElementwiseNotEqualsImplementation(Pr throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseNotEqualsImplementation(U value) { switch (typeof(T)) @@ -2310,6 +2376,7 @@ internal PrimitiveDataFrameColumn ElementwiseNotEqualsImplementation(U throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseGreaterThanOrEqualImplementation(PrimitiveDataFrameColumn column) where U : unmanaged { @@ -2387,6 +2454,7 @@ internal PrimitiveDataFrameColumn ElementwiseGreaterThanOrEqualImplementat throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseGreaterThanOrEqualImplementation(U value) { switch (typeof(T)) @@ -2459,6 +2527,7 @@ internal PrimitiveDataFrameColumn ElementwiseGreaterThanOrEqualImplementat throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseLessThanOrEqualImplementation(PrimitiveDataFrameColumn column) where U : unmanaged { @@ -2536,6 +2605,7 @@ internal PrimitiveDataFrameColumn ElementwiseLessThanOrEqualImplementation throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseLessThanOrEqualImplementation(U value) { switch (typeof(T)) @@ -2608,6 +2678,7 @@ internal PrimitiveDataFrameColumn ElementwiseLessThanOrEqualImplementation throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseGreaterThanImplementation(PrimitiveDataFrameColumn column) where U : unmanaged { @@ -2685,6 +2756,7 @@ internal PrimitiveDataFrameColumn ElementwiseGreaterThanImplementation( throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseGreaterThanImplementation(U value) { switch (typeof(T)) @@ -2757,6 +2829,7 @@ internal PrimitiveDataFrameColumn ElementwiseGreaterThanImplementation( throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseLessThanImplementation(PrimitiveDataFrameColumn column) where U : unmanaged { @@ -2834,6 +2907,7 @@ internal PrimitiveDataFrameColumn ElementwiseLessThanImplementation(Pri throw new NotSupportedException(); } } + internal PrimitiveDataFrameColumn ElementwiseLessThanImplementation(U value) { switch (typeof(T)) diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs index 0fe7820fe2..dbf73a5536 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs @@ -821,5 +821,29 @@ public override Dictionary> GetGroupedOccurrences(DataFr { return GetGroupedOccurrences(other, out otherColumnNullIndices); } + + public override PrimitiveDataFrameColumn ElementwiseIsNull() + { + var ret = new BooleanDataFrameColumn(Name, Length); + + for (long i = 0; i < Length; i++) + { + ret[i] = !_columnContainer[i].HasValue; + } + + return ret; + } + + public override PrimitiveDataFrameColumn ElementwiseIsNotNull() + { + var ret = new BooleanDataFrameColumn(Name, Length); + + for (long i = 0; i < Length; i++) + { + ret[i] = _columnContainer[i].HasValue; + } + + return ret; + } } } diff --git a/src/Microsoft.Data.Analysis/StringDataFrameColumn.BinaryOperations.cs b/src/Microsoft.Data.Analysis/StringDataFrameColumn.BinaryOperations.cs index 0bfbd3b6bc..c6ffe4c4cf 100644 --- a/src/Microsoft.Data.Analysis/StringDataFrameColumn.BinaryOperations.cs +++ b/src/Microsoft.Data.Analysis/StringDataFrameColumn.BinaryOperations.cs @@ -91,6 +91,9 @@ internal static PrimitiveDataFrameColumn ElementwiseEqualsImplementation(D /// public override PrimitiveDataFrameColumn ElementwiseEquals(DataFrameColumn column) { + if (column == null) + return ElementwiseIsNull(); + return ElementwiseEqualsImplementation(this, column); } @@ -128,6 +131,26 @@ internal static PrimitiveDataFrameColumn ElementwiseNotEqualsImplementatio return ret; } + public override PrimitiveDataFrameColumn ElementwiseIsNotNull() + { + PrimitiveDataFrameColumn ret = new PrimitiveDataFrameColumn(Name, Length); + for (long i = 0; i < Length; i++) + { + ret[i] = this[i] != null; + } + return ret; + } + + public override PrimitiveDataFrameColumn ElementwiseIsNull() + { + PrimitiveDataFrameColumn ret = new PrimitiveDataFrameColumn(Name, Length); + for (long i = 0; i < Length; i++) + { + ret[i] = this[i] == null; + } + return ret; + } + public PrimitiveDataFrameColumn ElementwiseNotEquals(string value) { PrimitiveDataFrameColumn ret = new PrimitiveDataFrameColumn(Name, Length); @@ -141,6 +164,9 @@ public PrimitiveDataFrameColumn ElementwiseNotEquals(string value) /// public override PrimitiveDataFrameColumn ElementwiseNotEquals(DataFrameColumn column) { + if (column == null) + return ElementwiseIsNotNull(); + return ElementwiseNotEqualsImplementation(this, column); } diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index ff7856e984..c42864a162 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -3490,5 +3490,73 @@ public void TestMeanMedian() Assert.Equal(4, df["Decimal"].Median()); } + + [Fact] + public void Test_PrimitiveColumnNotEqualsNull() + { + var col = new DoubleDataFrameColumn("col", new double?[] { 1.23, null, 2, 3 }); + var dfTest = new DataFrame(col); + + var filteredNullDf = dfTest.Filter(dfTest["col"].ElementwiseNotEquals(null)); + + Assert.True(filteredNullDf.Columns.IndexOf("col") >= 0); + Assert.Equal(3, filteredNullDf.Columns["col"].Length); + + Assert.Equal(1.23, filteredNullDf.Columns["col"][0]); + Assert.Equal(2.0, filteredNullDf.Columns["col"][1]); + Assert.Equal(3.0, filteredNullDf.Columns["col"][2]); + } + + [Fact] + public void Test_PrimitiveColumnEqualsNull() + { + var index = new Int32DataFrameColumn("index", new int[] { 1, 2, 3, 4, 5 }); + var col = new DoubleDataFrameColumn("col", new double?[] { 1.23, null, 2, 3, null }); ; + var dfTest = new DataFrame(index, col); + + var filteredNullDf = dfTest.Filter(dfTest["col"].ElementwiseEquals(null)); + + Assert.True(filteredNullDf.Columns.IndexOf("col") >= 0); + Assert.True(filteredNullDf.Columns.IndexOf("index") >= 0); + + Assert.Equal(2, filteredNullDf.Rows.Count); + + Assert.Equal(2, filteredNullDf.Columns["index"][0]); + Assert.Equal(5, filteredNullDf.Columns["index"][1]); + } + + [Fact] + public void Test_StringColumnNotEqualsNull() + { + var col = new StringDataFrameColumn("col", new[] { "One", null, "Two", "Three" }); + var dfTest = new DataFrame(col); + + var filteredNullDf = dfTest.Filter(dfTest["col"].ElementwiseNotEquals(null)); + + Assert.True(filteredNullDf.Columns.IndexOf("col") >= 0); + Assert.Equal(3, filteredNullDf.Columns["col"].Length); + + Assert.Equal("One", filteredNullDf.Columns["col"][0]); + Assert.Equal("Two", filteredNullDf.Columns["col"][1]); + Assert.Equal("Three", filteredNullDf.Columns["col"][2]); + } + + [Fact] + public void Test_StringColumnEqualsNull() + { + var index = new Int32DataFrameColumn("index", new int[] { 1, 2, 3, 4, 5 }); + var col = new StringDataFrameColumn("col", new[] { "One", null, "Three", "Four", null }); ; + var dfTest = new DataFrame(index, col); + + var filteredNullDf = dfTest.Filter(dfTest["col"].ElementwiseEquals(null)); + + Assert.True(filteredNullDf.Columns.IndexOf("col") >= 0); + Assert.True(filteredNullDf.Columns.IndexOf("index") >= 0); + + Assert.Equal(2, filteredNullDf.Rows.Count); + + Assert.Equal(2, filteredNullDf.Columns["index"][0]); + Assert.Equal(5, filteredNullDf.Columns["index"][1]); + } } } From ab7e69810e9a7cc96d7286f858cdcfec8c713904 Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Thu, 6 Jul 2023 12:51:04 +0300 Subject: [PATCH 2/8] Add comments --- .../DataFrameColumn.BinaryOperations.cs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/Microsoft.Data.Analysis/DataFrameColumn.BinaryOperations.cs b/src/Microsoft.Data.Analysis/DataFrameColumn.BinaryOperations.cs index 4a3bac6988..1c340575db 100644 --- a/src/Microsoft.Data.Analysis/DataFrameColumn.BinaryOperations.cs +++ b/src/Microsoft.Data.Analysis/DataFrameColumn.BinaryOperations.cs @@ -316,11 +316,17 @@ public virtual PrimitiveDataFrameColumn ElementwiseLessThan(T value) throw new NotImplementedException(); } + /// + /// Performs an element-wise equal to Null on each value in the column + /// public virtual PrimitiveDataFrameColumn ElementwiseIsNull() { throw new NotImplementedException(); } + /// + /// Performs an element-wise not equal to Null on each value in the column + /// public virtual PrimitiveDataFrameColumn ElementwiseIsNotNull() { throw new NotImplementedException(); From b0daf74a7e5e1e5afd925deeccf48eb26d24a5ff Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Thu, 6 Jul 2023 13:03:23 +0300 Subject: [PATCH 3/8] Fix merge issues (broken build) --- .../PrimitiveDataFrameColumn.cs | 4 ++-- test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs index dbf73a5536..5aed1c57f7 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs @@ -824,7 +824,7 @@ public override Dictionary> GetGroupedOccurrences(DataFr public override PrimitiveDataFrameColumn ElementwiseIsNull() { - var ret = new BooleanDataFrameColumn(Name, Length); + var ret = new PrimitiveDataFrameColumn(Name, Length); for (long i = 0; i < Length; i++) { @@ -836,7 +836,7 @@ public override PrimitiveDataFrameColumn ElementwiseIsNull() public override PrimitiveDataFrameColumn ElementwiseIsNotNull() { - var ret = new BooleanDataFrameColumn(Name, Length); + var ret = new PrimitiveDataFrameColumn(Name, Length); for (long i = 0; i < Length; i++) { diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index c42864a162..cbc6cc9e80 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -3494,7 +3494,7 @@ public void TestMeanMedian() [Fact] public void Test_PrimitiveColumnNotEqualsNull() { - var col = new DoubleDataFrameColumn("col", new double?[] { 1.23, null, 2, 3 }); + var col = new PrimitiveDataFrameColumn("col", new double?[] { 1.23, null, 2, 3 }); var dfTest = new DataFrame(col); var filteredNullDf = dfTest.Filter(dfTest["col"].ElementwiseNotEquals(null)); @@ -3510,8 +3510,8 @@ public void Test_PrimitiveColumnNotEqualsNull() [Fact] public void Test_PrimitiveColumnEqualsNull() { - var index = new Int32DataFrameColumn("index", new int[] { 1, 2, 3, 4, 5 }); - var col = new DoubleDataFrameColumn("col", new double?[] { 1.23, null, 2, 3, null }); ; + var index = new PrimitiveDataFrameColumn("index", new int[] { 1, 2, 3, 4, 5 }); + var col = new PrimitiveDataFrameColumn("col", new double?[] { 1.23, null, 2, 3, null }); ; var dfTest = new DataFrame(index, col); var filteredNullDf = dfTest.Filter(dfTest["col"].ElementwiseEquals(null)); @@ -3544,7 +3544,7 @@ public void Test_StringColumnNotEqualsNull() [Fact] public void Test_StringColumnEqualsNull() { - var index = new Int32DataFrameColumn("index", new int[] { 1, 2, 3, 4, 5 }); + var index = new PrimitiveDataFrameColumn("index", new int[] { 1, 2, 3, 4, 5 }); var col = new StringDataFrameColumn("col", new[] { "One", null, "Three", "Four", null }); ; var dfTest = new DataFrame(index, col); From 663db1f1b0724c14fcbd5d552e3c68f966013c07 Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Thu, 6 Jul 2023 14:20:15 +0300 Subject: [PATCH 4/8] Step 1 # Conflicts: # src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Computations.tt # src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnComputations.cs # src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnComputations.tt --- .../DateTimeComputation.cs | 61 +++++++++++++------ .../PrimitiveDataFrameColumn.Computations.cs | 16 ++--- .../PrimitiveDataFrameColumnComputations.cs | 32 +++++----- .../DataFrameTests.cs | 52 ++++++++++++++++ 4 files changed, 119 insertions(+), 42 deletions(-) diff --git a/src/Microsoft.Data.Analysis/DateTimeComputation.cs b/src/Microsoft.Data.Analysis/DateTimeComputation.cs index ba14e39292..3e50ec0c82 100644 --- a/src/Microsoft.Data.Analysis/DateTimeComputation.cs +++ b/src/Microsoft.Data.Analysis/DateTimeComputation.cs @@ -4,6 +4,8 @@ using System; using System.Collections.Generic; +using System.Diagnostics; +using System.Reflection; using System.Text; namespace Microsoft.Data.Analysis @@ -189,26 +191,37 @@ public void CumulativeSum(PrimitiveColumnContainer column, IEnumerable throw new NotSupportedException(); } - public void Max(PrimitiveColumnContainer column, out DateTime ret) + public void Max(PrimitiveColumnContainer column, out DateTime? ret) { - ret = column.Buffers[0].ReadOnlySpan[0]; + var maxDate = DateTime.MinValue; + bool hasMaxValue = false; + for (int b = 0; b < column.Buffers.Count; b++) { - var buffer = column.Buffers[b]; - var readOnlySpan = buffer.ReadOnlySpan; + var readOnlySpan = column.Buffers[b].ReadOnlySpan; + var bitmapSpan = column.NullBitMapBuffers[b].ReadOnlySpan; for (int i = 0; i < readOnlySpan.Length; i++) { + int byteIndex = (int)((uint)i / 8); + + //Check if bit is not set (value is null) - skip + if (((bitmapSpan[byteIndex] >> (i & 7)) & 1) == 0) + continue; + var val = readOnlySpan[i]; - if (val > ret) + if (val > maxDate) { - ret = val; + maxDate = val; + hasMaxValue = true; } } } + + ret = hasMaxValue ? maxDate : null; } - public void Max(PrimitiveColumnContainer column, IEnumerable rows, out DateTime ret) + public void Max(PrimitiveColumnContainer column, IEnumerable rows, out DateTime? ret) { ret = default; var readOnlySpan = column.Buffers[0].ReadOnlySpan; @@ -237,26 +250,38 @@ public void Max(PrimitiveColumnContainer column, IEnumerable row } } - public void Min(PrimitiveColumnContainer column, out DateTime ret) + public void Min(PrimitiveColumnContainer column, out DateTime? ret) { - ret = column.Buffers[0].ReadOnlySpan[0]; + var minDate = DateTime.MaxValue; + bool hasMinValue = false; + for (int b = 0; b < column.Buffers.Count; b++) { - var buffer = column.Buffers[b]; - var readOnlySpan = buffer.ReadOnlySpan; + var readOnlySpan = column.Buffers[b].ReadOnlySpan; + var bitmapSpan = column.NullBitMapBuffers[b].ReadOnlySpan; + for (int i = 0; i < readOnlySpan.Length; i++) { + int byteIndex = (int)((uint)i / 8); + + //Check if bit is not set (value is null) - skip + if (((bitmapSpan[byteIndex] >> (i & 7)) & 1) == 0) + continue; + var val = readOnlySpan[i]; - if (val < ret) + if (val < minDate) { - ret = val; + minDate = val; + hasMinValue = true; } } } + + ret = hasMinValue ? minDate : null; } - public void Min(PrimitiveColumnContainer column, IEnumerable rows, out DateTime ret) + public void Min(PrimitiveColumnContainer column, IEnumerable rows, out DateTime? ret) { ret = default; var readOnlySpan = column.Buffers[0].ReadOnlySpan; @@ -285,22 +310,22 @@ public void Min(PrimitiveColumnContainer column, IEnumerable row } } - public void Product(PrimitiveColumnContainer column, out DateTime ret) + public void Product(PrimitiveColumnContainer column, out DateTime? ret) { throw new NotSupportedException(); } - public void Product(PrimitiveColumnContainer column, IEnumerable rows, out DateTime ret) + public void Product(PrimitiveColumnContainer column, IEnumerable rows, out DateTime? ret) { throw new NotSupportedException(); } - public void Sum(PrimitiveColumnContainer column, out DateTime ret) + public void Sum(PrimitiveColumnContainer column, out DateTime? ret) { throw new NotSupportedException(); } - public void Sum(PrimitiveColumnContainer column, IEnumerable rows, out DateTime ret) + public void Sum(PrimitiveColumnContainer column, IEnumerable rows, out DateTime? ret) { throw new NotSupportedException(); } diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Computations.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Computations.cs index 58e6a1e7c5..5501236c03 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Computations.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Computations.cs @@ -93,49 +93,49 @@ public override DataFrameColumn CumulativeSum(IEnumerable rowIndices, bool /// public override object Max() { - PrimitiveColumnComputation.Instance.Max(_columnContainer, out T ret); + PrimitiveColumnComputation.Instance.Max(_columnContainer, out T? ret); return ret; } /// public override object Max(IEnumerable rowIndices) { - PrimitiveColumnComputation.Instance.Max(_columnContainer, rowIndices, out T ret); + PrimitiveColumnComputation.Instance.Max(_columnContainer, rowIndices, out T? ret); return ret; } /// public override object Min() { - PrimitiveColumnComputation.Instance.Min(_columnContainer, out T ret); + PrimitiveColumnComputation.Instance.Min(_columnContainer, out T? ret); return ret; } /// public override object Min(IEnumerable rowIndices) { - PrimitiveColumnComputation.Instance.Min(_columnContainer, rowIndices, out T ret); + PrimitiveColumnComputation.Instance.Min(_columnContainer, rowIndices, out T? ret); return ret; } /// public override object Product() { - PrimitiveColumnComputation.Instance.Product(_columnContainer, out T ret); + PrimitiveColumnComputation.Instance.Product(_columnContainer, out T? ret); return ret; } /// public override object Product(IEnumerable rowIndices) { - PrimitiveColumnComputation.Instance.Product(_columnContainer, rowIndices, out T ret); + PrimitiveColumnComputation.Instance.Product(_columnContainer, rowIndices, out T? ret); return ret; } /// public override object Sum() { - PrimitiveColumnComputation.Instance.Sum(_columnContainer, out T ret); + PrimitiveColumnComputation.Instance.Sum(_columnContainer, out T? ret); return ret; } /// public override object Sum(IEnumerable rowIndices) { - PrimitiveColumnComputation.Instance.Sum(_columnContainer, rowIndices, out T ret); + PrimitiveColumnComputation.Instance.Sum(_columnContainer, rowIndices, out T? ret); return ret; } /// diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnComputations.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnComputations.cs index 5410afa7ad..4105cb0c9c 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnComputations.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnComputations.cs @@ -26,14 +26,14 @@ internal interface IPrimitiveColumnComputation void CumulativeProduct(PrimitiveColumnContainer column, IEnumerable rows); void CumulativeSum(PrimitiveColumnContainer column); void CumulativeSum(PrimitiveColumnContainer column, IEnumerable rows); - void Max(PrimitiveColumnContainer column, out T ret); - void Max(PrimitiveColumnContainer column, IEnumerable rows, out T ret); - void Min(PrimitiveColumnContainer column, out T ret); - void Min(PrimitiveColumnContainer column, IEnumerable rows, out T ret); - void Product(PrimitiveColumnContainer column, out T ret); - void Product(PrimitiveColumnContainer column, IEnumerable rows, out T ret); - void Sum(PrimitiveColumnContainer column, out T ret); - void Sum(PrimitiveColumnContainer column, IEnumerable rows, out T ret); + void Max(PrimitiveColumnContainer column, out T? ret); + void Max(PrimitiveColumnContainer column, IEnumerable rows, out T? ret); + void Min(PrimitiveColumnContainer column, out T? ret); + void Min(PrimitiveColumnContainer column, IEnumerable rows, out T? ret); + void Product(PrimitiveColumnContainer column, out T? ret); + void Product(PrimitiveColumnContainer column, IEnumerable rows, out T? ret); + void Sum(PrimitiveColumnContainer column, out T? ret); + void Sum(PrimitiveColumnContainer column, IEnumerable rows, out T? ret); void Round(PrimitiveColumnContainer column); PrimitiveColumnContainer CreateTruncating(PrimitiveColumnContainer column) where U : unmanaged, INumber; } @@ -194,42 +194,42 @@ public void CumulativeSum(PrimitiveColumnContainer column, IEnumerable column, out bool ret) + public void Max(PrimitiveColumnContainer column, out bool? ret) { throw new NotSupportedException(); } - public void Max(PrimitiveColumnContainer column, IEnumerable rows, out bool ret) + public void Max(PrimitiveColumnContainer column, IEnumerable rows, out bool? ret) { throw new NotSupportedException(); } - public void Min(PrimitiveColumnContainer column, out bool ret) + public void Min(PrimitiveColumnContainer column, out bool? ret) { throw new NotSupportedException(); } - public void Min(PrimitiveColumnContainer column, IEnumerable rows, out bool ret) + public void Min(PrimitiveColumnContainer column, IEnumerable rows, out bool? ret) { throw new NotSupportedException(); } - public void Product(PrimitiveColumnContainer column, out bool ret) + public void Product(PrimitiveColumnContainer column, out bool? ret) { throw new NotSupportedException(); } - public void Product(PrimitiveColumnContainer column, IEnumerable rows, out bool ret) + public void Product(PrimitiveColumnContainer column, IEnumerable rows, out bool? ret) { throw new NotSupportedException(); } - public void Sum(PrimitiveColumnContainer column, out bool ret) + public void Sum(PrimitiveColumnContainer column, out bool? ret) { throw new NotSupportedException(); } - public void Sum(PrimitiveColumnContainer column, IEnumerable rows, out bool ret) + public void Sum(PrimitiveColumnContainer column, IEnumerable rows, out bool? ret) { throw new NotSupportedException(); } diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index cbc6cc9e80..301c94680d 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -1173,6 +1173,58 @@ public void TestComputationsIncludingDateTime() } } + [Fact] + public void TestIntComputations_MaxMin_WithNulls() + { + var column = new Int32DataFrameColumn("Int", new int?[] + { + null, + 2, + 1, + 4, + 3, + null + }); + + Assert.Equal(1, column.Min()); + Assert.Equal(4, column.Max()); + } + + [Fact] + public void TestDateTimeComputations_MaxMin_OnEmptyColumn() + { + var column = new DateTimeDataFrameColumn("DateTime"); + + Assert.Null(column.Min()); + Assert.Null(column.Max()); + } + + [Fact] + public void TestIntComputations_MaxMin_OnEmptyColumn() + { + var column = new Int32DataFrameColumn("Int"); + + Assert.Null(column.Min()); + Assert.Null(column.Max()); + } + + [Fact] + public void TestDateTimeComputations_MaxMin_WithNulls() + { + var dateTimeColumn = new DateTimeDataFrameColumn("DateTime", new DateTime?[] + { + null, + new DateTime(2022, 1, 1), + new DateTime(2020, 1, 1), + new DateTime(2023, 1, 1), + new DateTime(2021, 1, 1), + null + }); + + Assert.Equal(new DateTime(2020, 1, 1), dateTimeColumn.Min()); + Assert.Equal(new DateTime(2023, 1, 1), dateTimeColumn.Max()); + } + [Theory] [InlineData(5, 10)] [InlineData(-15, 10)] From 33b6432ee61bc0d5051181b0da3b1156a4499cc5 Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Thu, 6 Jul 2023 14:22:07 +0300 Subject: [PATCH 5/8] Cherry pick Step 2 commit from 6733 --- .../NumberMathComputation.cs | 17 ++++++++--------- .../DataFrameTests.cs | 8 ++++---- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.Data.Analysis/NumberMathComputation.cs b/src/Microsoft.Data.Analysis/NumberMathComputation.cs index 029acafb31..d0e7d4102e 100644 --- a/src/Microsoft.Data.Analysis/NumberMathComputation.cs +++ b/src/Microsoft.Data.Analysis/NumberMathComputation.cs @@ -9,7 +9,6 @@ using System; using System.Collections.Generic; using System.Runtime.Versioning; -using Microsoft.ML.Data; namespace Microsoft.Data.Analysis { @@ -75,43 +74,43 @@ public void CumulativeSum(PrimitiveColumnContainer column, IEnumerable CumulativeApply(column, Add, rows); } - public void Max(PrimitiveColumnContainer column, out T ret) + public void Max(PrimitiveColumnContainer column, out T? ret) { ret = CalculateReduction(column, T.Max, column[0].Value); } - public void Max(PrimitiveColumnContainer column, IEnumerable rows, out T ret) + public void Max(PrimitiveColumnContainer column, IEnumerable rows, out T? ret) { ret = CalculateReduction(column, T.Max, rows); } - public void Min(PrimitiveColumnContainer column, out T ret) + public void Min(PrimitiveColumnContainer column, out T? ret) { ret = CalculateReduction(column, T.Min, column[0].Value); } - public void Min(PrimitiveColumnContainer column, IEnumerable rows, out T ret) + public void Min(PrimitiveColumnContainer column, IEnumerable rows, out T? ret) { ret = CalculateReduction(column, T.Min, rows); } - public void Product(PrimitiveColumnContainer column, out T ret) + public void Product(PrimitiveColumnContainer column, out T? ret) { ret = CalculateReduction(column, Multiply, T.One); } - public void Product(PrimitiveColumnContainer column, IEnumerable rows, out T ret) + public void Product(PrimitiveColumnContainer column, IEnumerable rows, out T? ret) { ret = CalculateReduction(column, Multiply, rows); } - public void Sum(PrimitiveColumnContainer column, out T ret) + public void Sum(PrimitiveColumnContainer column, out T? ret) { ret = CalculateReduction(column, Add, T.Zero); } - public void Sum(PrimitiveColumnContainer column, IEnumerable rows, out T ret) + public void Sum(PrimitiveColumnContainer column, IEnumerable rows, out T? ret) { ret = CalculateReduction(column, Add, rows); } diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 301c94680d..db26d1cf82 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -1176,7 +1176,7 @@ public void TestComputationsIncludingDateTime() [Fact] public void TestIntComputations_MaxMin_WithNulls() { - var column = new Int32DataFrameColumn("Int", new int?[] + var column = new PrimitiveDataFrameColumn("Int", new int?[] { null, 2, @@ -1193,7 +1193,7 @@ public void TestIntComputations_MaxMin_WithNulls() [Fact] public void TestDateTimeComputations_MaxMin_OnEmptyColumn() { - var column = new DateTimeDataFrameColumn("DateTime"); + var column = new PrimitiveDataFrameColumn("DateTime"); Assert.Null(column.Min()); Assert.Null(column.Max()); @@ -1202,7 +1202,7 @@ public void TestDateTimeComputations_MaxMin_OnEmptyColumn() [Fact] public void TestIntComputations_MaxMin_OnEmptyColumn() { - var column = new Int32DataFrameColumn("Int"); + var column = new PrimitiveDataFrameColumn("Int"); Assert.Null(column.Min()); Assert.Null(column.Max()); @@ -1211,7 +1211,7 @@ public void TestIntComputations_MaxMin_OnEmptyColumn() [Fact] public void TestDateTimeComputations_MaxMin_WithNulls() { - var dateTimeColumn = new DateTimeDataFrameColumn("DateTime", new DateTime?[] + var dateTimeColumn = new PrimitiveDataFrameColumn("DateTime", new DateTime?[] { null, new DateTime(2022, 1, 1), From 11b5d6f040a564d570bd9ad73f5731e012ed5893 Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Thu, 6 Jul 2023 14:28:18 +0300 Subject: [PATCH 6/8] Fixed code review findings # Conflicts: # src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnComputations.cs # src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnComputations.tt --- .../DateTimeComputation.cs | 8 +--- .../NumberMathComputation.cs | 10 ++--- .../PrimitiveColumnContainer.cs | 41 ++++++++++--------- .../PrimitiveDataFrameColumn.Sort.cs | 2 +- .../PrimitiveDataFrameColumn.cs | 2 +- 5 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/Microsoft.Data.Analysis/DateTimeComputation.cs b/src/Microsoft.Data.Analysis/DateTimeComputation.cs index 3e50ec0c82..4dae3cfd9d 100644 --- a/src/Microsoft.Data.Analysis/DateTimeComputation.cs +++ b/src/Microsoft.Data.Analysis/DateTimeComputation.cs @@ -202,10 +202,8 @@ public void Max(PrimitiveColumnContainer column, out DateTime? ret) var bitmapSpan = column.NullBitMapBuffers[b].ReadOnlySpan; for (int i = 0; i < readOnlySpan.Length; i++) { - int byteIndex = (int)((uint)i / 8); - //Check if bit is not set (value is null) - skip - if (((bitmapSpan[byteIndex] >> (i & 7)) & 1) == 0) + if (!BitmapHelper.IsValid(bitmapSpan, i)) continue; var val = readOnlySpan[i]; @@ -262,10 +260,8 @@ public void Min(PrimitiveColumnContainer column, out DateTime? ret) for (int i = 0; i < readOnlySpan.Length; i++) { - int byteIndex = (int)((uint)i / 8); - //Check if bit is not set (value is null) - skip - if (((bitmapSpan[byteIndex] >> (i & 7)) & 1) == 0) + if (!BitmapHelper.IsValid(bitmapSpan, i)) continue; var val = readOnlySpan[i]; diff --git a/src/Microsoft.Data.Analysis/NumberMathComputation.cs b/src/Microsoft.Data.Analysis/NumberMathComputation.cs index d0e7d4102e..6e7318df74 100644 --- a/src/Microsoft.Data.Analysis/NumberMathComputation.cs +++ b/src/Microsoft.Data.Analysis/NumberMathComputation.cs @@ -143,7 +143,7 @@ protected void Apply(PrimitiveColumnContainer column, Func func) var bitmap = column.NullBitMapBuffers[b].ReadOnlySpan; for (int i = 0; i < buffer.Length; i++) { - if (column.IsValid(bitmap, i)) + if (BitmapHelper.IsValid(bitmap, i)) { buffer[i] = func(buffer[i]); } @@ -160,7 +160,7 @@ protected void CumulativeApply(PrimitiveColumnContainer column, Func var bitmap = column.NullBitMapBuffers[b].ReadOnlySpan; for (int i = 0; i < buffer.Length; i++) { - if (column.IsValid(bitmap, i)) + if (BitmapHelper.IsValid(bitmap, i)) { ret = func(buffer[i], ret); buffer[i] = ret; @@ -179,7 +179,7 @@ protected T CalculateReduction(PrimitiveColumnContainer column, Func var bitMap = column.NullBitMapBuffers[b].ReadOnlySpan; for (int i = 0; i < buffer.Length; i++) { - if (column.IsValid(bitMap, i)) + if (BitmapHelper.IsValid(bitMap, i)) { ret = checked(func(ret, buffer[i])); } @@ -212,7 +212,7 @@ protected void CumulativeApply(PrimitiveColumnContainer column, Func } row -= minRange; - if (column.IsValid(bitmap, (int)row)) + if (BitmapHelper.IsValid(bitmap, (int)row)) { if (!isInitialized) { @@ -252,7 +252,7 @@ protected T CalculateReduction(PrimitiveColumnContainer column, Func } row -= minRange; - if (column.IsValid(bitMap, (int)row)) + if (BitmapHelper.IsValid(bitMap, (int)row)) { if (!isInitialized) { diff --git a/src/Microsoft.Data.Analysis/PrimitiveColumnContainer.cs b/src/Microsoft.Data.Analysis/PrimitiveColumnContainer.cs index d65255d5be..1a3ae978a0 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveColumnContainer.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveColumnContainer.cs @@ -12,8 +12,24 @@ namespace Microsoft.Data.Analysis { + internal static class BitmapHelper + { + // Faster to use when we already have a span since it avoids indexing + public static bool IsValid(ReadOnlySpan bitMapBufferSpan, int index) + { + int nullBitMapSpanIndex = index / 8; + byte thisBitMap = bitMapBufferSpan[nullBitMapSpanIndex]; + return IsBitSet(thisBitMap, index); + } + + public static bool IsBitSet(byte curBitMap, int index) + { + return ((curBitMap >> (index & 7)) & 1) != 0; + } + } + /// - /// PrimitiveDataFrameColumnContainer is just a store for the column data. APIs that want to change the data must be defined in PrimitiveDataFrameColumn + /// PrimitiveColumnContainer is just a store for the column data. APIs that want to change the data must be defined in PrimitiveDataFrameColumn /// /// internal partial class PrimitiveColumnContainer : IEnumerable @@ -223,7 +239,7 @@ public void ApplyElementwise(Func func) for (int i = 0; i < mutableBuffer.Length; i++) { long curIndex = i + prevLength; - bool isValid = IsValid(mutableNullBitMapBuffer, i); + bool isValid = BitmapHelper.IsValid(mutableNullBitMapBuffer, i); T? value = func(isValid ? mutableBuffer[i] : null, curIndex); mutableBuffer[i] = value.GetValueOrDefault(); SetValidityBit(mutableNullBitMapBuffer, i, value != null); @@ -246,7 +262,7 @@ public void Apply(Func func, PrimitiveColumnContainer(Func func, PrimitiveColumnContainer bitMapBufferSpan, int index) - { - int nullBitMapSpanIndex = index / 8; - byte thisBitMap = bitMapBufferSpan[nullBitMapSpanIndex]; - return IsBitSet(thisBitMap, index); - } - public bool IsValid(long index) => NullCount == 0 || GetValidityBit(index); private byte SetBit(byte curBitMap, int index, bool value) @@ -329,11 +337,6 @@ internal void SetValidityBit(long index, bool value) SetValidityBit(bitMapBuffer.Span, (int)index, value); } - private bool IsBitSet(byte curBitMap, int index) - { - return ((curBitMap >> (index & 7)) & 1) != 0; - } - private bool GetValidityBit(long index) { if ((uint)index >= Length) @@ -350,7 +353,7 @@ private bool GetValidityBit(long index) int bitMapBufferIndex = (int)((uint)index / 8); Debug.Assert(bitMapBuffer.Length > bitMapBufferIndex); byte curBitMap = bitMapBuffer[bitMapBufferIndex]; - return IsBitSet(curBitMap, (int)index); + return BitmapHelper.IsBitSet(curBitMap, (int)index); } public long Length; @@ -512,7 +515,7 @@ public PrimitiveColumnContainer Clone(PrimitiveColumnContainer mapIndic spanIndex = buffer.Length - 1 - i; long mapRowIndex = mapIndicesIntSpan.IsEmpty ? mapIndicesLongSpan[spanIndex] : mapIndicesIntSpan[spanIndex]; - bool mapRowIndexIsValid = mapIndices.IsValid(mapIndicesNullBitMapSpan, spanIndex); + bool mapRowIndexIsValid = BitmapHelper.IsValid(mapIndicesNullBitMapSpan, spanIndex); if (mapRowIndexIsValid && (mapRowIndex < minRange || mapRowIndex >= maxRange)) { int bufferIndex = (int)(mapRowIndex / maxCapacity); @@ -527,7 +530,7 @@ public PrimitiveColumnContainer Clone(PrimitiveColumnContainer mapIndic { mapRowIndex -= minRange; value = thisSpan[(int)mapRowIndex]; - isValid = IsValid(thisNullBitMapSpan, (int)mapRowIndex); + isValid = BitmapHelper.IsValid(thisNullBitMapSpan, (int)mapRowIndex); } retSpan[i] = isValid ? value : default; diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs index 699779a921..0b5ebd2120 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs @@ -45,7 +45,7 @@ private PrimitiveDataFrameColumn GetSortIndices(IComparer comparer, out for (int i = 0; i < sortIndices.Length; i++) { int localSortIndex = sortIndices[i]; - if (_columnContainer.IsValid(nullBitMapSpan, localSortIndex)) + if (BitmapHelper.IsValid(nullBitMapSpan, localSortIndex)) { nonNullSortIndices.Add(sortIndices[i]); } diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs index 5aed1c57f7..91421ddb06 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs @@ -541,7 +541,7 @@ public override Dictionary> GroupColumnValues(out for (int i = 0; i < readOnlySpan.Length; i++) { long currentLength = i + previousLength; - if (_columnContainer.IsValid(nullBitMapSpan, i)) + if (BitmapHelper.IsValid(nullBitMapSpan, i)) { bool containsKey = multimap.TryGetValue(readOnlySpan[i], out ICollection values); if (containsKey) From cb39ed834492b1e6ec232a17d26e6883044e9f8a Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Thu, 6 Jul 2023 17:40:48 +0300 Subject: [PATCH 7/8] Cherry pick PR 6724 (fix dataframe arithmetics for columns having several value buffers) # Conflicts: # src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnArithmetic.cs # src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnArithmetic.tt --- .../PrimitiveDataFrameColumnArithmetic.cs | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnArithmetic.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnArithmetic.cs index ef3c9b1a8a..7fc3fa521e 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnArithmetic.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumnArithmetic.cs @@ -1,4 +1,4 @@ - + // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. @@ -265,47 +265,51 @@ public void RightShift(PrimitiveColumnContainer column, int value) } public void ElementwiseEquals(PrimitiveColumnContainer left, PrimitiveColumnContainer right, PrimitiveColumnContainer ret) { + long index = 0; for (int b = 0; b < left.Buffers.Count; b++) { var span = left.Buffers[b].ReadOnlySpan; var otherSpan = right.Buffers[b].ReadOnlySpan; for (int i = 0; i < span.Length; i++) { - ret[i] = (span[i] == otherSpan[i]); + ret[index++] = (span[i] == otherSpan[i]); } } } public void ElementwiseEquals(PrimitiveColumnContainer column, bool scalar, PrimitiveColumnContainer ret) { + long index = 0; for (int b = 0; b < column.Buffers.Count; b++) { var span = column.Buffers[b].ReadOnlySpan; for (int i = 0; i < span.Length; i++) { - ret[i] = (span[i] == scalar); + ret[index++] = (span[i] == scalar); } } } public void ElementwiseNotEquals(PrimitiveColumnContainer left, PrimitiveColumnContainer right, PrimitiveColumnContainer ret) { + long index = 0; for (int b = 0; b < left.Buffers.Count; b++) { var span = left.Buffers[b].ReadOnlySpan; var otherSpan = right.Buffers[b].ReadOnlySpan; for (int i = 0; i < span.Length; i++) { - ret[i] = (span[i] != otherSpan[i]); + ret[index++] = (span[i] != otherSpan[i]); } } } public void ElementwiseNotEquals(PrimitiveColumnContainer column, bool scalar, PrimitiveColumnContainer ret) { + long index = 0; for (int b = 0; b < column.Buffers.Count; b++) { var span = column.Buffers[b].ReadOnlySpan; for (int i = 0; i < span.Length; i++) { - ret[i] = (span[i] != scalar); + ret[index++] = (span[i] != scalar); } } } @@ -451,47 +455,51 @@ public void RightShift(PrimitiveColumnContainer column, int value) } public void ElementwiseEquals(PrimitiveColumnContainer left, PrimitiveColumnContainer right, PrimitiveColumnContainer ret) { + long index = 0; for (int b = 0; b < left.Buffers.Count; b++) { var span = left.Buffers[b].ReadOnlySpan; var otherSpan = right.Buffers[b].ReadOnlySpan; for (int i = 0; i < span.Length; i++) { - ret[i] = (span[i] == otherSpan[i]); + ret[index++] = (span[i] == otherSpan[i]); } } } public void ElementwiseEquals(PrimitiveColumnContainer column, DateTime scalar, PrimitiveColumnContainer ret) { + long index = 0; for (int b = 0; b < column.Buffers.Count; b++) { var span = column.Buffers[b].ReadOnlySpan; for (int i = 0; i < span.Length; i++) { - ret[i] = (span[i] == scalar); + ret[index++] = (span[i] == scalar); } } } public void ElementwiseNotEquals(PrimitiveColumnContainer left, PrimitiveColumnContainer right, PrimitiveColumnContainer ret) { + long index = 0; for (int b = 0; b < left.Buffers.Count; b++) { var span = left.Buffers[b].ReadOnlySpan; var otherSpan = right.Buffers[b].ReadOnlySpan; for (int i = 0; i < span.Length; i++) { - ret[i] = (span[i] != otherSpan[i]); + ret[index++] = (span[i] != otherSpan[i]); } } } public void ElementwiseNotEquals(PrimitiveColumnContainer column, DateTime scalar, PrimitiveColumnContainer ret) { + long index = 0; for (int b = 0; b < column.Buffers.Count; b++) { var span = column.Buffers[b].ReadOnlySpan; for (int i = 0; i < span.Length; i++) { - ret[i] = (span[i] != scalar); + ret[index++] = (span[i] != scalar); } } } From 2856d3a499d4c3acf74f2c8610a8af8bfc66987b Mon Sep 17 00:00:00 2001 From: Jake Radzikowski Date: Thu, 6 Jul 2023 12:10:11 -0700 Subject: [PATCH 8/8] Fix tests --- .../NumberMathComputation.cs | 25 +++++++++++++------ .../Microsoft.ML.Fairlearn.csproj | 3 ++- .../Microsoft.ML.AutoML.Tests.csproj | 3 ++- .../Microsoft.ML.Core.Tests.csproj | 2 ++ .../Microsoft.ML.Fairlearn.Tests.csproj | 2 ++ .../Microsoft.ML.Tests.csproj | 2 ++ 6 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.Data.Analysis/NumberMathComputation.cs b/src/Microsoft.Data.Analysis/NumberMathComputation.cs index 6e7318df74..70613017e3 100644 --- a/src/Microsoft.Data.Analysis/NumberMathComputation.cs +++ b/src/Microsoft.Data.Analysis/NumberMathComputation.cs @@ -8,6 +8,8 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; using System.Runtime.Versioning; namespace Microsoft.Data.Analysis @@ -76,7 +78,7 @@ public void CumulativeSum(PrimitiveColumnContainer column, IEnumerable public void Max(PrimitiveColumnContainer column, out T? ret) { - ret = CalculateReduction(column, T.Max, column[0].Value); + ret = CalculateReduction(column, T.Max); } public void Max(PrimitiveColumnContainer column, IEnumerable rows, out T? ret) @@ -86,7 +88,7 @@ public void Max(PrimitiveColumnContainer column, IEnumerable rows, out public void Min(PrimitiveColumnContainer column, out T? ret) { - ret = CalculateReduction(column, T.Min, column[0].Value); + ret = CalculateReduction(column, T.Min); } public void Min(PrimitiveColumnContainer column, IEnumerable rows, out T? ret) @@ -97,7 +99,7 @@ public void Min(PrimitiveColumnContainer column, IEnumerable rows, out public void Product(PrimitiveColumnContainer column, out T? ret) { - ret = CalculateReduction(column, Multiply, T.One); + ret = CalculateReduction(column, Multiply); } public void Product(PrimitiveColumnContainer column, IEnumerable rows, out T? ret) @@ -107,7 +109,7 @@ public void Product(PrimitiveColumnContainer column, IEnumerable rows, public void Sum(PrimitiveColumnContainer column, out T? ret) { - ret = CalculateReduction(column, Add, T.Zero); + ret = CalculateReduction(column, Add); } public void Sum(PrimitiveColumnContainer column, IEnumerable rows, out T? ret) @@ -169,9 +171,10 @@ protected void CumulativeApply(PrimitiveColumnContainer column, Func } } - protected T CalculateReduction(PrimitiveColumnContainer column, Func func, T startValue) + protected T? CalculateReduction(PrimitiveColumnContainer column, Func func) { - var ret = startValue; + T? ret = null; + bool isInitialized = false; for (int b = 0; b < column.Buffers.Count; b++) { @@ -181,7 +184,15 @@ protected T CalculateReduction(PrimitiveColumnContainer column, Func { if (BitmapHelper.IsValid(bitMap, i)) { - ret = checked(func(ret, buffer[i])); + if (!isInitialized) + { + isInitialized = true; + ret = buffer[i]; + } + else + { + ret = checked(func(ret.Value, buffer[i])); + } } } } diff --git a/src/Microsoft.ML.Fairlearn/Microsoft.ML.Fairlearn.csproj b/src/Microsoft.ML.Fairlearn/Microsoft.ML.Fairlearn.csproj index 72b2ad0edb..7ad6c422ef 100644 --- a/src/Microsoft.ML.Fairlearn/Microsoft.ML.Fairlearn.csproj +++ b/src/Microsoft.ML.Fairlearn/Microsoft.ML.Fairlearn.csproj @@ -2,7 +2,8 @@ - netstandard2.0 + net6.0 + net6.0 Microsoft.ML.Fairlearn None diff --git a/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj b/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj index f6b9a021d9..bdf8c5311e 100644 --- a/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj +++ b/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj @@ -1,7 +1,8 @@  $(NoWarn) - + net6.0 + net6.0 None diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj index 475eb5dbb1..49fddcd74f 100644 --- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj +++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj @@ -8,6 +8,8 @@ None + net6.0 + net6.0 diff --git a/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj b/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj index b950086278..90850d97af 100644 --- a/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj +++ b/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj @@ -2,6 +2,8 @@ None $(NoWarn);MSML_ParameterLocalVarName;MSML_PrivateFieldName;MSML_ExtendBaseTestClass;MSML_GeneralName + net6.0 + net6.0 diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index c50abd3350..e933dc2f86 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -1,6 +1,8 @@  + net6.0 + net6.0 Microsoft.ML.Tests true Test