diff --git a/common/src/main/java/org/apache/celeborn/common/client/MasterNotLeaderException.java b/common/src/main/java/org/apache/celeborn/common/client/MasterNotLeaderException.java index 38198cf786e..4f9057a1953 100644 --- a/common/src/main/java/org/apache/celeborn/common/client/MasterNotLeaderException.java +++ b/common/src/main/java/org/apache/celeborn/common/client/MasterNotLeaderException.java @@ -21,6 +21,8 @@ import javax.annotation.Nullable; +import scala.Tuple2; + import org.apache.commons.lang3.StringUtils; public class MasterNotLeaderException extends IOException { @@ -33,18 +35,26 @@ public class MasterNotLeaderException extends IOException { public MasterNotLeaderException( String currentPeer, String suggestedLeaderPeer, @Nullable Throwable cause) { + this(currentPeer, Tuple2.apply(suggestedLeaderPeer, suggestedLeaderPeer), false, cause); + } + + public MasterNotLeaderException( + String currentPeer, + Tuple2 suggestedLeaderPeer, + boolean bindPreferIp, + @Nullable Throwable cause) { super( String.format( "Master:%s is not the leader.%s%s", currentPeer, - currentPeer.equals(suggestedLeaderPeer) + currentPeer.equals(suggestedLeaderPeer._1) ? StringUtils.EMPTY : String.format(" Suggested leader is Master:%s.", suggestedLeaderPeer), cause == null ? StringUtils.EMPTY : String.format(" Exception:%s.", cause.getMessage())), cause); - this.leaderPeer = suggestedLeaderPeer; + this.leaderPeer = bindPreferIp ? suggestedLeaderPeer._1 : suggestedLeaderPeer._2; } public String getSuggestedLeaderAddress() { diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index d4531adbdf7..96e57f8cd22 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -427,6 +427,27 @@ object Utils extends Logging { } } + private def getIpHostAddressPair(host: String): (String, String) = { + try { + val inetAddress = InetAddress.getByName(host) + val hostAddress = inetAddress.getHostAddress + if (host.equals(hostAddress)) { + (hostAddress, inetAddress.getCanonicalHostName) + } else { + (hostAddress, host) + } + } catch { + case _: Throwable => (host, host) // return original input + } + } + + // Convert address (ip:port or host:port) to (ip:port, host:port) pair + def addressToIpHostAddressPair(address: String): (String, String) = { + val (host, port) = Utils.parseHostPort(address) + val (_ip, _host) = Utils.getIpHostAddressPair(host) + (_ip + ":" + port, _host + ":" + port) + } + def checkHostPort(hostPort: String): Unit = { if (hostPort != null && hostPort.split(":").length > 2) { assert( diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAHelper.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAHelper.java index fdd22443e8c..b0afab4b137 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAHelper.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAHelper.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.util.Optional; import com.google.protobuf.InvalidProtocolBufferException; import org.apache.ratis.protocol.Message; @@ -34,27 +35,30 @@ public class HAHelper { public static boolean checkShouldProcess( - RpcCallContext context, AbstractMetaManager masterStatusSystem) { + RpcCallContext context, AbstractMetaManager masterStatusSystem, boolean bindPreferIp) { HARaftServer ratisServer = getRatisServer(masterStatusSystem); if (ratisServer != null) { if (ratisServer.isLeader()) { return true; } - sendFailure(context, ratisServer, null); + sendFailure(context, ratisServer, null, bindPreferIp); return false; } return true; } public static void sendFailure( - RpcCallContext context, HARaftServer ratisServer, Throwable cause) { + RpcCallContext context, HARaftServer ratisServer, Throwable cause, boolean bindPreferIp) { if (context != null) { if (ratisServer != null) { - if (ratisServer.getCachedLeaderPeerRpcEndpoint().isPresent()) { + Optional leaderPeer = + ratisServer.getCachedLeaderPeerRpcEndpoint(); + if (leaderPeer.isPresent()) { context.sendFailure( new MasterNotLeaderException( ratisServer.getRpcEndpoint(), - ratisServer.getCachedLeaderPeerRpcEndpoint().get(), + leaderPeer.get().rpcEndpoints, + bindPreferIp, cause)); } else { context.sendFailure( diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java index 9e502058f09..61222549c58 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java @@ -55,6 +55,7 @@ import org.apache.celeborn.common.client.MasterClient; import org.apache.celeborn.common.exception.CelebornRuntimeException; import org.apache.celeborn.common.util.ThreadUtils; +import org.apache.celeborn.common.util.Utils; import org.apache.celeborn.service.deploy.master.clustermeta.ResourceProtos; import org.apache.celeborn.service.deploy.master.clustermeta.ResourceProtos.ResourceResponse; @@ -72,6 +73,7 @@ static long nextCallId() { return CALL_ID_COUNTER.getAndIncrement() & Long.MAX_VALUE; } + private final MasterNode localNode; private final InetSocketAddress ratisAddr; private final String rpcEndpoint; private final RaftServer server; @@ -89,7 +91,8 @@ static long nextCallId() { private long roleCheckIntervalMs; private final ReentrantReadWriteLock roleCheckLock = new ReentrantReadWriteLock(); private Optional cachedPeerRole = Optional.empty(); - private Optional cachedLeaderPeerRpcEndpoint = Optional.empty(); + private Optional cachedLeaderPeerRpcEndpoints = Optional.empty(); + private final CelebornConf conf; private long workerTimeoutDeadline; private long appTimeoutDeadline; @@ -99,7 +102,7 @@ static long nextCallId() { * * @param conf configuration * @param localRaftPeerId raft peer id of this Ratis server - * @param ratisAddr address of the ratis server + * @param localNode local node of this Ratis server * @param raftPeers peer nodes in the raft ring * @throws IOException */ @@ -107,13 +110,13 @@ private HARaftServer( MetaHandler metaHandler, CelebornConf conf, RaftPeerId localRaftPeerId, - InetSocketAddress ratisAddr, - String rpcEndpoint, + MasterNode localNode, List raftPeers) throws IOException { this.metaHandler = metaHandler; - this.ratisAddr = ratisAddr; - this.rpcEndpoint = rpcEndpoint; + this.localNode = localNode; + this.ratisAddr = localNode.ratisAddr(); + this.rpcEndpoint = localNode.rpcEndpoint(); this.raftPeerId = localRaftPeerId; this.raftGroup = RaftGroup.valueOf(RAFT_GROUP_ID, raftPeers); this.masterStateMachine = getStateMachine(); @@ -192,8 +195,8 @@ public static HARaftServer newMasterRatisServer( // Add other nodes belonging to the same service to the Ratis ring raftPeers.add(raftPeer); }); - return new HARaftServer( - metaHandler, conf, localRaftPeerId, ratisAddr, localNode.rpcEndpoint(), raftPeers); + + return new HARaftServer(metaHandler, conf, localRaftPeerId, localNode, raftPeers); } public ResourceResponse submitRequest(ResourceProtos.ResourceRequest request) @@ -421,12 +424,12 @@ public boolean isLeader() { /** * Get the suggested leader peer id. * - * @return RaftPeerId of the suggested leader node. + * @return RaftPeerId of the suggested leader node - Optional */ - public Optional getCachedLeaderPeerRpcEndpoint() { + public Optional getCachedLeaderPeerRpcEndpoint() { this.roleCheckLock.readLock().lock(); try { - return cachedLeaderPeerRpcEndpoint; + return cachedLeaderPeerRpcEndpoints; } finally { this.roleCheckLock.readLock().unlock(); } @@ -440,18 +443,20 @@ public void updateServerRole() { GroupInfoReply groupInfo = getGroupInfo(); RaftProtos.RoleInfoProto roleInfoProto = groupInfo.getRoleInfoProto(); RaftProtos.RaftPeerRole thisNodeRole = roleInfoProto.getRole(); - + Tuple2 leaderPeerRpcEndpoint = null; if (thisNodeRole.equals(RaftProtos.RaftPeerRole.LEADER)) { - setServerRole(thisNodeRole, getRpcEndpoint()); + // Current Node always uses original rpcEndpoint/internalRpcEndpoint, as if something wrong + // they would never return to client. + setServerRole(thisNodeRole, Tuple2.apply(this.rpcEndpoint, this.rpcEndpoint)); } else if (thisNodeRole.equals(RaftProtos.RaftPeerRole.FOLLOWER)) { ByteString leaderNodeId = roleInfoProto.getFollowerInfo().getLeaderInfo().getId().getId(); // There may be a chance, here we get leaderNodeId as null. For // example, in 3 node Ratis, if 2 nodes are down, there will // be no leader. - String leaderPeerRpcEndpoint = null; if (leaderNodeId != null && !leaderNodeId.isEmpty()) { - leaderPeerRpcEndpoint = + String clientAddress = roleInfoProto.getFollowerInfo().getLeaderInfo().getId().getClientAddress(); + leaderPeerRpcEndpoint = Utils.addressToIpHostAddressPair(clientAddress); } setServerRole(thisNodeRole, leaderPeerRpcEndpoint); @@ -470,7 +475,8 @@ public void updateServerRole() { } /** Set the current server role and the leader peer rpc endpoint. */ - private void setServerRole(RaftProtos.RaftPeerRole currentRole, String leaderPeerRpcEndpoint) { + private void setServerRole( + RaftProtos.RaftPeerRole currentRole, Tuple2 leaderPeerRpcEndpoint) { this.roleCheckLock.writeLock().lock(); try { boolean leaderChanged = false; @@ -490,7 +496,12 @@ private void setServerRole(RaftProtos.RaftPeerRole currentRole, String leaderPee } this.cachedPeerRole = Optional.ofNullable(currentRole); - this.cachedLeaderPeerRpcEndpoint = Optional.ofNullable(leaderPeerRpcEndpoint); + if (null != leaderPeerRpcEndpoint) { + this.cachedLeaderPeerRpcEndpoints = + Optional.of(new LeaderPeerEndpoints(leaderPeerRpcEndpoint)); + } else { + this.cachedLeaderPeerRpcEndpoints = Optional.empty(); + } } finally { this.roleCheckLock.writeLock().unlock(); } @@ -543,4 +554,13 @@ public long getWorkerTimeoutDeadline() { public long getAppTimeoutDeadline() { return appTimeoutDeadline; } + + public static class LeaderPeerEndpoints { + // the rpcEndpoints Tuple2 (ip:port, host:port) + public final Tuple2 rpcEndpoints; + + public LeaderPeerEndpoints(Tuple2 rpcEndpoints) { + this.rpcEndpoints = rpcEndpoints; + } + } } diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala index d126edb4a51..e899990d9e6 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala @@ -70,6 +70,8 @@ private[celeborn] class Master( metricsSystem.registerSource(new JVMCPUSource(conf, MetricsSystem.ROLE_MASTER)) metricsSystem.registerSource(new SystemMiscSource(conf, MetricsSystem.ROLE_MASTER)) + private val bindPreferIP: Boolean = conf.bindPreferIP + override val rpcEnv: RpcEnv = RpcEnv.create( RpcNameConstants.MASTER_SYS, masterArgs.host, @@ -269,12 +271,12 @@ private[celeborn] class Master( } def executeWithLeaderChecker[T](context: RpcCallContext, f: => T): Unit = - if (HAHelper.checkShouldProcess(context, statusSystem)) { + if (HAHelper.checkShouldProcess(context, statusSystem, bindPreferIP)) { try { f } catch { case e: Exception => - HAHelper.sendFailure(context, HAHelper.getRatisServer(statusSystem), e) + HAHelper.sendFailure(context, HAHelper.getRatisServer(statusSystem), e, bindPreferIP) } } diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala index a1ac2b67ea9..19a27ad133c 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala @@ -39,6 +39,8 @@ case class MasterNode( def rpcEndpoint: String = rpcHost + ":" + rpcPort + def rpcIpEndpoint: String = rpcAddr.getAddress.getHostAddress + ":" + rpcPort + lazy val ratisAddr = MasterNode.createSocketAddr(ratisHost, ratisPort) lazy val rpcAddr = MasterNode.createSocketAddr(rpcHost, rpcPort) diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java index f16d27bdf64..32ec1778059 100644 --- a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java +++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java @@ -25,6 +25,8 @@ import java.util.*; import java.util.concurrent.atomic.AtomicLong; +import scala.Tuple2; + import org.junit.*; import org.mockito.Mockito; @@ -154,10 +156,35 @@ public static void resetRaftServer() throws IOException, InterruptedException { } @Test - public void testLeaderAvaiable() { + public void testLeaderAvailable() { boolean hasLeader = RATISSERVER1.isLeader() || RATISSERVER2.isLeader() || RATISSERVER3.isLeader(); Assert.assertTrue(hasLeader); + + // Check if the rpc endpoint of the leader is as expected. + + HARaftServer leader = + RATISSERVER1.isLeader() + ? RATISSERVER1 + : (RATISSERVER2.isLeader() ? RATISSERVER2 : RATISSERVER3); + // one of them must be the follower given the three servers we have + HARaftServer follower = RATISSERVER1.isLeader() ? RATISSERVER2 : RATISSERVER1; + + // This is expected to be false, but as a side effect, updates getCachedLeaderPeerRpcEndpoint + boolean isFollowerCurrentLeader = follower.isLeader(); + Assert.assertFalse(isFollowerCurrentLeader); + + Optional cachedLeaderPeerRpcEndpoint = + follower.getCachedLeaderPeerRpcEndpoint(); + + Assert.assertTrue(cachedLeaderPeerRpcEndpoint.isPresent()); + + Tuple2 rpcEndpointsPair = cachedLeaderPeerRpcEndpoint.get().rpcEndpoints; + + // rpc endpoint may use custom host name then this ut need check ever ip/host + Assert.assertTrue( + leader.getRpcEndpoint().equals(rpcEndpointsPair._1) + || leader.getRpcEndpoint().equals(rpcEndpointsPair._2)); } private static final String HOSTNAME1 = "host1";