Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose DataStreamWriter.Foreach API #387

Merged
merged 28 commits into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
994ea3e
initial commit
suhsteve Jan 8, 2020
b9ab9e0
add descriptions and cleanup
suhsteve Jan 8, 2020
13a3d02
cleanup and refactor
suhsteve Jan 9, 2020
6325e71
Move TaskContext class definition from Microsoft.Spark.Worker to Micr…
suhsteve Jan 9, 2020
214f1ad
epochId is a long
suhsteve Jan 9, 2020
d7ec1dc
add DataStreamWriter.Foreach E2E test
suhsteve Jan 9, 2020
b60fbb4
edit description
suhsteve Jan 9, 2020
102fc3a
update test
suhsteve Jan 10, 2020
6221e92
cleanup
suhsteve Jan 10, 2020
b0d0a27
improve test
suhsteve Jan 11, 2020
ed782a1
cleanup
suhsteve Jan 11, 2020
5339306
PR comments
suhsteve Jan 14, 2020
c2aa3be
cleanup comments
suhsteve Jan 16, 2020
24fe2bf
Update src/csharp/Microsoft.Spark/Sql/ForeachWriter.cs
suhsteve Jan 23, 2020
9cb49ef
Merge branch 'master' into foreach
suhsteve Jan 24, 2020
cfb9e37
Address PR comments.
suhsteve Jan 25, 2020
3a9b287
PR comments
suhsteve Jan 25, 2020
75ed8ed
remove newline
suhsteve Jan 25, 2020
41ebfde
typo
suhsteve Jan 25, 2020
4e93d8e
PR comments
suhsteve Jan 29, 2020
8860a44
update Worker version check
suhsteve Jan 29, 2020
57a028b
Merge branch 'master' into foreach
imback82 Jan 29, 2020
33daff3
PR comments
suhsteve Jan 30, 2020
ec35490
remove methodname check
suhsteve Jan 30, 2020
985a3af
remove UdfWrapperMethodName
suhsteve Jan 30, 2020
a62197f
readd UdfWrapperMethodName
suhsteve Jan 30, 2020
02bf06b
readd UdfWrapperMethodName check
suhsteve Jan 30, 2020
66c9837
rename ForeachWriterWrapper.Execute to Process
suhsteve Jan 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Streaming;
using Xunit;
Expand Down Expand Up @@ -59,5 +63,190 @@ public void TestSignaturesV2_3_X()

Assert.IsType<DataStreamWriter>(dsw.Trigger(Trigger.Once()));
}

[SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
public void TestForeach()
{
// Temporary folder to put our test stream input.
using var srcTempDirectory = new TemporaryDirectory();
string streamInputPath = Path.Combine(srcTempDirectory.Path, "streamInput");

// [1, 2, ..., 99]
_spark.Range(1, 100).Write().Json(streamInputPath);

// Test a scenario where IForeachWriter runs without issues.
// If everything is working as expected, then:
// - Triggering stream will not throw an exception
// - 3 CSV files will be created in the temporary directory.
// - 0 Exception files will be created in the temporary directory.
// - The CSV files will contain valid data to read, where the
// expected entries will contain [101, 102, ..., 199]
TestAndValidateForeach(
streamInputPath,
new TestForeachWriter(),
3,
0,
Enumerable.Range(101, 99));

// Test scenario where IForeachWriter.Open returns false.
// When IForeachWriter.Open returns false, then IForeachWriter.Process
// is not called. Verify that:
// - Triggering stream will not throw an exception
// - 3 CSV files will be created in the temporary directory.
// - 0 Exception files will be created in the temporary directory.
// - The CSV files will not contain valid data to read.
TestAndValidateForeach(
streamInputPath,
new TestForeachWriterOpenFailure(),
3,
0,
Enumerable.Empty<int>());

// Test scenario where IForeachWriter.Process throws an Exception.
// When IForeachWriter.Process throws an Exception, then the exception
// is rethrown by ForeachWriterWrapper. We will limit the partitions
// to 1 to make validating this scenario simpler. Verify that:
// - Triggering stream throws an exception.
// - 1 CSV file will be created in the temporary directory.
// - 1 Exception will be created in the temporary directory. The
// thrown exception from Process() will be sent to Close().
// - The CSV file will not contain valid data to read.
TestAndValidateForeach(
streamInputPath,
new TestForeachWriterProcessFailure(),
1,
1,
Enumerable.Empty<int>());
}

private void TestAndValidateForeach(
string streamInputPath,
TestForeachWriter foreachWriter,
int expectedCSVFiles,
int expectedExceptionFiles,
IEnumerable<int> expectedOutput)
{
// Temporary folder the TestForeachWriter will write to.
using var dstTempDirectory = new TemporaryDirectory();
foreachWriter.WritePath = dstTempDirectory.Path;

// Read streamInputPath, repartition data, then
// call TestForeachWriter on the data.
DataStreamWriter dsw = _spark
.ReadStream()
.Schema("id INT")
.Json(streamInputPath)
.Repartition(expectedCSVFiles)
.WriteStream()
.Foreach(foreachWriter);

// Trigger the stream batch once.
if (expectedExceptionFiles > 0)
{
Assert.Throws<Exception>(
() => dsw.Trigger(Trigger.Once()).Start().AwaitTermination());
}
else
{
dsw.Trigger(Trigger.Once()).Start().AwaitTermination();
}

// Verify that TestForeachWriter created a unique .csv when
// ForeachWriter.Open was called on each partitionId.
Assert.Equal(
expectedCSVFiles,
Directory.GetFiles(dstTempDirectory.Path, "*.csv").Length);

// Only if ForeachWriter.Process(Row) throws an exception, will
// ForeachWriter.Close(Exception) create a file with the
// .exeception extension.
Assert.Equal(
expectedExceptionFiles,
Directory.GetFiles(dstTempDirectory.Path, "*.exception").Length);

// Read in the *.csv file(s) generated by the TestForeachWriter.
// If there are multiple input files, sorting by "id" will make
// validation simpler. Contents of the *.csv will only be populated
// on successful calls to the ForeachWriter.Process method.
DataFrame foreachWriterOutputDF = _spark
.Read()
.Schema("id INT")
imback82 marked this conversation as resolved.
Show resolved Hide resolved
.Csv(dstTempDirectory.Path)
.Sort("id");

// Validate expected *.csv data.
Assert.Equal(
expectedOutput.Select(i => new object[] { i }),
foreachWriterOutputDF.Collect().Select(r => r.Values));
}

[Serializable]
private class TestForeachWriter : IForeachWriter
{
[NonSerialized]
private StreamWriter _streamWriter;

private long _partitionId;

private long _epochId;

internal string WritePath { get; set; }

public void Close(Exception errorOrNull)
{
if (errorOrNull != null)
{
FileStream fs = File.Create(
Path.Combine(
WritePath,
$"Close-{_partitionId}-{_epochId}.exception"));
fs.Dispose();
}

_streamWriter?.Dispose();
}

public virtual bool Open(long partitionId, long epochId)
{
_partitionId = partitionId;
_epochId = epochId;
try
{
_streamWriter = new StreamWriter(
Path.Combine(
WritePath,
$"sink-foreachWriter-{_partitionId}-{_epochId}.csv"));
return true;
}
catch
{
return false;
imback82 marked this conversation as resolved.
Show resolved Hide resolved
}
}

public virtual void Process(Row value)
{
_streamWriter.WriteLine(string.Join(",", value.Values.Select(v => 100 + (int)v)));
}
}

[Serializable]
private class TestForeachWriterOpenFailure : TestForeachWriter
{
public override bool Open(long partitionId, long epochId)
{
base.Open(partitionId, epochId);
return false;
}
}

[Serializable]
private class TestForeachWriterProcessFailure : TestForeachWriter
{
public override void Process(Row value)
{
throw new Exception("TestForeachWriterProcessFailure Process(Row) failure.");
}
}
}
}
71 changes: 0 additions & 71 deletions src/csharp/Microsoft.Spark.Worker/Payload.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,81 +3,10 @@
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Linq;
using Microsoft.Spark.Utils;

