Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let ConnectionFactory see difference between proxied and direct connections #2169

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2021 Apple Inc. and the ServiceTalk project authors
* Copyright © 2021-2022 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.http.api;
package io.servicetalk.concurrent.internal;

import io.servicetalk.concurrent.internal.ContextMapUtils;
import io.servicetalk.context.api.ContextMap;

import java.util.HashMap;
Expand All @@ -26,11 +25,19 @@

import static java.util.Objects.requireNonNull;

final class DefaultContextMap implements ContextMap {
/**
* Default implementation of {@link ContextMap}.
* <p>
* Note: it's not thread-safe!
*/
public final class DefaultContextMap implements ContextMap {

private final HashMap<Key<?>, Object> theMap;

DefaultContextMap() {
/**
* Creates a new instance.
*/
public DefaultContextMap() {
theMap = new HashMap<>(4); // start with a smaller table
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.servicetalk.http.api;

import io.servicetalk.concurrent.internal.DefaultContextMap;
import io.servicetalk.context.api.ContextMap;
import io.servicetalk.encoding.api.ContentCodec;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2021 Apple Inc. and the ServiceTalk project authors
* Copyright © 2021-2022 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@

import io.servicetalk.context.api.ContextMap;
import io.servicetalk.context.api.ContextMap.Key;
import io.servicetalk.transport.api.ConnectionInfo;

import static io.servicetalk.context.api.ContextMap.Key.newKey;

Expand All @@ -33,6 +34,19 @@ public final class HttpContextKeys {
public static final Key<HttpExecutionStrategy> HTTP_EXECUTION_STRATEGY_KEY =
newKey("HTTP_EXECUTION_STRATEGY_KEY", HttpExecutionStrategy.class);

/**
* When opening a connection to a proxy, this key tells what is the actual (unresolved) target address behind the
* proxy this connection will be established to.
* <p>
* To distinguish between a
* <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Proxy_servers_and_tunneling#http_tunneling">secure
* HTTP proxy tunneling</a> and a clear text HTTP proxy, check presence of {@link ConnectionInfo#sslConfig()}.
*
* @see SingleAddressHttpClientBuilder#proxyAddress(Object)
*/
public static final Key<Object> HTTP_TARGET_ADDRESS_BEHIND_PROXY =
newKey("HTTP_TARGET_ADDRESS_BEHIND_PROXY", Object.class);

private HttpContextKeys() {
// No instances
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2019, 2021 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019, 2021-2022 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,7 +29,12 @@
import io.servicetalk.http.api.StreamingHttpRequester;
import io.servicetalk.http.api.StreamingHttpResponse;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static io.servicetalk.concurrent.api.Single.defer;
import static io.servicetalk.http.api.HttpContextKeys.HTTP_TARGET_ADDRESS_BEHIND_PROXY;
import static io.servicetalk.http.netty.ProxyConnectConnectionFactoryFilter.logUnexpectedAddress;
import static io.servicetalk.http.utils.HttpRequestUriUtils.getEffectiveRequestUri;
import static java.util.Objects.requireNonNull;

Expand All @@ -39,6 +44,9 @@
*/
final class AbsoluteAddressHttpRequesterFilter implements StreamingHttpClientFilterFactory,
StreamingHttpConnectionFilterFactory {

private static final Logger LOGGER = LoggerFactory.getLogger(AbsoluteAddressHttpRequesterFilter.class);

private final String scheme;
private final String authority;

Expand Down Expand Up @@ -84,6 +92,7 @@ public HttpExecutionStrategy requiredOffloads() {
private Single<StreamingHttpResponse> request(final StreamingHttpRequester delegate,
final StreamingHttpRequest request) {
return defer(() -> {
logUnexpectedAddress(request.context().put(HTTP_TARGET_ADDRESS_BEHIND_PROXY, authority), authority, LOGGER);
final String effectiveRequestUri = getEffectiveRequestUri(request, scheme, authority, false);
request.requestTarget(effectiveRequestUri);
return delegate.request(request).shareContextOnSubscribe();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2019-2020 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019-2020, 2022 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@
import io.servicetalk.client.api.DelegatingConnectionFactory;
import io.servicetalk.concurrent.SingleSource;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.internal.DefaultContextMap;
import io.servicetalk.context.api.ContextMap;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpExecutionStrategies;
Expand All @@ -33,12 +34,15 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.Processors.newSingleProcessor;
import static io.servicetalk.concurrent.api.Single.failed;
import static io.servicetalk.concurrent.api.SourceAdapters.fromSource;
import static io.servicetalk.http.api.HttpContextKeys.HTTP_TARGET_ADDRESS_BEHIND_PROXY;
import static io.servicetalk.http.api.HttpHeaderNames.CONTENT_LENGTH;
import static io.servicetalk.http.api.HttpHeaderValues.ZERO;
import static io.servicetalk.http.api.HttpResponseStatus.StatusClass.SUCCESSFUL_2XX;
Expand All @@ -52,6 +56,8 @@
final class ProxyConnectConnectionFactoryFilter<ResolvedAddress, C extends FilterableStreamingHttpConnection>
implements ConnectionFactoryFilter<ResolvedAddress, C> {

private static final Logger LOGGER = LoggerFactory.getLogger(ProxyConnectConnectionFactoryFilter.class);

private final String connectAddress;

ProxyConnectConnectionFactoryFilter(final CharSequence connectAddress) {
Expand All @@ -71,17 +77,24 @@ private ProxyFilter(final ConnectionFactory<ResolvedAddress, C> delegate) {

@Override
public Single<C> newConnection(final ResolvedAddress resolvedAddress,
@Nullable final ContextMap context,
@Nullable ContextMap context,
@Nullable final TransportObserver observer) {
return delegate().newConnection(resolvedAddress, context, observer).flatMap(c -> {
try {
return c.request(c.connect(connectAddress).addHeader(CONTENT_LENGTH, ZERO))
.flatMap(response -> handleConnectResponse(c, response))
// Close recently created connection in case of any error while it connects to the proxy:
.onErrorResume(t -> c.closeAsync().concat(failed(t)));
} catch (Throwable t) {
return c.closeAsync().concat(failed(t));
}
return Single.defer(() -> {
final ContextMap contextMap = context != null ? context : new DefaultContextMap();
logUnexpectedAddress(contextMap.put(HTTP_TARGET_ADDRESS_BEHIND_PROXY, connectAddress),
connectAddress, LOGGER);
return delegate().newConnection(resolvedAddress, contextMap, observer).flatMap(c -> {
try {
return c.request(c.connect(connectAddress).addHeader(CONTENT_LENGTH, ZERO))
.flatMap(response -> handleConnectResponse(c, response))
// Close recently created connection in case of any error while it connects to the
// proxy:
.onErrorResume(t -> c.closeAsync().concat(failed(t)));
// We do not apply shareContextOnSubscribe() here to isolate a context for `CONNECT` request.
} catch (Throwable t) {
return c.closeAsync().concat(failed(t));
}
}).shareContextOnSubscribe();
});
}

Expand Down Expand Up @@ -121,6 +134,13 @@ public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt
}
}

static void logUnexpectedAddress(@Nullable final Object current, final Object expected, final Logger logger) {
if (current != null && !expected.equals(current)) {
logger.info("Observed unexpected value for {}: {}, overridden with: {}",
HTTP_TARGET_ADDRESS_BEHIND_PROXY, current, expected);
}
}

@Override
public HttpExecutionStrategy requiredOffloads() {
// No influence since we do not block.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2019, 2021 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019, 2021-2022 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
import io.servicetalk.http.api.HttpClient;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.http.api.SingleAddressHttpClientBuilder;
import io.servicetalk.http.netty.HttpsProxyTest.TargetAddressCheckConnectionFactoryFilter;
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.ServerContext;

Expand All @@ -30,6 +31,7 @@

import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import javax.annotation.Nullable;

Expand All @@ -42,6 +44,7 @@
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;

class HttpProxyTest {
Expand All @@ -57,6 +60,7 @@ class HttpProxyTest {
@Nullable
private HostAndPort serverAddress;
private final AtomicInteger proxyRequestCount = new AtomicInteger();
private final AtomicReference<Object> targetAddress = new AtomicReference<>();

@BeforeEach
void setup() throws Exception {
Expand Down Expand Up @@ -108,12 +112,14 @@ void testRequest(ClientSource clientSource) throws Exception {

final BlockingHttpClient client = clientSource.clientBuilderFactory.apply(serverAddress)
.proxyAddress(proxyAddress)
.appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, false))
.buildBlocking();

final HttpResponse httpResponse = client.request(client.get("/path"));
assertThat(httpResponse.status(), is(OK));
assertThat(proxyRequestCount.get(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
safeClose(client);
}

Expand All @@ -131,6 +137,7 @@ void testBuilderReuseEachClientUsesOwnProxy() throws Exception {
return otherProxyClient.request(request);
});
BlockingHttpClient otherClient = builder.proxyAddress(serverHostAndPort(otherProxyContext))
.appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, false))
.buildBlocking()) {

final HttpResponse httpResponse = otherClient.request(client.get("/path"));
Expand All @@ -143,5 +150,6 @@ void testBuilderReuseEachClientUsesOwnProxy() throws Exception {
assertThat(httpResponse.status(), is(OK));
assertThat(proxyRequestCount.get(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2019, 2021 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019, 2021-2022 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,22 +15,32 @@
*/
package io.servicetalk.http.netty;

import io.servicetalk.client.api.ConnectionFactory;
import io.servicetalk.client.api.ConnectionFactoryFilter;
import io.servicetalk.client.api.DelegatingConnectionFactory;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.context.api.ContextMap;
import io.servicetalk.http.api.BlockingHttpClient;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.test.resources.DefaultTestCerts;
import io.servicetalk.transport.api.ClientSslConfigBuilder;
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.IoExecutor;
import io.servicetalk.transport.api.ServerContext;
import io.servicetalk.transport.api.ServerSslConfigBuilder;
import io.servicetalk.transport.api.TransportObserver;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.http.api.HttpContextKeys.HTTP_TARGET_ADDRESS_BEHIND_PROXY;
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.HttpSerializers.textSerializerUtf8;
Expand All @@ -40,12 +50,14 @@
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;

class HttpsProxyTest {

private final ProxyTunnel proxyTunnel = new ProxyTunnel();
private final AtomicReference<Object> targetAddress = new AtomicReference<>();

@Nullable
private HostAndPort proxyAddress;
Expand Down Expand Up @@ -105,6 +117,7 @@ void createClient() {
.proxyAddress(proxyAddress)
.sslConfig(new ClientSslConfigBuilder(DefaultTestCerts::loadServerCAPem)
.peerHost(serverPemHostname()).build())
.appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, true))
.buildBlocking();
}

Expand All @@ -115,12 +128,41 @@ void testRequest() throws Exception {
assertThat(httpResponse.status(), is(OK));
assertThat(proxyTunnel.connectCount(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
}

@Test
void testBadProxyResponse() {
proxyTunnel.badResponseProxy();
assert client != null;
assertThrows(ProxyResponseException.class, () -> client.request(client.get("/path")));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
}

static final class TargetAddressCheckConnectionFactoryFilter
implements ConnectionFactoryFilter<InetSocketAddress, FilterableStreamingHttpConnection> {

private final AtomicReference<Object> targetAddress;
private final boolean secure;

TargetAddressCheckConnectionFactoryFilter(AtomicReference<Object> targetAddress, boolean secure) {
this.targetAddress = targetAddress;
this.secure = secure;
}

@Override
public ConnectionFactory<InetSocketAddress, FilterableStreamingHttpConnection> create(
ConnectionFactory<InetSocketAddress, FilterableStreamingHttpConnection> original) {
return new DelegatingConnectionFactory<InetSocketAddress, FilterableStreamingHttpConnection>(original) {
@Override
public Single<FilterableStreamingHttpConnection> newConnection(InetSocketAddress address,
@Nullable ContextMap context, @Nullable TransportObserver observer) {
assert context != null;
targetAddress.set(context.get(HTTP_TARGET_ADDRESS_BEHIND_PROXY));
return delegate().newConnection(address, context, observer)
.whenOnSuccess(c -> assertThat(c.connectionContext().sslConfig() != null, is(secure)));
}
};
}
}
}