diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java new file mode 100644 index 00000000000000..de26b8db8b2107 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java @@ -0,0 +1,178 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.util; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.core.Single; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +/** + * A cache which de-duplicates the executions and stores the results of asynchronous tasks. Each + * task is identified by a key of type {@link KeyT} and has the result of type {@link ValueT}. + * + *

Use {@link #executeIfNot} or {@link #execute} and subscribe the returned {@link Single} to + * start executing a task. The {@link Single} turns to completed once the task is {@code finished}. + * Errors are propagated if any. + * + *

Calling {@code execute[IfNot]} multiple times with the same task key can get an {@link Single} + * which connects to the same underlying execution if the task is still executing, or get a + * completed {@link Single} if the task is already finished. Set {@code force} to {@code true } to + * re-execute a finished task. + * + *

Dispose the {@link Single} to cancel to task execution. + */ +@ThreadSafe +public final class AsyncTaskCache { + @GuardedBy("this") + private final Map finished; + + @GuardedBy("this") + private final Map> inProgress; + + public static AsyncTaskCache create() { + return new AsyncTaskCache<>(); + } + + private AsyncTaskCache() { + this.finished = new HashMap<>(); + this.inProgress = new HashMap<>(); + } + + /** Returns a set of keys for tasks which is finished. */ + public ImmutableSet getFinishedTasks() { + synchronized (this) { + return ImmutableSet.copyOf(finished.keySet()); + } + } + + /** Returns a set of keys for tasks which is still executing. */ + public ImmutableSet getInProgressTasks() { + synchronized (this) { + return ImmutableSet.copyOf(inProgress.keySet()); + } + } + + /** + * Executes a task if it hasn't been executed. + * + * @param key identifies the task. + * @return a {@link Single} which turns to completed once the task is finished or propagates the + * error if any. + */ + public Single executeIfNot(KeyT key, Single task) { + return execute(key, task, false); + } + + /** + * Executes a task. + * + * @param key identifies the task. + * @param force re-execute a finished task if set to {@code true}. + * @return a {@link Single} which turns to completed once the task is finished or propagates the + * error if any. + */ + public Single execute(KeyT key, Single task, boolean force) { + return Single.defer( + () -> { + synchronized (this) { + if (!force && finished.containsKey(key)) { + return Single.just(finished.get(key)); + } + + finished.remove(key); + + Observable execution = + inProgress.computeIfAbsent( + key, + missingKey -> { + AtomicInteger subscribeTimes = new AtomicInteger(0); + return Single.defer( + () -> { + int times = subscribeTimes.incrementAndGet(); + Preconditions.checkState( + times == 1, "Subscribed more than once to the task"); + return task; + }) + .doOnSuccess( + value -> { + synchronized (this) { + finished.put(key, value); + inProgress.remove(key); + } + }) + .doOnError( + error -> { + synchronized (this) { + inProgress.remove(key); + } + }) + .doOnDispose( + () -> { + synchronized (this) { + inProgress.remove(key); + } + }) + .toObservable() + .publish() + .refCount(); + }); + + return Single.fromObservable(execution); + } + }); + } + + /** An {@link AsyncTaskCache} without result. */ + public static final class NoResult { + private final AsyncTaskCache> cache; + + public static AsyncTaskCache.NoResult create() { + return new AsyncTaskCache.NoResult<>(AsyncTaskCache.create()); + } + + public NoResult(AsyncTaskCache> cache) { + this.cache = cache; + } + + /** Same as {@link AsyncTaskCache#executeIfNot} but operates on {@link Completable}. */ + public Completable executeIfNot(KeyT key, Completable task) { + return Completable.fromSingle( + cache.executeIfNot(key, task.toSingleDefault(Optional.empty()))); + } + + /** Same as {@link AsyncTaskCache#executeIfNot} but operates on {@link Completable}. */ + public Completable execute(KeyT key, Completable task, boolean force) { + return Completable.fromSingle( + cache.execute(key, task.toSingleDefault(Optional.empty()), force)); + } + + /** Returns a set of keys for tasks which is finished. */ + public ImmutableSet getFinishedTasks() { + return cache.getFinishedTasks(); + } + + /** Returns a set of keys for tasks which is still executing. */ + public ImmutableSet getInProgressTasks() { + return cache.getInProgressTasks(); + } + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java new file mode 100644 index 00000000000000..9e2d641f1edc1b --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java @@ -0,0 +1,299 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.util; + +import static com.google.common.truth.Truth.assertThat; + +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.SingleEmitter; +import io.reactivex.rxjava3.observers.TestObserver; +import io.reactivex.rxjava3.plugins.RxJavaPlugins; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link AsyncTaskCache}. */ +@RunWith(JUnit4.class) +public class AsyncTaskCacheTest { + + private final AtomicReference rxGlobalThrowable = new AtomicReference<>(null); + + @Before + public void setUp() { + RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set); + } + + @After + public void tearDown() throws Throwable { + // Make sure rxjava didn't receive global errors + Throwable t = rxGlobalThrowable.getAndSet(null); + if (t != null) { + throw t; + } + } + + @Test + public void execute_noSubscription_noExecution() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicBoolean executed = new AtomicBoolean(false); + + cache.executeIfNot( + "key1", + Single.create( + emitter -> { + executed.set(true); + emitter.onSuccess("value1"); + })); + + assertThat(executed.get()).isFalse(); + assertThat(cache.getInProgressTasks()).isEmpty(); + assertThat(cache.getFinishedTasks()).isEmpty(); + } + + @Test + public void execute_taskFinished_completed() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + TestObserver observer = + cache.executeIfNot("key1", Single.create(emitterRef::set)).test(); + SingleEmitter emitter = emitterRef.get(); + assertThat(emitter).isNotNull(); + + emitter.onSuccess("value1"); + + observer.assertValue("value1"); + assertThat(cache.getInProgressTasks()).isEmpty(); + assertThat(cache.getFinishedTasks()).containsExactly("key1"); + } + + @Test + public void execute_taskHasError_propagateError() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + TestObserver observer = + cache.executeIfNot("key1", Single.create(emitterRef::set)).test(); + SingleEmitter emitter = emitterRef.get(); + assertThat(emitter).isNotNull(); + Throwable error = new IllegalStateException("error"); + + emitter.onError(error); + + observer.assertError(error); + assertThat(cache.getInProgressTasks()).isEmpty(); + assertThat(cache.getFinishedTasks()).isEmpty(); + } + + @Test + public void execute_taskInProgress_noReExecution() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + AtomicInteger executionTimes = new AtomicInteger(0); + Single single = + cache.executeIfNot( + "key1", + Single.create( + emitter -> { + executionTimes.incrementAndGet(); + emitterRef.set(emitter); + })); + TestObserver ob1 = single.test(); + ob1.assertEmpty(); + SingleEmitter emitter = emitterRef.get(); + assertThat(emitter).isNotNull(); + assertThat(cache.getInProgressTasks()).containsExactly("key1"); + assertThat(cache.getFinishedTasks()).isEmpty(); + + TestObserver ob2 = single.test(); + ob2.assertEmpty(); + emitter.onSuccess("value1"); + + ob1.assertValue("value1"); + ob2.assertValue("value1"); + assertThat(executionTimes.get()).isEqualTo(1); + assertThat(cache.getInProgressTasks()).isEmpty(); + assertThat(cache.getFinishedTasks()).containsExactly("key1"); + } + + @Test + public void executeForcibly_taskInProgress_noReExecution() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + AtomicInteger executionTimes = new AtomicInteger(0); + Single single = + cache.execute( + "key1", + Single.create( + emitter -> { + executionTimes.incrementAndGet(); + emitterRef.set(emitter); + }), + /* force= */ true); + TestObserver ob1 = single.test(); + ob1.assertEmpty(); + SingleEmitter emitter = emitterRef.get(); + assertThat(emitter).isNotNull(); + assertThat(cache.getInProgressTasks()).containsExactly("key1"); + assertThat(cache.getFinishedTasks()).isEmpty(); + + TestObserver ob2 = single.test(); + ob2.assertEmpty(); + emitter.onSuccess("value1"); + + ob1.assertValue("value1"); + ob2.assertValue("value1"); + assertThat(executionTimes.get()).isEqualTo(1); + assertThat(cache.getInProgressTasks()).isEmpty(); + assertThat(cache.getFinishedTasks()).containsExactly("key1"); + } + + @Test + public void execute_taskFinished_noReExecution() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + AtomicInteger executionTimes = new AtomicInteger(0); + Single single = + cache.executeIfNot( + "key1", + Single.create( + emitter -> { + executionTimes.incrementAndGet(); + emitterRef.set(emitter); + })); + TestObserver ob1 = single.test(); + SingleEmitter emitter = emitterRef.get(); + assertThat(emitter).isNotNull(); + emitter.onSuccess("value1"); + ob1.assertValue("value1"); + assertThat(cache.getFinishedTasks()).containsExactly("key1"); + + TestObserver ob2 = single.test(); + + ob2.assertValue("value1"); + assertThat(executionTimes.get()).isEqualTo(1); + } + + @Test + public void executeForcibly_taskFinished_reExecution() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + AtomicInteger executionTimes = new AtomicInteger(0); + Single single = + cache.execute( + "key1", + Single.create( + emitter -> { + executionTimes.incrementAndGet(); + emitterRef.set(emitter); + }), + /* force= */ true); + TestObserver ob1 = single.test(); + SingleEmitter emitter = emitterRef.get(); + assertThat(emitter).isNotNull(); + emitter.onSuccess("value1"); + ob1.assertValue("value1"); + assertThat(cache.getFinishedTasks()).containsExactly("key1"); + + TestObserver ob2 = single.test(); + + ob2.assertEmpty(); + assertThat(executionTimes.get()).isEqualTo(2); + assertThat(cache.getInProgressTasks()).containsExactly("key1"); + assertThat(cache.getFinishedTasks()).isEmpty(); + } + + @Test + public void execute_dispose_cancelled() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + TestObserver observer = + cache.executeIfNot("key1", Single.create(emitterRef::set)).test(); + SingleEmitter emitter = emitterRef.get(); + assertThat(emitter).isNotNull(); + AtomicBoolean disposed = new AtomicBoolean(false); + emitter.setCancellable(() -> disposed.set(true)); + + observer.dispose(); + + assertThat(disposed.get()).isTrue(); + assertThat(cache.getInProgressTasks()).isEmpty(); + assertThat(cache.getFinishedTasks()).isEmpty(); + } + + @Test + public void execute_disposeWhenMultipleSubscriptions_notCancelled() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + Single single = cache.executeIfNot("key1", Single.create(emitterRef::set)); + TestObserver ob1 = single.test(); + TestObserver ob2 = single.test(); + SingleEmitter emitter = emitterRef.get(); + assertThat(emitter).isNotNull(); + AtomicBoolean disposed = new AtomicBoolean(false); + emitter.setCancellable(() -> disposed.set(true)); + + ob1.dispose(); + + ob2.assertEmpty(); + assertThat(disposed.get()).isFalse(); + assertThat(cache.getInProgressTasks()).containsExactly("key1"); + assertThat(cache.getFinishedTasks()).isEmpty(); + } + + @Test + public void execute_disposeWhenMultipleSubscriptions_cancelled() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + Single single = cache.executeIfNot("key1", Single.create(emitterRef::set)); + TestObserver ob1 = single.test(); + TestObserver ob2 = single.test(); + SingleEmitter emitter = emitterRef.get(); + assertThat(emitter).isNotNull(); + AtomicBoolean disposed = new AtomicBoolean(false); + emitter.setCancellable(() -> disposed.set(true)); + + ob1.dispose(); + ob2.dispose(); + + assertThat(disposed.get()).isTrue(); + assertThat(cache.getInProgressTasks()).isEmpty(); + assertThat(cache.getFinishedTasks()).isEmpty(); + } + + @Test + public void execute_multipleTasks_completeOne() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef1 = new AtomicReference<>(null); + TestObserver observer1 = + cache.executeIfNot("key1", Single.create(emitterRef1::set)).test(); + SingleEmitter emitter1 = emitterRef1.get(); + assertThat(emitter1).isNotNull(); + AtomicReference> emitterRef2 = new AtomicReference<>(null); + TestObserver observer2 = + cache.executeIfNot("key2", Single.create(emitterRef2::set)).test(); + SingleEmitter emitter2 = emitterRef1.get(); + assertThat(emitter2).isNotNull(); + + emitter1.onSuccess("value1"); + + observer1.assertValue("value1"); + observer2.assertEmpty(); + assertThat(cache.getInProgressTasks()).containsExactly("key2"); + assertThat(cache.getFinishedTasks()).containsExactly("key1"); + } +}