diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index d5d8c4d9e45..c2f06c8046d 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -350,7 +350,6 @@ public boolean isEquivalentTo(RoundRobinPicker picker) { @VisibleForTesting static final class StaticStrideScheduler { private final short[] scaledWeights; - private final int sizeDivisor; private final AtomicInteger sequence; private static final int K_MAX_WEIGHT = 0xFFFF; @@ -373,7 +372,7 @@ static final class StaticStrideScheduler { if (numWeightedChannels > 0) { meanWeight = (short) Math.round(scalingFactor * sumWeight / numWeightedChannels); } else { - meanWeight = 1; + meanWeight = (short) Math.round(scalingFactor); } // scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly @@ -387,7 +386,6 @@ static final class StaticStrideScheduler { } this.scaledWeights = scaledWeights; - this.sizeDivisor = numChannels; this.sequence = new AtomicInteger(random.nextInt()); } @@ -433,15 +431,18 @@ long getSequence() { * an offset that varies per backend index is also included to the calculation. */ int pick() { + int i = 0; while (true) { + i++; long sequence = this.nextSequence(); - int backendIndex = (int) (sequence % this.sizeDivisor); - long generation = sequence / this.sizeDivisor; - int weight = Short.toUnsignedInt(this.scaledWeights[backendIndex]); + int backendIndex = (int) (sequence % scaledWeights.length); + long generation = sequence / scaledWeights.length; + int weight = Short.toUnsignedInt(scaledWeights[backendIndex]); long offset = (long) K_MAX_WEIGHT / 2 * backendIndex; if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) { continue; } + assert i <= scaledWeights.length : "scheduler has more than one pass through"; return backendIndex; } } diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 879cac871b5..505b2f7dff6 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -330,11 +330,11 @@ weightedSubchannel3.new OrcaReportListener(weightedConfig.errorUtilizationPenalt } assertThat(pickCount.size()).isEqualTo(3); assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio)) - .isAtMost(0.001); + .isAtMost(0.0001); assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio )) - .isAtMost(0.001); + .isAtMost(0.0001); assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio )) - .isAtMost(0.001); + .isAtMost(0.0001); } @Test @@ -751,12 +751,12 @@ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalt } assertThat(pickCount.size()).isEqualTo(3); assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9)) - .isAtMost(0.002); + .isAtMost(0.001); assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9)) - .isAtMost(0.002); + .isAtMost(0.001); // subchannel3's weight is average of subchannel1 and subchannel2 assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9)) - .isAtMost(0.002); + .isAtMost(0.001); } @Test @@ -947,7 +947,7 @@ public void testStaticStrideSchedulerNonIntegers1() { } for (int i = 0; i < 3; i++) { assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) - .isAtMost(0.01); + .isAtMost(0.001); } } @@ -964,7 +964,7 @@ public void testStaticStrideSchedulerNonIntegers2() { } for (int i = 0; i < 3; i++) { assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) - .isAtMost(0.01); + .isAtMost(0.001); } } @@ -981,7 +981,7 @@ public void testTwoWeights() { } for (int i = 0; i < 2; i++) { assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) - .isAtMost(0.01); + .isAtMost(0.001); } } @@ -1015,7 +1015,7 @@ public void testManyComplexWeights() { } for (int i = 0; i < 8; i++) { assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) - .isAtMost(0.01); + .isAtMost(0.004); } }