Skip to content

Commit

Permalink
concurrent-api: cleanup AsyncContext operations (#3181)
Browse files Browse the repository at this point in the history
Motivation:

Perhaps the main goal of the AsyncContext is to allow a way to bundle up thread-local (ish) state and restore it when an async operation occurs. However, restore operation is essentially inlined in all the context operations which makes the mechanism more difficult to understand than it needs to be.

Modifications:

Add a few new methods to the AsyncContextProvider

- captureContext() the goal of this method is to package up context for propagation to the next continuation. This stands in contrast to the existing context() method which is used for reading the async ContextMap.
- attachContext(ContextMap) this method is used to restore the local state for temporary use. It returns a new Scope type, which then reverts this local environment to what it was before the restore.
Result:

Cleaner async context operations.
  • Loading branch information
bryce-anderson authored Feb 4, 2025
1 parent a7e83b4 commit c6d55e2
Show file tree
Hide file tree
Showing 32 changed files with 366 additions and 586 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright © 2025 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.api;

import io.servicetalk.context.api.ContextMap;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.util.concurrent.TimeUnit;
import java.util.function.Function;

@Fork(1)
@State(Scope.Benchmark)
@Warmup(iterations = 5, time = 3)
@Measurement(iterations = 5, time = 3)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@BenchmarkMode(Mode.AverageTime)
public class AsyncContextProviderBenchmark {

/**
* gc profiling of the DefaultAsyncContextProvider shows that the Scope based detachment can be stack allocated
* at least under some conditions.
*
* Benchmark Mode Cnt Score Error Units
* AsyncContextProviderBenchmark.contextRestoreCost avgt 5 3.932 ± 0.022 ns/op
* AsyncContextProviderBenchmark.contextRestoreCost:gc.alloc.rate avgt 5 ≈ 10⁻⁴ MB/sec
* AsyncContextProviderBenchmark.contextRestoreCost:gc.alloc.rate.norm avgt 5 ≈ 10⁻⁶ B/op
* AsyncContextProviderBenchmark.contextRestoreCost:gc.count avgt 5 ≈ 0 counts
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost avgt 5 1.712 ± 0.005 ns/op
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost:gc.alloc.rate avgt 5 ≈ 10⁻⁴ MB/sec
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost:gc.alloc.rate.norm avgt 5 ≈ 10⁻⁷ B/op
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost:gc.count avgt 5 ≈ 0 counts
*/

private static final ContextMap.Key<String> KEY = ContextMap.Key.newKey("test-key", String.class);
private static final String EXPECTED = "hello, world!";

private Function<String, String> wrappedFunction;

@Setup
public void setup() {
// This will capture the current context
wrappedFunction = AsyncContext.wrapFunction(ignored -> AsyncContext.context().get(KEY));
AsyncContext.context().put(KEY, EXPECTED);
}

@Benchmark
public String contextRestoreCost() {
return wrappedFunction.apply("ignored");
}

@Benchmark
public String contextSaveAndRestoreCost() {
return AsyncContext.wrapFunction(Function.<String>identity()).apply("ignored");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public final class AsyncContext {
private static final int STATE_INIT = 0;
private static final int STATE_AUTO_ENABLED = 1;
private static final int STATE_ENABLED = 2;

/**
* Note this mechanism is racy. Currently only the {@link #disable()} method is exposed publicly and
* {@link #STATE_DISABLED} is a terminal state. Because we favor going to the disabled state we don't have to worry
Expand Down Expand Up @@ -438,7 +439,7 @@ public static ScheduledExecutorService wrapJdkScheduledExecutorService(final Sch
*/
public static Runnable wrapRunnable(final Runnable runnable) {
AsyncContextProvider provider = provider();
return provider.wrapRunnable(runnable, provider.context());
return provider.wrapRunnable(runnable, provider.captureContext());
}

/**
Expand All @@ -449,7 +450,7 @@ public static Runnable wrapRunnable(final Runnable runnable) {
*/
public static <V> Callable<V> wrapCallable(final Callable<V> callable) {
AsyncContextProvider provider = provider();
return provider.wrapCallable(callable, provider.context());
return provider.wrapCallable(callable, provider.captureContext());
}

/**
Expand All @@ -460,7 +461,7 @@ public static <V> Callable<V> wrapCallable(final Callable<V> callable) {
*/
public static <T> Consumer<T> wrapConsumer(final Consumer<T> consumer) {
AsyncContextProvider provider = provider();
return provider.wrapConsumer(consumer, provider.context());
return provider.wrapConsumer(consumer, provider.captureContext());
}

/**
Expand All @@ -472,7 +473,7 @@ public static <T> Consumer<T> wrapConsumer(final Consumer<T> consumer) {
*/
public static <T, U> Function<T, U> wrapFunction(final Function<T, U> func) {
AsyncContextProvider provider = provider();
return provider.wrapFunction(func, provider.context());
return provider.wrapFunction(func, provider.captureContext());
}

/**
Expand All @@ -484,7 +485,7 @@ public static <T, U> Function<T, U> wrapFunction(final Function<T, U> func) {
*/
public static <T, U> BiConsumer<T, U> wrapBiConsume(final BiConsumer<T, U> consumer) {
AsyncContextProvider provider = provider();
return provider.wrapBiConsumer(consumer, provider.context());
return provider.wrapBiConsumer(consumer, provider.captureContext());
}

/**
Expand All @@ -497,7 +498,7 @@ public static <T, U> BiConsumer<T, U> wrapBiConsume(final BiConsumer<T, U> consu
*/
public static <T, U, V> BiFunction<T, U, V> wrapBiFunction(BiFunction<T, U, V> func) {
AsyncContextProvider provider = provider();
return provider.wrapBiFunction(func, provider.context());
return provider.wrapBiFunction(func, provider.captureContext());
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,30 @@ interface AsyncContextProvider {
/**
* Get the current context.
*
* Note that this method is for getting the {@link ContextMap} for use by the application code. For saving the
* current state for crossing an async boundary see the {@link AsyncContextProvider#captureContext()} method.
*
* @return The current context.
*/
ContextMap context();

/**
* Capture existing context in preparation for an asynchronous thread jump.
*
* Note that this can do more than just package up the ServiceTalk {@link AsyncContext} and could be enhanced or
* wrapped to bundle up additional contexts such as the OpenTelemetry or grpc contexts.
* @return the saved context state that may be restored later.
*/
ContextMap captureContext();

/**
* Restore the previously saved {@link ContextMap} to the local state.
* @param contextMap representing the state previously saved via {@link AsyncContextProvider#captureContext()} and
* that is intended to be restored.
* @return a {@link Scope} that must be closed at the end of the attachment.
*/
Scope attachContext(ContextMap contextMap);

/**
* Wrap the {@link Cancellable} to ensure it is able to track {@link AsyncContext} correctly.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1730,7 +1730,7 @@ public final Future<Void> toFuture() {
*/
ContextMap contextForSubscribe(AsyncContextProvider provider) {
// the default behavior is to copy the map. Some operators may want to use shared map
return provider.context().copy();
return provider.captureContext().copy();
}

/**
Expand All @@ -1740,9 +1740,19 @@ ContextMap contextForSubscribe(AsyncContextProvider provider) {
* @param subscriber {@link Subscriber} to subscribe for the result.
*/
protected final void subscribeInternal(Subscriber subscriber) {
requireNonNull(subscriber);
AsyncContextProvider contextProvider = AsyncContext.provider();
ContextMap contextMap = contextForSubscribe(contextProvider);
subscribeWithContext(subscriber, contextProvider, contextMap);
Subscriber wrapped = contextProvider.wrapCancellable(subscriber, contextMap);
if (contextProvider.context() == contextMap) {
// No need to wrap as we are sharing the AsyncContext
handleSubscribe(wrapped, contextMap, contextProvider);
} else {
// Ensure that AsyncContext used for handleSubscribe() is the contextMap for the subscribe()
try (Scope unused = contextProvider.attachContext(contextMap)) {
handleSubscribe(wrapped, contextMap, contextProvider);
}
}
}

/**
Expand Down Expand Up @@ -2262,19 +2272,6 @@ final void delegateSubscribe(Subscriber subscriber,
handleSubscribe(subscriber, contextMap, contextProvider);
}

private void subscribeWithContext(Subscriber subscriber,
AsyncContextProvider contextProvider, ContextMap contextMap) {
requireNonNull(subscriber);
Subscriber wrapped = contextProvider.wrapCancellable(subscriber, contextMap);
if (contextProvider.context() == contextMap) {
// No need to wrap as we are sharing the AsyncContext
handleSubscribe(wrapped, contextMap, contextProvider);
} else {
// Ensure that AsyncContext used for handleSubscribe() is the contextMap for the subscribe()
contextProvider.wrapRunnable(() -> handleSubscribe(wrapped, contextMap, contextProvider), contextMap).run();
}
}

/**
* Override for {@link #handleSubscribe(CompletableSource.Subscriber)}.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ final class CompletableShareContextOnSubscribe extends AbstractNoHandleSubscribe

@Override
ContextMap contextForSubscribe(AsyncContextProvider provider) {
return provider.context();
return provider.captureContext();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import java.util.List;
import java.util.concurrent.Callable;

import static io.servicetalk.concurrent.api.DefaultAsyncContextProvider.INSTANCE;

final class ContextAwareExecutorUtils {

private ContextAwareExecutorUtils() {
Expand All @@ -32,7 +30,7 @@ private ContextAwareExecutorUtils() {

static <X> Collection<? extends Callable<X>> wrap(Collection<? extends Callable<X>> tasks) {
List<Callable<X>> wrappedTasks = new ArrayList<>(tasks.size());
ContextMap contextMap = INSTANCE.context();
ContextMap contextMap = AsyncContext.provider().captureContext();
for (Callable<X> task : tasks) {
wrappedTasks.add(new ContextPreservingCallable<>(task, contextMap));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
package io.servicetalk.concurrent.api;

import io.servicetalk.context.api.ContextMap;
import io.servicetalk.context.api.ContextMapHolder;

import java.util.function.BiConsumer;

import static io.servicetalk.concurrent.api.AsyncContextMapThreadLocal.CONTEXT_THREAD_LOCAL;
import static java.util.Objects.requireNonNull;

final class ContextPreservingBiConsumer<T, U> implements BiConsumer<T, U> {
Expand All @@ -34,28 +32,8 @@ final class ContextPreservingBiConsumer<T, U> implements BiConsumer<T, U> {

@Override
public void accept(T t, U u) {
final Thread currentThread = Thread.currentThread();
if (currentThread instanceof ContextMapHolder) {
final ContextMapHolder asyncContextMapHolder = (ContextMapHolder) currentThread;
ContextMap prev = asyncContextMapHolder.context();
try {
asyncContextMapHolder.context(saved);
delegate.accept(t, u);
} finally {
asyncContextMapHolder.context(prev);
}
} else {
slowPath(t, u);
}
}

private void slowPath(T t, U u) {
ContextMap prev = CONTEXT_THREAD_LOCAL.get();
try {
CONTEXT_THREAD_LOCAL.set(saved);
try (Scope ignored = AsyncContext.provider().attachContext(saved)) {
delegate.accept(t, u);
} finally {
CONTEXT_THREAD_LOCAL.set(prev);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
package io.servicetalk.concurrent.api;

import io.servicetalk.context.api.ContextMap;
import io.servicetalk.context.api.ContextMapHolder;

import java.util.function.BiFunction;

import static io.servicetalk.concurrent.api.AsyncContextMapThreadLocal.CONTEXT_THREAD_LOCAL;
import static java.util.Objects.requireNonNull;

final class ContextPreservingBiFunction<T, U, V> implements BiFunction<T, U, V> {
Expand All @@ -34,28 +32,8 @@ final class ContextPreservingBiFunction<T, U, V> implements BiFunction<T, U, V>

@Override
public V apply(T t, U u) {
final Thread currentThread = Thread.currentThread();
if (currentThread instanceof ContextMapHolder) {
final ContextMapHolder asyncContextMapHolder = (ContextMapHolder) currentThread;
ContextMap prev = asyncContextMapHolder.context();
try {
asyncContextMapHolder.context(saved);
return delegate.apply(t, u);
} finally {
asyncContextMapHolder.context(prev);
}
} else {
return slowPath(t, u);
}
}

private V slowPath(T t, U u) {
ContextMap prev = CONTEXT_THREAD_LOCAL.get();
try {
CONTEXT_THREAD_LOCAL.set(saved);
try (Scope ignored = AsyncContext.provider().attachContext(saved)) {
return delegate.apply(t, u);
} finally {
CONTEXT_THREAD_LOCAL.set(prev);
}
}
}
Loading

0 comments on commit c6d55e2

Please sign in to comment.