Skip to content

Commit

Permalink
Fix some stuff and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Prashanth Govindarajan committed Mar 16, 2021
1 parent 1260e22 commit 96bb44a
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/Microsoft.Data.Analysis/DataFrameColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,14 @@ public virtual DataFrameColumn Sort(bool ascending = true)
/// <param name="cursor">The row cursor which has the current position</param>
/// <param name="schemaColumn">The <see cref="DataViewSchema.Column"/> in <see cref="DataViewSchema"/></param>
/// <param name="ValueGetter">The cached ValueGetter for this column.</param>
internal virtual void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn, Delegate ValueGetter) => throw new NotImplementedException();
protected internal virtual void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn, Delegate ValueGetter) => throw new NotImplementedException();

/// <summary>
/// Returns the ValueGetter for each active column in <paramref name="cursor"/> as a delegate to be cached.
/// </summary>
/// <param name="cursor">The row cursor which has the current position</param>
/// <param name="schemaColumn">The <see cref="DataViewSchema.Column"/> in <see cref="DataViewSchema"/></param>
internal virtual Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) => throw new NotImplementedException();
protected internal virtual Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) => throw new NotImplementedException();

/// <summary>
/// Clamps values beyond the specified thresholds
Expand Down
27 changes: 14 additions & 13 deletions src/Microsoft.Data.Analysis/IDataView.Extension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,25 @@ public static DataFrame ToDataFrame(this IDataView dataView, long maxRows, param
}
}

List<Delegate> activeColumnDelegates = new List<Delegate>();

DataViewRowCursor cursor = dataView.GetRowCursor(activeColumns);
int columnIndex = 0;
foreach (DataViewSchema.Column column in activeColumns)
{
Delegate valueGetter = columns[columnIndex].GetValueGetterUsingCursor(cursor, column);
activeColumnDelegates.Add(valueGetter);
columnIndex++;
}
while (cursor.MoveNext() && cursor.Position < maxRows)
using (DataViewRowCursor cursor = dataView.GetRowCursor(activeColumns))
{
columnIndex = 0;
List<Delegate> activeColumnDelegates = new List<Delegate>();
int columnIndex = 0;
foreach (DataViewSchema.Column column in activeColumns)
{
columns[columnIndex].AddValueUsingCursor(cursor, column, activeColumnDelegates[columnIndex]);
Delegate valueGetter = columns[columnIndex].GetValueGetterUsingCursor(cursor, column);
activeColumnDelegates.Add(valueGetter);
columnIndex++;
}
while (cursor.MoveNext() && cursor.Position < maxRows)
{
columnIndex = 0;
foreach (DataViewSchema.Column column in activeColumns)
{
columns[columnIndex].AddValueUsingCursor(cursor, column, activeColumnDelegates[columnIndex]);
columnIndex++;
}
}
}

return new DataFrame(columns);
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ private static ValueGetter<ushort> CreateCharValueGetterDelegate(DataViewRowCurs
private static ValueGetter<double> CreateDecimalValueGetterDelegate(DataViewRowCursor cursor, PrimitiveDataFrameColumn<decimal> column) =>
(ref double value) => value = (double?)column[cursor.Position] ?? double.NaN;

internal override void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column column, Delegate getter)
protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column column, Delegate getter)
{
long row = cursor.Position;
T value = default;
Expand All @@ -797,7 +797,7 @@ internal override void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSch
}
}

internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn)
protected internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn)
{
return cursor.GetGetter<T>(schemaColumn);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.Data.Analysis/StringDataFrameColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ protected internal override Delegate GetDataViewGetter(DataViewRowCursor cursor)
private ValueGetter<ReadOnlyMemory<char>> CreateValueGetterDelegate(DataViewRowCursor cursor) =>
(ref ReadOnlyMemory<char> value) => value = this[cursor.Position].AsMemory();

internal override void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn, Delegate getter)
protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn, Delegate getter)
{
long row = cursor.Position;
ReadOnlyMemory<char> value = default;
Expand All @@ -489,7 +489,7 @@ internal override void AddValueUsingCursor(DataViewRowCursor cursor, DataViewSch
throw new IndexOutOfRangeException(nameof(row));
}
}
internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn)
protected internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn)
{
return cursor.GetGetter<ReadOnlyMemory<char>>(schemaColumn);
}
Expand Down
74 changes: 64 additions & 10 deletions test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,20 @@ private IDataView GetASampleIDataView()
return data;
}

