diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java index b2b2ab45e1b2b7..14c3cc8c59f056 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java @@ -23,6 +23,7 @@ import build.bazel.remote.execution.v2.Digest; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.flogger.GoogleLogger; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; @@ -31,7 +32,6 @@ import com.google.common.util.concurrent.SettableFuture; import com.google.devtools.build.lib.concurrent.ThreadSafety; import com.google.devtools.build.lib.exec.SpawnProgressEvent; -import com.google.devtools.build.lib.exec.SpawnRunner.SpawnExecutionContext; import com.google.devtools.build.lib.remote.common.LazyFileOutputStream; import com.google.devtools.build.lib.remote.common.OutputDigestMismatchException; import com.google.devtools.build.lib.remote.common.ProgressStatusListener; @@ -40,7 +40,9 @@ import com.google.devtools.build.lib.remote.common.RemoteCacheClient.ActionKey; import com.google.devtools.build.lib.remote.common.RemoteCacheClient.CachedActionResult; import com.google.devtools.build.lib.remote.options.RemoteOptions; +import com.google.devtools.build.lib.remote.util.AsyncTaskCache; import com.google.devtools.build.lib.remote.util.DigestUtil; +import com.google.devtools.build.lib.remote.util.RxFutures; import com.google.devtools.build.lib.server.FailureDetails.FailureDetail; import com.google.devtools.build.lib.server.FailureDetails.RemoteExecution; import com.google.devtools.build.lib.server.FailureDetails.RemoteExecution.Code; @@ -48,6 +50,7 @@ import com.google.devtools.build.lib.vfs.FileSystemUtils; import com.google.devtools.build.lib.vfs.Path; import com.google.protobuf.ByteString; +import io.reactivex.rxjava3.core.Completable; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; @@ -62,15 +65,11 @@ public class RemoteCache implements AutoCloseable { private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); - /** See {@link SpawnExecutionContext#lockOutputFiles()}. */ - @FunctionalInterface - interface OutputFilesLocker { - void lock() throws InterruptedException; - } - private static final ListenableFuture COMPLETED_SUCCESS = immediateFuture(null); private static final ListenableFuture EMPTY_BYTES = immediateFuture(new byte[0]); + protected final AsyncTaskCache.NoResult casUploadCache = AsyncTaskCache.NoResult.create(); + protected final RemoteCacheClient cacheProtocol; protected final RemoteOptions options; protected final DigestUtil digestUtil; @@ -88,11 +87,19 @@ public CachedActionResult downloadActionResult( return getFromFuture(cacheProtocol.downloadActionResult(context, actionKey, inlineOutErr)); } + /** + * Returns a set of digests that the remote cache does not know about. The returned set is + * guaranteed to be a subset of {@code digests}. + */ public ListenableFuture> findMissingDigests( RemoteActionExecutionContext context, Iterable digests) { + if (Iterables.isEmpty(digests)) { + return Futures.immediateFuture(ImmutableSet.of()); + } return cacheProtocol.findMissingDigests(context, digests); } + /** Upload the action result to the remote cache. */ public ListenableFuture uploadActionResult( RemoteActionExecutionContext context, ActionKey actionKey, ActionResult actionResult) { return cacheProtocol.uploadActionResult(context, actionKey, actionResult); @@ -101,6 +108,9 @@ public ListenableFuture uploadActionResult( /** * Upload a local file to the remote cache. * + *

Trying to upload the same file multiple times concurrently, results in only one upload being + * performed. + * * @param context the context for the action. * @param digest the digest of the file. * @param file the file to upload. @@ -111,12 +121,20 @@ public final ListenableFuture uploadFile( return COMPLETED_SUCCESS; } - return cacheProtocol.uploadFile(context, digest, file); + Completable upload = + casUploadCache.executeIfNot( + digest, + RxFutures.toCompletable( + () -> cacheProtocol.uploadFile(context, digest, file), directExecutor())); + return RxFutures.toListenableFuture(upload); } /** * Upload sequence of bytes to the remote cache. * + *

Trying to upload the same BLOB multiple times concurrently, results in only one upload being + * performed. + * * @param context the context for the action. * @param digest the digest of the file. * @param data the BLOB to upload. @@ -127,7 +145,12 @@ public final ListenableFuture uploadBlob( return COMPLETED_SUCCESS; } - return cacheProtocol.uploadBlob(context, digest, data); + Completable upload = + casUploadCache.executeIfNot( + digest, + RxFutures.toCompletable( + () -> cacheProtocol.uploadBlob(context, digest, data), directExecutor())); + return RxFutures.toListenableFuture(upload); } public static void waitForBulkTransfer( diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java index c0db565594bd55..4a4135c34d6131 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java @@ -19,20 +19,24 @@ import build.bazel.remote.execution.v2.Digest; import build.bazel.remote.execution.v2.Directory; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; import com.google.devtools.build.lib.remote.common.RemoteCacheClient; import com.google.devtools.build.lib.remote.merkletree.MerkleTree; import com.google.devtools.build.lib.remote.merkletree.MerkleTree.PathOrBytes; import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.build.lib.remote.util.DigestUtil; +import com.google.devtools.build.lib.remote.util.RxFutures; import com.google.protobuf.Message; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.subjects.AsyncSubject; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** A {@link RemoteCache} with additional functionality needed for remote execution. */ public class RemoteExecutionCache extends RemoteCache { @@ -57,16 +61,56 @@ public RemoteExecutionCache( public void ensureInputsPresent( RemoteActionExecutionContext context, MerkleTree merkleTree, - Map additionalInputs) + Map additionalInputs, + boolean force) throws IOException, InterruptedException { - Iterable allDigests = - Iterables.concat(merkleTree.getAllDigests(), additionalInputs.keySet()); - ImmutableSet missingDigests = - getFromFuture(cacheProtocol.findMissingDigests(context, allDigests)); + ImmutableSet allDigests = + ImmutableSet.builder() + .addAll(merkleTree.getAllDigests()) + .addAll(additionalInputs.keySet()) + .build(); + + // Collect digests that are not being or already uploaded + ConcurrentHashMap> missingDigestSubjects = + new ConcurrentHashMap<>(); List> uploadFutures = new ArrayList<>(); - for (Digest missingDigest : missingDigests) { - uploadFutures.add(uploadBlob(context, missingDigest, merkleTree, additionalInputs)); + for (Digest digest : allDigests) { + Completable upload = + casUploadCache.execute( + digest, + Completable.defer( + () -> { + // The digest hasn't been processed, add it to the collection which will be used + // later for findMissingDigests call + AsyncSubject missingDigestSubject = AsyncSubject.create(); + missingDigestSubjects.put(digest, missingDigestSubject); + + return missingDigestSubject.flatMapCompletable( + missing -> { + if (!missing) { + return Completable.complete(); + } + return RxFutures.toCompletable( + () -> uploadBlob(context, digest, merkleTree, additionalInputs), + MoreExecutors.directExecutor()); + }); + }), + force); + uploadFutures.add(RxFutures.toListenableFuture(upload)); + } + + ImmutableSet missingDigests = + getFromFuture(findMissingDigests(context, missingDigestSubjects.keySet())); + for (Map.Entry> entry : missingDigestSubjects.entrySet()) { + AsyncSubject missingSubject = entry.getValue(); + if (missingDigests.contains(entry.getKey())) { + missingSubject.onNext(true); + } else { + // The digest is already existed in the remote cache, skip the upload. + missingSubject.onNext(false); + } + missingSubject.onComplete(); } waitForBulkTransfer(uploadFutures, /* cancelRemainingOnInterrupt=*/ false); diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java index 36659dc46896e7..32adb2612e1755 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java @@ -1039,7 +1039,7 @@ public void uploadOutputs(RemoteAction action, SpawnResult spawnResult) * *

Must be called before calling {@link #executeRemotely}. */ - public void uploadInputsIfNotPresent(RemoteAction action) + public void uploadInputsIfNotPresent(RemoteAction action, boolean force) throws IOException, InterruptedException { checkState(mayBeExecutedRemotely(action.spawn), "spawn can't be executed remotely"); @@ -1049,7 +1049,7 @@ public void uploadInputsIfNotPresent(RemoteAction action) additionalInputs.put(action.actionKey.getDigest(), action.action); additionalInputs.put(action.commandHash, action.command); remoteExecutionCache.ensureInputsPresent( - action.remoteActionExecutionContext, action.merkleTree, additionalInputs); + action.remoteActionExecutionContext, action.merkleTree, additionalInputs, force); } /** diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteRepositoryRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteRepositoryRemoteExecutor.java index 546ec79e50fb8f..646a4dc9ebf060 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteRepositoryRemoteExecutor.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteRepositoryRemoteExecutor.java @@ -152,7 +152,7 @@ public ExecutionResult execute( additionalInputs.put(actionDigest, action); additionalInputs.put(commandHash, command); - remoteCache.ensureInputsPresent(context, merkleTree, additionalInputs); + remoteCache.ensureInputsPresent(context, merkleTree, additionalInputs, /*force=*/ true); } try (SilentCloseable c = diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteSpawnRunner.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteSpawnRunner.java index 8fc35bc3c57c3f..fe744480f702c9 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteSpawnRunner.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteSpawnRunner.java @@ -233,6 +233,7 @@ public SpawnResult exec(Spawn spawn, SpawnExecutionContext context) } AtomicBoolean useCachedResult = new AtomicBoolean(acceptCachedResult); + AtomicBoolean forceUploadInput = new AtomicBoolean(false); try { return retrier.execute( () -> { @@ -240,7 +241,10 @@ public SpawnResult exec(Spawn spawn, SpawnExecutionContext context) try (SilentCloseable c = prof.profile(UPLOAD_TIME, "upload missing inputs")) { Duration networkTimeStart = action.getNetworkTime().getDuration(); Stopwatch uploadTime = Stopwatch.createStarted(); - remoteExecutionService.uploadInputsIfNotPresent(action); + // Upon retry, we force upload inputs + remoteExecutionService.uploadInputsIfNotPresent( + action, forceUploadInput.getAndSet(true)); + // subtract network time consumed here to ensure wall clock during upload is not // double // counted, and metrics time computation does not exceed total time diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java index 06d9951d855176..b719889e9768d0 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java @@ -324,7 +324,7 @@ public void onError(Throwable t) { }); // Upload all missing inputs (that is, the virtual action input from above) - client.ensureInputsPresent(context, merkleTree, ImmutableMap.of()); + client.ensureInputsPresent(context, merkleTree, ImmutableMap.of(), /*force=*/ true); } @Test diff --git a/src/test/java/com/google/devtools/build/lib/remote/InMemoryRemoteCache.java b/src/test/java/com/google/devtools/build/lib/remote/InMemoryRemoteCache.java index 8b809252a07ead..88f3160f8be3a7 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/InMemoryRemoteCache.java +++ b/src/test/java/com/google/devtools/build/lib/remote/InMemoryRemoteCache.java @@ -26,7 +26,7 @@ import java.io.IOException; import java.util.Map; -class InMemoryRemoteCache extends RemoteCache { +class InMemoryRemoteCache extends RemoteExecutionCache { InMemoryRemoteCache( Map casEntries, RemoteOptions options, DigestUtil digestUtil) { @@ -74,6 +74,10 @@ int getNumFailedDownloads() { return ((InMemoryCacheClient) cacheProtocol).getNumFailedDownloads(); } + Map getNumFindMissingDigests() { + return ((InMemoryCacheClient) cacheProtocol).getNumFindMissingDigests(); + } + @Override public void close() { cacheProtocol.close(); diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java index db11adca9d5900..4ba4c423dfed38 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java @@ -61,6 +61,7 @@ import com.google.devtools.build.lib.actions.cache.MetadataInjector; import com.google.devtools.build.lib.actions.util.ActionsTestUtil; import com.google.devtools.build.lib.clock.JavaClock; +import com.google.devtools.build.lib.collect.nestedset.NestedSet; import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder; import com.google.devtools.build.lib.collect.nestedset.Order; import com.google.devtools.build.lib.exec.util.FakeOwner; @@ -68,6 +69,7 @@ import com.google.devtools.build.lib.remote.RemoteExecutionService.RemoteActionResult; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; import com.google.devtools.build.lib.remote.common.RemoteCacheClient.CachedActionResult; +import com.google.devtools.build.lib.remote.common.RemoteExecutionClient; import com.google.devtools.build.lib.remote.common.RemotePathResolver; import com.google.devtools.build.lib.remote.common.RemotePathResolver.DefaultRemotePathResolver; import com.google.devtools.build.lib.remote.common.RemotePathResolver.SiblingRepositoryLayoutResolver; @@ -75,6 +77,7 @@ import com.google.devtools.build.lib.remote.options.RemoteOutputsMode; import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.remote.util.FakeSpawnExecutionContext; +import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule; import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; import com.google.devtools.build.lib.remote.util.Utils.InMemoryOutput; import com.google.devtools.build.lib.skyframe.TreeArtifactValue; @@ -91,7 +94,12 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collection; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -99,6 +107,8 @@ /** Tests for {@link RemoteExecutionService}. */ @RunWith(JUnit4.class) public class RemoteExecutionServiceTest { + @Rule public final RxNoGlobalErrorsRule rxNoGlobalErrorsRule = new RxNoGlobalErrorsRule(); + private final DigestUtil digestUtil = new DigestUtil(DigestHashFunction.SHA256); RemoteOptions remoteOptions; @@ -108,6 +118,7 @@ public class RemoteExecutionServiceTest { private RemotePathResolver remotePathResolver; private FileOutErr outErr; private InMemoryRemoteCache cache; + private RemoteExecutionClient executor; private RemoteActionExecutionContext remoteActionExecutionContext; @Before @@ -130,6 +141,7 @@ public final void setUp() throws Exception { outErr = new FileOutErr(stdout, stderr); cache = new InMemoryRemoteCache(remoteOptions, digestUtil); + executor = mock(RemoteExecutionClient.class); RequestMetadata metadata = TracingMetadataUtils.buildMetadata("none", "none", "action-id", null); @@ -1258,6 +1270,48 @@ public void uploadOutputs_emptyOutputs_doNotPerformUpload() throws Exception { .containsExactly(emptyDigest); } + @Test + public void uploadInputsIfNotPresent_deduplicateFindMissingBlobCalls() throws Exception { + int taskCount = 100; + ExecutorService executorService = Executors.newFixedThreadPool(taskCount); + AtomicReference error = new AtomicReference<>(null); + Semaphore semaphore = new Semaphore(0); + ActionInput input = ActionInputHelper.fromPath("inputs/foo"); + Digest inputDigest = fakeFileCache.createScratchInput(input, "input-foo"); + RemoteExecutionService service = newRemoteExecutionService(); + + for (int i = 0; i < taskCount; ++i) { + executorService.execute( + () -> { + try { + Spawn spawn = + newSpawn( + ImmutableMap.of(), + ImmutableSet.of(), + NestedSetBuilder.create(Order.STABLE_ORDER, input)); + FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); + RemoteAction action = service.buildRemoteAction(spawn, context); + + service.uploadInputsIfNotPresent(action, /*force=*/ false); + } catch (Throwable e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + error.set(e); + } finally { + semaphore.release(); + } + }); + } + semaphore.acquire(taskCount); + + assertThat(error.get()).isNull(); + assertThat(cache.getNumFindMissingDigests()).containsEntry(inputDigest, 1); + for (Integer num : cache.getNumFindMissingDigests().values()) { + assertThat(num).isEqualTo(1); + } + } + private Spawn newSpawnFromResult(RemoteActionResult result) { return newSpawnFromResult(ImmutableMap.of(), result); } @@ -1304,12 +1358,19 @@ private Spawn newSpawnFromResultWithInMemoryOutput( private Spawn newSpawn( ImmutableMap executionInfo, ImmutableSet outputs) { + return newSpawn(executionInfo, outputs, NestedSetBuilder.emptySet(Order.STABLE_ORDER)); + } + + private Spawn newSpawn( + ImmutableMap executionInfo, + ImmutableSet outputs, + NestedSet inputs) { return new SimpleSpawn( new FakeOwner("foo", "bar", "//dummy:label"), /*arguments=*/ ImmutableList.of(), /*environment=*/ ImmutableMap.of(), /*executionInfo=*/ executionInfo, - /*inputs=*/ NestedSetBuilder.emptySet(Order.STABLE_ORDER), + /*inputs=*/ inputs, /*outputs=*/ outputs, ResourceSet.ZERO); } @@ -1346,7 +1407,7 @@ private RemoteExecutionService newRemoteExecutionService( digestUtil, remoteOptions, cache, - null, + executor, ImmutableSet.copyOf(topLevelOutputs), null); } diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnRunnerTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnRunnerTest.java index eec46d990802b6..c33b522401128a 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnRunnerTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnRunnerTest.java @@ -17,6 +17,7 @@ import static java.nio.charset.StandardCharsets.ISO_8859_1; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; @@ -259,7 +260,7 @@ public void nonCachableSpawnsShouldNotBeCached_localFallback() throws Exception runner.exec(spawn, policy); verify(localRunner).exec(spawn, policy); - verify(cache).ensureInputsPresent(any(), any(), any()); + verify(cache).ensureInputsPresent(any(), any(), any(), anyBoolean()); verifyNoMoreInteractions(cache); } diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java index 065c6e3ab4c044..c26629f2b3cf14 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java +++ b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java @@ -27,10 +27,12 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.AbstractMap.SimpleEntry; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; /** A {@link RemoteCacheClient} that stores its contents in memory. */ public final class InMemoryCacheClient implements RemoteCacheClient { @@ -41,6 +43,8 @@ public final class InMemoryCacheClient implements RemoteCacheClient { private AtomicInteger numSuccess = new AtomicInteger(); private AtomicInteger numFailures = new AtomicInteger(); + private final ConcurrentMap numFindMissingDigests = + new ConcurrentHashMap<>(); public InMemoryCacheClient(Map casEntries) { this.cas = new ConcurrentHashMap<>(); @@ -65,6 +69,12 @@ public int getNumFailedDownloads() { return numFailures.get(); } + public Map getNumFindMissingDigests() { + return numFindMissingDigests.entrySet().stream() + .map(entry -> new SimpleEntry<>(entry.getKey(), entry.getValue().get())) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + @Override public ListenableFuture downloadBlob( RemoteActionExecutionContext context, Digest digest, OutputStream out) { @@ -134,6 +144,9 @@ public ListenableFuture> findMissingDigests( RemoteActionExecutionContext context, Iterable digests) { ImmutableSet.Builder missingBuilder = ImmutableSet.builder(); for (Digest digest : digests) { + numFindMissingDigests + .computeIfAbsent(digest, (key) -> new AtomicInteger(0)) + .incrementAndGet(); if (!cas.containsKey(digest)) { missingBuilder.add(digest); }