diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 098fa7974b87b..75172957746ae 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -34,6 +34,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; @@ -81,6 +82,22 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb handleMessage(msgObj, client, callback); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + BlockTransferMessage header = BlockTransferMessage.Decoder.fromByteBuffer(messageHeader); + return handleStream(header, client, callback); + } + + protected StreamCallbackWithID handleStream( + BlockTransferMessage header, + TransportClient client, + RpcResponseCallback callback) { + throw new UnsupportedOperationException("Unexpected message header: " + header); + } + protected void handleMessage( BlockTransferMessage msgObj, TransportClient client, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0b7a27402369d..7b9c75cd07795 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -19,13 +19,13 @@ import java.io.*; import java.nio.charset.StandardCharsets; +import java.util.regex.Pattern; import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.regex.Matcher; -import java.util.regex.Pattern; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index e49e27ab5aa79..2a013f5497a67 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -140,7 +140,8 @@ public void registerWithShuffleServer( ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException { checkInit(); try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) { - ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); + ByteBuffer registerMessage = + new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); client.sendRpcSync(registerMessage, registrationTimeoutMs); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java new file mode 100644 index 0000000000000..9b3736a9ecf1b --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java @@ -0,0 +1,150 @@ +package org.apache.spark.network.shuffle; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import java.nio.file.StandardOpenOption; + +import org.apache.spark.network.client.StreamCallbackWithID; + +public class FileWriterStreamCallback implements StreamCallbackWithID { + + private static final Logger logger = LoggerFactory.getLogger(FileWriterStreamCallback.class); + + public enum FileType { + DATA("shuffle-data"), + INDEX("shuffle-index"); + + private final String typeString; + + FileType(String typeString) { + this.typeString = typeString; + } + + @Override + public String toString() { + return typeString; + } + } + + private final String appId; + private final int shuffleId; + private final int mapId; + private final File file; + private final FileType fileType; + private WritableByteChannel fileOutputChannel = null; + + public FileWriterStreamCallback( + String appId, + int shuffleId, + int mapId, + File file, + FileWriterStreamCallback.FileType fileType) { + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.file = file; + this.fileType = fileType; + } + + public void open() { + logger.info( + "Opening {} for remote writing. File type: {}", file.getAbsolutePath(), fileType); + if (fileOutputChannel != null) { + throw new IllegalStateException( + String.format( + "File %s for is already open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); + } + if (!file.exists()) { + try { + if (!file.getParentFile().isDirectory() && !file.getParentFile().mkdirs()) { + throw new IOException( + String.format( + "Failed to create shuffle file directory at" + + file.getParentFile().getAbsolutePath() + "(type: %s).", fileType)); + } + + if (!file.createNewFile()) { + throw new IOException( + String.format( + "Failed to create shuffle file (type: %s).", fileType)); + } + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Failed to create shuffle file at %s for backup (type: %s).", + file.getAbsolutePath(), + fileType), + e); + } + } + try { + // TODO encryption + fileOutputChannel = FileChannel.open(file.toPath(), StandardOpenOption.APPEND); + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Failed to find file for writing at %s (type: %s).", + file.getAbsolutePath(), + fileType), + e); + } + } + + @Override + public String getID() { + return String.format("%s-%d-%d-%s", + appId, + shuffleId, + mapId, + fileType); + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + verifyShuffleFileOpenForWriting(); + while (buf.hasRemaining()) { + fileOutputChannel.write(buf); + } + } + + @Override + public void onComplete(String streamId) throws IOException { + logger.info( + "Finished writing {}. File type: {}", file.getAbsolutePath(), fileType); + fileOutputChannel.close(); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + logger.warn("Failed to write shuffle file at {} (type: %s).", + file.getAbsolutePath(), + fileType, + cause); + fileOutputChannel.close(); + // TODO delete parent dirs too + if (!file.delete()) { + logger.warn( + "Failed to delete incomplete remote shuffle file at %s (type: %s)", + file.getAbsolutePath(), + fileType); + } + } + + private void verifyShuffleFileOpenForWriting() { + if (fileOutputChannel == null) { + throw new RuntimeException( + String.format( + "Shuffle file at %s not open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); + } + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java new file mode 100644 index 0000000000000..06af3bd141fba --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.k8s; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.spark.network.shuffle.protocol.ShuffleServiceHeartbeat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.shuffle.ExternalShuffleClient; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; +import org.apache.spark.network.util.TransportConf; + +/** + * A client for talking to the external shuffle service in Kubernetes coarse-grained mode. + * + * This is used by the Spark driver to register with each external shuffle service on the cluster. + * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably + * after the application exits. Kubernetes does not provide a great alternative to do this, so Spark + * has to detect this itself. + */ +public class KubernetesExternalShuffleClient extends ExternalShuffleClient { + private static final Logger logger = + LoggerFactory.getLogger(KubernetesExternalShuffleClient.class); + + private final ScheduledExecutorService heartbeaterThread = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("kubernetes-external-shuffle-client-heartbeater") + .build()); + + /** + * Creates a Kubernetes external shuffle client that wraps the {@link ExternalShuffleClient}. + * Please refer to docs on {@link ExternalShuffleClient} for more information. + */ + public KubernetesExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean authEnabled, + long registrationTimeoutMs) { + super(conf, secretKeyHolder, authEnabled, registrationTimeoutMs); + } + + public void registerDriverWithShuffleService( + String host, + int port, + long heartbeatTimeoutMs, + long heartbeatIntervalMs) throws IOException, InterruptedException { + + checkInit(); + ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer(); + logger.info("Registering with external shuffle service at " + host + ":" + port); + TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs)); + } + + private class RegisterDriverCallback implements RpcResponseCallback { + private final TransportClient client; + private final long heartbeatIntervalMs; + + private RegisterDriverCallback(TransportClient client, long heartbeatIntervalMs) { + this.client = client; + this.heartbeatIntervalMs = heartbeatIntervalMs; + } + + @Override + public void onSuccess(ByteBuffer response) { + heartbeaterThread.scheduleAtFixedRate( + new Heartbeater(client), 0, heartbeatIntervalMs, TimeUnit.MILLISECONDS); + logger.info("Successfully registered app " + appId + " with external shuffle service."); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Unable to register app " + appId + " with external shuffle service. " + + "Please manually remove shuffle data after driver exit. Error: " + e); + } + } + + @Override + public void close() { + heartbeaterThread.shutdownNow(); + super.close(); + } + + private class Heartbeater implements Runnable { + + private final TransportClient client; + + private Heartbeater(TransportClient client) { + this.client = client; + } + + @Override + public void run() { + // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout + client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer()); + } + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 60179f126bc44..3510509f20eee 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; +import org.apache.spark.network.shuffle.protocol.ShuffleServiceHeartbeat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,7 +32,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.shuffle.ExternalShuffleClient; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; import org.apache.spark.network.util.TransportConf; /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index a68a297519b66..f5196638f9140 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -23,8 +23,6 @@ import io.netty.buffer.Unpooled; import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; -import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; /** * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or @@ -42,7 +40,8 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6); + HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_PARTITION_STREAM(7), + REGISTER_SHUFFLE_INDEX(8), OPEN_SHUFFLE_PARTITION(9), UPLOAD_SHUFFLE_INDEX(10); private final byte id; @@ -68,6 +67,10 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 4: return RegisterDriver.decode(buf); case 5: return ShuffleServiceHeartbeat.decode(buf); case 6: return UploadBlockStream.decode(buf); + case 7: return UploadShufflePartitionStream.decode(buf); + case 8: return RegisterShuffleIndex.decode(buf); + case 9: return OpenShufflePartition.decode(buf); + case 10: return UploadShuffleIndex.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java new file mode 100644 index 0000000000000..63d2387bd6d1a --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java @@ -0,0 +1,76 @@ +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +public class OpenShufflePartition extends BlockTransferMessage { + public final String appId; + public final int shuffleId; + public final int mapId; + public final int partitionId; + + public OpenShufflePartition( + String appId, int shuffleId, int mapId, int partitionId) { + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.partitionId = partitionId; + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenShufflePartition) { + OpenShufflePartition o = (OpenShufflePartition) other; + return Objects.equal(appId, o.appId) + && shuffleId == o.shuffleId + && mapId == o.mapId + && partitionId == o.partitionId; + } + return false; + } + + @Override + protected Type type() { + return Type.OPEN_SHUFFLE_PARTITION; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, shuffleId, mapId, partitionId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .add("partitionId", partitionId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + buf.writeInt(partitionId); + } + + public static OpenShufflePartition decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + int partitionId = buf.readInt(); + return new OpenShufflePartition(appId, shuffleId, mapId, partitionId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java similarity index 94% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java index d5f53ccb7f741..516a51ad7cc14 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java @@ -15,13 +15,12 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle.protocol.mesos; +package org.apache.spark.network.shuffle.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; // Needed by ScalaDoc. See SPARK-7726 import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java new file mode 100644 index 0000000000000..bc870e440274b --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * Register shuffle index to the External Shuffle Service. + */ +public class RegisterShuffleIndex extends BlockTransferMessage { + public final String appId; + public final int shuffleId; + public final int mapId; + + public RegisterShuffleIndex( + String appId, + int shuffleId, + int mapId) { + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RegisterShuffleIndex) { + RegisterShuffleIndex o = (RegisterShuffleIndex) other; + return Objects.equal(appId, o.appId) + && shuffleId == o.shuffleId + && mapId == o.mapId; + } + return false; + } + + @Override + protected Type type() { + return Type.REGISTER_SHUFFLE_INDEX; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, shuffleId, mapId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + } + + public static RegisterShuffleIndex decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + return new RegisterShuffleIndex(appId, shuffleId, mapId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ShuffleServiceHeartbeat.java similarity index 92% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ShuffleServiceHeartbeat.java index b30bb9aed55b6..1a6ffc0f91333 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ShuffleServiceHeartbeat.java @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle.protocol.mesos; +package org.apache.spark.network.shuffle.protocol; import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; // Needed by ScalaDoc. See SPARK-7726 import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java new file mode 100644 index 0000000000000..b11a02f6b9219 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * Upload shuffle index request to the External Shuffle Service. + */ +public class UploadShuffleIndex extends BlockTransferMessage { + public final String appId; + public final int shuffleId; + public final int mapId; + + public UploadShuffleIndex( + String appId, + int shuffleId, + int mapId) { + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadShuffleIndex) { + UploadShuffleIndex o = (UploadShuffleIndex) other; + return Objects.equal(appId, o.appId) + && shuffleId == o.shuffleId + && mapId == o.mapId; + } + return false; + } + + @Override + protected Type type() { + return Type.UPLOAD_SHUFFLE_INDEX; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, shuffleId, mapId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + } + + public static UploadShuffleIndex decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + return new UploadShuffleIndex(appId, shuffleId, mapId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java new file mode 100644 index 0000000000000..ad8f5405192fc --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * Upload shuffle partition request to the External Shuffle Service. + */ +public class UploadShufflePartitionStream extends BlockTransferMessage { + public final String appId; + public final int shuffleId; + public final int mapId; + public final int partitionId; + public final int partitionLength; + + public UploadShufflePartitionStream( + String appId, + int shuffleId, + int mapId, + int partitionId, + int partitionLength) { + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.partitionId = partitionId; + this.partitionLength = partitionLength; + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadShufflePartitionStream) { + UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; + return Objects.equal(appId, o.appId) + && shuffleId == o.shuffleId + && mapId == o.mapId + && partitionId == o.partitionId + && partitionLength == o.partitionLength; + } + return false; + } + + @Override + protected Type type() { + return Type.UPLOAD_SHUFFLE_PARTITION_STREAM; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, shuffleId, mapId, partitionId, partitionLength); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .add("partitionId", partitionId) + .add("partitionLength", partitionLength) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + buf.writeInt(partitionId); + buf.writeInt(partitionLength); + } + + public static UploadShufflePartitionStream decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + int partitionId = buf.readInt(); + int partitionLength = buf.readInt(); + return new UploadShufflePartitionStream( + appId, shuffleId, mapId, partitionId, partitionLength); + } +} diff --git a/core/pom.xml b/core/pom.xml index 49b1a54e32598..544ae61279c4d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -352,6 +352,11 @@ py4j 0.10.8.1 + + org.scala-lang.modules + scala-java8-compat_${scala.binary.version} + 0.9.0 + org.apache.spark spark-tags_${scala.binary.version} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java new file mode 100644 index 0000000000000..7846fad70b159 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java @@ -0,0 +1,23 @@ +package org.apache.spark.shuffle.api; + +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + +public interface CommittedPartition { + + /** + * Indicates the number of bytes written in a committed partition. + * Note that returning the length is mainly for backwards compatibility + * and should be removed in a more polished variant. After this method + * is called, the writer will be discarded; it's expected that the + * implementation will close any underlying resources. + */ + long length(); + + /** + * Indicates the shuffle location to which this partition was written. + * Some implementations may not need to specify a shuffle location. + */ + Optional shuffleLocation(); +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java index b091e231f2cd7..19cd94712a8ae 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java @@ -16,11 +16,13 @@ */ package org.apache.spark.shuffle.api; +import java.io.IOException; + public interface ShuffleDataIO { - void initialize(); + void initialize() throws IOException; - ShuffleReadSupport readSupport(); + ShuffleReadSupport readSupport() throws IOException; - ShuffleWriteSupport writeSupport(); + ShuffleWriteSupport writeSupport() throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 06415dba72d34..becb9413a8f4f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -17,11 +17,13 @@ package org.apache.spark.shuffle.api; +import java.io.IOException; + public interface ShuffleMapOutputWriter { - ShufflePartitionWriter newPartitionWriter(int partitionId); + ShufflePartitionWriter newPartitionWriter(int partitionId) throws IOException; - void commitAllPartitions(); + void commitAllPartitions() throws IOException; - void abort(Exception exception); + void abort(Exception exception) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java index 59eae0a782200..46d1699724981 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java @@ -18,8 +18,13 @@ package org.apache.spark.shuffle.api; import java.io.InputStream; +import java.io.IOException; +import java.util.Optional; + +import org.apache.spark.storage.ShuffleLocation; public interface ShufflePartitionReader { - InputStream fetchPartition(int reduceId); + InputStream fetchPartition(int reduceId, Optional shuffleLocation) + throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java index 7fa667cf137ea..e7cc6dd913d11 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.api; -import java.io.InputStream; +import java.io.IOException; import java.io.OutputStream; /** @@ -28,20 +28,18 @@ public interface ShufflePartitionWriter { /** * Return a stream that should persist the bytes for this partition. */ - OutputStream openPartitionStream(); + OutputStream openPartitionStream() throws IOException; /** - * Indicate that the partition was written successfully and there are no more incoming bytes. Returns - * the length of the partition that is written. Note that returning the length is mainly for backwards - * compatibility and should be removed in a more polished variant. After this method is called, the writer - * will be discarded; it's expected that the implementation will close any underlying resources. + * Indicate that the partition was written successfully and there are no more incoming bytes. + * Returns a {@link CommittedPartition} indicating information about that written partition. */ - long commitAndGetTotalLength(); + CommittedPartition commitPartition() throws IOException; /** * Indicate that the write has failed for some reason and the implementation can handle the * failure reason. After this method is called, this writer will be discarded; it's expected that * the implementation will close any underlying resources. */ - void abort(Exception failureReason); + void abort(Exception failureReason) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java index b1be7c1de98ac..ebe8fd12dccdb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java @@ -17,8 +17,11 @@ package org.apache.spark.shuffle.api; +import java.io.IOException; + public interface ShuffleReadSupport { - ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId); + ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId) + throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java index 2f61dbaa17c69..f88555f8a1bd9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java @@ -17,7 +17,10 @@ package org.apache.spark.shuffle.api; +import java.io.IOException; + public interface ShuffleWriteSupport { - ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId); + ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) + throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java new file mode 100644 index 0000000000000..7e37659dbb3f9 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java @@ -0,0 +1,32 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.shuffle.api.CommittedPartition; +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + +public class ExternalCommittedPartition implements CommittedPartition { + + private final long length; + private final Optional shuffleLocation; + + public ExternalCommittedPartition(long length) { + this.length = length; + this.shuffleLocation = Optional.empty(); + } + + public ExternalCommittedPartition(long length, ShuffleLocation shuffleLocation) { + this.length = length; + this.shuffleLocation = Optional.of(shuffleLocation); + } + + @Override + public long length() { + return length; + } + + @Override + public Optional shuffleLocation() { + return shuffleLocation; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java new file mode 100644 index 0000000000000..10e4093759c38 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -0,0 +1,86 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.*; +import org.apache.spark.SecurityManager; +import org.apache.spark.internal.config.package$; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.netty.SparkTransportConf; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.shuffle.api.ShuffleDataIO; +import org.apache.spark.shuffle.api.ShuffleReadSupport; +import org.apache.spark.shuffle.api.ShuffleWriteSupport; +import org.apache.spark.network.shuffle.k8s.KubernetesExternalShuffleClient; + +import org.apache.spark.storage.BlockManager; +import scala.Tuple2; + +import java.util.List; +import java.util.Random; + +public class ExternalShuffleDataIO implements ShuffleDataIO { + + private final SparkConf conf; + private final TransportConf transportConf; + private final TransportContext context; + private static MapOutputTracker mapOutputTracker; + private static SecurityManager securityManager; + private static List> hostPorts; + private static Boolean isDriver; + private static KubernetesExternalShuffleClient shuffleClient; + + public ExternalShuffleDataIO( + SparkConf sparkConf) { + this.conf = sparkConf; + // TODO: Grab numUsableCores + this.transportConf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 1); + // Close idle connections + this.context = new TransportContext(transportConf, new NoOpRpcHandler(), true, true); + } + + @Override + public void initialize() { + SparkEnv env = SparkEnv.get(); + mapOutputTracker = env.mapOutputTracker(); + securityManager = env.securityManager(); + isDriver = env.blockManager().blockManagerId().isDriver(); + hostPorts = mapOutputTracker.getRemoteShuffleServiceAddress(); + if (isDriver) { + shuffleClient = new KubernetesExternalShuffleClient(transportConf, securityManager, + securityManager.isAuthenticationEnabled(), conf.getTimeAsMs( + package$.MODULE$.SHUFFLE_REGISTRATION_TIMEOUT().key(), "5000ms")); + shuffleClient.init(conf.getAppId()); + for (Tuple2 hp : hostPorts) { + try { + shuffleClient.registerDriverWithShuffleService( + hp._1, hp._2, + conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + conf.getTimeAsSeconds("spark.network.timeout", "120s") + "s"), + conf.getTimeAsSeconds( + package$.MODULE$.EXECUTOR_HEARTBEAT_INTERVAL().key(), "10s")); + } catch (Exception e) { + throw new RuntimeException("Unable to register driver with ESS", e); + } + } + BlockManager.ShuffleMetricsSource metricSource = + new BlockManager.ShuffleMetricsSource( + "RemoteShuffleService", shuffleClient.shuffleMetrics()); + env.metricsSystem().registerSource(metricSource); + } + } + + @Override + public ShuffleReadSupport readSupport() { + return new ExternalShuffleReadSupport( + transportConf, context, securityManager.isAuthenticationEnabled(), securityManager); + } + + @Override + public ShuffleWriteSupport writeSupport() { + int rnd = new Random().nextInt(hostPorts.size()); + Tuple2 hostPort = hostPorts.get(rnd); + return new ExternalShuffleWriteSupport( + transportConf, context, securityManager.isAuthenticationEnabled(), + securityManager, hostPort._1, hostPort._2); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java new file mode 100644 index 0000000000000..c1178da2411ff --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java @@ -0,0 +1,38 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.storage.ShuffleLocation; + +import java.io.*; + +public class ExternalShuffleLocation implements ShuffleLocation { + + private String shuffleHostname; + private int shufflePort; + + public ExternalShuffleLocation() { /* for serialization */ } + + public ExternalShuffleLocation(String shuffleHostname, int shufflePort) { + this.shuffleHostname = shuffleHostname; + this.shufflePort = shufflePort; + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeUTF(shuffleHostname); + out.writeInt(shufflePort); + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + this.shuffleHostname = in.readUTF(); + this.shufflePort = in.readInt(); + } + + public String getShuffleHostname() { + return this.shuffleHostname; + } + + public int getShufflePort() { + return this.shufflePort; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java new file mode 100644 index 0000000000000..8866d14feca53 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -0,0 +1,95 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.shuffle.protocol.RegisterShuffleIndex; +import org.apache.spark.network.shuffle.protocol.UploadShuffleIndex; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; + + +public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter { + + private final TransportClientFactory clientFactory; + private final String hostName; + private final int port; + private final String appId; + private final int shuffleId; + private final int mapId; + + public ExternalShuffleMapOutputWriter( + TransportClientFactory clientFactory, + String hostName, + int port, + String appId, + int shuffleId, + int mapId) { + this.clientFactory = clientFactory; + this.hostName = hostName; + this.port = port; + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + + TransportClient client = null; + try { + client = clientFactory.createUnmanagedClient(hostName, port); + ByteBuffer registerShuffleIndex = new RegisterShuffleIndex( + appId, shuffleId, mapId).toByteBuffer(); + String requestID = String.format( + "index-register-%s-%d-%d", appId, shuffleId, mapId); + client.setClientId(requestID); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + client.sendRpcSync(registerShuffleIndex, 60000); + } catch (Exception e) { + client.close(); + logger.error("Encountered error while creating transport client", e); + throw new RuntimeException(e); + } + } + + private static final Logger logger = + LoggerFactory.getLogger(ExternalShuffleMapOutputWriter.class); + + @Override + public ShufflePartitionWriter newPartitionWriter(int partitionId) { + try { + return new ExternalShufflePartitionWriter(clientFactory, + hostName, port, appId, shuffleId, mapId, partitionId); + } catch (Exception e) { + clientFactory.close(); + logger.error("Encountered error while creating transport client", e); + throw new RuntimeException(e); + } + } + + @Override + public void commitAllPartitions() { + TransportClient client = null; + try { + client = clientFactory.createUnmanagedClient(hostName, port); + ByteBuffer uploadShuffleIndex = new UploadShuffleIndex( + appId, shuffleId, mapId).toByteBuffer(); + String requestID = String.format( + "index-upload-%s-%d-%d", appId, shuffleId, mapId); + client.setClientId(requestID); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + client.sendRpcSync(uploadShuffleIndex, 60000); + } catch (Exception e) { + logger.error("Encountered error while creating transport client", e); + client.close(); + throw new RuntimeException(e); + } + } + + @Override + public void abort(Exception exception) { + clientFactory.close(); + logger.error("Encountered error while " + + "attempting to add partitions to ESS", exception); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java new file mode 100644 index 0000000000000..10f1b71008472 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -0,0 +1,79 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.shuffle.protocol.OpenShufflePartition; +import org.apache.spark.shuffle.api.ShufflePartitionReader; +import org.apache.spark.storage.ShuffleLocation; +import org.apache.spark.util.ByteBufferInputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Optional; + +public class ExternalShufflePartitionReader implements ShufflePartitionReader { + + private static final Logger logger = + LoggerFactory.getLogger(ExternalShufflePartitionReader.class); + + private final TransportClientFactory clientFactory; + private final String appId; + private final int shuffleId; + private final int mapId; + + public ExternalShufflePartitionReader( + TransportClientFactory clientFactory, + String appId, + int shuffleId, + int mapId) { + this.clientFactory = clientFactory; + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + public InputStream fetchPartition(int reduceId, Optional shuffleLocation) { + assert shuffleLocation.isPresent() && + shuffleLocation.get() instanceof ExternalShuffleLocation; + ExternalShuffleLocation externalShuffleLocation = + (ExternalShuffleLocation) shuffleLocation.get(); + logger.info(String.format("Found external shuffle location on node: %s:%d", + externalShuffleLocation.getShuffleHostname(), + externalShuffleLocation.getShufflePort())); + String hostname = externalShuffleLocation.getShuffleHostname(); + int port = externalShuffleLocation.getShufflePort(); + + OpenShufflePartition openMessage = + new OpenShufflePartition(appId, shuffleId, mapId, reduceId); + TransportClient client = null; + try { + client = clientFactory.createUnmanagedClient(hostname, port); + String requestID = String.format( + "read-%s-%d-%d-%d", appId, shuffleId, mapId, reduceId); + client.setClientId(requestID); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + + ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000); + logger.info("response is: " + response.toString() + + " " + response.array() + " " + response.hasArray()); + if (response.hasArray()) { + logger.info("response hashcode: " + Arrays.hashCode(response.array())); + // use heap buffer; no array is created; only the reference is used + return new ByteArrayInputStream(response.array()); + } + return new ByteBufferInputStream(response); + + } catch (Exception e) { + if (client != null) { + client.close(); + } + logger.error("Encountered exception while trying to fetch blocks", e); + throw new RuntimeException(e); + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java new file mode 100644 index 0000000000000..edf046a32ffe3 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -0,0 +1,110 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.shuffle.protocol.UploadShufflePartitionStream; +import org.apache.spark.shuffle.api.CommittedPartition; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.nio.ByteBuffer; +import java.util.Arrays; + +public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { + + private static final Logger logger = + LoggerFactory.getLogger(ExternalShufflePartitionWriter.class); + + private final TransportClientFactory clientFactory; + private final String hostName; + private final int port; + private final String appId; + private final int shuffleId; + private final int mapId; + private final int partitionId; + + private long totalLength = 0; + private ByteArrayOutputStream partitionBuffer; + + public ExternalShufflePartitionWriter( + TransportClientFactory clientFactory, + String hostName, + int port, + String appId, + int shuffleId, + int mapId, + int partitionId) { + this.clientFactory = clientFactory; + this.hostName = hostName; + this.port = port; + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.partitionId = partitionId; + // TODO: Set buffer size + this.partitionBuffer = new ByteArrayOutputStream(); + } + + @Override + public OutputStream openPartitionStream() { return partitionBuffer; } + + @Override + public CommittedPartition commitPartition() { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + logger.info("Successfully uploaded partition"); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Encountered an error uploading partition", e); + } + }; + TransportClient client = null; + try { + byte[] buf = partitionBuffer.toByteArray(); + int size = buf.length; + ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId, + partitionId, size).toByteBuffer(); + ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf)); + client = clientFactory.createUnmanagedClient(hostName, port); + client.setClientId(String.format("data-%s-%d-%d-%d", + appId, shuffleId, mapId, partitionId)); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + logger.info("THE BUFFER HASH CODE IS: " + Arrays.hashCode(buf)); + client.uploadStream(new NioManagedBuffer(streamHeader), managedBuffer, callback); + totalLength += size; + logger.info("Partition Length: " + totalLength); + logger.info("Size: " + size); + } catch (Exception e) { + if (client != null) { + partitionBuffer = null; + client.close(); + } + logger.error("Encountered error while attempting to upload partition to ESS", e); + throw new RuntimeException(e); + } finally { + logger.info("Successfully sent partition to ESS"); + } + return new ExternalCommittedPartition( + totalLength, new ExternalShuffleLocation(hostName, port)); + } + + @Override + public void abort(Exception failureReason) { + try { + partitionBuffer.close(); + partitionBuffer = null; + } catch(IOException e) { + logger.error("Failed to close streams after failing to upload partition", e); + } + logger.error("Encountered error while attempting" + + "to upload partition to ESS", failureReason); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java new file mode 100644 index 0000000000000..0bde10a777668 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -0,0 +1,56 @@ +package org.apache.spark.shuffle.external; + +import com.google.common.collect.Lists; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.crypto.AuthClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.shuffle.api.ShufflePartitionReader; +import org.apache.spark.shuffle.api.ShuffleReadSupport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +public class ExternalShuffleReadSupport implements ShuffleReadSupport { + + private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleReadSupport.class); + + private final TransportConf conf; + private final TransportContext context; + private final boolean authEnabled; + private final SecretKeyHolder secretKeyHolder; + + public ExternalShuffleReadSupport( + TransportConf conf, + TransportContext context, + boolean authEnabled, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.context = context; + this.authEnabled = authEnabled; + this.secretKeyHolder = secretKeyHolder; + } + + @Override + public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId) { + // TODO combine this into a function with ExternalShuffleWriteSupport + List bootstraps = Lists.newArrayList(); + if (authEnabled) { + bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); + } + TransportClientFactory clientFactory = context.createClientFactory(bootstraps); + try { + return new ExternalShufflePartitionReader(clientFactory, + appId, + shuffleId, + mapId); + } catch (Exception e) { + clientFactory.close(); + logger.error("Encountered creating transport client for partition reader"); + throw new RuntimeException(e); // what is standard practice here? + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java new file mode 100644 index 0000000000000..413c2fd63f20a --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -0,0 +1,54 @@ +package org.apache.spark.shuffle.external; + +import com.google.common.collect.Lists; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.crypto.AuthClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShuffleWriteSupport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +public class ExternalShuffleWriteSupport implements ShuffleWriteSupport { + + private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleWriteSupport.class); + + private final TransportConf conf; + private final TransportContext context; + private final boolean authEnabled; + private final SecretKeyHolder secretKeyHolder; + private final String hostname; + private final int port; + + public ExternalShuffleWriteSupport( + TransportConf conf, + TransportContext context, + boolean authEnabled, + SecretKeyHolder secretKeyHolder, + String hostname, + int port) { + this.conf = conf; + this.context = context; + this.authEnabled = authEnabled; + this.secretKeyHolder = secretKeyHolder; + this.hostname = hostname; + this.port = port; +} + + @Override + public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) { + List bootstraps = Lists.newArrayList(); + if (authEnabled) { + bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); + } + TransportClientFactory clientFactory = context.createClientFactory(bootstraps); + logger.info("Clientfactory: " + clientFactory.toString()); + return new ExternalShuffleMapOutputWriter( + clientFactory, hostname, port, appId, shuffleId, mapId); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index c683b2854b17c..b21d37401c059 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -17,24 +17,8 @@ package org.apache.spark.shuffle.sort; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import javax.annotation.Nullable; - -import scala.None$; -import scala.Option; -import scala.Product2; -import scala.Tuple2; -import scala.collection.Iterator; - import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; @@ -42,14 +26,26 @@ import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.api.CommittedPartition; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.ShuffleWriteSupport; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.None$; +import scala.Option; +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import javax.annotation.Nullable; +import java.io.*; +import java.util.Arrays; /** * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path @@ -94,7 +90,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private DiskBlockObjectWriter[] partitionWriters; private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; - private long[] partitionLengths; + private CommittedPartition[] committedPartitions; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -131,7 +127,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; + long[] partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; @@ -166,25 +162,31 @@ public void write(Iterator> records) throws IOException { } if (pluggableWriteSupport != null) { - partitionLengths = combineAndWritePartitionsUsingPluggableWriter(); + committedPartitions = combineAndWritePartitionsUsingPluggableWriter(); + logger.info("Successfully wrote partitions with pluggable writer"); } else { File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); File tmp = Utils.tempFileWith(output); try { - partitionLengths = combineAndWritePartitions(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + committedPartitions = combineAndWritePartitions(tmp); + logger.info("Successfully wrote partitions without shuffle"); + // TODO: Investigate when commitedPartitions is null or returns empty + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, + mapId, + Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray(), + tmp); } finally { if (tmp != null && tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions); } @VisibleForTesting long[] getPartitionLengths() { - return partitionLengths; + return Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray(); } /** @@ -192,12 +194,12 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] combineAndWritePartitions(File outputFile) throws IOException { + private CommittedPartition[] combineAndWritePartitions(File outputFile) throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; + final CommittedPartition[] partitions = new CommittedPartition[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator - return lengths; + return partitions; } assert(outputFile != null); final FileOutputStream out = new FileOutputStream(outputFile, true); @@ -210,7 +212,8 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException { final FileInputStream in = new FileInputStream(file); boolean copyThrewException = true; try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + partitions[i] = + new LocalCommittedPartition(Utils.copyStream(in, out, false, transferToEnabled)); copyThrewException = false; } finally { Closeables.close(in, copyThrewException); @@ -225,21 +228,21 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException { writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; - return lengths; + return partitions; } - private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOException { + private CommittedPartition[] combineAndWritePartitionsUsingPluggableWriter() throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; + final CommittedPartition[] partitions = new CommittedPartition[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator - return lengths; + return partitions; } assert(pluggableWriteSupport != null); final long writeStartTime = System.nanoTime(); ShuffleMapOutputWriter mapOutputWriter = pluggableWriteSupport.newMapOutputWriter( - appId, shuffleId, mapId); + appId, shuffleId, mapId); try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); @@ -251,7 +254,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio try (OutputStream out = writer.openPartitionStream()) { Utils.copyStream(in, out, false, false); } - lengths[i] = writer.commitAndGetTotalLength(); + partitions[i] = writer.commitPartition(); copyThrewException = false; } catch (Exception e) { try { @@ -279,7 +282,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; - return lengths; + return partitions; } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java new file mode 100644 index 0000000000000..817855d957966 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java @@ -0,0 +1,25 @@ +package org.apache.spark.shuffle.sort; + +import org.apache.spark.shuffle.api.CommittedPartition; +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + +public class LocalCommittedPartition implements CommittedPartition { + + private final long length; + + public LocalCommittedPartition(long length) { + this.length = length; + } + + @Override + public long length() { + return length; + } + + @Override + public Optional shuffleLocation() { + return Optional.empty(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 150a783aa87fb..97f34bf460495 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -20,8 +20,11 @@ import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import java.util.Arrays; import java.util.Iterator; +import java.util.stream.Collectors; +import org.apache.spark.shuffle.api.CommittedPartition; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -85,7 +88,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final int initialSortBufferSize; private final int inputBufferSizeInBytes; private final int outputBufferSizeInBytes; - private final ShuffleWriteSupport pluggableWriteSupport; // TODO initialize + private final ShuffleWriteSupport pluggableWriteSupport; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -236,12 +239,12 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final long[] partitionLengths; + final CommittedPartition[] committedPartitions; final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); final File tmp = Utils.tempFileWith(output); try { try { - partitionLengths = mergeSpills(spills, tmp); + committedPartitions = mergeSpills(spills, tmp); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && ! spill.file.delete()) { @@ -250,14 +253,17 @@ void closeAndWriteOutput() throws IOException { } } if (pluggableWriteSupport == null) { - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, + mapId, + Arrays.stream(committedPartitions).mapToLong(CommittedPartition::length).toArray(), + tmp); } } finally { if (tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions); } @VisibleForTesting @@ -289,7 +295,7 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { + private CommittedPartition[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); final CompressionCodec compressionCodec = compressionEnabled ? CompressionCodec$.MODULE$.createCodec(sparkConf) : null; @@ -301,18 +307,18 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file - return new long[partitioner.numPartitions()]; + return new CommittedPartition[partitioner.numPartitions()]; } else if (spills.length == 1) { if (pluggableWriteSupport != null) { - writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec); + return writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec); } else { // Here, we don't need to perform any metrics updates because the bytes written to this // output file would have already been counted as shuffle bytes written. Files.move(spills[0].file, outputFile); } - return spills[0].partitionLengths; + return toLocalCommittedPartition(spills[0].partitionLengths); } else { - final long[] partitionLengths; + final CommittedPartition[] committedPartitions; // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill @@ -324,21 +330,24 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" // branch in ExternalSorter. if (pluggableWriteSupport != null) { - partitionLengths = mergeSpillsWithPluggableWriter(spills, compressionCodec); + committedPartitions = mergeSpillsWithPluggableWriter(spills, compressionCodec); } else if (fastMergeEnabled && fastMergeIsSupported) { // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + committedPartitions = + toLocalCommittedPartition(mergeSpillsWithTransferTo(spills, outputFile)); } else { logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + committedPartitions = toLocalCommittedPartition( + mergeSpillsWithFileStream(spills, outputFile, null)); } } else { logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + committedPartitions = toLocalCommittedPartition( + mergeSpillsWithFileStream(spills, outputFile, compressionCodec)); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that @@ -349,7 +358,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti if (pluggableWriteSupport == null) { writeMetrics.incBytesWritten(outputFile.length()); } - return partitionLengths; + return committedPartitions; } } catch (IOException e) { if (outputFile.exists() && !outputFile.delete()) { @@ -359,6 +368,12 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti } } + private static CommittedPartition[] toLocalCommittedPartition(long[] partitionLengths) { + return Arrays.stream(partitionLengths) + .mapToObj(length -> new LocalCommittedPartition(length)) + .collect(Collectors.toList()).toArray(new CommittedPartition[partitionLengths.length]); + } + /** * Merges spill files using Java FileStreams. This code path is typically slower than * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], @@ -512,13 +527,13 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th /** * Merges spill files using the ShufflePartitionWriter API. */ - private long[] mergeSpillsWithPluggableWriter( + private CommittedPartition[] mergeSpillsWithPluggableWriter( SpillInfo[] spills, @Nullable CompressionCodec compressionCodec) throws IOException { assert (spills.length >= 2); assert(pluggableWriteSupport != null); final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; + final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; boolean threwException = true; @@ -533,6 +548,10 @@ private long[] mergeSpillsWithPluggableWriter( ShufflePartitionWriter writer = mapOutputWriter.newPartitionWriter(partition); try { try (OutputStream partitionOutput = writer.openPartitionStream()) { + OutputStream partitionOutputStream = partitionOutput; + if (compressionCodec != null) { + partitionOutputStream = compressionCodec.compressedOutputStream(partitionOutput); + } for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { @@ -542,17 +561,18 @@ private long[] mergeSpillsWithPluggableWriter( partitionInputStream = blockManager.serializerManager().wrapForEncryption( partitionInputStream); if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + partitionInputStream = + compressionCodec.compressedInputStream(partitionInputStream); } - Utils.copyStream(partitionInputStream, partitionOutput, false, false); + Utils.copyStream(partitionInputStream, partitionOutputStream, false, false); } finally { partitionInputStream.close(); } } } } - partitionLengths[partition] = writer.commitAndGetTotalLength(); - writeMetrics.incBytesWritten(partitionLengths[partition]); + committedPartitions[partition] = writer.commitPartition(); + writeMetrics.incBytesWritten(committedPartitions[partition].length()); } catch (Exception e) { try { writer.abort(e); @@ -578,14 +598,15 @@ private long[] mergeSpillsWithPluggableWriter( Closeables.close(stream, threwException); } } - return partitionLengths; + return committedPartitions; } - private void writeSingleSpillFileUsingPluggableWriter( + private CommittedPartition[] writeSingleSpillFileUsingPluggableWriter( SpillInfo spillInfo, @Nullable CompressionCodec compressionCodec) throws IOException { assert(pluggableWriteSupport != null); final int numPartitions = partitioner.numPartitions(); + final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions]; boolean threwException = true; InputStream spillInputStream = new NioBufferedFileInputStream( spillInfo.file, @@ -604,7 +625,11 @@ private void writeSingleSpillFileUsingPluggableWriter( partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } try (OutputStream partitionOutput = writer.openPartitionStream()) { - Utils.copyStream(partitionInputStream, partitionOutput, false, false); + OutputStream partitionOutputStream = partitionOutput; + if (compressionCodec != null) { + partitionOutputStream = compressionCodec.compressedOutputStream(partitionOutput); + } + Utils.copyStream(partitionInputStream, partitionOutputStream, false, false); } } catch (Exception e) { try { @@ -616,9 +641,11 @@ private void writeSingleSpillFileUsingPluggableWriter( } finally { partitionInputStream.close(); } - writeMetrics.incBytesWritten(writer.commitAndGetTotalLength()); + committedPartitions[partition] = writer.commitPartition(); + writeMetrics.incBytesWritten(committedPartitions[partition].length()); } threwException = false; + mapOutputWriter.commitAllPartitions(); } catch (Exception e) { try { mapOutputWriter.abort(e); @@ -630,6 +657,7 @@ private void writeSingleSpillFileUsingPluggableWriter( Closeables.close(spillInputStream, threwException); } writeMetrics.decBytesWritten(spillInfo.file.length()); + return committedPartitions; } @Override diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1c4fa4bc6541f..214ff3ee18feb 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -33,8 +33,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.shuffle._ +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleLocation} import org.apache.spark.util._ /** @@ -213,6 +213,7 @@ private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +private[spark] case object GetRemoteShuffleServiceAddresses extends MapOutputTrackerMessage private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) @@ -233,6 +234,9 @@ private[spark] class MapOutputTrackerMasterEndpoint( logInfo("MapOutputTrackerMasterEndpoint stopped!") context.reply(true) stop() + + case GetRemoteShuffleServiceAddresses => + context.reply(tracker.getRemoteShuffleServiceAddress()) } } @@ -298,6 +302,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int) : Option[ShuffleLocation] + + def getRemoteShuffleServiceAddress(): java.util.List[(String, Integer)] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -318,7 +326,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private[spark] class MapOutputTrackerMaster( conf: SparkConf, broadcastManager: BroadcastManager, - isLocal: Boolean) + isLocal: Boolean, + shuffleServiceAddressProvider: ShuffleServiceAddressProvider + = DefaultShuffleServiceAddressProvider) extends MapOutputTracker(conf) { // The size at which we use Broadcast to send the map output statuses to the executors @@ -644,6 +654,10 @@ private[spark] class MapOutputTrackerMaster( } } + override def getRemoteShuffleServiceAddress(): java.util.List[(String, Integer)] = + shuffleServiceAddressProvider + .getShuffleServiceAddresses().map { case (h, p) => (h, new Integer(p))}.asJava + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) @@ -666,6 +680,14 @@ private[spark] class MapOutputTrackerMaster( trackerEndpoint = null shuffleStatuses.clear() } + + override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int): + Option[ShuffleLocation] = { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => shuffleStatus.mapStatuses(mapId).shuffleLocationForBlock(reduceId) + case None => Option.empty + } + } } /** @@ -779,6 +801,19 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } } + + override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int): + Option[ShuffleLocation] = { + mapStatuses.get(shuffleId) match { + case Some(shuffleStatus) => shuffleStatus(mapId).shuffleLocationForBlock(reduceId) + case None => Option.empty + } + } + + override def getRemoteShuffleServiceAddress(): java.util.List[(String, Integer)] = { + trackerEndpoint + .askSync[java.util.List[(String, Integer)]](GetRemoteShuffleServiceAddresses) + } } private[spark] object MapOutputTracker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 845a3d5f6d6f9..247016584d1f2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -576,6 +576,10 @@ class SparkContext(config: SparkConf) extends Logging { _env.metricsSystem.registerSource(e.executorAllocationManagerSource) } appStatusSource.foreach(_env.metricsSystem.registerSource(_)) + + // Initialize the ShuffleDataIo + _env.shuffleDataIO.foreach(_.initialize()) + // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 66038eeaea54f..c2b56864bf36d 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -21,11 +21,10 @@ import java.io.File import java.net.Socket import java.util.Locale +import com.google.common.collect.MapMaker import scala.collection.mutable import scala.util.Properties -import com.google.common.collect.MapMaker - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager @@ -39,7 +38,8 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleManager, ShuffleServiceAddressProviderFactory} +import org.apache.spark.shuffle.api.ShuffleDataIO import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} @@ -66,6 +66,7 @@ class SparkEnv ( val blockManager: BlockManager, val securityManager: SecurityManager, val metricsSystem: MetricsSystem, + val shuffleDataIO: Option[ShuffleDataIO], val memoryManager: MemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -302,7 +303,26 @@ object SparkEnv extends Logging { val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf, broadcastManager, isLocal) + val master = conf.get("spark.master") + val shuffleProvider = conf.get(SHUFFLE_SERVICE_PROVIDER_CLASS) + .map(clazz => Utils.loadExtensions( + classOf[ShuffleServiceAddressProviderFactory], + Seq(clazz), conf)).getOrElse(Seq()) + val serviceLoaders = shuffleProvider + .filter(_.canCreate(conf.get("spark.master"))) + if (serviceLoaders.size > 1) { + throw new SparkException( + s"Multiple external cluster managers registered for the url $master: $serviceLoaders") + } + val loader = Utils.getContextOrSparkClassLoader + logInfo(s"Loader: $loader") + logInfo(s"Service loader: $serviceLoaders") + val shuffleServiceAddressProvider = serviceLoaders.headOption + .map(_.create(conf)) + .getOrElse(DefaultShuffleServiceAddressProvider) + shuffleServiceAddressProvider.start() + + new MapOutputTrackerMaster(conf, broadcastManager, isLocal, shuffleServiceAddressProvider) } else { new MapOutputTrackerWorker(conf) } @@ -365,6 +385,9 @@ object SparkEnv extends Logging { ms } + val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) + .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) + val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf, isDriver) } @@ -384,6 +407,7 @@ object SparkEnv extends Logging { blockManager, securityManager, metricsSystem, + shuffleIoPlugin, memoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index a30a501e5d4a1..b6f4fc1921bf4 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -118,6 +118,8 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) env.metricsSystem.registerSource(executorSource) env.metricsSystem.registerSource(env.blockManager.shuffleMetricsSource) + // Initialize the ShuffleDataIO + env.shuffleDataIO.foreach(_.initialize()) } // Whether to load classes in user jars before those in Spark jars diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bede012e33977..c249565fa1519 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -182,6 +182,9 @@ package object config { private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) + private[spark] val K8S_SHUFFLE_SERVICE_ENABLED = + ConfigBuilder("spark.k8s.shuffle.service.enabled").booleanConf.createWithDefault(false) + private[spark] val SHUFFLE_SERVICE_PORT = ConfigBuilder("spark.shuffle.service.port").intConf.createWithDefault(7337) @@ -441,6 +444,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_SERVICE_PROVIDER_CLASS = + ConfigBuilder("spark.shuffle.provider.plugin.class") + .doc("Experimental. Specify a class that can handle detecting shuffle service pods.") + .stringConf + .createOptional + private[spark] val SHUFFLE_IO_PLUGIN_CLASS = ConfigBuilder("spark.shuffle.io.plugin.class") .doc("Experimental. Specify a class that can handle reading and writing shuffle blocks to" + diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64f0a060a247c..21613a5946f68 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,13 +19,13 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} -import scala.collection.mutable - import org.roaringbitmap.RoaringBitmap +import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.config -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.api.CommittedPartition +import org.apache.spark.storage.{BlockManagerId, ShuffleLocation} import org.apache.spark.util.Utils /** @@ -36,6 +36,8 @@ private[spark] sealed trait MapStatus { /** Location where this task was run. */ def location: BlockManagerId + def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] + /** * Estimated size for the reduce block, in bytes. * @@ -56,11 +58,29 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + def apply(loc: BlockManagerId, committedPartitions: Array[CommittedPartition]): MapStatus = { + val shuffleLocationsArray = committedPartitions.collect { + case partition if partition != null && partition.shuffleLocation().isPresent + => partition.shuffleLocation().get() + case _ => null + } + val lengthsArray = committedPartitions.collect { + case partition if partition != null => partition.length() + case _ => 0 + + } + if (committedPartitions.length > minPartitionsToUseHighlyCompressMapStatus) { + HighlyCompressedMapStatus(loc, lengthsArray, shuffleLocationsArray) + } else { + new CompressedMapStatus(loc, lengthsArray, shuffleLocationsArray) + } + } + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation]) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation]) } } @@ -103,17 +123,28 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var shuffleLocations: Array[ShuffleLocation]) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], null) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], + shuffleLocations: Array[ShuffleLocation]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), shuffleLocations) } override def location: BlockManagerId = loc + override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = { + if (shuffleLocations.apply(reduceId) == null) { + Option.empty + } else { + Option.apply(shuffleLocations.apply(reduceId)) + } + } + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } @@ -122,6 +153,7 @@ private[spark] class CompressedMapStatus( loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) + out.writeObject(shuffleLocations) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -129,6 +161,7 @@ private[spark] class CompressedMapStatus( val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) + shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]] } } @@ -148,17 +181,26 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte]) + private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], + private[this] var shuffleLocations: Array[ShuffleLocation]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, null) // For deserialization only override def location: BlockManagerId = loc + override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = { + if (shuffleLocations.apply(reduceId) == null) { + Option.empty + } else { + Option.apply(shuffleLocations.apply(reduceId)) + } + } + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -180,6 +222,7 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeInt(kv._1) out.writeByte(kv._2) } + out.writeObject(shuffleLocations) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -195,11 +238,17 @@ private[spark] class HighlyCompressedMapStatus private ( hugeBlockSizesImpl(block) = size } hugeBlockSizes = hugeBlockSizesImpl + shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]] } } private[spark] object HighlyCompressedMapStatus { def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + apply(loc, uncompressedSizes, Array.empty[ShuffleLocation]) + } + + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], + shuffleLocation: Array[ShuffleLocation]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -240,6 +289,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes) + hugeBlockSizes, shuffleLocation) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7632c35f0318a..caeecedc5d36e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle +import scala.compat.java8.OptionConverters + import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager @@ -53,9 +55,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( appId, handle.shuffleId, mapId) blockIds.map { case blockId@ShuffleBlockId(_, _, reduceId) => - (blockId, reader.fetchPartition(reduceId)) + (blockId, serializerManager.wrapStream(blockId, + reader.fetchPartition(reduceId, OptionConverters.toJava( + mapOutputTracker.getShuffleLocation(handle.shuffleId, mapId, reduceId))))) case dataBlockId@ShuffleDataBlockId(_, _, reduceId) => - (dataBlockId, reader.fetchPartition(reduceId)) + (dataBlockId, serializerManager.wrapStream(dataBlockId, + reader.fetchPartition(reduceId, OptionConverters.toJava( + mapOutputTracker.getShuffleLocation(handle.shuffleId, mapId, reduceId))))) case invalid => throw new IllegalArgumentException(s"Invalid block id $invalid") } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala index 2eed51962181c..8a776281041d1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala @@ -20,38 +20,36 @@ package org.apache.spark.shuffle import java.io.{InputStream, OutputStream} import java.nio.ByteBuffer -import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.api.ShufflePartitionWriter +import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.{ByteBufferInputStream, Utils} class ShufflePartitionWriterOutputStream( - partitionWriter: ShufflePartitionWriter, buffer: ByteBuffer, bufferSize: Int) - extends OutputStream { + blockId: ShuffleBlockId, + partitionWriter: ShufflePartitionWriter, + buffer: ByteBuffer, + serializerManager: SerializerManager) + extends OutputStream { - private var currentChunkSize = 0 - private val bufferForRead = buffer.asReadOnlyBuffer() private var underlyingOutputStream: OutputStream = _ override def write(b: Int): Unit = { - buffer.putInt(b) - currentChunkSize += 1 - if (currentChunkSize == bufferSize) { + buffer.put(b.asInstanceOf[Byte]) + if (buffer.remaining() == 0) { pushBufferedBytesToUnderlyingOutput() } } private def pushBufferedBytesToUnderlyingOutput(): Unit = { - bufferForRead.reset() - var bufferInputStream: InputStream = new ByteBufferInputStream(bufferForRead) - if (currentChunkSize < bufferSize) { - bufferInputStream = new LimitedInputStream(bufferInputStream, currentChunkSize) - } + buffer.flip() + var bufferInputStream: InputStream = new ByteBufferInputStream(buffer) if (underlyingOutputStream == null) { - underlyingOutputStream = partitionWriter.openPartitionStream() + underlyingOutputStream = serializerManager.wrapStream(blockId, + partitionWriter.openPartitionStream()) } Utils.copyStream(bufferInputStream, underlyingOutputStream, false, false) - buffer.reset() - currentChunkSize = 0 + buffer.clear() } override def flush(): Unit = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala similarity index 60% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala rename to core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala index 83daddf714489..96d529872b306 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala @@ -14,24 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.scheduler.cluster.k8s -import io.fabric8.kubernetes.api.model.Pod +package org.apache.spark.shuffle -sealed trait ExecutorPodState { - def pod: Pod +trait ShuffleServiceAddressProvider { + def start(): Unit = {} + def getShuffleServiceAddresses(): List[(String, Int)] + def stop(): Unit = {} } -case class PodRunning(pod: Pod) extends ExecutorPodState - -case class PodPending(pod: Pod) extends ExecutorPodState - -sealed trait FinalPodState extends ExecutorPodState - -case class PodSucceeded(pod: Pod) extends FinalPodState - -case class PodFailed(pod: Pod) extends FinalPodState - -case class PodDeleted(pod: Pod) extends FinalPodState - -case class PodUnknown(pod: Pod) extends ExecutorPodState +private[spark] object DefaultShuffleServiceAddressProvider extends ShuffleServiceAddressProvider { + override def getShuffleServiceAddresses(): List[(String, Int)] = List.empty[(String, Int)] +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala new file mode 100644 index 0000000000000..68adb8e44585c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.SparkConf + +trait ShuffleServiceAddressProviderFactory { + def canCreate(masterUrl: String): Boolean + def create(conf: SparkConf): ShuffleServiceAddressProvider +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 53a5c4f3afba4..eb7ae313918ed 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -77,11 +77,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager " Shuffle will continue to spill to disk when necessary.") } - private val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) - .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) - - shuffleIoPlugin.foreach(_.initialize()) - /** * A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ @@ -131,7 +126,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition, context, metrics, - shuffleIoPlugin.map(_.readSupport())) + SparkEnv.get.shuffleDataIO.map(_.readSupport())) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -154,7 +149,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context, env.conf, metrics, - shuffleIoPlugin.map(_.writeSupport()).orNull) + env.shuffleDataIO.map(_.writeSupport()).orNull) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, @@ -163,10 +158,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager mapId, env.conf, metrics, - shuffleIoPlugin.map(_.writeSupport()).orNull) + env.shuffleDataIO.map(_.writeSupport()).orNull) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, shuffleIoPlugin.map(_.writeSupport())) + shuffleBlockResolver, other, mapId, context, env.shuffleDataIO.map(_.writeSupport())) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 1c804c99d0e31..b6ab2f354e81b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,13 +70,16 @@ private[spark] class SortShuffleWriter[K, V, C]( val tmp = Utils.tempFileWith(output) try { val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = pluggableWriteSupport.map { writeSupport => - sorter.writePartitionedToExternalShuffleWriteSupport(mapId, dep.shuffleId, writeSupport) + val committedPartitions = pluggableWriteSupport.map { writeSupport => + sorter.writePartitionedToExternalShuffleWriteSupport(blockId, writeSupport) }.getOrElse(sorter.writePartitionedFile(blockId, tmp)) if (pluggableWriteSupport.isEmpty) { - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, + mapId, + committedPartitions.map(_.length()), + tmp) } - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, committedPartitions) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1dfbc6effb346..f604f20ee7220 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -132,13 +132,14 @@ private[spark] class BlockManager( private[spark] val externalShuffleServiceEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) + private val remoteReadNioBufferConversion = conf.getBoolean("spark.network.remoteReadNioBufferConversion", false) val diskBlockManager = { // Only perform cleanup if an external service is not serving our shuffle files. - val deleteFilesOnStop = - !externalShuffleServiceEnabled || executorId == SparkContext.DRIVER_IDENTIFIER + val deleteFilesOnStop = !externalShuffleServiceEnabled || + executorId == SparkContext.DRIVER_IDENTIFIER new DiskBlockManager(conf, deleteFilesOnStop) } @@ -259,8 +260,8 @@ private[spark] class BlockManager( blockManagerId } - // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { + // Register Executors' configuration with the local shuffle service, if one should exist. registerWithExternalShuffleServer() } @@ -1663,7 +1664,7 @@ private[spark] object BlockManager { blockManagers.toMap } - private class ShuffleMetricsSource( + class ShuffleMetricsSource( override val sourceName: String, metricSet: MetricSet) extends Source { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 86f7c08eddcb5..cf8e4793c448e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -437,7 +437,6 @@ final class ShuffleBlockFetcherIterator( s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" throwFetchFailedException(blockId, address, new IOException(msg)) } - val in = try { buf.createInputStream() } catch { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala new file mode 100644 index 0000000000000..72846cb001c8a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import java.io.Externalizable + +trait ShuffleLocation extends Externalizable { + +} diff --git a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala index b03276b2ce16f..b2263a51051a0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala @@ -19,9 +19,9 @@ package org.apache.spark.storage import java.nio.ByteBuffer -import org.apache.spark.serializer.{SerializationStream, SerializerInstance} +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShufflePartitionWriterOutputStream -import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} +import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleMapOutputWriter, ShufflePartitionWriter} /** * Replicates the concept of {@link DiskBlockObjectWriter}, but with some key differences: @@ -30,10 +30,12 @@ import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWri * left to the implementation of the underlying implementation of the writer plugin. */ private[spark] class ShufflePartitionObjectWriter( + blockId: ShuffleBlockId, bufferSize: Int, serializerInstance: SerializerInstance, + serializerManager: SerializerManager, mapOutputWriter: ShuffleMapOutputWriter) - extends PairsWriter { + extends PairsWriter { // Reused buffer. Experiments should be done with off-heap at some point. private val buffer = ByteBuffer.allocate(bufferSize) @@ -44,22 +46,21 @@ private[spark] class ShufflePartitionObjectWriter( def startNewPartition(partitionId: Int): Unit = { require(buffer.position() == 0, "Buffer was not flushed to the underlying output on the previous partition.") - buffer.reset() currentWriter = mapOutputWriter.newPartitionWriter(partitionId) val currentWriterStream = new ShufflePartitionWriterOutputStream( - currentWriter, buffer, bufferSize) + blockId, currentWriter, buffer, serializerManager) objectOutputStream = serializerInstance.serializeStream(currentWriterStream) } - def commitCurrentPartition(): Long = { + def commitCurrentPartition(): CommittedPartition = { require(objectOutputStream != null, "Cannot commit a partition that has not been started.") require(currentWriter != null, "Cannot commit a partition that has not been started.") objectOutputStream.close() - val length = currentWriter.commitAndGetTotalLength() - buffer.reset() + val committedPartition = currentWriter.commitPartition() + buffer.clear() currentWriter = null objectOutputStream = null - length + committedPartition } def abortCurrentPartition(throwable: Exception): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 569c8bd092f37..70c36c40865ba 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -18,19 +18,19 @@ package org.apache.spark.util.collection import java.io._ -import java.util.Comparator +import java.util.{Comparator, Optional} +import com.google.common.io.ByteStreams import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.io.ByteStreams - import org.apache.spark.{util, _} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer._ -import org.apache.spark.shuffle.api.ShuffleWriteSupport -import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShufflePartitionObjectWriter} +import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleWriteSupport} +import org.apache.spark.shuffle.sort.LocalCommittedPartition +import org.apache.spark.storage._ /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -683,10 +683,10 @@ private[spark] class ExternalSorter[K, V, C]( */ def writePartitionedFile( blockId: BlockId, - outputFile: File): Array[Long] = { + outputFile: File): Array[CommittedPartition] = { // Track location of each range in the output file - val lengths = new Array[Long](numPartitions) + val committedPartitions = new Array[CommittedPartition](numPartitions) val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, context.taskMetrics().shuffleWriteMetrics) @@ -700,7 +700,7 @@ private[spark] class ExternalSorter[K, V, C]( it.writeNext(writer) } val segment = writer.commitAndGet() - lengths(partitionId) = segment.length + committedPartitions(partitionId) = new LocalCommittedPartition(segment.length) } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. @@ -710,7 +710,7 @@ private[spark] class ExternalSorter[K, V, C]( writer.write(elem._1, elem._2) } val segment = writer.commitAndGet() - lengths(id) = segment.length + committedPartitions(id) = new LocalCommittedPartition(segment.length) } } } @@ -720,21 +720,25 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - lengths + committedPartitions } /** * Write all partitions to some backend that is pluggable. */ def writePartitionedToExternalShuffleWriteSupport( - mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[Long] = { + blockId: ShuffleBlockId, + writeSupport: ShuffleWriteSupport): Array[CommittedPartition] = { // Track location of each range in the output file - val lengths = new Array[Long](numPartitions) - val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, shuffleId, mapId) + val committedPartitions = new Array[CommittedPartition](numPartitions) + val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, blockId.shuffleId, + blockId.mapId) val writer = new ShufflePartitionObjectWriter( + blockId, Math.min(serializerBatchSize, Integer.MAX_VALUE).toInt, serInstance, + serializerManager, mapOutputWriter) try { @@ -749,7 +753,7 @@ private[spark] class ExternalSorter[K, V, C]( while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(writer) } - lengths(partitionId) = writer.commitCurrentPartition() + committedPartitions(partitionId) = writer.commitCurrentPartition() } catch { case e: Exception => util.Utils.tryLogNonFatalError { @@ -767,7 +771,7 @@ private[spark] class ExternalSorter[K, V, C]( for (elem <- elements) { writer.write(elem._1, elem._2) } - lengths(id) = writer.commitCurrentPartition() + committedPartitions(id) = writer.commitCurrentPartition() } catch { case e: Exception => util.Utils.tryLogNonFatalError { @@ -781,6 +785,7 @@ private[spark] class ExternalSorter[K, V, C]( mapOutputWriter.commitAllPartitions() } catch { case e: Exception => + logError("Error writing shuffle data.", e) util.Utils.tryLogNonFatalError { writer.abortCurrentPartition(e) mapOutputWriter.abort(e) @@ -791,7 +796,7 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - lengths + committedPartitions } def stop(): Unit = { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 360c1769ad31a..93ab301a4cb9c 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -23,7 +23,7 @@ import java.nio.file.StandardOpenOption; import java.util.*; -import org.apache.commons.io.IOUtils; +import org.apache.spark.shuffle.api.CommittedPartition; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -351,7 +351,8 @@ private void testMergingSpills( private void testMergingSpills( boolean transferToEnabled, boolean useShuffleWriterPlugin) throws IOException { - final UnsafeShuffleWriter writer = createWriter(transferToEnabled, useShuffleWriterPlugin); + final UnsafeShuffleWriter writer = + createWriter(transferToEnabled, useShuffleWriterPlugin); final ArrayList> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { dataToWrite.add(new Tuple2<>(i, i)); @@ -644,11 +645,13 @@ public void testPeakMemoryUsed() throws Exception { private final class TestShuffleWriteSupport implements ShuffleWriteSupport { @Override - public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) { + public ShuffleMapOutputWriter newMapOutputWriter( + String appId, int shuffleId, int mapId) { try { if (!mergedOutputFile.exists() && !mergedOutputFile.createNewFile()) { throw new IllegalStateException( - String.format("Failed to create merged output file %s.", mergedOutputFile.getAbsolutePath())); + String.format("Failed to create merged output file %s.", + mergedOutputFile.getAbsolutePath())); } } catch (IOException e) { throw new RuntimeException(e); @@ -673,7 +676,7 @@ public OutputStream openPartitionStream() { } @Override - public long commitAndGetTotalLength() { + public CommittedPartition commitPartition() { byte[] partitionBytes = byteBuffer.toByteArray(); try { Files.write(mergedOutputFile.toPath(), partitionBytes, StandardOpenOption.APPEND); @@ -682,7 +685,7 @@ public long commitAndGetTotalLength() { } int length = partitionBytes.length; partitionSizesInMergedFile[partitionId] = length; - return length; + return new LocalCommittedPartition(length); } @Override diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d477242..90f6c3523ece8 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer - import org.mockito.Matchers.any import org.mockito.Mockito._ @@ -27,7 +26,7 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, ShuffleLocation} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -84,9 +83,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array[Long](compressedSize1000, compressedSize10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array[Long](compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +106,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array[Long](compressedSize1000, compressedSize1000, compressedSize1000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array[Long](compressedSize10000, compressedSize1000, compressedSize1000))) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -260,7 +259,8 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), + Array.empty[ShuffleLocation])) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) diff --git a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala index 3a68fded945b3..579fc9a45ba9b 100644 --- a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala +++ b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala @@ -19,9 +19,11 @@ package org.apache.spark import java.io._ import java.nio.file.Paths +import java.util.Optional import javax.ws.rs.core.UriBuilder import org.apache.spark.shuffle.api._ +import org.apache.spark.storage.ShuffleLocation import org.apache.spark.util.Utils class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { @@ -31,7 +33,7 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { override def initialize(): Unit = {} override def readSupport(): ShuffleReadSupport = (appId: String, shuffleId: Int, mapId: Int) => { - reduceId: Int => { + (reduceId: Int, shuffleLocation: Optional[ShuffleLocation]) => { new FileInputStream(resolvePartitionFile(appId, shuffleId, mapId, reduceId)) } } @@ -49,8 +51,14 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { new FileOutputStream(shuffleFile) } - override def commitAndGetTotalLength(): Long = - resolvePartitionFile(appId, shuffleId, mapId, partitionId).length + override def commitPartition(): CommittedPartition = { + new CommittedPartition { + override def length(): Long = + resolvePartitionFile(appId, shuffleId, mapId, partitionId).length + + override def shuffleLocation(): Optional[ShuffleLocation] = Optional.empty() + } + } override def abort(failureReason: Exception): Unit = {} } @@ -64,7 +72,6 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { private def resolvePartitionFile( appId: String, shuffleId: Int, mapId: Int, reduceId: Int): File = { - import java.io.OutputStream Paths.get(UriBuilder.fromUri(shuffleDir.toURI) .path(appId) .path(shuffleId.toString) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 467e49026a029..2e1950bbc7f57 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -26,7 +26,6 @@ import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag - import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import org.roaringbitmap.RoaringBitmap diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 75202b57833ae..8919692a8f761 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -49,7 +49,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ - private val conf: SparkConf = new SparkConf(loadDefaults = false).set("spark.app.id", "spark-app-id") + private val conf: SparkConf = + new SparkConf(loadDefaults = false).set("spark.app.id", "spark-app-id") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala new file mode 100644 index 0000000000000..883ac10718dfd --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import org.apache.spark.sql.SparkSession + +/** + * Usage: GroupByShuffleTest + */ +object GroupByShuffleTest { + def main(args: Array[String]) { + val spark = SparkSession + .builder + .appName("GroupByShuffle Test") + .getOrCreate() + + val words = Array("one", "two", "two", "three", "three", "three") + val wordPairsRDD = spark.sparkContext.parallelize(words).map(word => (word, 1)) + + val wordCountsWithGroup = wordPairsRDD + .groupByKey() + .map(t => (t._1, t._2.sum)) + .collect() + + println(wordCountsWithGroup.mkString(",")) + + val wordPairsRDD2 = spark.sparkContext.parallelize(words, 1).map(word => (word, 1)) + + val wordCountsWithGroup2 = wordPairsRDD2 + .groupByKey() + .map(t => (t._1, t._2.sum)) + .collect() + + println(wordCountsWithGroup2.mkString(",")) + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 4d3c34041bc17..a1e6f83f7cba4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -23,7 +23,7 @@ import java.util.Random import org.apache.spark.sql.SparkSession /** - * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + * Usage: SkewedGroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object SkewedGroupByTest { def main(args: Array[String]) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index e8bf16df190e8..44d5a2c22dd42 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -289,6 +289,26 @@ private[spark] object Config extends Logging { .booleanConf .createWithDefault(true) + val KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE = + ConfigBuilder("spark.kubernetes.shuffle.service.remote.pods.namespace") + .doc("Namespace of the pods that are running the shuffle service instances for remote" + + " pushing of shuffle data.") + .stringConf + .createOptional + + val KUBERNETES_REMOTE_SHUFFLE_SERVICE_PORT = + ConfigBuilder("spark.kubernetes.shuffle.service.remote.port") + .doc("Port of the external k8s shuffle service pods") + .intConf + .createWithDefault(7337) + + val KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL = + ConfigBuilder("spark.kubernetes.shuffle.service.cleanup.interval") + .doc("Cleanup interval for the shuffle service to take down an app id") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("30s") + + val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." @@ -313,4 +333,7 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." + + val KUBERNETES_REMOTE_SHUFFLE_SERVICE_LABELS = + "spark.kubernetes.shuffle.service.remote.label." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala new file mode 100644 index 0000000000000..07dbffacc31f5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import java.io.{DataOutputStream, File, FileOutputStream} +import java.nio.ByteBuffer +import java.nio.file.Paths +import java.util +import java.util.concurrent.{ConcurrentHashMap, ExecutionException, TimeUnit} +import java.util.function.BiFunction + +import com.codahale.metrics._ +import com.google.common.cache.{CacheBuilder, CacheLoader, Weigher} +import scala.collection.JavaConverters._ +import scala.collection.immutable.TreeMap + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.ExternalShuffleService +import org.apache.spark.deploy.k8s.Config.KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.FileSegmentManagedBuffer +import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.protocol._ +import org.apache.spark.network.util.{JavaUtils, TransportConf} +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * An RPC endpoint that receives registration requests from Spark drivers running on Kubernetes. + * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. + */ +private[spark] class KubernetesExternalShuffleBlockHandler( + transportConf: TransportConf, + cleanerIntervals: Long, + indexCacheSize: String) + extends ExternalShuffleBlockHandler(transportConf, null) with Logging { + + ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher") + .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervals, TimeUnit.SECONDS) + + // Stores a map of app id to app state (timeout value and last heartbeat) + private val connectedApps = new ConcurrentHashMap[String, AppState]() + private val indexCacheLoader = new CacheLoader[File, ShuffleIndexInformation]() { + override def load(file: File): ShuffleIndexInformation = new ShuffleIndexInformation(file) + } + private val shuffleIndexCache = CacheBuilder.newBuilder() + .maximumWeight(JavaUtils.byteStringAsBytes(indexCacheSize)) + .weigher(new Weigher[File, ShuffleIndexInformation]() { + override def weigh(file: File, indexInfo: ShuffleIndexInformation): Int = + indexInfo.getSize + }) + .build(indexCacheLoader) + + // TODO: Investigate cleanup if appId is terminated + private val globalPartitionLengths = new ConcurrentHashMap[(String, Int, Int), TreeMap[Int, Long]] + + private final val shuffleDir = Utils.createDirectory("/tmp", "spark-shuffle-dir") + + private final val metricSet: RemoteShuffleMetrics = new RemoteShuffleMetrics() + + private def scanLeft[a, b](xs: Iterable[a])(s: b)(f: (b, a) => b) = + xs.foldLeft(List(s))( (acc, x) => f(acc.head, x) :: acc).reverse + + protected override def handleMessage( + message: BlockTransferMessage, + client: TransportClient, + callback: RpcResponseCallback): Unit = { + message match { + case RegisterDriverParam(appId, appState) => + val responseDelayContext = metricSet.registerDriverRequestLatencyMillis.time() + val address = client.getSocketAddress + val timeout = appState.heartbeatTimeout + logInfo(s"Received registration request from app $appId (remote address $address, " + + s"heartbeat timeout $timeout ms).") + if (connectedApps.containsKey(appId)) { + logWarning(s"Received a registration request from app $appId, but it was already " + + s"registered") + } + val driverDir = Paths.get(shuffleDir.getAbsolutePath, appId).toFile + if (!driverDir.mkdir()) { + throw new RuntimeException(s"Failed to create dir ${driverDir.getAbsolutePath}") + } + connectedApps.put(appId, appState) + responseDelayContext.stop() + callback.onSuccess(ByteBuffer.allocate(0)) + + case Heartbeat(appId) => + val address = client.getSocketAddress + Option(connectedApps.get(appId)) match { + case Some(existingAppState) => + logTrace(s"Received ShuffleServiceHeartbeat from app '$appId' (remote " + + s"address $address).") + existingAppState.lastHeartbeat = System.nanoTime() + case None => + logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " + + s"address $address, appId '$appId').") + } + + case RegisterIndexParam(appId, shuffleId, mapId) => + logInfo(s"Received register index param from app $appId") + globalPartitionLengths.putIfAbsent( + (appId, shuffleId, mapId), TreeMap.empty[Int, Long]) + callback.onSuccess(ByteBuffer.allocate(0)) + + case UploadIndexParam(appId, shuffleId, mapId) => + val responseDelayContext = metricSet.writeIndexRequestLatencyMillis.time() + try { + logInfo(s"Received upload index param from app $appId") + val partitionMap = globalPartitionLengths.get((appId, shuffleId, mapId)) + val out = new DataOutputStream( + new FileOutputStream(getFile(appId, shuffleId, mapId, "index"))) + scanLeft(partitionMap.values)(0L)(_ + _).foreach(l => out.writeLong(l)) + out.close() + callback.onSuccess(ByteBuffer.allocate(0)) + } finally { + responseDelayContext.stop() + } + + case OpenParam(appId, shuffleId, mapId, partitionId) => + logInfo(s"Received open param from app $appId") + val responseDelayContext = metricSet.openBlockRequestLatencyMillis.time() + val indexFile = getFile(appId, shuffleId, mapId, "index") + logInfo(s"Map: " + + s"${globalPartitionLengths.get((appId, shuffleId, mapId)).toString()}" + + s"for partitionId: $partitionId") + try { + val shuffleIndexInformation = shuffleIndexCache.get(indexFile) + val shuffleIndexRecord = shuffleIndexInformation.getIndex(partitionId) + val managedBuffer = new FileSegmentManagedBuffer( + transportConf, + getFile(appId, shuffleId, mapId, "data"), + shuffleIndexRecord.getOffset, + shuffleIndexRecord.getLength) + callback.onSuccess(managedBuffer.nioByteBuffer()) + } catch { + case e: ExecutionException => logError(s"Unable to write index file $indexFile", e) + } finally { + responseDelayContext.stop() + } + case _ => super.handleMessage(message, client, callback) + } + } + + protected override def handleStream( + header: BlockTransferMessage, + client: TransportClient, + callback: RpcResponseCallback): StreamCallbackWithID = { + header match { + case UploadParam( + appId, shuffleId, mapId, partitionId, partitionLength) => + val responseDelayContext = metricSet.writeBlockRequestLatencyMillis.time() + try { + logInfo(s"Received upload param from app $appId") + val lengthMap = TreeMap(partitionId -> partitionLength.toLong) + globalPartitionLengths.merge((appId, shuffleId, mapId), lengthMap, + new BiFunction[TreeMap[Int, Long], TreeMap[Int, Long], TreeMap[Int, Long]]() { + override def apply(t: TreeMap[Int, Long], u: TreeMap[Int, Long]): + TreeMap[Int, Long] = { + t ++ u + } + }) + getFileWriterStreamCallback( + appId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) + } finally { + responseDelayContext.stop() + } + case _ => + super.handleStream(header, client, callback) + } + } + + protected override def getAllMetrics: MetricSet = metricSet + + private def getFileWriterStreamCallback( + appId: String, + shuffleId: Int, + mapId: Int, + extension: String, + fileType: FileWriterStreamCallback.FileType): StreamCallbackWithID = { + val file = getFile(appId, shuffleId, mapId, extension) + val streamCallback = + new FileWriterStreamCallback(appId, shuffleId, mapId, file, fileType) + streamCallback.open() + streamCallback + } + + private def getFile( + appId: String, + shuffleId: Int, + mapId: Int, + extension: String): File = { + Paths.get(shuffleDir.getAbsolutePath, appId, + s"shuffle_${shuffleId}_${mapId}_0.$extension").toFile + } + + /** An extractor object for matching BlockTransferMessages. */ + private object RegisterDriverParam { + def unapply(r: RegisterDriver): Option[(String, AppState)] = + Some((r.getAppId, new AppState(r.getHeartbeatTimeoutMs, System.nanoTime()))) + } + + private object Heartbeat { + def unapply(h: ShuffleServiceHeartbeat): Option[String] = Some(h.getAppId) + } + + private object UploadParam { + def unapply(u: UploadShufflePartitionStream): Option[(String, Int, Int, Int, Int)] = + Some((u.appId, u.shuffleId, u.mapId, u.partitionId, u.partitionLength)) + } + + private object UploadIndexParam { + def unapply(u: UploadShuffleIndex): Option[(String, Int, Int)] = + Some((u.appId, u.shuffleId, u.mapId)) + } + + private object RegisterIndexParam { + def unapply(u: RegisterShuffleIndex): Option[(String, Int, Int)] = + Some((u.appId, u.shuffleId, u.mapId)) + } + + private object OpenParam { + def unapply(o: OpenShufflePartition): Option[(String, Int, Int, Int)] = + Some((o.appId, o.shuffleId, o.mapId, o.partitionId)) + } + + private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long) + + private class CleanerThread extends Runnable { + override def run(): Unit = { + val now = System.nanoTime() + connectedApps.asScala.foreach { case (appId, appState) => + if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) { + logInfo(s"Application $appId timed out. Removing shuffle files.") + connectedApps.remove(appId) + applicationRemoved(appId, false) + try { + val driverDir = Paths.get(shuffleDir.getAbsolutePath, appId).toFile + logInfo(s"Driver dir is: ${driverDir.getAbsolutePath}") + driverDir.delete() + } catch { + case e: Exception => logError("Unable to delete files", e) + } + } + } + } + } + private class RemoteShuffleMetrics extends MetricSet { + private val allMetrics = new util.HashMap[String, Metric]() + // Time latency for write request in ms + private val _writeBlockRequestLatencyMillis = new Timer() + def writeBlockRequestLatencyMillis: Timer = _writeBlockRequestLatencyMillis + // Time latency for write index file in ms + private val _writeIndexRequestLatencyMillis = new Timer() + def writeIndexRequestLatencyMillis: Timer = _writeIndexRequestLatencyMillis + // Time latency for read request in ms + private val _openBlockRequestLatencyMillis = new Timer() + def openBlockRequestLatencyMillis: Timer = _openBlockRequestLatencyMillis + // Time latency for executor registration latency in ms + private val _registerDriverRequestLatencyMillis = new Timer() + def registerDriverRequestLatencyMillis: Timer = _registerDriverRequestLatencyMillis + // Block transfer rate in byte per second + private val _blockTransferRateBytes = new Meter() + def blockTransferRateBytes: Meter = _blockTransferRateBytes + + allMetrics.put("writeBlockRequestLatencyMillis", _writeBlockRequestLatencyMillis) + allMetrics.put("writeIndexRequestLatencyMillis", _writeIndexRequestLatencyMillis) + allMetrics.put("openBlockRequestLatencyMillis", _openBlockRequestLatencyMillis) + allMetrics.put("registerDriverRequestLatencyMillis", _registerDriverRequestLatencyMillis) + allMetrics.put("blockTransferRateBytes", _blockTransferRateBytes) + + override def getMetrics: util.Map[String, Metric] = allMetrics + } + +} + +/** + * A wrapper of [[ExternalShuffleService]] that provides an additional endpoint for drivers + * to associate with. This allows the shuffle service to detect when a driver is terminated + * and can clean up the associated shuffle files. + */ +private[spark] class KubernetesExternalShuffleService( + conf: SparkConf, securityManager: SecurityManager) + extends ExternalShuffleService(conf, securityManager) { + + protected override def newShuffleBlockHandler( + transportConf: TransportConf): ExternalShuffleBlockHandler = { + val cleanerIntervals = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL) + val indexCacheSize = conf.get("spark.shuffle.service.index.cache.size", "100m") + new KubernetesExternalShuffleBlockHandler(transportConf, cleanerIntervals, indexCacheSize) + } +} + +private[spark] object KubernetesExternalShuffleService extends Logging { + + def main(args: Array[String]): Unit = { + ExternalShuffleService.main(args, + (conf: SparkConf, sm: SecurityManager) => new KubernetesExternalShuffleService(conf, sm)) + } +} + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 77bd66b608e7c..44e843cdb73aa 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -21,11 +21,13 @@ import java.io.File import com.google.common.base.Charsets import com.google.common.io.Files import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient} +import io.fabric8.kubernetes.client.Config._ import io.fabric8.kubernetes.client.utils.HttpClientUtils import okhttp3.Dispatcher import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.util.ThreadUtils /** @@ -34,6 +36,35 @@ import org.apache.spark.util.ThreadUtils * options for different components. */ private[spark] object SparkKubernetesClientFactory { + def getDriverKubernetesClient(conf: SparkConf, masterURL: String): KubernetesClient = { + val wasSparkSubmittedInClusterMode = conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) + val (authConfPrefix, + apiServerUri, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { + require(conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, + "If the application is deployed using spark-submit in cluster mode, the driver pod name " + + "must be provided.") + (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + KUBERNETES_MASTER_INTERNAL_URL, + Some(new File(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + } else { + (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, + KubernetesUtils.parseMasterUrl(masterURL), + None, + None) + } + + val kubernetesClient = createKubernetesClient( + apiServerUri, + Some(conf.get(KUBERNETES_NAMESPACE)), + authConfPrefix, + conf, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) + kubernetesClient + } def createKubernetesClient( master: String, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala index 435a5f1461c92..aa4ce28aeb6ba 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging /** * An immutable view of the current executor pods that are running in the cluster. */ -private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, SparkPodState]) { import ExecutorPodsSnapshot._ @@ -42,15 +42,15 @@ object ExecutorPodsSnapshot extends Logging { ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) } - def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, SparkPodState]) - private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, SparkPodState] = { executorPods.map { pod => (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) }.toMap } - private def toState(pod: Pod): ExecutorPodState = { + private def toState(pod: Pod): SparkPodState = { if (isDeleted(pod)) { PodDeleted(pod) } else { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index b31fbb420ed6d..fd245f8f5f31b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -42,32 +42,8 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val wasSparkSubmittedInClusterMode = sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) - val (authConfPrefix, - apiServerUri, - defaultServiceAccountToken, - defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { - require(sc.conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, - "If the application is deployed using spark-submit in cluster mode, the driver pod name " + - "must be provided.") - (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, - KUBERNETES_MASTER_INTERNAL_URL, - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - } else { - (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, - KubernetesUtils.parseMasterUrl(masterURL), - None, - None) - } - - val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( - apiServerUri, - Some(sc.conf.get(KUBERNETES_NAMESPACE)), - authConfPrefix, - sc.conf, - defaultServiceAccountToken, - defaultServiceAccountCaCrt) + val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( + sc.conf, masterURL) if (sc.conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) { KubernetesUtils.loadPodFromTemplate( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala new file mode 100644 index 0000000000000..e5d9594fc3d5e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.Locale + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.internal.Logging + +sealed trait SparkPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends SparkPodState + +case class PodPending(pod: Pod) extends SparkPodState + +sealed trait FinalPodState extends SparkPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends SparkPodState + +object SparkPodState extends Logging { + def toState(pod: Pod): SparkPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase(Locale.ROOT) + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala new file mode 100644 index 0000000000000..f0bdb35216655 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.k8s + +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} +import java.util.concurrent.locks.ReentrantReadWriteLock + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watch, Watcher} +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.cluster.k8s.{SparkPodState, _} +import org.apache.spark.shuffle.ShuffleServiceAddressProvider +import org.apache.spark.util.Utils + + +class KubernetesShuffleServiceAddressProvider( + kubernetesClient: KubernetesClient, + pollForPodsExecutor: ScheduledExecutorService, + podLabels: Map[String, String], + namespace: String, + portNumber: Int) + extends ShuffleServiceAddressProvider with Logging { + + // General implementation remark: this bears a strong resemblance to ExecutorPodsSnapshotsStore, + // but we don't need all "in-between" lists of all executor pods, just the latest known list + // when we query in getShuffleServiceAddresses. + + private val podsUpdateLock = new ReentrantReadWriteLock() + + private val shuffleServicePods = mutable.HashMap.empty[String, Pod] + + private var shuffleServicePodsWatch: Watch = _ + private var pollForPodsTask: ScheduledFuture[_] = _ + + override def start(): Unit = { + pollForPods() + pollForPodsTask = pollForPodsExecutor.scheduleWithFixedDelay( + () => pollForPods(), 0, 10, TimeUnit.SECONDS) + shuffleServicePodsWatch = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava).watch(new PutPodsInCacheWatcher()) + } + + override def stop(): Unit = { + Utils.tryLogNonFatalError { + if (pollForPodsTask != null) { + pollForPodsTask.cancel(false) + } + } + + Utils.tryLogNonFatalError { + if (shuffleServicePodsWatch != null) { + shuffleServicePodsWatch.close() + } + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() + } + } + + override def getShuffleServiceAddresses(): List[(String, Int)] = { + val readLock = podsUpdateLock.readLock() + readLock.lock() + try { + val addresses = shuffleServicePods.values.map(pod => { + (pod.getStatus.getPodIP, portNumber) + }).toList + logInfo(s"Found remote shuffle service addresses at $addresses.") + addresses + } finally { + readLock.unlock() + } + } + + // TODO: Re-register with found shuffle service instances + private def pollForPods(): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + val allPods = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava) + .list() + shuffleServicePods.clear() + allPods.getItems.asScala.foreach(updatePod) + } finally { + writeLock.unlock() + } + } + + private def updatePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only update pods under lock.") + val state = SparkPodState.toState(pod) + state match { + case PodPending(_) | PodFailed(_) | PodSucceeded(_) | PodDeleted(_) => + shuffleServicePods.remove(pod.getMetadata.getName) + case PodRunning(_) => + shuffleServicePods.put(pod.getMetadata.getName, pod) + case _ => + logWarning(s"Unknown state $state for pod named ${pod.getMetadata.getName}") + } + } + + private def deletePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only delete under lock.") + shuffleServicePods.remove(pod.getMetadata.getName) + } + + private class PutPodsInCacheWatcher extends Watcher[Pod] { + override def eventReceived(action: Watcher.Action, pod: Pod): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + updatePod(pod) + } finally { + writeLock.unlock() + } + } + + override def onClose(e: KubernetesClientException): Unit = {} + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala new file mode 100644 index 0000000000000..57e68d4053291 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.k8s + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory +import org.apache.spark.internal.{config => C, Logging} +import org.apache.spark.shuffle._ +import org.apache.spark.util.ThreadUtils + +class KubernetesShuffleServiceAddressProviderFactory + extends ShuffleServiceAddressProviderFactory with Logging { + override def canCreate(masterUrl: String): Boolean = masterUrl.startsWith("k8s://") + + override def create(conf: SparkConf): ShuffleServiceAddressProvider = { + if (conf.get(C.K8S_SHUFFLE_SERVICE_ENABLED)) { + val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( + conf, conf.get("spark.master")) + val pollForPodsExecutor = ThreadUtils.newDaemonThreadPoolScheduledExecutor( + "poll-shuffle-service-pods", 1) + logInfo("Beginning to search for K8S pods that act as an External Shuffle Service") + val shuffleServiceLabels = conf.getAllWithPrefix(KUBERNETES_REMOTE_SHUFFLE_SERVICE_LABELS) + val shuffleServicePodsNamespace = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE) + require(shuffleServicePodsNamespace.isDefined, "Namespace for the pods running the external" + + s" shuffle service must be defined by" + + s" ${KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE.key}") + require(shuffleServiceLabels.nonEmpty, "Requires labels for external shuffle service pods") + + val port: Int = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_PORT) + new KubernetesShuffleServiceAddressProvider( + kubernetesClient, + pollForPodsExecutor, + shuffleServiceLabels.toMap, + shuffleServicePodsNamespace.get, + port) + } else DefaultShuffleServiceAddressProvider + } +} diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 0843040324707..c37e17f92eda4 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -28,7 +28,7 @@ ARG spark_uid=185 RUN set -ex && \ apk upgrade --no-cache && \ - apk add --no-cache bash tini libc6-compat linux-pam krb5 krb5-libs && \ + apk add --no-cache bash tini libc6-compat linux-pam krb5 krb5-libs procps && \ mkdir -p /opt/spark && \ mkdir -p /opt/spark/examples && \ mkdir -p /opt/spark/work-dir && \ diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 859aa836a3157..6d94b9efd1d29 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -28,8 +28,7 @@ import org.apache.spark.deploy.mesos.config._ import org.apache.spark.internal.Logging import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage -import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat} +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterDriver, ShuffleServiceHeartbeat} import org.apache.spark.network.util.TransportConf import org.apache.spark.util.ThreadUtils diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index da71f8f9e407c..a69b0d3050351 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import org.apache.spark.SparkContext import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider} /** * Cluster Manager for creation of Mesos scheduler and backend @@ -60,5 +61,9 @@ private[spark] class MesosClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } + + def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = { + DefaultShuffleServiceAddressProvider + } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala index 64cd1bd088001..8e83d49d2332e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider} /** * Cluster Manager for creation of Yarn scheduler and backend @@ -53,4 +54,8 @@ private[spark] class YarnClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } + + def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = { + DefaultShuffleServiceAddressProvider + } } diff --git a/sbin/start-k8s-shuffle-service.sh b/sbin/start-k8s-shuffle-service.sh new file mode 100644 index 0000000000000..84ee303202200 --- /dev/null +++ b/sbin/start-k8s-shuffle-service.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the K8S external shuffle server on the machine this script is executed on. +# TODO: Describe K8s ESS +# +# Usage: start-k8s-shuffle-service.sh +# +# + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" + +exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.k8s.KubernetesExternalShuffleService 1