diff --git a/src/Microsoft.Data.Analysis/DataFrame.IDataView.cs b/src/Microsoft.Data.Analysis/DataFrame.IDataView.cs index 4755f296f4..27de92da69 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.IDataView.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.IDataView.cs @@ -11,7 +11,7 @@ namespace Microsoft.Data.Analysis { public partial class DataFrame : IDataView - { + { // TODO: support shuffling bool IDataView.CanShuffle => false; @@ -53,6 +53,7 @@ private DataViewRowCursor GetRowCursorCore(IEnumerable co return new RowCursor(this, activeColumns); } + DataViewRowCursor IDataView.GetRowCursor(IEnumerable columnsNeeded, Random rand) { return GetRowCursorCore(columnsNeeded); diff --git a/src/Microsoft.Data.Analysis/DataFrameColumn.cs b/src/Microsoft.Data.Analysis/DataFrameColumn.cs index c064ed8bda..bd21d6fe96 100644 --- a/src/Microsoft.Data.Analysis/DataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/DataFrameColumn.cs @@ -247,6 +247,20 @@ public virtual DataFrameColumn Sort(bool ascending = true) /// protected internal virtual void AddDataViewColumn(DataViewSchema.Builder builder) => throw new NotImplementedException(); + /// + /// Appends a value to this using + /// + /// The row cursor which has the current position + /// The cached ValueGetter for this column. + protected internal virtual void AddValueUsingCursor(DataViewRowCursor cursor, Delegate ValueGetter) => throw new NotImplementedException(); + + /// + /// Returns the ValueGetter for each active column in as a delegate to be cached. + /// + /// The row cursor which has the current position + /// The to return the ValueGetter for. + protected internal virtual Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) => throw new NotImplementedException(); + /// /// Clamps values beyond the specified thresholds /// diff --git a/src/Microsoft.Data.Analysis/IDataView.Extension.cs b/src/Microsoft.Data.Analysis/IDataView.Extension.cs new file mode 100644 index 0000000000..32b97d365a --- /dev/null +++ b/src/Microsoft.Data.Analysis/IDataView.Extension.cs @@ -0,0 +1,144 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.Data.Analysis; +using Microsoft.ML.Data; + +namespace Microsoft.ML +{ + public static class IDataViewExtensions + { + private const int defaultMaxRows = 100; + + /// + /// Returns a from this . + /// + /// The current . + /// The max number or rows in the . Defaults to 100. Use -1 to construct a DataFrame using all the rows in . + /// A with . + public static DataFrame ToDataFrame(this IDataView dataView, long maxRows = defaultMaxRows) + { + return ToDataFrame(dataView, maxRows, null); + } + + /// + /// Returns a with the first 100 rows of this . + /// + /// The current . + /// The columns selected for the resultant DataFrame + /// A with the selected columns and 100 rows. + public static DataFrame ToDataFrame(this IDataView dataView, params string[] selectColumns) + { + return ToDataFrame(dataView, defaultMaxRows, selectColumns); + } + + /// + /// Returns a with the first of this . + /// + /// The current . + /// The max number or rows in the . Use -1 to construct a DataFrame using all the rows in . + /// The columns selected for the resultant DataFrame + /// A with the selected columns and rows. + public static DataFrame ToDataFrame(this IDataView dataView, long maxRows, params string[] selectColumns) + { + DataViewSchema schema = dataView.Schema; + List dataFrameColumns = new List(schema.Count); + maxRows = maxRows == -1 ? long.MaxValue : maxRows; + + HashSet selectColumnsSet = null; + if (selectColumns != null && selectColumns.Length > 0) + { + selectColumnsSet = new HashSet(selectColumns); + } + + List activeDataViewColumns = new List(); + foreach (DataViewSchema.Column dataViewColumn in schema) + { + if (dataViewColumn.IsHidden || (selectColumnsSet != null && !selectColumnsSet.Contains(dataViewColumn.Name))) + { + continue; + } + + activeDataViewColumns.Add(dataViewColumn); + DataViewType type = dataViewColumn.Type; + if (type == BooleanDataViewType.Instance) + { + dataFrameColumns.Add(new BooleanDataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.Byte) + { + dataFrameColumns.Add(new ByteDataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.Double) + { + dataFrameColumns.Add(new DoubleDataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.Single) + { + dataFrameColumns.Add(new SingleDataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.Int32) + { + dataFrameColumns.Add(new Int32DataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.Int64) + { + dataFrameColumns.Add(new Int64DataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.SByte) + { + dataFrameColumns.Add(new SByteDataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.Int16) + { + dataFrameColumns.Add(new Int16DataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.UInt32) + { + dataFrameColumns.Add(new UInt32DataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.UInt64) + { + dataFrameColumns.Add(new UInt64DataFrameColumn(dataViewColumn.Name)); + } + else if (type == NumberDataViewType.UInt16) + { + dataFrameColumns.Add(new UInt16DataFrameColumn(dataViewColumn.Name)); + } + else if (type == TextDataViewType.Instance) + { + dataFrameColumns.Add(new StringDataFrameColumn(dataViewColumn.Name)); + } + else + { + throw new NotSupportedException(String.Format(Microsoft.Data.Strings.NotSupportedColumnType, type.RawType.Name)); + } + } + + using (DataViewRowCursor cursor = dataView.GetRowCursor(activeDataViewColumns)) + { + Delegate[] activeColumnDelegates = new Delegate[activeDataViewColumns.Count]; + int columnIndex = 0; + foreach (DataViewSchema.Column activeDataViewColumn in activeDataViewColumns) + { + Delegate valueGetter = dataFrameColumns[columnIndex].GetValueGetterUsingCursor(cursor, activeDataViewColumn); + activeColumnDelegates[columnIndex] = valueGetter; + columnIndex++; + } + while (cursor.MoveNext() && cursor.Position < maxRows) + { + for (int i = 0; i < activeColumnDelegates.Length; i++) + { + dataFrameColumns[i].AddValueUsingCursor(cursor, activeColumnDelegates[i]); + } + } + } + + return new DataFrame(dataFrameColumns); + } + } + +} diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs index 613644a346..a7e7d20cb9 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs @@ -775,5 +775,31 @@ private static ValueGetter CreateCharValueGetterDelegate(DataViewRowCurs private static ValueGetter CreateDecimalValueGetterDelegate(DataViewRowCursor cursor, PrimitiveDataFrameColumn column) => (ref double value) => value = (double?)column[cursor.Position] ?? double.NaN; + + protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, Delegate getter) + { + long row = cursor.Position; + T value = default; + Debug.Assert(getter != null, "Excepted getter to be valid"); + (getter as ValueGetter)(ref value); + + if (Length > row) + { + this[row] = value; + } + else if (Length == row) + { + Append(value); + } + else + { + throw new IndexOutOfRangeException(nameof(row)); + } + } + + protected internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) + { + return cursor.GetGetter(schemaColumn); + } } } diff --git a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs index 92f2be029e..7ada30e10c 100644 --- a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs @@ -467,5 +467,32 @@ protected internal override Delegate GetDataViewGetter(DataViewRowCursor cursor) private ValueGetter> CreateValueGetterDelegate(DataViewRowCursor cursor) => (ref ReadOnlyMemory value) => value = this[cursor.Position].AsMemory(); + + protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, Delegate getter) + { + long row = cursor.Position; + ReadOnlyMemory value = default; + Debug.Assert(getter != null, "Excepted getter to be valid"); + + (getter as ValueGetter>)(ref value); + + if (Length > row) + { + this[row] = value.ToString(); + } + else if (Length == row) + { + Append(value.ToString()); + } + else + { + throw new IndexOutOfRangeException(nameof(row)); + } + } + + protected internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) + { + return cursor.GetGetter>(schemaColumn); + } } } diff --git a/src/Microsoft.Data.Analysis/strings.Designer.cs b/src/Microsoft.Data.Analysis/strings.Designer.cs index 030a79a8a3..fc64940869 100644 --- a/src/Microsoft.Data.Analysis/strings.Designer.cs +++ b/src/Microsoft.Data.Analysis/strings.Designer.cs @@ -258,6 +258,15 @@ internal static string NonSeekableStream { } } + /// + /// Looks up a localized string similar to {0} is not a supported column type.. + /// + internal static string NotSupportedColumnType { + get { + return ResourceManager.GetString("NotSupportedColumnType", resourceCulture); + } + } + /// /// Looks up a localized string similar to numeric column. /// diff --git a/src/Microsoft.Data.Analysis/strings.resx b/src/Microsoft.Data.Analysis/strings.resx index 267140834a..ad9f114050 100644 --- a/src/Microsoft.Data.Analysis/strings.resx +++ b/src/Microsoft.Data.Analysis/strings.resx @@ -183,10 +183,13 @@ Expected a seekable stream + + {0} is not a supported column type. + numeric column Cannot span multiple buffers - + \ No newline at end of file diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.IDataView.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs similarity index 57% rename from test/Microsoft.Data.Analysis.Tests/DataFrameTests.IDataView.cs rename to test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs index 9ed4963b7f..dea8099876 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.IDataView.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs @@ -10,12 +10,12 @@ namespace Microsoft.Data.Analysis.Tests { - public partial class DataFrameTests + public partial class DataFrameIDataViewTests { [Fact] public void TestIDataView() { - IDataView dataView = MakeDataFrameWithAllColumnTypes(10, withNulls: false); + IDataView dataView = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false); DataDebuggerPreview preview = dataView.Preview(); Assert.Equal(10, preview.RowView.Length); @@ -85,7 +85,7 @@ public void TestIDataView() [Fact] public void TestIDataViewSchemaInvalidate() { - DataFrame df = MakeDataFrameWithAllMutableColumnTypes(10, withNulls: false); + DataFrame df = DataFrameTests.MakeDataFrameWithAllMutableColumnTypes(10, withNulls: false); IDataView dataView = df; @@ -113,7 +113,7 @@ public void TestIDataViewSchemaInvalidate() public void TestIDataViewWithNulls() { int length = 10; - IDataView dataView = MakeDataFrameWithAllColumnTypes(length, withNulls: true); + IDataView dataView = DataFrameTests.MakeDataFrameWithAllColumnTypes(length, withNulls: true); DataDebuggerPreview preview = dataView.Preview(); Assert.Equal(length, preview.RowView.Length); @@ -224,5 +224,200 @@ public void TestIDataViewWithNulls() Assert.Equal("", preview.ColumnView[14].Values[5].ToString()); // null row Assert.Equal("foo", preview.ColumnView[14].Values[6].ToString()); } + + [Fact] + public void TestDataFrameFromIDataView() + { + DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false); + df.Columns.Remove("Char"); // Because chars are returned as uint16 by IDataView, so end up comparing CharDataFrameColumn to UInt16DataFrameColumn and fail asserts + IDataView dfAsIDataView = df; + DataFrame newDf = dfAsIDataView.ToDataFrame(); + Assert.Equal(dfAsIDataView.GetRowCount(), newDf.Rows.Count); + Assert.Equal(dfAsIDataView.Schema.Count, newDf.Columns.Count); + for (int i = 0; i < df.Columns.Count; i++) + { + Assert.True(df.Columns[i].ElementwiseEquals(newDf.Columns[i]).All()); + } + } + + [Fact] + public void TestDataFrameFromIDataView_SelectColumns() + { + DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false); + IDataView dfAsIDataView = df; + DataFrame newDf = dfAsIDataView.ToDataFrame("Int", "Double"); + Assert.Equal(dfAsIDataView.GetRowCount(), newDf.Rows.Count); + Assert.Equal(2, newDf.Columns.Count); + Assert.True(df.Columns["Int"].ElementwiseEquals(newDf.Columns["Int"]).All()); + Assert.True(df.Columns["Double"].ElementwiseEquals(newDf.Columns["Double"]).All()); + } + + [Theory] + [InlineData(10, 5)] + [InlineData(110, 100)] + [InlineData(110, -1)] + public void TestDataFrameFromIDataView_SelectRows(int dataFrameSize, int rowSize) + { + DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(dataFrameSize, withNulls: false); + df.Columns.Remove("Char"); // Because chars are returned as uint16 by DataViewSchema, so end up comparing CharDataFrameColumn to UInt16DataFrameColumn and fail asserts + df.Columns.Remove("Decimal"); // Because decimal is returned as double by DataViewSchema, so end up comparing DecimalDataFrameColumn to DoubleDataFrameColumn and fail asserts + IDataView dfAsIDataView = df; + DataFrame newDf; + if (rowSize == 100) + { + // Test default + newDf = dfAsIDataView.ToDataFrame(); + } + else + { + newDf = dfAsIDataView.ToDataFrame(rowSize); + } + if (rowSize == -1) + { + rowSize = dataFrameSize; + } + Assert.Equal(rowSize, newDf.Rows.Count); + Assert.Equal(df.Columns.Count, newDf.Columns.Count); + for (int i = 0; i < newDf.Columns.Count; i++) + { + Assert.Equal(rowSize, newDf.Columns[i].Length); + Assert.Equal(df.Columns[i].Name, newDf.Columns[i].Name); + } + Assert.Equal(dfAsIDataView.Schema.Count, newDf.Columns.Count); + for (int c = 0; c < df.Columns.Count; c++) + { + for (int r = 0; r < rowSize; r++) + { + Assert.Equal(df.Columns[c][r], newDf.Columns[c][r]); + } + } + } + + [Fact] + public void TestDataFrameFromIDataView_SelectColumnsAndRows() + { + DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false); + IDataView dfAsIDataView = df; + DataFrame newDf = dfAsIDataView.ToDataFrame(5, "Int", "Double"); + Assert.Equal(5, newDf.Rows.Count); + for (int i = 0; i < newDf.Columns.Count; i++) + { + Assert.Equal(5, newDf.Columns[i].Length); + } + Assert.Equal(2, newDf.Columns.Count); + for (int r = 0; r < 5; r++) + { + Assert.Equal(df.Columns["Int"][r], newDf.Columns["Int"][r]); + Assert.Equal(df.Columns["Double"][r], newDf.Columns["Double"][r]); + } + } + + private class InputData + { + public string Name { get; set; } + public bool FilterNext { get; set; } + public float Value { get; set; } + } + + private IDataView GetASampleIDataView() + { + var mlContext = new MLContext(); + + // Get a small dataset as an IEnumerable. + var enumerableOfData = new[] + { + new InputData() { Name = "Joey", FilterNext = false, Value = 1.0f }, + new InputData() { Name = "Chandler", FilterNext = false , Value = 2.0f}, + new InputData() { Name = "Ross", FilterNext = false , Value = 3.0f}, + new InputData() { Name = "Monica", FilterNext = true , Value = 4.0f}, + new InputData() { Name = "Rachel", FilterNext = true , Value = 5.0f}, + new InputData() { Name = "Phoebe", FilterNext = false , Value = 6.0f}, + }; + + IDataView data = mlContext.Data.LoadFromEnumerable(enumerableOfData); + return data; + } + + private void VerifyDataFrameColumnAndDataViewColumnValues(string columnName, IDataView data, DataFrame df, int maxRows = -1) + { + int cc = 0; + var nameDataViewColumn = data.GetColumn(columnName); + foreach (var value in nameDataViewColumn) + { + if (maxRows != -1 && cc >= maxRows) + { + return; + } + Assert.Equal(value, df.Columns[columnName][cc++]); + } + } + + [Fact] + public void TestDataFrameFromIDataView_MLData() + { + IDataView data = GetASampleIDataView(); + DataFrame df = data.ToDataFrame(); + Assert.Equal(6, df.Rows.Count); + Assert.Equal(3, df.Columns.Count); + foreach (var column in df.Columns) + { + Assert.Equal(6, column.Length); + } + + VerifyDataFrameColumnAndDataViewColumnValues("Name", data, df); + VerifyDataFrameColumnAndDataViewColumnValues("FilterNext", data, df); + VerifyDataFrameColumnAndDataViewColumnValues("Value", data, df); + } + + [Fact] + public void TestDataFrameFromIDataView_MLData_SelectColumns() + { + IDataView data = GetASampleIDataView(); + DataFrame df = data.ToDataFrame("Name", "Value"); + Assert.Equal(6, df.Rows.Count); + Assert.Equal(2, df.Columns.Count); + foreach (var column in df.Columns) + { + Assert.Equal(6, column.Length); + } + + VerifyDataFrameColumnAndDataViewColumnValues("Name", data, df); + VerifyDataFrameColumnAndDataViewColumnValues("Value", data, df); + } + + [Theory] + [InlineData(3)] + [InlineData(0)] + public void TestDataFrameFromIDataView_MLData_SelectRows(int maxRows) + { + IDataView data = GetASampleIDataView(); + DataFrame df = data.ToDataFrame(maxRows); + Assert.Equal(maxRows, df.Rows.Count); + Assert.Equal(3, df.Columns.Count); + foreach (var column in df.Columns) + { + Assert.Equal(maxRows, column.Length); + } + + VerifyDataFrameColumnAndDataViewColumnValues("Name", data, df, maxRows); + VerifyDataFrameColumnAndDataViewColumnValues("FilterNext", data, df, maxRows); + VerifyDataFrameColumnAndDataViewColumnValues("Value", data, df, maxRows); + } + + [Fact] + public void TestDataFrameFromIDataView_MLData_SelectColumnsAndRows() + { + IDataView data = GetASampleIDataView(); + DataFrame df = data.ToDataFrame(3, "Name", "Value"); + Assert.Equal(3, df.Rows.Count); + Assert.Equal(2, df.Columns.Count); + foreach (var column in df.Columns) + { + Assert.Equal(3, column.Length); + } + + VerifyDataFrameColumnAndDataViewColumnValues("Name", data, df, 3); + VerifyDataFrameColumnAndDataViewColumnValues("Value", data, df, 3); + } } } diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index c277aae36e..300babbffb 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -62,10 +62,12 @@ public static ArrowStringDataFrameColumn CreateArrowStringColumn(int length, boo // write the current length to (index + 1) int offsetIndex = (i + 1) * 4; - offsetMemory[offsetIndex++] = (byte)(3 * validStringsIndex); - offsetMemory[offsetIndex++] = 0; - offsetMemory[offsetIndex++] = 0; - offsetMemory[offsetIndex++] = 0; + int offsetValue = 3 * validStringsIndex; + byte[] offsetValueBytes = BitConverter.GetBytes(offsetValue); + offsetMemory[offsetIndex++] = offsetValueBytes[0]; + offsetMemory[offsetIndex++] = offsetValueBytes[1]; + offsetMemory[offsetIndex++] = offsetValueBytes[2]; + offsetMemory[offsetIndex++] = offsetValueBytes[3]; } int nullCount = withNulls ? 1 : 0;