Skip to content

Commit

Permalink
Speed up SkyframeExecutor#collectActionLookupValues by increasing p…
Browse files Browse the repository at this point in the history
…arallelism.

Make `Sharder` and `ActionLookupValueTraversal` thread-safe. For non-incremental builds, use `InMemoryGraph#parallelForEach`. For incremental builds, accumulate values during the parallel graph traversal instead of afterwards.

PiperOrigin-RevId: 669420833
Change-Id: I838bf8a375eda6e286317ac43006659e2411084d
  • Loading branch information
justinhorvitz authored and copybara-github committed Aug 30, 2024
1 parent f5db726 commit e3795f9
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 96 deletions.
33 changes: 22 additions & 11 deletions src/main/java/com/google/devtools/build/lib/concurrent/Sharder.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
package com.google.devtools.build.lib.concurrent;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.Collections;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

/**
* A class to build shards (work queues) for a given task.
Expand All @@ -28,33 +30,42 @@
* @param <T> the type of collection over which we're sharding
*/
public final class Sharder<T> implements Iterable<List<T>> {
private final List<List<T>> shards;
private int nextShard = 0;
private final ImmutableList<List<T>> shards;
private final AtomicInteger count = new AtomicInteger();

public Sharder(int maxNumShards, int expectedTotalSize) {
Preconditions.checkArgument(maxNumShards > 0);
Preconditions.checkArgument(expectedTotalSize >= 0);
this.shards = immutableListOfLists(maxNumShards, expectedTotalSize / maxNumShards);
}

/**
* Adds an item to a shard.
*
* <p>May safely be called concurrently by multiple threads.
*/
@ThreadSafe
public void add(T item) {
shards.get(nextShard).add(item);
nextShard = (nextShard + 1) % shards.size();
int nextShardIndex = count.incrementAndGet() % shards.size();
List<T> shard = shards.get(nextShardIndex);
synchronized (shard) {
shard.add(item);
}
}

/**
* Returns an immutable list of mutable lists.
*
* @param numLists the number of top-level lists.
* @param expectedSize the exepected size of each mutable list.
* @param expectedSize the expected size of each mutable list.
* @return a list of lists.
*/
private static <T> List<List<T>> immutableListOfLists(int numLists, int expectedSize) {
List<List<T>> list = Lists.newArrayListWithCapacity(numLists);
private static <T> ImmutableList<List<T>> immutableListOfLists(int numLists, int expectedSize) {
var outerList = ImmutableList.<List<T>>builderWithExpectedSize(numLists);
for (int i = 0; i < numLists; i++) {
list.add(Lists.<T>newArrayListWithExpectedSize(expectedSize));
outerList.add(new ArrayList<>(expectedSize));
}
return Collections.unmodifiableList(list);
return outerList.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,34 @@

import com.google.devtools.build.lib.actions.ActionLookupKey;
import com.google.devtools.build.lib.actions.ActionLookupValue;
import com.google.devtools.build.lib.analysis.ConfiguredTarget;
import com.google.devtools.build.lib.analysis.ConfiguredTargetValue;
import com.google.devtools.build.lib.analysis.configuredtargets.InputFileConfiguredTarget;
import com.google.devtools.build.lib.analysis.configuredtargets.OutputFileConfiguredTarget;
import com.google.devtools.build.lib.bugreport.BugReport;
import com.google.devtools.build.lib.buildeventstream.BuildEventStreamProtos;
import com.google.devtools.build.lib.concurrent.Sharder;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.skyframe.SkyValue;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;

/** Represents the traversal of the ActionLookupValues in a build. */
public class ActionLookupValuesTraversal {
public final class ActionLookupValuesTraversal {
// Some metrics indicate this is a rough average # of ALVs in a build.
private final Sharder<ActionLookupValue> actionLookupValueShards =
new Sharder<>(NUM_JOBS, /* expectedTotalSize= */ 200_000);

// Metrics.
private int configuredObjectCount = 0;
private int configuredTargetCount = 0;
private int actionCount = 0;
private int actionCountNotIncludingAspects = 0;
private int inputFileConfiguredTargetCount = 0;
private int outputFileConfiguredTargetCount = 0;
private int otherConfiguredTargetCount = 0;

public ActionLookupValuesTraversal() {}
private final AtomicInteger configuredObjectCount = new AtomicInteger();
private final AtomicInteger configuredTargetCount = new AtomicInteger();
private final LongAdder actionCount = new LongAdder();
private final LongAdder actionCountNotIncludingAspects = new LongAdder();
private final AtomicInteger inputFileConfiguredTargetCount = new AtomicInteger();
private final AtomicInteger outputFileConfiguredTargetCount = new AtomicInteger();
private final AtomicInteger otherConfiguredTargetCount = new AtomicInteger();

@ThreadSafe
void accumulate(ActionLookupKey key, SkyValue value) {
boolean isConfiguredTarget = value instanceof ConfiguredTargetValue;
boolean isActionLookupValue = value instanceof ActionLookupValue;
Expand All @@ -65,53 +66,51 @@ void accumulate(ActionLookupKey key, SkyValue value) {
// will show up again under its own key. Avoids double counting by skipping accumulation.
return;
}
configuredObjectCount++;
configuredObjectCount.incrementAndGet();
if (isConfiguredTarget) {
configuredTargetCount++;
configuredTargetCount.incrementAndGet();
}
if (isActionLookupValue) {
ActionLookupValue alv = (ActionLookupValue) value;
int numActions = alv.getNumActions();
actionCount += numActions;
actionCount.add(numActions);
if (isConfiguredTarget) {
actionCountNotIncludingAspects += numActions;
actionCountNotIncludingAspects.add(numActions);
}
actionLookupValueShards.add(alv);
return;
}
if (!(value instanceof NonRuleConfiguredTargetValue)) {
if (!(value instanceof NonRuleConfiguredTargetValue nonRuleVal)) {
BugReport.sendBugReport(
new IllegalStateException(
String.format("Unexpected value type: %s %s %s", value.getClass(), key, value)));
return;
}
ConfiguredTarget configuredTarget =
((NonRuleConfiguredTargetValue) value).getConfiguredTarget();
if (configuredTarget instanceof InputFileConfiguredTarget) {
inputFileConfiguredTargetCount++;
} else if (configuredTarget instanceof OutputFileConfiguredTarget) {
outputFileConfiguredTargetCount++;
} else {
otherConfiguredTargetCount++;
}
AtomicInteger counter =
switch (nonRuleVal.getConfiguredTarget()) {
case InputFileConfiguredTarget input -> inputFileConfiguredTargetCount;
case OutputFileConfiguredTarget output -> outputFileConfiguredTargetCount;
default -> otherConfiguredTargetCount;
};
counter.incrementAndGet();
}

Sharder<ActionLookupValue> getActionLookupValueShards() {
return actionLookupValueShards;
}

int getActionCount() {
return actionCount;
return actionCount.intValue();
}

BuildEventStreamProtos.BuildMetrics.BuildGraphMetrics.Builder getMetrics() {
return BuildEventStreamProtos.BuildMetrics.BuildGraphMetrics.newBuilder()
.setActionLookupValueCount(configuredObjectCount)
.setActionLookupValueCountNotIncludingAspects(configuredTargetCount)
.setActionCount(actionCount)
.setActionCountNotIncludingAspects(actionCountNotIncludingAspects)
.setInputFileConfiguredTargetCount(inputFileConfiguredTargetCount)
.setOutputFileConfiguredTargetCount(outputFileConfiguredTargetCount)
.setOtherConfiguredTargetCount(otherConfiguredTargetCount);
.setActionLookupValueCount(configuredObjectCount.get())
.setActionLookupValueCountNotIncludingAspects(configuredTargetCount.get())
.setActionCount(actionCount.intValue())
.setActionCountNotIncludingAspects(actionCountNotIncludingAspects.intValue())
.setInputFileConfiguredTargetCount(inputFileConfiguredTargetCount.get())
.setOutputFileConfiguredTargetCount(outputFileConfiguredTargetCount.get())
.setOtherConfiguredTargetCount(otherConfiguredTargetCount.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3167,21 +3167,28 @@ final ActionLookupValuesTraversal collectActionLookupValuesInBuild(
Profiler.instance().profile("skyframeExecutor.collectActionLookupValuesInBuild")) {
ActionLookupValuesTraversal alvTraversal = new ActionLookupValuesTraversal();
if (!tracksStateForIncrementality()) {
// If we do not have graph edges, we cannot traverse the graph and find only actions in the
// current build. In this case we can simply return all ActionLookupValues in the graph,
// since the graph's lifetime is a single build anyway.
for (Map.Entry<SkyKey, SkyValue> entry : memoizingEvaluator.getDoneValues().entrySet()) {
if ((entry.getKey() instanceof ActionLookupKey) && entry.getValue() != null) {
alvTraversal.accumulate((ActionLookupKey) entry.getKey(), entry.getValue());
}
}
return alvTraversal;
// For non-incremental builds, do a parallel sweep over the whole graph.
memoizingEvaluator
.getInMemoryGraph()
.parallelForEach(
e -> {
if (!(e.getKey() instanceof ActionLookupKey key) || !e.isDone()) {
return;
}
SkyValue value = e.getValue();
if (value == null) {
return; // Error.
}
alvTraversal.accumulate(key, value);
});
} else {
// When incrementality is enabled, traverse the analysis graph top-down. This is slower, but
// is necessary to avoid collecting nodes that are in the graph from a previous build, but
// unnecessary for this build.
// TODO: jhorvitz - We could use the faster parallel sweep on clean builds.
new TransitiveActionLookupKeysCollector(SkyframeExecutorWrappingWalkableGraph.of(this))
.collect(Iterables.concat(topLevelCtKeys, aspectKeys), alvTraversal);
}

Map<ActionLookupKey, SkyValue> foundActions =
new TransitiveActionLookupKeysCollector(SkyframeExecutorWrappingWalkableGraph.of(this))
.collect(Iterables.concat(topLevelCtKeys, aspectKeys));
foundActions.forEach(alvTraversal::accumulate);
return alvTraversal;
}
}
Expand Down Expand Up @@ -3662,26 +3669,13 @@ private static int getNumberOfModifiedFiles(Iterable<SkyKey> modifiedValues) {
Iterables.filter(modifiedValues, SkyFunctionName.functionIs(FileStateKey.FILE_STATE)));
}

/**
* A sentinel used in {@link TransitiveActionLookupKeysCollector.VisitActionLookupKey#collected}.
*
* <p>Since the traversal is concurrent and {@link ActionLookupKey}s can have many reverse
* dependencies, it's better to short-circuit before recursively creating a subtask. The presence
* of this value indicates that another thread already intends to visit the key.
*/
private static final class ClaimedLookupValueSentinel implements SkyValue {
private static final ClaimedLookupValueSentinel INSTANCE = new ClaimedLookupValueSentinel();

private ClaimedLookupValueSentinel() {}
}

/**
* Collects the {@link ActionLookupKey} transitive closure of given {@link ActionLookupKey}s.
*
* <p>In the non-Skymeld case, this class is constructed and performs one traversal before
* shutdown at the end of analysis.
*/
private static class TransitiveActionLookupKeysCollector {
private static final class TransitiveActionLookupKeysCollector {
private final WalkableGraph walkableGraph;

private TransitiveActionLookupKeysCollector(WalkableGraph walkableGraph) {
Expand All @@ -3692,23 +3686,23 @@ private TransitiveActionLookupKeysCollector(WalkableGraph walkableGraph) {
* Traverses the transitive closure of {@code visitationRoots} and returns an {@link
* ActionLookupKey} keyed map to corresponding values for all visited keys.
*/
private Map<ActionLookupKey, SkyValue> collect(Iterable<ActionLookupKey> visitationRoots)
private void collect(
Iterable<ActionLookupKey> visitationRoots, ActionLookupValuesTraversal alvTraversal)
throws InterruptedException {
ForkJoinPool executorService =
NamedForkJoinPool.newNamedPool(
"find-action-lookup-values-in-build", Runtime.getRuntime().availableProcessors());
var collected = new ConcurrentHashMap<ActionLookupKey, SkyValue>();
var seen = Sets.<ActionLookupKey>newConcurrentHashSet();
List<Future<?>> futures = Lists.newArrayListWithCapacity(Iterables.size(visitationRoots));
for (ActionLookupKey key : visitationRoots) {
if (tryClaimVisitation(key, collected)) {
futures.add(executorService.submit(new VisitActionLookupKey(key, collected)));
if (seen.add(key)) {
futures.add(executorService.submit(new VisitActionLookupKey(key, seen, alvTraversal)));
}
}
try {
for (Future<?> future : futures) {
future.get();
}
return collected;
} catch (ExecutionException e) {
throw new IllegalStateException("Error collecting transitive ActionLookupValues", e);
} finally {
Expand All @@ -3719,25 +3713,18 @@ private Map<ActionLookupKey, SkyValue> collect(Iterable<ActionLookupKey> visitat
}
}

/**
* Attempts to claim ownership of {@code key}'s visitation.
*
* @return false if {@code key} is already included in {@link #globalVisitedSet}, was already
* claimed or has a value.
*/
private boolean tryClaimVisitation(
ActionLookupKey key, ConcurrentHashMap<ActionLookupKey, SkyValue> collected) {
return collected.putIfAbsent(key, ClaimedLookupValueSentinel.INSTANCE) == null;
}

protected final class VisitActionLookupKey extends RecursiveAction {
private final class VisitActionLookupKey extends RecursiveAction {
private final ActionLookupKey key;
private final ConcurrentHashMap<ActionLookupKey, SkyValue> collected;
private final Set<ActionLookupKey> seen;
private final ActionLookupValuesTraversal alvTraversal;

private VisitActionLookupKey(
ActionLookupKey key, ConcurrentHashMap<ActionLookupKey, SkyValue> collected) {
ActionLookupKey key,
Set<ActionLookupKey> seen,
ActionLookupValuesTraversal alvTraversal) {
this.key = key;
this.collected = collected;
this.seen = seen;
this.alvTraversal = alvTraversal;
}

@Override
Expand All @@ -3749,10 +3736,11 @@ public void compute() {
Thread.currentThread().interrupt();
}
if (value == null) { // The value failed to evaluate.
collected.remove(key);
return;
}
collected.put(key, value);

alvTraversal.accumulate(key, value);

Iterable<SkyKey> directDeps;
try {
directDeps = walkableGraph.getDirectDeps(key);
Expand All @@ -3772,8 +3760,8 @@ public void compute() {
if (!(dep instanceof ActionLookupKey depKey)) {
continue;
}
if (tryClaimVisitation(depKey, collected)) {
subtasks.add(new VisitActionLookupKey(depKey, collected));
if (seen.add(depKey)) {
subtasks.add(new VisitActionLookupKey(depKey, seen, alvTraversal));
}
}
invokeAll(subtasks);
Expand Down

0 comments on commit e3795f9

Please sign in to comment.