diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 098fa7974b87b..75172957746ae 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -34,6 +34,7 @@
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
@@ -81,6 +82,22 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
handleMessage(msgObj, client, callback);
}
+ @Override
+ public StreamCallbackWithID receiveStream(
+ TransportClient client,
+ ByteBuffer messageHeader,
+ RpcResponseCallback callback) {
+ BlockTransferMessage header = BlockTransferMessage.Decoder.fromByteBuffer(messageHeader);
+ return handleStream(header, client, callback);
+ }
+
+ protected StreamCallbackWithID handleStream(
+ BlockTransferMessage header,
+ TransportClient client,
+ RpcResponseCallback callback) {
+ throw new UnsupportedOperationException("Unexpected message header: " + header);
+ }
+
protected void handleMessage(
BlockTransferMessage msgObj,
TransportClient client,
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
index 0b7a27402369d..7b9c75cd07795 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
@@ -19,13 +19,13 @@
import java.io.*;
import java.nio.charset.StandardCharsets;
+import java.util.regex.Pattern;
import java.util.*;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.regex.Matcher;
-import java.util.regex.Pattern;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index e49e27ab5aa79..2a013f5497a67 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -140,7 +140,8 @@ public void registerWithShuffleServer(
ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException {
checkInit();
try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) {
- ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer();
+ ByteBuffer registerMessage =
+ new RegisterExecutor(appId, execId, executorInfo).toByteBuffer();
client.sendRpcSync(registerMessage, registrationTimeoutMs);
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java
new file mode 100644
index 0000000000000..9b3736a9ecf1b
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java
@@ -0,0 +1,150 @@
+package org.apache.spark.network.shuffle;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
+import java.nio.file.StandardOpenOption;
+
+import org.apache.spark.network.client.StreamCallbackWithID;
+
+public class FileWriterStreamCallback implements StreamCallbackWithID {
+
+ private static final Logger logger = LoggerFactory.getLogger(FileWriterStreamCallback.class);
+
+ public enum FileType {
+ DATA("shuffle-data"),
+ INDEX("shuffle-index");
+
+ private final String typeString;
+
+ FileType(String typeString) {
+ this.typeString = typeString;
+ }
+
+ @Override
+ public String toString() {
+ return typeString;
+ }
+ }
+
+ private final String appId;
+ private final int shuffleId;
+ private final int mapId;
+ private final File file;
+ private final FileType fileType;
+ private WritableByteChannel fileOutputChannel = null;
+
+ public FileWriterStreamCallback(
+ String appId,
+ int shuffleId,
+ int mapId,
+ File file,
+ FileWriterStreamCallback.FileType fileType) {
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.file = file;
+ this.fileType = fileType;
+ }
+
+ public void open() {
+ logger.info(
+ "Opening {} for remote writing. File type: {}", file.getAbsolutePath(), fileType);
+ if (fileOutputChannel != null) {
+ throw new IllegalStateException(
+ String.format(
+ "File %s for is already open for writing (type: %s).",
+ file.getAbsolutePath(),
+ fileType));
+ }
+ if (!file.exists()) {
+ try {
+ if (!file.getParentFile().isDirectory() && !file.getParentFile().mkdirs()) {
+ throw new IOException(
+ String.format(
+ "Failed to create shuffle file directory at"
+ + file.getParentFile().getAbsolutePath() + "(type: %s).", fileType));
+ }
+
+ if (!file.createNewFile()) {
+ throw new IOException(
+ String.format(
+ "Failed to create shuffle file (type: %s).", fileType));
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(
+ String.format(
+ "Failed to create shuffle file at %s for backup (type: %s).",
+ file.getAbsolutePath(),
+ fileType),
+ e);
+ }
+ }
+ try {
+ // TODO encryption
+ fileOutputChannel = FileChannel.open(file.toPath(), StandardOpenOption.APPEND);
+ } catch (IOException e) {
+ throw new RuntimeException(
+ String.format(
+ "Failed to find file for writing at %s (type: %s).",
+ file.getAbsolutePath(),
+ fileType),
+ e);
+ }
+ }
+
+ @Override
+ public String getID() {
+ return String.format("%s-%d-%d-%s",
+ appId,
+ shuffleId,
+ mapId,
+ fileType);
+ }
+
+ @Override
+ public void onData(String streamId, ByteBuffer buf) throws IOException {
+ verifyShuffleFileOpenForWriting();
+ while (buf.hasRemaining()) {
+ fileOutputChannel.write(buf);
+ }
+ }
+
+ @Override
+ public void onComplete(String streamId) throws IOException {
+ logger.info(
+ "Finished writing {}. File type: {}", file.getAbsolutePath(), fileType);
+ fileOutputChannel.close();
+ }
+
+ @Override
+ public void onFailure(String streamId, Throwable cause) throws IOException {
+ logger.warn("Failed to write shuffle file at {} (type: %s).",
+ file.getAbsolutePath(),
+ fileType,
+ cause);
+ fileOutputChannel.close();
+ // TODO delete parent dirs too
+ if (!file.delete()) {
+ logger.warn(
+ "Failed to delete incomplete remote shuffle file at %s (type: %s)",
+ file.getAbsolutePath(),
+ fileType);
+ }
+ }
+
+ private void verifyShuffleFileOpenForWriting() {
+ if (fileOutputChannel == null) {
+ throw new RuntimeException(
+ String.format(
+ "Shuffle file at %s not open for writing (type: %s).",
+ file.getAbsolutePath(),
+ fileType));
+ }
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java
new file mode 100644
index 0000000000000..06af3bd141fba
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.k8s;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.spark.network.shuffle.protocol.ShuffleServiceHeartbeat;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.shuffle.ExternalShuffleClient;
+import org.apache.spark.network.shuffle.protocol.RegisterDriver;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A client for talking to the external shuffle service in Kubernetes coarse-grained mode.
+ *
+ * This is used by the Spark driver to register with each external shuffle service on the cluster.
+ * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably
+ * after the application exits. Kubernetes does not provide a great alternative to do this, so Spark
+ * has to detect this itself.
+ */
+public class KubernetesExternalShuffleClient extends ExternalShuffleClient {
+ private static final Logger logger =
+ LoggerFactory.getLogger(KubernetesExternalShuffleClient.class);
+
+ private final ScheduledExecutorService heartbeaterThread =
+ Executors.newSingleThreadScheduledExecutor(
+ new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat("kubernetes-external-shuffle-client-heartbeater")
+ .build());
+
+ /**
+ * Creates a Kubernetes external shuffle client that wraps the {@link ExternalShuffleClient}.
+ * Please refer to docs on {@link ExternalShuffleClient} for more information.
+ */
+ public KubernetesExternalShuffleClient(
+ TransportConf conf,
+ SecretKeyHolder secretKeyHolder,
+ boolean authEnabled,
+ long registrationTimeoutMs) {
+ super(conf, secretKeyHolder, authEnabled, registrationTimeoutMs);
+ }
+
+ public void registerDriverWithShuffleService(
+ String host,
+ int port,
+ long heartbeatTimeoutMs,
+ long heartbeatIntervalMs) throws IOException, InterruptedException {
+
+ checkInit();
+ ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer();
+ logger.info("Registering with external shuffle service at " + host + ":" + port);
+ TransportClient client = clientFactory.createClient(host, port);
+ client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs));
+ }
+
+ private class RegisterDriverCallback implements RpcResponseCallback {
+ private final TransportClient client;
+ private final long heartbeatIntervalMs;
+
+ private RegisterDriverCallback(TransportClient client, long heartbeatIntervalMs) {
+ this.client = client;
+ this.heartbeatIntervalMs = heartbeatIntervalMs;
+ }
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ heartbeaterThread.scheduleAtFixedRate(
+ new Heartbeater(client), 0, heartbeatIntervalMs, TimeUnit.MILLISECONDS);
+ logger.info("Successfully registered app " + appId + " with external shuffle service.");
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.warn("Unable to register app " + appId + " with external shuffle service. " +
+ "Please manually remove shuffle data after driver exit. Error: " + e);
+ }
+ }
+
+ @Override
+ public void close() {
+ heartbeaterThread.shutdownNow();
+ super.close();
+ }
+
+ private class Heartbeater implements Runnable {
+
+ private final TransportClient client;
+
+ private Heartbeater(TransportClient client) {
+ this.client = client;
+ }
+
+ @Override
+ public void run() {
+ // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout
+ client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer());
+ }
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
index 60179f126bc44..3510509f20eee 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -24,7 +24,7 @@
import java.util.concurrent.TimeUnit;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
-import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat;
+import org.apache.spark.network.shuffle.protocol.ShuffleServiceHeartbeat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -32,7 +32,7 @@
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.shuffle.ExternalShuffleClient;
-import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
+import org.apache.spark.network.shuffle.protocol.RegisterDriver;
import org.apache.spark.network.util.TransportConf;
/**
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index a68a297519b66..f5196638f9140 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -23,8 +23,6 @@
import io.netty.buffer.Unpooled;
import org.apache.spark.network.protocol.Encodable;
-import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
-import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat;
/**
* Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or
@@ -42,7 +40,8 @@ public abstract class BlockTransferMessage implements Encodable {
/** Preceding every serialized message is its type, which allows us to deserialize it. */
public enum Type {
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
- HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6);
+ HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_PARTITION_STREAM(7),
+ REGISTER_SHUFFLE_INDEX(8), OPEN_SHUFFLE_PARTITION(9), UPLOAD_SHUFFLE_INDEX(10);
private final byte id;
@@ -68,6 +67,10 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
case 4: return RegisterDriver.decode(buf);
case 5: return ShuffleServiceHeartbeat.decode(buf);
case 6: return UploadBlockStream.decode(buf);
+ case 7: return UploadShufflePartitionStream.decode(buf);
+ case 8: return RegisterShuffleIndex.decode(buf);
+ case 9: return OpenShufflePartition.decode(buf);
+ case 10: return UploadShuffleIndex.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java
new file mode 100644
index 0000000000000..63d2387bd6d1a
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java
@@ -0,0 +1,76 @@
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+public class OpenShufflePartition extends BlockTransferMessage {
+ public final String appId;
+ public final int shuffleId;
+ public final int mapId;
+ public final int partitionId;
+
+ public OpenShufflePartition(
+ String appId, int shuffleId, int mapId, int partitionId) {
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.partitionId = partitionId;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof OpenShufflePartition) {
+ OpenShufflePartition o = (OpenShufflePartition) other;
+ return Objects.equal(appId, o.appId)
+ && shuffleId == o.shuffleId
+ && mapId == o.mapId
+ && partitionId == o.partitionId;
+ }
+ return false;
+ }
+
+ @Override
+ protected Type type() {
+ return Type.OPEN_SHUFFLE_PARTITION;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, shuffleId, mapId, partitionId);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("shuffleId", shuffleId)
+ .add("mapId", mapId)
+ .add("partitionId", partitionId)
+ .toString();
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ buf.writeInt(shuffleId);
+ buf.writeInt(mapId);
+ buf.writeInt(partitionId);
+ }
+
+ public static OpenShufflePartition decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ int shuffleId = buf.readInt();
+ int mapId = buf.readInt();
+ int partitionId = buf.readInt();
+ return new OpenShufflePartition(appId, shuffleId, mapId, partitionId);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java
similarity index 94%
rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java
index d5f53ccb7f741..516a51ad7cc14 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java
@@ -15,13 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.network.shuffle.protocol.mesos;
+package org.apache.spark.network.shuffle.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.apache.spark.network.protocol.Encoders;
-import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
// Needed by ScalaDoc. See SPARK-7726
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java
new file mode 100644
index 0000000000000..bc870e440274b
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/**
+ * Register shuffle index to the External Shuffle Service.
+ */
+public class RegisterShuffleIndex extends BlockTransferMessage {
+ public final String appId;
+ public final int shuffleId;
+ public final int mapId;
+
+ public RegisterShuffleIndex(
+ String appId,
+ int shuffleId,
+ int mapId) {
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof RegisterShuffleIndex) {
+ RegisterShuffleIndex o = (RegisterShuffleIndex) other;
+ return Objects.equal(appId, o.appId)
+ && shuffleId == o.shuffleId
+ && mapId == o.mapId;
+ }
+ return false;
+ }
+
+ @Override
+ protected Type type() {
+ return Type.REGISTER_SHUFFLE_INDEX;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, shuffleId, mapId);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("shuffleId", shuffleId)
+ .add("mapId", mapId)
+ .toString();
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId) + 4 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ buf.writeInt(shuffleId);
+ buf.writeInt(mapId);
+ }
+
+ public static RegisterShuffleIndex decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ int shuffleId = buf.readInt();
+ int mapId = buf.readInt();
+ return new RegisterShuffleIndex(appId, shuffleId, mapId);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ShuffleServiceHeartbeat.java
similarity index 92%
rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java
rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ShuffleServiceHeartbeat.java
index b30bb9aed55b6..1a6ffc0f91333 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ShuffleServiceHeartbeat.java
@@ -15,11 +15,10 @@
* limitations under the License.
*/
-package org.apache.spark.network.shuffle.protocol.mesos;
+package org.apache.spark.network.shuffle.protocol;
import io.netty.buffer.ByteBuf;
import org.apache.spark.network.protocol.Encoders;
-import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
// Needed by ScalaDoc. See SPARK-7726
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java
new file mode 100644
index 0000000000000..b11a02f6b9219
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/**
+ * Upload shuffle index request to the External Shuffle Service.
+ */
+public class UploadShuffleIndex extends BlockTransferMessage {
+ public final String appId;
+ public final int shuffleId;
+ public final int mapId;
+
+ public UploadShuffleIndex(
+ String appId,
+ int shuffleId,
+ int mapId) {
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof UploadShuffleIndex) {
+ UploadShuffleIndex o = (UploadShuffleIndex) other;
+ return Objects.equal(appId, o.appId)
+ && shuffleId == o.shuffleId
+ && mapId == o.mapId;
+ }
+ return false;
+ }
+
+ @Override
+ protected Type type() {
+ return Type.UPLOAD_SHUFFLE_INDEX;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, shuffleId, mapId);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("shuffleId", shuffleId)
+ .add("mapId", mapId)
+ .toString();
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId) + 4 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ buf.writeInt(shuffleId);
+ buf.writeInt(mapId);
+ }
+
+ public static UploadShuffleIndex decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ int shuffleId = buf.readInt();
+ int mapId = buf.readInt();
+ return new UploadShuffleIndex(appId, shuffleId, mapId);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java
new file mode 100644
index 0000000000000..ad8f5405192fc
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/**
+ * Upload shuffle partition request to the External Shuffle Service.
+ */
+public class UploadShufflePartitionStream extends BlockTransferMessage {
+ public final String appId;
+ public final int shuffleId;
+ public final int mapId;
+ public final int partitionId;
+ public final int partitionLength;
+
+ public UploadShufflePartitionStream(
+ String appId,
+ int shuffleId,
+ int mapId,
+ int partitionId,
+ int partitionLength) {
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.partitionId = partitionId;
+ this.partitionLength = partitionLength;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof UploadShufflePartitionStream) {
+ UploadShufflePartitionStream o = (UploadShufflePartitionStream) other;
+ return Objects.equal(appId, o.appId)
+ && shuffleId == o.shuffleId
+ && mapId == o.mapId
+ && partitionId == o.partitionId
+ && partitionLength == o.partitionLength;
+ }
+ return false;
+ }
+
+ @Override
+ protected Type type() {
+ return Type.UPLOAD_SHUFFLE_PARTITION_STREAM;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, shuffleId, mapId, partitionId, partitionLength);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("shuffleId", shuffleId)
+ .add("mapId", mapId)
+ .add("partitionId", partitionId)
+ .add("partitionLength", partitionLength)
+ .toString();
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ buf.writeInt(shuffleId);
+ buf.writeInt(mapId);
+ buf.writeInt(partitionId);
+ buf.writeInt(partitionLength);
+ }
+
+ public static UploadShufflePartitionStream decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ int shuffleId = buf.readInt();
+ int mapId = buf.readInt();
+ int partitionId = buf.readInt();
+ int partitionLength = buf.readInt();
+ return new UploadShufflePartitionStream(
+ appId, shuffleId, mapId, partitionId, partitionLength);
+ }
+}
diff --git a/core/pom.xml b/core/pom.xml
index 49b1a54e32598..544ae61279c4d 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -352,6 +352,11 @@
py4j
0.10.8.1
+
+ org.scala-lang.modules
+ scala-java8-compat_${scala.binary.version}
+ 0.9.0
+
org.apache.spark
spark-tags_${scala.binary.version}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java
new file mode 100644
index 0000000000000..7846fad70b159
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java
@@ -0,0 +1,23 @@
+package org.apache.spark.shuffle.api;
+
+import org.apache.spark.storage.ShuffleLocation;
+
+import java.util.Optional;
+
+public interface CommittedPartition {
+
+ /**
+ * Indicates the number of bytes written in a committed partition.
+ * Note that returning the length is mainly for backwards compatibility
+ * and should be removed in a more polished variant. After this method
+ * is called, the writer will be discarded; it's expected that the
+ * implementation will close any underlying resources.
+ */
+ long length();
+
+ /**
+ * Indicates the shuffle location to which this partition was written.
+ * Some implementations may not need to specify a shuffle location.
+ */
+ Optional shuffleLocation();
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
index b091e231f2cd7..19cd94712a8ae 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
@@ -16,11 +16,13 @@
*/
package org.apache.spark.shuffle.api;
+import java.io.IOException;
+
public interface ShuffleDataIO {
- void initialize();
+ void initialize() throws IOException;
- ShuffleReadSupport readSupport();
+ ShuffleReadSupport readSupport() throws IOException;
- ShuffleWriteSupport writeSupport();
+ ShuffleWriteSupport writeSupport() throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
index 06415dba72d34..becb9413a8f4f 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
@@ -17,11 +17,13 @@
package org.apache.spark.shuffle.api;
+import java.io.IOException;
+
public interface ShuffleMapOutputWriter {
- ShufflePartitionWriter newPartitionWriter(int partitionId);
+ ShufflePartitionWriter newPartitionWriter(int partitionId) throws IOException;
- void commitAllPartitions();
+ void commitAllPartitions() throws IOException;
- void abort(Exception exception);
+ void abort(Exception exception) throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java
index 59eae0a782200..46d1699724981 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java
@@ -18,8 +18,13 @@
package org.apache.spark.shuffle.api;
import java.io.InputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+import org.apache.spark.storage.ShuffleLocation;
public interface ShufflePartitionReader {
- InputStream fetchPartition(int reduceId);
+ InputStream fetchPartition(int reduceId, Optional shuffleLocation)
+ throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
index 7fa667cf137ea..e7cc6dd913d11 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
@@ -17,7 +17,7 @@
package org.apache.spark.shuffle.api;
-import java.io.InputStream;
+import java.io.IOException;
import java.io.OutputStream;
/**
@@ -28,20 +28,18 @@ public interface ShufflePartitionWriter {
/**
* Return a stream that should persist the bytes for this partition.
*/
- OutputStream openPartitionStream();
+ OutputStream openPartitionStream() throws IOException;
/**
- * Indicate that the partition was written successfully and there are no more incoming bytes. Returns
- * the length of the partition that is written. Note that returning the length is mainly for backwards
- * compatibility and should be removed in a more polished variant. After this method is called, the writer
- * will be discarded; it's expected that the implementation will close any underlying resources.
+ * Indicate that the partition was written successfully and there are no more incoming bytes.
+ * Returns a {@link CommittedPartition} indicating information about that written partition.
*/
- long commitAndGetTotalLength();
+ CommittedPartition commitPartition() throws IOException;
/**
* Indicate that the write has failed for some reason and the implementation can handle the
* failure reason. After this method is called, this writer will be discarded; it's expected that
* the implementation will close any underlying resources.
*/
- void abort(Exception failureReason);
+ void abort(Exception failureReason) throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java
index b1be7c1de98ac..ebe8fd12dccdb 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java
@@ -17,8 +17,11 @@
package org.apache.spark.shuffle.api;
+import java.io.IOException;
+
public interface ShuffleReadSupport {
- ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId);
+ ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId)
+ throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java
index 2f61dbaa17c69..f88555f8a1bd9 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java
@@ -17,7 +17,10 @@
package org.apache.spark.shuffle.api;
+import java.io.IOException;
+
public interface ShuffleWriteSupport {
- ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId);
+ ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId)
+ throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java
new file mode 100644
index 0000000000000..7e37659dbb3f9
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java
@@ -0,0 +1,32 @@
+package org.apache.spark.shuffle.external;
+
+import org.apache.spark.shuffle.api.CommittedPartition;
+import org.apache.spark.storage.ShuffleLocation;
+
+import java.util.Optional;
+
+public class ExternalCommittedPartition implements CommittedPartition {
+
+ private final long length;
+ private final Optional shuffleLocation;
+
+ public ExternalCommittedPartition(long length) {
+ this.length = length;
+ this.shuffleLocation = Optional.empty();
+ }
+
+ public ExternalCommittedPartition(long length, ShuffleLocation shuffleLocation) {
+ this.length = length;
+ this.shuffleLocation = Optional.of(shuffleLocation);
+ }
+
+ @Override
+ public long length() {
+ return length;
+ }
+
+ @Override
+ public Optional shuffleLocation() {
+ return shuffleLocation;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java
new file mode 100644
index 0000000000000..10e4093759c38
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java
@@ -0,0 +1,86 @@
+package org.apache.spark.shuffle.external;
+
+import org.apache.spark.*;
+import org.apache.spark.SecurityManager;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.netty.SparkTransportConf;
+import org.apache.spark.network.server.NoOpRpcHandler;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.shuffle.api.ShuffleDataIO;
+import org.apache.spark.shuffle.api.ShuffleReadSupport;
+import org.apache.spark.shuffle.api.ShuffleWriteSupport;
+import org.apache.spark.network.shuffle.k8s.KubernetesExternalShuffleClient;
+
+import org.apache.spark.storage.BlockManager;
+import scala.Tuple2;
+
+import java.util.List;
+import java.util.Random;
+
+public class ExternalShuffleDataIO implements ShuffleDataIO {
+
+ private final SparkConf conf;
+ private final TransportConf transportConf;
+ private final TransportContext context;
+ private static MapOutputTracker mapOutputTracker;
+ private static SecurityManager securityManager;
+ private static List> hostPorts;
+ private static Boolean isDriver;
+ private static KubernetesExternalShuffleClient shuffleClient;
+
+ public ExternalShuffleDataIO(
+ SparkConf sparkConf) {
+ this.conf = sparkConf;
+ // TODO: Grab numUsableCores
+ this.transportConf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 1);
+ // Close idle connections
+ this.context = new TransportContext(transportConf, new NoOpRpcHandler(), true, true);
+ }
+
+ @Override
+ public void initialize() {
+ SparkEnv env = SparkEnv.get();
+ mapOutputTracker = env.mapOutputTracker();
+ securityManager = env.securityManager();
+ isDriver = env.blockManager().blockManagerId().isDriver();
+ hostPorts = mapOutputTracker.getRemoteShuffleServiceAddress();
+ if (isDriver) {
+ shuffleClient = new KubernetesExternalShuffleClient(transportConf, securityManager,
+ securityManager.isAuthenticationEnabled(), conf.getTimeAsMs(
+ package$.MODULE$.SHUFFLE_REGISTRATION_TIMEOUT().key(), "5000ms"));
+ shuffleClient.init(conf.getAppId());
+ for (Tuple2 hp : hostPorts) {
+ try {
+ shuffleClient.registerDriverWithShuffleService(
+ hp._1, hp._2,
+ conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs",
+ conf.getTimeAsSeconds("spark.network.timeout", "120s") + "s"),
+ conf.getTimeAsSeconds(
+ package$.MODULE$.EXECUTOR_HEARTBEAT_INTERVAL().key(), "10s"));
+ } catch (Exception e) {
+ throw new RuntimeException("Unable to register driver with ESS", e);
+ }
+ }
+ BlockManager.ShuffleMetricsSource metricSource =
+ new BlockManager.ShuffleMetricsSource(
+ "RemoteShuffleService", shuffleClient.shuffleMetrics());
+ env.metricsSystem().registerSource(metricSource);
+ }
+ }
+
+ @Override
+ public ShuffleReadSupport readSupport() {
+ return new ExternalShuffleReadSupport(
+ transportConf, context, securityManager.isAuthenticationEnabled(), securityManager);
+ }
+
+ @Override
+ public ShuffleWriteSupport writeSupport() {
+ int rnd = new Random().nextInt(hostPorts.size());
+ Tuple2 hostPort = hostPorts.get(rnd);
+ return new ExternalShuffleWriteSupport(
+ transportConf, context, securityManager.isAuthenticationEnabled(),
+ securityManager, hostPort._1, hostPort._2);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java
new file mode 100644
index 0000000000000..c1178da2411ff
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java
@@ -0,0 +1,38 @@
+package org.apache.spark.shuffle.external;
+
+import org.apache.spark.storage.ShuffleLocation;
+
+import java.io.*;
+
+public class ExternalShuffleLocation implements ShuffleLocation {
+
+ private String shuffleHostname;
+ private int shufflePort;
+
+ public ExternalShuffleLocation() { /* for serialization */ }
+
+ public ExternalShuffleLocation(String shuffleHostname, int shufflePort) {
+ this.shuffleHostname = shuffleHostname;
+ this.shufflePort = shufflePort;
+ }
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeUTF(shuffleHostname);
+ out.writeInt(shufflePort);
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ this.shuffleHostname = in.readUTF();
+ this.shufflePort = in.readInt();
+ }
+
+ public String getShuffleHostname() {
+ return this.shuffleHostname;
+ }
+
+ public int getShufflePort() {
+ return this.shufflePort;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java
new file mode 100644
index 0000000000000..8866d14feca53
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java
@@ -0,0 +1,95 @@
+package org.apache.spark.shuffle.external;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.shuffle.protocol.RegisterShuffleIndex;
+import org.apache.spark.network.shuffle.protocol.UploadShuffleIndex;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.ShufflePartitionWriter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.nio.ByteBuffer;
+
+
+public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter {
+
+ private final TransportClientFactory clientFactory;
+ private final String hostName;
+ private final int port;
+ private final String appId;
+ private final int shuffleId;
+ private final int mapId;
+
+ public ExternalShuffleMapOutputWriter(
+ TransportClientFactory clientFactory,
+ String hostName,
+ int port,
+ String appId,
+ int shuffleId,
+ int mapId) {
+ this.clientFactory = clientFactory;
+ this.hostName = hostName;
+ this.port = port;
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+
+ TransportClient client = null;
+ try {
+ client = clientFactory.createUnmanagedClient(hostName, port);
+ ByteBuffer registerShuffleIndex = new RegisterShuffleIndex(
+ appId, shuffleId, mapId).toByteBuffer();
+ String requestID = String.format(
+ "index-register-%s-%d-%d", appId, shuffleId, mapId);
+ client.setClientId(requestID);
+ logger.info("clientid: " + client.getClientId() + " " + client.isActive());
+ client.sendRpcSync(registerShuffleIndex, 60000);
+ } catch (Exception e) {
+ client.close();
+ logger.error("Encountered error while creating transport client", e);
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static final Logger logger =
+ LoggerFactory.getLogger(ExternalShuffleMapOutputWriter.class);
+
+ @Override
+ public ShufflePartitionWriter newPartitionWriter(int partitionId) {
+ try {
+ return new ExternalShufflePartitionWriter(clientFactory,
+ hostName, port, appId, shuffleId, mapId, partitionId);
+ } catch (Exception e) {
+ clientFactory.close();
+ logger.error("Encountered error while creating transport client", e);
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void commitAllPartitions() {
+ TransportClient client = null;
+ try {
+ client = clientFactory.createUnmanagedClient(hostName, port);
+ ByteBuffer uploadShuffleIndex = new UploadShuffleIndex(
+ appId, shuffleId, mapId).toByteBuffer();
+ String requestID = String.format(
+ "index-upload-%s-%d-%d", appId, shuffleId, mapId);
+ client.setClientId(requestID);
+ logger.info("clientid: " + client.getClientId() + " " + client.isActive());
+ client.sendRpcSync(uploadShuffleIndex, 60000);
+ } catch (Exception e) {
+ logger.error("Encountered error while creating transport client", e);
+ client.close();
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void abort(Exception exception) {
+ clientFactory.close();
+ logger.error("Encountered error while " +
+ "attempting to add partitions to ESS", exception);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java
new file mode 100644
index 0000000000000..10f1b71008472
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java
@@ -0,0 +1,79 @@
+package org.apache.spark.shuffle.external;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.shuffle.protocol.OpenShufflePartition;
+import org.apache.spark.shuffle.api.ShufflePartitionReader;
+import org.apache.spark.storage.ShuffleLocation;
+import org.apache.spark.util.ByteBufferInputStream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Optional;
+
+public class ExternalShufflePartitionReader implements ShufflePartitionReader {
+
+ private static final Logger logger =
+ LoggerFactory.getLogger(ExternalShufflePartitionReader.class);
+
+ private final TransportClientFactory clientFactory;
+ private final String appId;
+ private final int shuffleId;
+ private final int mapId;
+
+ public ExternalShufflePartitionReader(
+ TransportClientFactory clientFactory,
+ String appId,
+ int shuffleId,
+ int mapId) {
+ this.clientFactory = clientFactory;
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ }
+
+ @Override
+ public InputStream fetchPartition(int reduceId, Optional shuffleLocation) {
+ assert shuffleLocation.isPresent() &&
+ shuffleLocation.get() instanceof ExternalShuffleLocation;
+ ExternalShuffleLocation externalShuffleLocation =
+ (ExternalShuffleLocation) shuffleLocation.get();
+ logger.info(String.format("Found external shuffle location on node: %s:%d",
+ externalShuffleLocation.getShuffleHostname(),
+ externalShuffleLocation.getShufflePort()));
+ String hostname = externalShuffleLocation.getShuffleHostname();
+ int port = externalShuffleLocation.getShufflePort();
+
+ OpenShufflePartition openMessage =
+ new OpenShufflePartition(appId, shuffleId, mapId, reduceId);
+ TransportClient client = null;
+ try {
+ client = clientFactory.createUnmanagedClient(hostname, port);
+ String requestID = String.format(
+ "read-%s-%d-%d-%d", appId, shuffleId, mapId, reduceId);
+ client.setClientId(requestID);
+ logger.info("clientid: " + client.getClientId() + " " + client.isActive());
+
+ ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000);
+ logger.info("response is: " + response.toString() +
+ " " + response.array() + " " + response.hasArray());
+ if (response.hasArray()) {
+ logger.info("response hashcode: " + Arrays.hashCode(response.array()));
+ // use heap buffer; no array is created; only the reference is used
+ return new ByteArrayInputStream(response.array());
+ }
+ return new ByteBufferInputStream(response);
+
+ } catch (Exception e) {
+ if (client != null) {
+ client.close();
+ }
+ logger.error("Encountered exception while trying to fetch blocks", e);
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java
new file mode 100644
index 0000000000000..edf046a32ffe3
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java
@@ -0,0 +1,110 @@
+package org.apache.spark.shuffle.external;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.shuffle.protocol.UploadShufflePartitionStream;
+import org.apache.spark.shuffle.api.CommittedPartition;
+import org.apache.spark.shuffle.api.ShufflePartitionWriter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+public class ExternalShufflePartitionWriter implements ShufflePartitionWriter {
+
+ private static final Logger logger =
+ LoggerFactory.getLogger(ExternalShufflePartitionWriter.class);
+
+ private final TransportClientFactory clientFactory;
+ private final String hostName;
+ private final int port;
+ private final String appId;
+ private final int shuffleId;
+ private final int mapId;
+ private final int partitionId;
+
+ private long totalLength = 0;
+ private ByteArrayOutputStream partitionBuffer;
+
+ public ExternalShufflePartitionWriter(
+ TransportClientFactory clientFactory,
+ String hostName,
+ int port,
+ String appId,
+ int shuffleId,
+ int mapId,
+ int partitionId) {
+ this.clientFactory = clientFactory;
+ this.hostName = hostName;
+ this.port = port;
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.partitionId = partitionId;
+ // TODO: Set buffer size
+ this.partitionBuffer = new ByteArrayOutputStream();
+ }
+
+ @Override
+ public OutputStream openPartitionStream() { return partitionBuffer; }
+
+ @Override
+ public CommittedPartition commitPartition() {
+ RpcResponseCallback callback = new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ logger.info("Successfully uploaded partition");
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.error("Encountered an error uploading partition", e);
+ }
+ };
+ TransportClient client = null;
+ try {
+ byte[] buf = partitionBuffer.toByteArray();
+ int size = buf.length;
+ ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId,
+ partitionId, size).toByteBuffer();
+ ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf));
+ client = clientFactory.createUnmanagedClient(hostName, port);
+ client.setClientId(String.format("data-%s-%d-%d-%d",
+ appId, shuffleId, mapId, partitionId));
+ logger.info("clientid: " + client.getClientId() + " " + client.isActive());
+ logger.info("THE BUFFER HASH CODE IS: " + Arrays.hashCode(buf));
+ client.uploadStream(new NioManagedBuffer(streamHeader), managedBuffer, callback);
+ totalLength += size;
+ logger.info("Partition Length: " + totalLength);
+ logger.info("Size: " + size);
+ } catch (Exception e) {
+ if (client != null) {
+ partitionBuffer = null;
+ client.close();
+ }
+ logger.error("Encountered error while attempting to upload partition to ESS", e);
+ throw new RuntimeException(e);
+ } finally {
+ logger.info("Successfully sent partition to ESS");
+ }
+ return new ExternalCommittedPartition(
+ totalLength, new ExternalShuffleLocation(hostName, port));
+ }
+
+ @Override
+ public void abort(Exception failureReason) {
+ try {
+ partitionBuffer.close();
+ partitionBuffer = null;
+ } catch(IOException e) {
+ logger.error("Failed to close streams after failing to upload partition", e);
+ }
+ logger.error("Encountered error while attempting" +
+ "to upload partition to ESS", failureReason);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java
new file mode 100644
index 0000000000000..0bde10a777668
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java
@@ -0,0 +1,56 @@
+package org.apache.spark.shuffle.external;
+
+import com.google.common.collect.Lists;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.crypto.AuthClientBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.shuffle.api.ShufflePartitionReader;
+import org.apache.spark.shuffle.api.ShuffleReadSupport;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+
+public class ExternalShuffleReadSupport implements ShuffleReadSupport {
+
+ private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleReadSupport.class);
+
+ private final TransportConf conf;
+ private final TransportContext context;
+ private final boolean authEnabled;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public ExternalShuffleReadSupport(
+ TransportConf conf,
+ TransportContext context,
+ boolean authEnabled,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.context = context;
+ this.authEnabled = authEnabled;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ @Override
+ public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId) {
+ // TODO combine this into a function with ExternalShuffleWriteSupport
+ List bootstraps = Lists.newArrayList();
+ if (authEnabled) {
+ bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder));
+ }
+ TransportClientFactory clientFactory = context.createClientFactory(bootstraps);
+ try {
+ return new ExternalShufflePartitionReader(clientFactory,
+ appId,
+ shuffleId,
+ mapId);
+ } catch (Exception e) {
+ clientFactory.close();
+ logger.error("Encountered creating transport client for partition reader");
+ throw new RuntimeException(e); // what is standard practice here?
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java
new file mode 100644
index 0000000000000..413c2fd63f20a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java
@@ -0,0 +1,54 @@
+package org.apache.spark.shuffle.external;
+
+import com.google.common.collect.Lists;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.crypto.AuthClientBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.ShuffleWriteSupport;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+
+public class ExternalShuffleWriteSupport implements ShuffleWriteSupport {
+
+ private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleWriteSupport.class);
+
+ private final TransportConf conf;
+ private final TransportContext context;
+ private final boolean authEnabled;
+ private final SecretKeyHolder secretKeyHolder;
+ private final String hostname;
+ private final int port;
+
+ public ExternalShuffleWriteSupport(
+ TransportConf conf,
+ TransportContext context,
+ boolean authEnabled,
+ SecretKeyHolder secretKeyHolder,
+ String hostname,
+ int port) {
+ this.conf = conf;
+ this.context = context;
+ this.authEnabled = authEnabled;
+ this.secretKeyHolder = secretKeyHolder;
+ this.hostname = hostname;
+ this.port = port;
+}
+
+ @Override
+ public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) {
+ List bootstraps = Lists.newArrayList();
+ if (authEnabled) {
+ bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder));
+ }
+ TransportClientFactory clientFactory = context.createClientFactory(bootstraps);
+ logger.info("Clientfactory: " + clientFactory.toString());
+ return new ExternalShuffleMapOutputWriter(
+ clientFactory, hostname, port, appId, shuffleId, mapId);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index c683b2854b17c..b21d37401c059 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -17,24 +17,8 @@
package org.apache.spark.shuffle.sort;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.OutputStream;
-import javax.annotation.Nullable;
-
-import scala.None$;
-import scala.Option;
-import scala.Product2;
-import scala.Tuple2;
-import scala.collection.Iterator;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
@@ -42,14 +26,26 @@
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.api.CommittedPartition;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.ShuffleWriteSupport;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.None$;
+import scala.Option;
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import javax.annotation.Nullable;
+import java.io.*;
+import java.util.Arrays;
/**
* This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
@@ -94,7 +90,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
private DiskBlockObjectWriter[] partitionWriters;
private FileSegment[] partitionWriterSegments;
@Nullable private MapStatus mapStatus;
- private long[] partitionLengths;
+ private CommittedPartition[] committedPartitions;
/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
@@ -131,7 +127,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
public void write(Iterator> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
- partitionLengths = new long[numPartitions];
+ long[] partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
@@ -166,25 +162,31 @@ public void write(Iterator> records) throws IOException {
}
if (pluggableWriteSupport != null) {
- partitionLengths = combineAndWritePartitionsUsingPluggableWriter();
+ committedPartitions = combineAndWritePartitionsUsingPluggableWriter();
+ logger.info("Successfully wrote partitions with pluggable writer");
} else {
File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
File tmp = Utils.tempFileWith(output);
try {
- partitionLengths = combineAndWritePartitions(tmp);
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
+ committedPartitions = combineAndWritePartitions(tmp);
+ logger.info("Successfully wrote partitions without shuffle");
+ // TODO: Investigate when commitedPartitions is null or returns empty
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId,
+ mapId,
+ Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray(),
+ tmp);
} finally {
if (tmp != null && tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions);
}
@VisibleForTesting
long[] getPartitionLengths() {
- return partitionLengths;
+ return Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray();
}
/**
@@ -192,12 +194,12 @@ long[] getPartitionLengths() {
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
- private long[] combineAndWritePartitions(File outputFile) throws IOException {
+ private CommittedPartition[] combineAndWritePartitions(File outputFile) throws IOException {
// Track location of the partition starts in the output file
- final long[] lengths = new long[numPartitions];
+ final CommittedPartition[] partitions = new CommittedPartition[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
- return lengths;
+ return partitions;
}
assert(outputFile != null);
final FileOutputStream out = new FileOutputStream(outputFile, true);
@@ -210,7 +212,8 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException {
final FileInputStream in = new FileInputStream(file);
boolean copyThrewException = true;
try {
- lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+ partitions[i] =
+ new LocalCommittedPartition(Utils.copyStream(in, out, false, transferToEnabled));
copyThrewException = false;
} finally {
Closeables.close(in, copyThrewException);
@@ -225,21 +228,21 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException {
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
- return lengths;
+ return partitions;
}
- private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOException {
+ private CommittedPartition[] combineAndWritePartitionsUsingPluggableWriter() throws IOException {
// Track location of the partition starts in the output file
- final long[] lengths = new long[numPartitions];
+ final CommittedPartition[] partitions = new CommittedPartition[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
- return lengths;
+ return partitions;
}
assert(pluggableWriteSupport != null);
final long writeStartTime = System.nanoTime();
ShuffleMapOutputWriter mapOutputWriter = pluggableWriteSupport.newMapOutputWriter(
- appId, shuffleId, mapId);
+ appId, shuffleId, mapId);
try {
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
@@ -251,7 +254,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio
try (OutputStream out = writer.openPartitionStream()) {
Utils.copyStream(in, out, false, false);
}
- lengths[i] = writer.commitAndGetTotalLength();
+ partitions[i] = writer.commitPartition();
copyThrewException = false;
} catch (Exception e) {
try {
@@ -279,7 +282,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
- return lengths;
+ return partitions;
}
@Override
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java
new file mode 100644
index 0000000000000..817855d957966
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java
@@ -0,0 +1,25 @@
+package org.apache.spark.shuffle.sort;
+
+import org.apache.spark.shuffle.api.CommittedPartition;
+import org.apache.spark.storage.ShuffleLocation;
+
+import java.util.Optional;
+
+public class LocalCommittedPartition implements CommittedPartition {
+
+ private final long length;
+
+ public LocalCommittedPartition(long length) {
+ this.length = length;
+ }
+
+ @Override
+ public long length() {
+ return length;
+ }
+
+ @Override
+ public Optional shuffleLocation() {
+ return Optional.empty();
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 150a783aa87fb..97f34bf460495 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -20,8 +20,11 @@
import javax.annotation.Nullable;
import java.io.*;
import java.nio.channels.FileChannel;
+import java.util.Arrays;
import java.util.Iterator;
+import java.util.stream.Collectors;
+import org.apache.spark.shuffle.api.CommittedPartition;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
@@ -85,7 +88,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
private final int initialSortBufferSize;
private final int inputBufferSizeInBytes;
private final int outputBufferSizeInBytes;
- private final ShuffleWriteSupport pluggableWriteSupport; // TODO initialize
+ private final ShuffleWriteSupport pluggableWriteSupport;
@Nullable private MapStatus mapStatus;
@Nullable private ShuffleExternalSorter sorter;
@@ -236,12 +239,12 @@ void closeAndWriteOutput() throws IOException {
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
- final long[] partitionLengths;
+ final CommittedPartition[] committedPartitions;
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final File tmp = Utils.tempFileWith(output);
try {
try {
- partitionLengths = mergeSpills(spills, tmp);
+ committedPartitions = mergeSpills(spills, tmp);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && ! spill.file.delete()) {
@@ -250,14 +253,17 @@ void closeAndWriteOutput() throws IOException {
}
}
if (pluggableWriteSupport == null) {
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId,
+ mapId,
+ Arrays.stream(committedPartitions).mapToLong(CommittedPartition::length).toArray(),
+ tmp);
}
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions);
}
@VisibleForTesting
@@ -289,7 +295,7 @@ void forceSorterToSpill() throws IOException {
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
+ private CommittedPartition[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
final CompressionCodec compressionCodec =
compressionEnabled ? CompressionCodec$.MODULE$.createCodec(sparkConf) : null;
@@ -301,18 +307,18 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
- return new long[partitioner.numPartitions()];
+ return new CommittedPartition[partitioner.numPartitions()];
} else if (spills.length == 1) {
if (pluggableWriteSupport != null) {
- writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec);
+ return writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec);
} else {
// Here, we don't need to perform any metrics updates because the bytes written to this
// output file would have already been counted as shuffle bytes written.
Files.move(spills[0].file, outputFile);
}
- return spills[0].partitionLengths;
+ return toLocalCommittedPartition(spills[0].partitionLengths);
} else {
- final long[] partitionLengths;
+ final CommittedPartition[] committedPartitions;
// There are multiple spills to merge, so none of these spill files' lengths were counted
// towards our shuffle write count or shuffle write time. If we use the slow merge path,
// then the final output file's size won't necessarily be equal to the sum of the spill
@@ -324,21 +330,24 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
// shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
// branch in ExternalSorter.
if (pluggableWriteSupport != null) {
- partitionLengths = mergeSpillsWithPluggableWriter(spills, compressionCodec);
+ committedPartitions = mergeSpillsWithPluggableWriter(spills, compressionCodec);
} else if (fastMergeEnabled && fastMergeIsSupported) {
// Compression is disabled or we are using an IO compression codec that supports
// decompression of concatenated compressed streams, so we can perform a fast spill merge
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled && !encryptionEnabled) {
logger.debug("Using transferTo-based fast merge");
- partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+ committedPartitions =
+ toLocalCommittedPartition(mergeSpillsWithTransferTo(spills, outputFile));
} else {
logger.debug("Using fileStream-based fast merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+ committedPartitions = toLocalCommittedPartition(
+ mergeSpillsWithFileStream(spills, outputFile, null));
}
} else {
logger.debug("Using slow merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ committedPartitions = toLocalCommittedPartition(
+ mergeSpillsWithFileStream(spills, outputFile, compressionCodec));
}
// When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
// in-memory records, we write out the in-memory records to a file but do not count that
@@ -349,7 +358,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
if (pluggableWriteSupport == null) {
writeMetrics.incBytesWritten(outputFile.length());
}
- return partitionLengths;
+ return committedPartitions;
}
} catch (IOException e) {
if (outputFile.exists() && !outputFile.delete()) {
@@ -359,6 +368,12 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
}
}
+ private static CommittedPartition[] toLocalCommittedPartition(long[] partitionLengths) {
+ return Arrays.stream(partitionLengths)
+ .mapToObj(length -> new LocalCommittedPartition(length))
+ .collect(Collectors.toList()).toArray(new CommittedPartition[partitionLengths.length]);
+ }
+
/**
* Merges spill files using Java FileStreams. This code path is typically slower than
* the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[],
@@ -512,13 +527,13 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
/**
* Merges spill files using the ShufflePartitionWriter API.
*/
- private long[] mergeSpillsWithPluggableWriter(
+ private CommittedPartition[] mergeSpillsWithPluggableWriter(
SpillInfo[] spills,
@Nullable CompressionCodec compressionCodec) throws IOException {
assert (spills.length >= 2);
assert(pluggableWriteSupport != null);
final int numPartitions = partitioner.numPartitions();
- final long[] partitionLengths = new long[numPartitions];
+ final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions];
final InputStream[] spillInputStreams = new InputStream[spills.length];
boolean threwException = true;
@@ -533,6 +548,10 @@ private long[] mergeSpillsWithPluggableWriter(
ShufflePartitionWriter writer = mapOutputWriter.newPartitionWriter(partition);
try {
try (OutputStream partitionOutput = writer.openPartitionStream()) {
+ OutputStream partitionOutputStream = partitionOutput;
+ if (compressionCodec != null) {
+ partitionOutputStream = compressionCodec.compressedOutputStream(partitionOutput);
+ }
for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
if (partitionLengthInSpill > 0) {
@@ -542,17 +561,18 @@ private long[] mergeSpillsWithPluggableWriter(
partitionInputStream = blockManager.serializerManager().wrapForEncryption(
partitionInputStream);
if (compressionCodec != null) {
- partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+ partitionInputStream =
+ compressionCodec.compressedInputStream(partitionInputStream);
}
- Utils.copyStream(partitionInputStream, partitionOutput, false, false);
+ Utils.copyStream(partitionInputStream, partitionOutputStream, false, false);
} finally {
partitionInputStream.close();
}
}
}
}
- partitionLengths[partition] = writer.commitAndGetTotalLength();
- writeMetrics.incBytesWritten(partitionLengths[partition]);
+ committedPartitions[partition] = writer.commitPartition();
+ writeMetrics.incBytesWritten(committedPartitions[partition].length());
} catch (Exception e) {
try {
writer.abort(e);
@@ -578,14 +598,15 @@ private long[] mergeSpillsWithPluggableWriter(
Closeables.close(stream, threwException);
}
}
- return partitionLengths;
+ return committedPartitions;
}
- private void writeSingleSpillFileUsingPluggableWriter(
+ private CommittedPartition[] writeSingleSpillFileUsingPluggableWriter(
SpillInfo spillInfo,
@Nullable CompressionCodec compressionCodec) throws IOException {
assert(pluggableWriteSupport != null);
final int numPartitions = partitioner.numPartitions();
+ final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions];
boolean threwException = true;
InputStream spillInputStream = new NioBufferedFileInputStream(
spillInfo.file,
@@ -604,7 +625,11 @@ private void writeSingleSpillFileUsingPluggableWriter(
partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
}
try (OutputStream partitionOutput = writer.openPartitionStream()) {
- Utils.copyStream(partitionInputStream, partitionOutput, false, false);
+ OutputStream partitionOutputStream = partitionOutput;
+ if (compressionCodec != null) {
+ partitionOutputStream = compressionCodec.compressedOutputStream(partitionOutput);
+ }
+ Utils.copyStream(partitionInputStream, partitionOutputStream, false, false);
}
} catch (Exception e) {
try {
@@ -616,9 +641,11 @@ private void writeSingleSpillFileUsingPluggableWriter(
} finally {
partitionInputStream.close();
}
- writeMetrics.incBytesWritten(writer.commitAndGetTotalLength());
+ committedPartitions[partition] = writer.commitPartition();
+ writeMetrics.incBytesWritten(committedPartitions[partition].length());
}
threwException = false;
+ mapOutputWriter.commitAllPartitions();
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
@@ -630,6 +657,7 @@ private void writeSingleSpillFileUsingPluggableWriter(
Closeables.close(spillInputStream, threwException);
}
writeMetrics.decBytesWritten(spillInfo.file.length());
+ return committedPartitions;
}
@Override
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 1c4fa4bc6541f..214ff3ee18feb 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -33,8 +33,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.shuffle.MetadataFetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
+import org.apache.spark.shuffle._
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleLocation}
import org.apache.spark.util._
/**
@@ -213,6 +213,7 @@ private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
+private[spark] case object GetRemoteShuffleServiceAddresses extends MapOutputTrackerMessage
private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext)
@@ -233,6 +234,9 @@ private[spark] class MapOutputTrackerMasterEndpoint(
logInfo("MapOutputTrackerMasterEndpoint stopped!")
context.reply(true)
stop()
+
+ case GetRemoteShuffleServiceAddresses =>
+ context.reply(tracker.getRemoteShuffleServiceAddress())
}
}
@@ -298,6 +302,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]
+ def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int) : Option[ShuffleLocation]
+
+ def getRemoteShuffleServiceAddress(): java.util.List[(String, Integer)]
+
/**
* Deletes map output status information for the specified shuffle stage.
*/
@@ -318,7 +326,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
private[spark] class MapOutputTrackerMaster(
conf: SparkConf,
broadcastManager: BroadcastManager,
- isLocal: Boolean)
+ isLocal: Boolean,
+ shuffleServiceAddressProvider: ShuffleServiceAddressProvider
+ = DefaultShuffleServiceAddressProvider)
extends MapOutputTracker(conf) {
// The size at which we use Broadcast to send the map output statuses to the executors
@@ -644,6 +654,10 @@ private[spark] class MapOutputTrackerMaster(
}
}
+ override def getRemoteShuffleServiceAddress(): java.util.List[(String, Integer)] =
+ shuffleServiceAddressProvider
+ .getShuffleServiceAddresses().map { case (h, p) => (h, new Integer(p))}.asJava
+
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
@@ -666,6 +680,14 @@ private[spark] class MapOutputTrackerMaster(
trackerEndpoint = null
shuffleStatuses.clear()
}
+
+ override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int):
+ Option[ShuffleLocation] = {
+ shuffleStatuses.get(shuffleId) match {
+ case Some(shuffleStatus) => shuffleStatus.mapStatuses(mapId).shuffleLocationForBlock(reduceId)
+ case None => Option.empty
+ }
+ }
}
/**
@@ -779,6 +801,19 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
}
}
}
+
+ override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int):
+ Option[ShuffleLocation] = {
+ mapStatuses.get(shuffleId) match {
+ case Some(shuffleStatus) => shuffleStatus(mapId).shuffleLocationForBlock(reduceId)
+ case None => Option.empty
+ }
+ }
+
+ override def getRemoteShuffleServiceAddress(): java.util.List[(String, Integer)] = {
+ trackerEndpoint
+ .askSync[java.util.List[(String, Integer)]](GetRemoteShuffleServiceAddresses)
+ }
}
private[spark] object MapOutputTracker extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 845a3d5f6d6f9..247016584d1f2 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -576,6 +576,10 @@ class SparkContext(config: SparkConf) extends Logging {
_env.metricsSystem.registerSource(e.executorAllocationManagerSource)
}
appStatusSource.foreach(_env.metricsSystem.registerSource(_))
+
+ // Initialize the ShuffleDataIo
+ _env.shuffleDataIO.foreach(_.initialize())
+
// Make sure the context is stopped if the user forgets about it. This avoids leaving
// unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM
// is killed, though.
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 66038eeaea54f..c2b56864bf36d 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -21,11 +21,10 @@ import java.io.File
import java.net.Socket
import java.util.Locale
+import com.google.common.collect.MapMaker
import scala.collection.mutable
import scala.util.Properties
-import com.google.common.collect.MapMaker
-
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
@@ -39,7 +38,8 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
-import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleManager, ShuffleServiceAddressProviderFactory}
+import org.apache.spark.shuffle.api.ShuffleDataIO
import org.apache.spark.storage._
import org.apache.spark.util.{RpcUtils, Utils}
@@ -66,6 +66,7 @@ class SparkEnv (
val blockManager: BlockManager,
val securityManager: SecurityManager,
val metricsSystem: MetricsSystem,
+ val shuffleDataIO: Option[ShuffleDataIO],
val memoryManager: MemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
@@ -302,7 +303,26 @@ object SparkEnv extends Logging {
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
val mapOutputTracker = if (isDriver) {
- new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
+ val master = conf.get("spark.master")
+ val shuffleProvider = conf.get(SHUFFLE_SERVICE_PROVIDER_CLASS)
+ .map(clazz => Utils.loadExtensions(
+ classOf[ShuffleServiceAddressProviderFactory],
+ Seq(clazz), conf)).getOrElse(Seq())
+ val serviceLoaders = shuffleProvider
+ .filter(_.canCreate(conf.get("spark.master")))
+ if (serviceLoaders.size > 1) {
+ throw new SparkException(
+ s"Multiple external cluster managers registered for the url $master: $serviceLoaders")
+ }
+ val loader = Utils.getContextOrSparkClassLoader
+ logInfo(s"Loader: $loader")
+ logInfo(s"Service loader: $serviceLoaders")
+ val shuffleServiceAddressProvider = serviceLoaders.headOption
+ .map(_.create(conf))
+ .getOrElse(DefaultShuffleServiceAddressProvider)
+ shuffleServiceAddressProvider.start()
+
+ new MapOutputTrackerMaster(conf, broadcastManager, isLocal, shuffleServiceAddressProvider)
} else {
new MapOutputTrackerWorker(conf)
}
@@ -365,6 +385,9 @@ object SparkEnv extends Logging {
ms
}
+ val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
+ .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head)
+
val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
new OutputCommitCoordinator(conf, isDriver)
}
@@ -384,6 +407,7 @@ object SparkEnv extends Logging {
blockManager,
securityManager,
metricsSystem,
+ shuffleIoPlugin,
memoryManager,
outputCommitCoordinator,
conf)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index a30a501e5d4a1..b6f4fc1921bf4 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -118,6 +118,8 @@ private[spark] class Executor(
env.blockManager.initialize(conf.getAppId)
env.metricsSystem.registerSource(executorSource)
env.metricsSystem.registerSource(env.blockManager.shuffleMetricsSource)
+ // Initialize the ShuffleDataIO
+ env.shuffleDataIO.foreach(_.initialize())
}
// Whether to load classes in user jars before those in Spark jars
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index bede012e33977..c249565fa1519 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -182,6 +182,9 @@ package object config {
private[spark] val SHUFFLE_SERVICE_ENABLED =
ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false)
+ private[spark] val K8S_SHUFFLE_SERVICE_ENABLED =
+ ConfigBuilder("spark.k8s.shuffle.service.enabled").booleanConf.createWithDefault(false)
+
private[spark] val SHUFFLE_SERVICE_PORT =
ConfigBuilder("spark.shuffle.service.port").intConf.createWithDefault(7337)
@@ -441,6 +444,12 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[spark] val SHUFFLE_SERVICE_PROVIDER_CLASS =
+ ConfigBuilder("spark.shuffle.provider.plugin.class")
+ .doc("Experimental. Specify a class that can handle detecting shuffle service pods.")
+ .stringConf
+ .createOptional
+
private[spark] val SHUFFLE_IO_PLUGIN_CLASS =
ConfigBuilder("spark.shuffle.io.plugin.class")
.doc("Experimental. Specify a class that can handle reading and writing shuffle blocks to" +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 64f0a060a247c..21613a5946f68 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,13 +19,13 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
-import scala.collection.mutable
-
import org.roaringbitmap.RoaringBitmap
+import scala.collection.mutable
import org.apache.spark.SparkEnv
import org.apache.spark.internal.config
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.shuffle.api.CommittedPartition
+import org.apache.spark.storage.{BlockManagerId, ShuffleLocation}
import org.apache.spark.util.Utils
/**
@@ -36,6 +36,8 @@ private[spark] sealed trait MapStatus {
/** Location where this task was run. */
def location: BlockManagerId
+ def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation]
+
/**
* Estimated size for the reduce block, in bytes.
*
@@ -56,11 +58,29 @@ private[spark] object MapStatus {
.map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
.getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)
+ def apply(loc: BlockManagerId, committedPartitions: Array[CommittedPartition]): MapStatus = {
+ val shuffleLocationsArray = committedPartitions.collect {
+ case partition if partition != null && partition.shuffleLocation().isPresent
+ => partition.shuffleLocation().get()
+ case _ => null
+ }
+ val lengthsArray = committedPartitions.collect {
+ case partition if partition != null => partition.length()
+ case _ => 0
+
+ }
+ if (committedPartitions.length > minPartitionsToUseHighlyCompressMapStatus) {
+ HighlyCompressedMapStatus(loc, lengthsArray, shuffleLocationsArray)
+ } else {
+ new CompressedMapStatus(loc, lengthsArray, shuffleLocationsArray)
+ }
+ }
+
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
- HighlyCompressedMapStatus(loc, uncompressedSizes)
+ HighlyCompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation])
} else {
- new CompressedMapStatus(loc, uncompressedSizes)
+ new CompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation])
}
}
@@ -103,17 +123,28 @@ private[spark] object MapStatus {
*/
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
- private[this] var compressedSizes: Array[Byte])
+ private[this] var compressedSizes: Array[Byte],
+ private[this] var shuffleLocations: Array[ShuffleLocation])
extends MapStatus with Externalizable {
- protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
+ // For deserialization only
+ protected def this() = this(null, null.asInstanceOf[Array[Byte]], null)
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
- this(loc, uncompressedSizes.map(MapStatus.compressSize))
+ def this(loc: BlockManagerId, uncompressedSizes: Array[Long],
+ shuffleLocations: Array[ShuffleLocation]) {
+ this(loc, uncompressedSizes.map(MapStatus.compressSize), shuffleLocations)
}
override def location: BlockManagerId = loc
+ override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = {
+ if (shuffleLocations.apply(reduceId) == null) {
+ Option.empty
+ } else {
+ Option.apply(shuffleLocations.apply(reduceId))
+ }
+ }
+
override def getSizeForBlock(reduceId: Int): Long = {
MapStatus.decompressSize(compressedSizes(reduceId))
}
@@ -122,6 +153,7 @@ private[spark] class CompressedMapStatus(
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
+ out.writeObject(shuffleLocations)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -129,6 +161,7 @@ private[spark] class CompressedMapStatus(
val len = in.readInt()
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
+ shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]]
}
}
@@ -148,17 +181,26 @@ private[spark] class HighlyCompressedMapStatus private (
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
- private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte])
+ private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte],
+ private[this] var shuffleLocations: Array[ShuffleLocation])
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during deserialization
require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1, null) // For deserialization only
+ protected def this() = this(null, -1, null, -1, null, null) // For deserialization only
override def location: BlockManagerId = loc
+ override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = {
+ if (shuffleLocations.apply(reduceId) == null) {
+ Option.empty
+ } else {
+ Option.apply(shuffleLocations.apply(reduceId))
+ }
+ }
+
override def getSizeForBlock(reduceId: Int): Long = {
assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
@@ -180,6 +222,7 @@ private[spark] class HighlyCompressedMapStatus private (
out.writeInt(kv._1)
out.writeByte(kv._2)
}
+ out.writeObject(shuffleLocations)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -195,11 +238,17 @@ private[spark] class HighlyCompressedMapStatus private (
hugeBlockSizesImpl(block) = size
}
hugeBlockSizes = hugeBlockSizesImpl
+ shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]]
}
}
private[spark] object HighlyCompressedMapStatus {
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ apply(loc, uncompressedSizes, Array.empty[ShuffleLocation])
+ }
+
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long],
+ shuffleLocation: Array[ShuffleLocation]): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
@@ -240,6 +289,6 @@ private[spark] object HighlyCompressedMapStatus {
emptyBlocks.trim()
emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
- hugeBlockSizes)
+ hugeBlockSizes, shuffleLocation)
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 7632c35f0318a..caeecedc5d36e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle
+import scala.compat.java8.OptionConverters
+
import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer.SerializerManager
@@ -53,9 +55,13 @@ private[spark] class BlockStoreShuffleReader[K, C](
appId, handle.shuffleId, mapId)
blockIds.map {
case blockId@ShuffleBlockId(_, _, reduceId) =>
- (blockId, reader.fetchPartition(reduceId))
+ (blockId, serializerManager.wrapStream(blockId,
+ reader.fetchPartition(reduceId, OptionConverters.toJava(
+ mapOutputTracker.getShuffleLocation(handle.shuffleId, mapId, reduceId)))))
case dataBlockId@ShuffleDataBlockId(_, _, reduceId) =>
- (dataBlockId, reader.fetchPartition(reduceId))
+ (dataBlockId, serializerManager.wrapStream(dataBlockId,
+ reader.fetchPartition(reduceId, OptionConverters.toJava(
+ mapOutputTracker.getShuffleLocation(handle.shuffleId, mapId, reduceId)))))
case invalid =>
throw new IllegalArgumentException(s"Invalid block id $invalid")
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala
index 2eed51962181c..8a776281041d1 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala
@@ -20,38 +20,36 @@ package org.apache.spark.shuffle
import java.io.{InputStream, OutputStream}
import java.nio.ByteBuffer
-import org.apache.spark.network.util.LimitedInputStream
+import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.api.ShufflePartitionWriter
+import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.{ByteBufferInputStream, Utils}
class ShufflePartitionWriterOutputStream(
- partitionWriter: ShufflePartitionWriter, buffer: ByteBuffer, bufferSize: Int)
- extends OutputStream {
+ blockId: ShuffleBlockId,
+ partitionWriter: ShufflePartitionWriter,
+ buffer: ByteBuffer,
+ serializerManager: SerializerManager)
+ extends OutputStream {
- private var currentChunkSize = 0
- private val bufferForRead = buffer.asReadOnlyBuffer()
private var underlyingOutputStream: OutputStream = _
override def write(b: Int): Unit = {
- buffer.putInt(b)
- currentChunkSize += 1
- if (currentChunkSize == bufferSize) {
+ buffer.put(b.asInstanceOf[Byte])
+ if (buffer.remaining() == 0) {
pushBufferedBytesToUnderlyingOutput()
}
}
private def pushBufferedBytesToUnderlyingOutput(): Unit = {
- bufferForRead.reset()
- var bufferInputStream: InputStream = new ByteBufferInputStream(bufferForRead)
- if (currentChunkSize < bufferSize) {
- bufferInputStream = new LimitedInputStream(bufferInputStream, currentChunkSize)
- }
+ buffer.flip()
+ var bufferInputStream: InputStream = new ByteBufferInputStream(buffer)
if (underlyingOutputStream == null) {
- underlyingOutputStream = partitionWriter.openPartitionStream()
+ underlyingOutputStream = serializerManager.wrapStream(blockId,
+ partitionWriter.openPartitionStream())
}
Utils.copyStream(bufferInputStream, underlyingOutputStream, false, false)
- buffer.reset()
- currentChunkSize = 0
+ buffer.clear()
}
override def flush(): Unit = {
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala
similarity index 60%
rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala
rename to core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala
index 83daddf714489..96d529872b306 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala
@@ -14,24 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster.k8s
-import io.fabric8.kubernetes.api.model.Pod
+package org.apache.spark.shuffle
-sealed trait ExecutorPodState {
- def pod: Pod
+trait ShuffleServiceAddressProvider {
+ def start(): Unit = {}
+ def getShuffleServiceAddresses(): List[(String, Int)]
+ def stop(): Unit = {}
}
-case class PodRunning(pod: Pod) extends ExecutorPodState
-
-case class PodPending(pod: Pod) extends ExecutorPodState
-
-sealed trait FinalPodState extends ExecutorPodState
-
-case class PodSucceeded(pod: Pod) extends FinalPodState
-
-case class PodFailed(pod: Pod) extends FinalPodState
-
-case class PodDeleted(pod: Pod) extends FinalPodState
-
-case class PodUnknown(pod: Pod) extends ExecutorPodState
+private[spark] object DefaultShuffleServiceAddressProvider extends ShuffleServiceAddressProvider {
+ override def getShuffleServiceAddresses(): List[(String, Int)] = List.empty[(String, Int)]
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala
new file mode 100644
index 0000000000000..68adb8e44585c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.SparkConf
+
+trait ShuffleServiceAddressProviderFactory {
+ def canCreate(masterUrl: String): Boolean
+ def create(conf: SparkConf): ShuffleServiceAddressProvider
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 53a5c4f3afba4..eb7ae313918ed 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -77,11 +77,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
" Shuffle will continue to spill to disk when necessary.")
}
- private val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
- .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head)
-
- shuffleIoPlugin.foreach(_.initialize())
-
/**
* A mapping from shuffle ids to the number of mappers producing output for those shuffles.
*/
@@ -131,7 +126,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
endPartition,
context,
metrics,
- shuffleIoPlugin.map(_.readSupport()))
+ SparkEnv.get.shuffleDataIO.map(_.readSupport()))
}
/** Get a writer for a given partition. Called on executors by map tasks. */
@@ -154,7 +149,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
context,
env.conf,
metrics,
- shuffleIoPlugin.map(_.writeSupport()).orNull)
+ env.shuffleDataIO.map(_.writeSupport()).orNull)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
@@ -163,10 +158,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
mapId,
env.conf,
metrics,
- shuffleIoPlugin.map(_.writeSupport()).orNull)
+ env.shuffleDataIO.map(_.writeSupport()).orNull)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(
- shuffleBlockResolver, other, mapId, context, shuffleIoPlugin.map(_.writeSupport()))
+ shuffleBlockResolver, other, mapId, context, env.shuffleDataIO.map(_.writeSupport()))
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 1c804c99d0e31..b6ab2f354e81b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -70,13 +70,16 @@ private[spark] class SortShuffleWriter[K, V, C](
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
- val partitionLengths = pluggableWriteSupport.map { writeSupport =>
- sorter.writePartitionedToExternalShuffleWriteSupport(mapId, dep.shuffleId, writeSupport)
+ val committedPartitions = pluggableWriteSupport.map { writeSupport =>
+ sorter.writePartitionedToExternalShuffleWriteSupport(blockId, writeSupport)
}.getOrElse(sorter.writePartitionedFile(blockId, tmp))
if (pluggableWriteSupport.isEmpty) {
- shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
+ shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId,
+ mapId,
+ committedPartitions.map(_.length()),
+ tmp)
}
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
+ mapStatus = MapStatus(blockManager.shuffleServerId, committedPartitions)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 1dfbc6effb346..f604f20ee7220 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -132,13 +132,14 @@ private[spark] class BlockManager(
private[spark] val externalShuffleServiceEnabled =
conf.get(config.SHUFFLE_SERVICE_ENABLED)
+
private val remoteReadNioBufferConversion =
conf.getBoolean("spark.network.remoteReadNioBufferConversion", false)
val diskBlockManager = {
// Only perform cleanup if an external service is not serving our shuffle files.
- val deleteFilesOnStop =
- !externalShuffleServiceEnabled || executorId == SparkContext.DRIVER_IDENTIFIER
+ val deleteFilesOnStop = !externalShuffleServiceEnabled ||
+ executorId == SparkContext.DRIVER_IDENTIFIER
new DiskBlockManager(conf, deleteFilesOnStop)
}
@@ -259,8 +260,8 @@ private[spark] class BlockManager(
blockManagerId
}
- // Register Executors' configuration with the local shuffle service, if one should exist.
if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
+ // Register Executors' configuration with the local shuffle service, if one should exist.
registerWithExternalShuffleServer()
}
@@ -1663,7 +1664,7 @@ private[spark] object BlockManager {
blockManagers.toMap
}
- private class ShuffleMetricsSource(
+ class ShuffleMetricsSource(
override val sourceName: String,
metricSet: MetricSet) extends Source {
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 86f7c08eddcb5..cf8e4793c448e 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -437,7 +437,6 @@ final class ShuffleBlockFetcherIterator(
s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
throwFetchFailedException(blockId, address, new IOException(msg))
}
-
val in = try {
buf.createInputStream()
} catch {
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala
new file mode 100644
index 0000000000000..72846cb001c8a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import java.io.Externalizable
+
+trait ShuffleLocation extends Externalizable {
+
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala
index b03276b2ce16f..b2263a51051a0 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala
@@ -19,9 +19,9 @@ package org.apache.spark.storage
import java.nio.ByteBuffer
-import org.apache.spark.serializer.{SerializationStream, SerializerInstance}
+import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShufflePartitionWriterOutputStream
-import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
+import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleMapOutputWriter, ShufflePartitionWriter}
/**
* Replicates the concept of {@link DiskBlockObjectWriter}, but with some key differences:
@@ -30,10 +30,12 @@ import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWri
* left to the implementation of the underlying implementation of the writer plugin.
*/
private[spark] class ShufflePartitionObjectWriter(
+ blockId: ShuffleBlockId,
bufferSize: Int,
serializerInstance: SerializerInstance,
+ serializerManager: SerializerManager,
mapOutputWriter: ShuffleMapOutputWriter)
- extends PairsWriter {
+ extends PairsWriter {
// Reused buffer. Experiments should be done with off-heap at some point.
private val buffer = ByteBuffer.allocate(bufferSize)
@@ -44,22 +46,21 @@ private[spark] class ShufflePartitionObjectWriter(
def startNewPartition(partitionId: Int): Unit = {
require(buffer.position() == 0,
"Buffer was not flushed to the underlying output on the previous partition.")
- buffer.reset()
currentWriter = mapOutputWriter.newPartitionWriter(partitionId)
val currentWriterStream = new ShufflePartitionWriterOutputStream(
- currentWriter, buffer, bufferSize)
+ blockId, currentWriter, buffer, serializerManager)
objectOutputStream = serializerInstance.serializeStream(currentWriterStream)
}
- def commitCurrentPartition(): Long = {
+ def commitCurrentPartition(): CommittedPartition = {
require(objectOutputStream != null, "Cannot commit a partition that has not been started.")
require(currentWriter != null, "Cannot commit a partition that has not been started.")
objectOutputStream.close()
- val length = currentWriter.commitAndGetTotalLength()
- buffer.reset()
+ val committedPartition = currentWriter.commitPartition()
+ buffer.clear()
currentWriter = null
objectOutputStream = null
- length
+ committedPartition
}
def abortCurrentPartition(throwable: Exception): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 569c8bd092f37..70c36c40865ba 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -18,19 +18,19 @@
package org.apache.spark.util.collection
import java.io._
-import java.util.Comparator
+import java.util.{Comparator, Optional}
+import com.google.common.io.ByteStreams
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import com.google.common.io.ByteStreams
-
import org.apache.spark.{util, _}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.serializer._
-import org.apache.spark.shuffle.api.ShuffleWriteSupport
-import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShufflePartitionObjectWriter}
+import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleWriteSupport}
+import org.apache.spark.shuffle.sort.LocalCommittedPartition
+import org.apache.spark.storage._
/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -683,10 +683,10 @@ private[spark] class ExternalSorter[K, V, C](
*/
def writePartitionedFile(
blockId: BlockId,
- outputFile: File): Array[Long] = {
+ outputFile: File): Array[CommittedPartition] = {
// Track location of each range in the output file
- val lengths = new Array[Long](numPartitions)
+ val committedPartitions = new Array[CommittedPartition](numPartitions)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics().shuffleWriteMetrics)
@@ -700,7 +700,7 @@ private[spark] class ExternalSorter[K, V, C](
it.writeNext(writer)
}
val segment = writer.commitAndGet()
- lengths(partitionId) = segment.length
+ committedPartitions(partitionId) = new LocalCommittedPartition(segment.length)
}
} else {
// We must perform merge-sort; get an iterator by partition and write everything directly.
@@ -710,7 +710,7 @@ private[spark] class ExternalSorter[K, V, C](
writer.write(elem._1, elem._2)
}
val segment = writer.commitAndGet()
- lengths(id) = segment.length
+ committedPartitions(id) = new LocalCommittedPartition(segment.length)
}
}
}
@@ -720,21 +720,25 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
- lengths
+ committedPartitions
}
/**
* Write all partitions to some backend that is pluggable.
*/
def writePartitionedToExternalShuffleWriteSupport(
- mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[Long] = {
+ blockId: ShuffleBlockId,
+ writeSupport: ShuffleWriteSupport): Array[CommittedPartition] = {
// Track location of each range in the output file
- val lengths = new Array[Long](numPartitions)
- val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, shuffleId, mapId)
+ val committedPartitions = new Array[CommittedPartition](numPartitions)
+ val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, blockId.shuffleId,
+ blockId.mapId)
val writer = new ShufflePartitionObjectWriter(
+ blockId,
Math.min(serializerBatchSize, Integer.MAX_VALUE).toInt,
serInstance,
+ serializerManager,
mapOutputWriter)
try {
@@ -749,7 +753,7 @@ private[spark] class ExternalSorter[K, V, C](
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
- lengths(partitionId) = writer.commitCurrentPartition()
+ committedPartitions(partitionId) = writer.commitCurrentPartition()
} catch {
case e: Exception =>
util.Utils.tryLogNonFatalError {
@@ -767,7 +771,7 @@ private[spark] class ExternalSorter[K, V, C](
for (elem <- elements) {
writer.write(elem._1, elem._2)
}
- lengths(id) = writer.commitCurrentPartition()
+ committedPartitions(id) = writer.commitCurrentPartition()
} catch {
case e: Exception =>
util.Utils.tryLogNonFatalError {
@@ -781,6 +785,7 @@ private[spark] class ExternalSorter[K, V, C](
mapOutputWriter.commitAllPartitions()
} catch {
case e: Exception =>
+ logError("Error writing shuffle data.", e)
util.Utils.tryLogNonFatalError {
writer.abortCurrentPartition(e)
mapOutputWriter.abort(e)
@@ -791,7 +796,7 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
- lengths
+ committedPartitions
}
def stop(): Unit = {
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 360c1769ad31a..93ab301a4cb9c 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -23,7 +23,7 @@
import java.nio.file.StandardOpenOption;
import java.util.*;
-import org.apache.commons.io.IOUtils;
+import org.apache.spark.shuffle.api.CommittedPartition;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
@@ -351,7 +351,8 @@ private void testMergingSpills(
private void testMergingSpills(
boolean transferToEnabled,
boolean useShuffleWriterPlugin) throws IOException {
- final UnsafeShuffleWriter