Skip to content

Commit

Permalink
Return to a one-token-per-channel model
Browse files Browse the repository at this point in the history
  • Loading branch information
jchambers committed Feb 26, 2022
1 parent fb8b2d4 commit 04ff236
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 319 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

package com.eatthepath.pushy.apns;

import com.eatthepath.pushy.apns.auth.AuthenticationTokenProvider;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.socket.SocketChannel;
Expand Down Expand Up @@ -70,7 +69,6 @@ class ApnsChannelFactory implements PooledObjectFactory<Channel>, Closeable {
AttributeKey.valueOf(ApnsChannelFactory.class, "channelReadyPromise");

ApnsChannelFactory(final ApnsClientConfiguration clientConfiguration,
final AuthenticationTokenProvider authenticationTokenProvider,
final EventLoopGroup eventLoopGroup) {

this.sslContext = clientConfiguration.getSslContext();
Expand Down Expand Up @@ -114,9 +112,10 @@ protected void initChannel(final SocketChannel channel) {
{
final ApnsClientHandler.ApnsClientHandlerBuilder clientHandlerBuilder;

if (authenticationTokenProvider != null) {
if (clientConfiguration.getSigningKey().isPresent()) {
clientHandlerBuilder = new TokenAuthenticationApnsClientHandler.TokenAuthenticationApnsClientHandlerBuilder()
.authenticationTokenProvider(authenticationTokenProvider)
.signingKey(clientConfiguration.getSigningKey().get())
.tokenExpiration(clientConfiguration.getTokenExpiration())
.authority(authority)
.idlePingInterval(clientConfiguration.getIdlePingInterval());
} else {
Expand Down
17 changes: 1 addition & 16 deletions pushy/src/main/java/com/eatthepath/pushy/apns/ApnsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

package com.eatthepath.pushy.apns;

import com.eatthepath.pushy.apns.auth.AuthenticationTokenProvider;
import com.eatthepath.pushy.apns.util.concurrent.PushNotificationFuture;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
Expand Down Expand Up @@ -80,8 +79,6 @@ public class ApnsClient {
private final EventLoopGroup eventLoopGroup;
private final boolean shouldShutDownEventLoopGroup;

private final AuthenticationTokenProvider authenticationTokenProvider;

private final ApnsChannelPool channelPool;

private final ApnsClientMetricsListener metricsListener;
Expand Down Expand Up @@ -135,15 +132,11 @@ protected ApnsClient(final ApnsClientConfiguration clientConfiguration, final Ev
this.shouldShutDownEventLoopGroup = true;
}

this.authenticationTokenProvider = clientConfiguration.getSigningKey()
.map(signingKey -> new AuthenticationTokenProvider(signingKey, clientConfiguration.getTokenExpiration(), this.eventLoopGroup))
.orElse(null);

this.metricsListener = clientConfiguration.getMetricsListener()
.orElseGet(NoopApnsClientMetricsListener::new);

final ApnsChannelFactory channelFactory =
new ApnsChannelFactory(clientConfiguration, this.authenticationTokenProvider, this.eventLoopGroup);
new ApnsChannelFactory(clientConfiguration, this.eventLoopGroup);

final ApnsChannelPoolMetricsListener channelPoolMetricsListener = new ApnsChannelPoolMetricsListener() {

Expand Down Expand Up @@ -229,10 +222,6 @@ public <T extends ApnsPushNotification> PushNotificationFuture<T, PushNotificati
return responseFuture;
}

AuthenticationTokenProvider getAuthenticationTokenProvider() {
return this.authenticationTokenProvider;
}

/**
* <p>Gracefully shuts down the client, closing all connections and releasing all persistent resources. The
* disconnection process will wait until notifications that have been sent to the APNs server have been either
Expand All @@ -257,10 +246,6 @@ public CompletableFuture<Void> close() {
final CompletableFuture<Void> closeFuture;

if (this.isClosed.compareAndSet(false, true)) {
if (this.authenticationTokenProvider != null) {
this.authenticationTokenProvider.close();
}

closeFuture = new CompletableFuture<>();

this.channelPool.close().addListener((GenericFutureListener<Future<Void>>) closePoolFuture -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,30 @@

package com.eatthepath.pushy.apns;

import com.eatthepath.pushy.apns.auth.ApnsSigningKey;
import com.eatthepath.pushy.apns.auth.AuthenticationToken;
import com.eatthepath.pushy.apns.auth.AuthenticationTokenProvider;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http2.*;
import io.netty.handler.codec.http2.Http2ConnectionDecoder;
import io.netty.handler.codec.http2.Http2ConnectionEncoder;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.util.AsciiString;
import io.netty.util.collection.IntObjectHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.util.Map;
import java.time.Instant;
import java.util.Objects;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;

class TokenAuthenticationApnsClientHandler extends ApnsClientHandler {

private final AuthenticationTokenProvider authenticationTokenProvider;
private final ApnsSigningKey signingKey;
private AuthenticationToken authenticationToken;

private final Http2Connection.PropertyKey authenticationTokenPropertyKey;
private final Map<Integer, AuthenticationToken> unattachedAuthenticationTokensByStreamId = new IntObjectHashMap<>();
private final Duration tokenExpiration;
private ScheduledFuture<?> tokenExpirationFuture;

private static final AsciiString APNS_AUTHORIZATION_HEADER = new AsciiString("authorization");

Expand All @@ -49,68 +54,68 @@ class TokenAuthenticationApnsClientHandler extends ApnsClientHandler {
private static final Logger log = LoggerFactory.getLogger(TokenAuthenticationApnsClientHandler.class);

public static class TokenAuthenticationApnsClientHandlerBuilder extends ApnsClientHandlerBuilder {
private AuthenticationTokenProvider authenticationTokenProvider;
private ApnsSigningKey signingKey;
private Duration tokenExpiration;

public TokenAuthenticationApnsClientHandlerBuilder authenticationTokenProvider(final AuthenticationTokenProvider authenticationTokenProvider) {
this.authenticationTokenProvider = authenticationTokenProvider;
public TokenAuthenticationApnsClientHandlerBuilder signingKey(final ApnsSigningKey signingKey) {
this.signingKey = signingKey;
return this;
}

public AuthenticationTokenProvider authenticationTokenProvider() {
return this.authenticationTokenProvider;
public ApnsSigningKey signingKey() {
return this.signingKey;
}

public TokenAuthenticationApnsClientHandlerBuilder tokenExpiration(final Duration tokenExpiration) {
this.tokenExpiration = tokenExpiration;
return this;
}

public Duration tokenExpiration() {
return this.tokenExpiration;
}

@Override
public ApnsClientHandler build(final Http2ConnectionDecoder decoder, final Http2ConnectionEncoder encoder, final Http2Settings initialSettings) {
Objects.requireNonNull(this.authority(), "Authority must be set before building a TokenAuthenticationApnsClientHandler.");
Objects.requireNonNull(this.authenticationTokenProvider(), "Authentication token provider must be set before building a TokenAuthenticationApnsClientHandler.");
Objects.requireNonNull(this.signingKey(), "Signing key must be set before building a TokenAuthenticationApnsClientHandler.");
Objects.requireNonNull(this.tokenExpiration(), "Token expiration duration must be set before building a TokenAuthenticationApnsClientHandler.");

final ApnsClientHandler handler = new TokenAuthenticationApnsClientHandler(decoder, encoder, initialSettings, this.authority(), this.idlePingInterval(), this.authenticationTokenProvider());
final ApnsClientHandler handler = new TokenAuthenticationApnsClientHandler(decoder, encoder, initialSettings, this.authority(), this.idlePingInterval(), this.signingKey(), this.tokenExpiration());
this.frameListener(handler);
return handler;
}
}

protected TokenAuthenticationApnsClientHandler(final Http2ConnectionDecoder decoder, final Http2ConnectionEncoder encoder, final Http2Settings initialSettings, final String authority, final Duration idlePingInterval, final AuthenticationTokenProvider authenticationTokenProvider) {
protected TokenAuthenticationApnsClientHandler(final Http2ConnectionDecoder decoder, final Http2ConnectionEncoder encoder, final Http2Settings initialSettings, final String authority, final Duration idlePingInterval, final ApnsSigningKey signingKey, final Duration tokenExpiration) {
super(decoder, encoder, initialSettings, authority, idlePingInterval);

this.authenticationTokenProvider = Objects.requireNonNull(authenticationTokenProvider, "Authentication token provider must not be null for token-based client handlers.");
this.authenticationTokenPropertyKey = this.connection().newKey();
}

@Override
public void onStreamAdded(final Http2Stream stream) {
super.onStreamAdded(stream);

stream.setProperty(this.authenticationTokenPropertyKey, this.unattachedAuthenticationTokensByStreamId.remove(stream.id()));
}

@Override
public void onStreamRemoved(final Http2Stream stream) {
super.onStreamRemoved(stream);

stream.removeProperty(this.authenticationTokenPropertyKey);
this.signingKey = Objects.requireNonNull(signingKey, "Signing key must not be null for token-based client handlers.");
this.tokenExpiration = Objects.requireNonNull(tokenExpiration, "Token expiration must not be null for token-based client handlers");
}

@Override
protected Http2Headers getHeadersForPushNotification(final ApnsPushNotification pushNotification, final ChannelHandlerContext context, final int streamId) {
final AuthenticationToken authenticationToken = this.authenticationTokenProvider.getAuthenticationToken();

this.unattachedAuthenticationTokensByStreamId.put(streamId, authenticationToken);
if (this.authenticationToken == null) {
log.debug("Generated a new authentication token for channel {} at stream {}", context.channel(), streamId);
this.authenticationToken = new AuthenticationToken(this.signingKey, Instant.now());

tokenExpirationFuture = context.executor().schedule(() -> {
log.debug("Authentication token for channel {} has expired", context.channel());
TokenAuthenticationApnsClientHandler.this.authenticationToken = null;
}, this.tokenExpiration.toMillis(), TimeUnit.MILLISECONDS);
}

return super.getHeadersForPushNotification(pushNotification, context, streamId)
.add(APNS_AUTHORIZATION_HEADER, authenticationToken.getAuthorizationHeader());
.add(APNS_AUTHORIZATION_HEADER, this.authenticationToken.getAuthorizationHeader());
}

@Override
protected void handleErrorResponse(final ChannelHandlerContext context, final int streamId, final Http2Headers headers, final ApnsPushNotification pushNotification, final ErrorResponse errorResponse) {
super.handleErrorResponse(context, streamId, headers, pushNotification, errorResponse);

if (EXPIRED_AUTH_TOKEN_REASON.equals(errorResponse.getReason())) {
log.warn("APNs server reports token for channel {} has expired.", context.channel());

this.authenticationTokenProvider.expireAuthenticationToken(
this.connection().stream(streamId).getProperty(this.authenticationTokenPropertyKey));
log.warn("APNs server reports token for channel {} has expired; will close channel", context.channel());

// Once the server thinks our token has expired, it will "wedge" the connection. There's no way to recover
// from this situation, and all we can do is close the connection and create a new one.
Expand All @@ -120,7 +125,10 @@ protected void handleErrorResponse(final ChannelHandlerContext context, final in

@Override
public void channelInactive(final ChannelHandlerContext context) throws Exception {
this.unattachedAuthenticationTokensByStreamId.clear();
if (this.tokenExpirationFuture != null) {
this.tokenExpirationFuture.cancel(false);
this.tokenExpirationFuture = null;
}

super.channelInactive(context);
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

package com.eatthepath.pushy.apns;

import com.eatthepath.pushy.apns.auth.AuthenticationToken;
import com.eatthepath.pushy.apns.server.*;
import com.eatthepath.pushy.apns.util.SimpleApnsPushNotification;
import com.eatthepath.pushy.apns.util.concurrent.PushNotificationFuture;
Expand Down Expand Up @@ -416,8 +415,6 @@ void testSendNotificationWithExpiredAuthenticationToken() throws Exception {
final TestClientMetricsListener metricsListener = new TestClientMetricsListener();
final ApnsClient client = this.buildTokenAuthenticationClient(metricsListener);

final AuthenticationToken initialToken = client.getAuthenticationTokenProvider().getAuthenticationToken();

try {
server.start(PORT).get();

Expand All @@ -432,8 +429,6 @@ void testSendNotificationWithExpiredAuthenticationToken() throws Exception {
client.close().get();
server.shutdown().get();
}

assertNotEquals(initialToken, client.getAuthenticationTokenProvider().getAuthenticationToken());
}

@ParameterizedTest
Expand Down
Loading

0 comments on commit 04ff236

Please sign in to comment.