namespace Microsoft.Spark.Worker
{
/// <summary>
/// TaskContext stores information related to a task.
/// </summary>
internal class TaskContext
{
internal int StageId { get; set; }

internal int PartitionId { get; set; }

internal int AttemptNumber { get; set; }

internal long AttemptId { get; set; }

internal bool IsBarrier { get; set; }

internal int Port { get; set; }

internal string Secret { get; set; }

internal IEnumerable<Resource> Resources { get; set; } = new List<Resource>();

internal Dictionary<string, string> LocalProperties { get; set; } =
new Dictionary<string, string>();

public override bool Equals(object obj)
{
if (!(obj is TaskContext other))
{
return false;
}

return (StageId == other.StageId) &&
(PartitionId == other.PartitionId) &&
(AttemptNumber == other.AttemptNumber) &&
(AttemptId == other.AttemptId) &&
Resources.SequenceEqual(other.Resources) &&
(LocalProperties.Count == other.LocalProperties.Count) &&
!LocalProperties.Except(other.LocalProperties).Any();
}

public override int GetHashCode()
{
return StageId;
}

internal class Resource
{
internal string Key { get; set; }
internal string Value { get; set; }
internal IEnumerable<string> Addresses { get; set; } = new List<string>();

public override bool Equals(object obj)
{
if (!(obj is Resource other))
{
return false;
}

return (Key == other.Key) &&
(Value == other.Value) &&
Addresses.SequenceEqual(Addresses);
}

public override int GetHashCode()
{
return Key.GetHashCode();
}
}
}

/// <summary>
/// BroadcastVariables stores information on broadcast variables.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ internal Payload Process(Stream stream)

payload.SplitIndex = BinaryPrimitives.ReadInt32BigEndian(splitIndexBytes);
payload.Version = SerDe.ReadString(stream);

payload.TaskContext = new TaskContextProcessor(_version).Process(stream);
TaskContextHolder.Set(payload.TaskContext);

payload.SparkFilesDir = SerDe.ReadString(stream);

if (Utils.SettingUtils.IsDatabricks)
Expand Down
8 changes: 8 additions & 0 deletions src/csharp/Microsoft.Spark/Attributes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,12 @@ public DeprecatedAttribute(string version)
{
}
}

/// <summary>
/// Custom attribute to denote that a class is a Udf Wrapper.
/// </summary>
[AttributeUsage(AttributeTargets.Class)]
internal sealed class UdfWrapperAttribute : Attribute
{
}
}
22 changes: 13 additions & 9 deletions src/csharp/Microsoft.Spark/RDD.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,14 @@ public RDD<T> Sample(bool withReplacement, double fraction, long? seed = null)
public IEnumerable<T> Collect()
{
(int port, string secret) = CollectAndServe();
using (ISocketWrapper socket = SocketFactory.CreateSocket())
{
socket.Connect(IPAddress.Loopback, port, secret);
using ISocketWrapper socket = SocketFactory.CreateSocket();
socket.Connect(IPAddress.Loopback, port, secret);

var collector = new RDD.Collector();
System.IO.Stream stream = socket.InputStream;
foreach (T element in collector.Collect(stream, _serializedMode).Cast<T>())
{
yield return element;
}
var collector = new RDD.Collector();
System.IO.Stream stream = socket.InputStream;
foreach (T element in collector.Collect(stream, _serializedMode).Cast<T>())
{
yield return element;
}
}

Expand Down Expand Up @@ -341,6 +339,7 @@ private JvmObjectReference GetJvmRef()
/// </summary>
/// <typeparam name="TArg">Input type</typeparam>
/// <typeparam name="TResult">Output type</typeparam>
[UdfWrapper]
internal sealed class MapUdfWrapper<TArg, TResult>
{
private readonly Func<TArg, TResult> _func;
Expand All @@ -361,6 +360,7 @@ internal IEnumerable<object> Execute(int pid, IEnumerable<object> input)
/// </summary>
/// <typeparam name="TArg">Input type</typeparam>
/// <typeparam name="TResult">Output type</typeparam>
[UdfWrapper]
internal sealed class FlatMapUdfWrapper<TArg, TResult>
{
private readonly Func<TArg, IEnumerable<TResult>> _func;
Expand All @@ -382,6 +382,7 @@ internal IEnumerable<object> Execute(int pid, IEnumerable<object> input)
/// </summary>
/// <typeparam name="TArg">Input type</typeparam>
/// <typeparam name="TResult">Output type</typeparam>
[UdfWrapper]
internal sealed class MapPartitionsUdfWrapper<TArg, TResult>
{
private readonly Func<IEnumerable<TArg>, IEnumerable<TResult>> _func;
Expand All @@ -403,6 +404,7 @@ internal IEnumerable<object> Execute(int pid, IEnumerable<object> input)
/// </summary>
/// <typeparam name="TArg">Input type</typeparam>
/// <typeparam name="TResult">Output type</typeparam>
[UdfWrapper]
internal sealed class MapPartitionsWithIndexUdfWrapper<TArg, TResult>
{
private readonly Func<int, IEnumerable<TArg>, IEnumerable<TResult>> _func;
Expand All @@ -423,9 +425,11 @@ internal IEnumerable<object> Execute(int pid, IEnumerable<object> input)
/// Helper to map the UDF for Filter() to
/// <see cref="RDD.WorkerFunction.ExecuteDelegate"/>.
/// </summary>
[UdfWrapper]
internal class FilterUdfWrapper
{
private readonly Func<T, bool> _func;

internal FilterUdfWrapper(Func<T, bool> func)
{
_func = func;
Expand Down
Loading