Skip to content

Commit

Permalink
Pass scheme name to #create method of DetachedSocketFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
ok2c committed Feb 10, 2025
1 parent daaa08f commit bed9c65
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public void connect(
if (LOG.isDebugEnabled()) {
LOG.debug("{} connecting {}->{} ({})", endpointHost, localAddress, remoteAddress, connectTimeout);
}
final Socket socket = detachedSocketFactory.create(socksProxy);
final Socket socket = detachedSocketFactory.create(endpointHost.getSchemeName(), socksProxy);
try {
// Always bind to the local address if it's provided.
if (localAddress != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,11 @@ public interface DetachedSocketFactory {

Socket create(Proxy proxy) throws IOException;

/**
* @since 5.5
*/
default Socket create(String schemeName, Proxy proxy) throws IOException {
return create(proxy);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ void testTargetConnect() throws Exception {

Mockito.when(dnsResolver.resolve("somehost", 8443)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8443)));
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);

Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
Mockito.when(tlsSocketStrategy.upgrade(
Expand All @@ -401,15 +401,15 @@ void testTargetConnect() throws Exception {

Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create("https", null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 234);
Mockito.verify(tlsSocketStrategy).upgrade(socket, "somehost", 443, tlsConfig, context);

mgr.connect(endpoint1, TimeValue.ofMilliseconds(123), context);

Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(2)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(2)).create(null);
Mockito.verify(detachedSocketFactory, Mockito.times(2)).create("https", null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 123);
Mockito.verify(tlsSocketStrategy, Mockito.times(2)).upgrade(socket, "somehost", 443, tlsConfig, context);
}
Expand Down Expand Up @@ -446,13 +446,13 @@ void testProxyConnectAndUpgrade() throws Exception {
Mockito.when(schemePortResolver.resolve(proxy.getSchemeName(), proxy)).thenReturn(8080);
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);

mgr.connect(endpoint1, null, context);

Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy", 8080);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(proxy.getSchemeName(), proxy);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create("http", null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8080), 234);

Mockito.when(conn.getSocket()).thenReturn(socket);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void testConnect() throws Exception {
Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses);

Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);

final SocketConfig socketConfig = SocketConfig.custom()
.setSoKeepAlive(true)
Expand Down Expand Up @@ -143,7 +143,7 @@ void testConnectWithTLSUpgrade() throws Exception {
Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses);

Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);

Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
final SSLSocket upgradedSocket = Mockito.mock(SSLSocket.class);
Expand Down Expand Up @@ -178,7 +178,7 @@ void testConnectTimeout() throws Exception {
);
Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses);
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);
Mockito.doThrow(new SocketTimeoutException()).when(socket).connect(Mockito.any(), Mockito.anyInt());
Assertions.assertThrows(ConnectTimeoutException.class, () ->
connectionOperator.connect(
Expand All @@ -200,7 +200,7 @@ void testConnectFailure() throws Exception {
Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses);

Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);
Mockito.doThrow(new ConnectException()).when(socket).connect(Mockito.any(), Mockito.anyInt());

Assertions.assertThrows(HttpHostConnectException.class, () ->
Expand All @@ -218,7 +218,7 @@ void testConnectFailover() throws Exception {

Mockito.when(dnsResolver.resolve("somehost", 80)).thenReturn(Arrays.asList(ipAddress1, ipAddress2));
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(80);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);
Mockito.doThrow(new ConnectException()).when(socket).connect(
Mockito.eq(ipAddress1),
Mockito.anyInt());
Expand All @@ -243,7 +243,7 @@ void testConnectExplicitAddress() throws Exception {
final HttpHost host = new HttpHost(ip);

Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(80);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);

final InetSocketAddress localAddress = new InetSocketAddress(local, 0);
final TlsConfig tlsConfig = TlsConfig.custom()
Expand Down Expand Up @@ -316,7 +316,7 @@ void testConnectWithDisableDnsResolution() throws Exception {
Mockito.when(dnsResolver.resolve(host.getHostName(), port)).thenReturn(resolvedAddresses);

Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);

final SocketConfig socketConfig = SocketConfig.custom()
.setSoKeepAlive(true)
Expand Down Expand Up @@ -359,7 +359,7 @@ void testConnectWithDnsResolutionAndFallback() throws Exception {
);
Mockito.when(dnsResolver.resolve("fallbackhost.com", port)).thenReturn(resolvedAddresses);
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);

// Simulate failure to connect to the first resolved address
Mockito.doThrow(new ConnectException()).when(socket).connect(Mockito.eq(new InetSocketAddress(ip1, port)), Mockito.anyInt());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ void testTargetConnect() throws Exception {

Mockito.when(dnsResolver.resolve("somehost", 8443)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8443)));
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);

Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
Mockito.when(tlsSocketStrategy.upgrade(
Expand All @@ -281,15 +281,15 @@ void testTargetConnect() throws Exception {

Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create("https", null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 234);
Mockito.verify(tlsSocketStrategy).upgrade(socket, "somehost", 443, tlsConfig, context);

mgr.connect(endpoint1, TimeValue.ofMilliseconds(123), context);

Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(2)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(2)).create(null);
Mockito.verify(detachedSocketFactory, Mockito.times(2)).create("https", null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 123);
Mockito.verify(tlsSocketStrategy, Mockito.times(2)).upgrade(socket, "somehost", 443, tlsConfig, context);
}
Expand Down Expand Up @@ -336,13 +336,13 @@ void testProxyConnectAndUpgrade() throws Exception {
Mockito.when(schemePortResolver.resolve(proxy.getSchemeName(), proxy)).thenReturn(8080);
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(detachedSocketFactory.create(Mockito.any(), Mockito.any())).thenReturn(socket);

mgr.connect(endpoint1, null, context);

Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy", 8080);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(proxy.getSchemeName(), proxy);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create("http", null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8080), 234);

Mockito.when(conn.isOpen()).thenReturn(true);
Expand Down

0 comments on commit bed9c65

Please sign in to comment.