private void VerifyDataFrameColumnAndDataViewColumnValues<T>(string columnName, IDataView data, DataFrame df, int maxRows = -1)
{
int cc = 0;
var nameDataViewColumn = data.GetColumn<T>(columnName);
foreach (var value in nameDataViewColumn)
{
if (maxRows != -1 && cc >= maxRows)
{
return;
}
Assert.Equal(value, df.Columns[columnName][cc++]);
}
}

[Fact]
public void TestDataFrameFromIDataView_MLData()
{
Expand All @@ -334,20 +348,60 @@ public void TestDataFrameFromIDataView_MLData()
Assert.Equal(6, column.Length);
}

void VerifyDataFrameColumnAndDataViewColumnValues<T>(string columnName)
VerifyDataFrameColumnAndDataViewColumnValues<string>("Name", data, df);
VerifyDataFrameColumnAndDataViewColumnValues<bool>("FilterNext", data, df);
VerifyDataFrameColumnAndDataViewColumnValues<float>("Value", data, df);
}

[Fact]
public void TestDataFrameFromIDataView_MLData_SelectColumns()
{
IDataView data = GetASampleIDataView();
DataFrame df = data.ToDataFrame("Name", "Value");
Assert.Equal(6, df.Rows.Count);
Assert.Equal(2, df.Columns.Count);
foreach (var column in df.Columns)
{
int cc = 0;
var nameDataViewColumn = data.GetColumn<T>(columnName);
foreach (var value in nameDataViewColumn)
{
Assert.Equal(value, df.Columns[columnName][cc++]);
}
Assert.Equal(6, column.Length);
}

VerifyDataFrameColumnAndDataViewColumnValues<string>("Name");
VerifyDataFrameColumnAndDataViewColumnValues<bool>("FilterNext");
VerifyDataFrameColumnAndDataViewColumnValues<float>("Value");
VerifyDataFrameColumnAndDataViewColumnValues<string>("Name", data, df);
VerifyDataFrameColumnAndDataViewColumnValues<float>("Value", data, df);
}

[Theory]
[InlineData(3)]
[InlineData(0)]
public void TestDataFrameFromIDataView_MLData_SelectRows(int maxRows)
{
IDataView data = GetASampleIDataView();
DataFrame df = data.ToDataFrame(maxRows);
Assert.Equal(maxRows, df.Rows.Count);
Assert.Equal(3, df.Columns.Count);
foreach (var column in df.Columns)
{
Assert.Equal(maxRows, column.Length);
}

VerifyDataFrameColumnAndDataViewColumnValues<string>("Name", data, df, maxRows);
VerifyDataFrameColumnAndDataViewColumnValues<bool>("FilterNext", data, df, maxRows);
VerifyDataFrameColumnAndDataViewColumnValues<float>("Value", data, df, maxRows);
}

[Fact]
public void TestDataFrameFromIDataView_MLData_SelectColumnsAndRows()
{
IDataView data = GetASampleIDataView();
DataFrame df = data.ToDataFrame(3, "Name", "Value");
Assert.Equal(3, df.Rows.Count);
Assert.Equal(2, df.Columns.Count);
foreach (var column in df.Columns)
{
Assert.Equal(3, column.Length);
}

VerifyDataFrameColumnAndDataViewColumnValues<string>("Name", data, df, 3);
VerifyDataFrameColumnAndDataViewColumnValues<float>("Value", data, df, 3);
}
}
}

0 comments on commit 96bb44a

Please sign in to comment.