diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index 9a53f9fcafdd2..bd6bdaabf3c67 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -17,7 +17,7 @@ package org.apache.arrow.driver.jdbc; -import static org.apache.arrow.driver.jdbc.utils.FlightStreamQueue.createNewQueue; +import static org.apache.arrow.driver.jdbc.utils.FlightEndpointDataQueue.createNewQueue; import java.sql.ResultSet; import java.sql.ResultSetMetaData; @@ -26,7 +26,8 @@ import java.util.TimeZone; import java.util.concurrent.TimeUnit; -import org.apache.arrow.driver.jdbc.utils.FlightStreamQueue; +import org.apache.arrow.driver.jdbc.client.CloseableEndpointStreamPair; +import org.apache.arrow.driver.jdbc.utils.FlightEndpointDataQueue; import org.apache.arrow.driver.jdbc.utils.VectorSchemaRootTransformer; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightStream; @@ -47,8 +48,8 @@ public final class ArrowFlightJdbcFlightStreamResultSet extends ArrowFlightJdbcVectorSchemaRootResultSet { private final ArrowFlightConnection connection; - private FlightStream currentFlightStream; - private FlightStreamQueue flightStreamQueue; + private CloseableEndpointStreamPair currentEndpointData; + private FlightEndpointDataQueue flightEndpointDataQueue; private VectorSchemaRootTransformer transformer; private VectorSchemaRoot currentVectorSchemaRoot; @@ -102,20 +103,20 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo( resultSet.transformer = transformer; - resultSet.execute(flightInfo); + resultSet.populateData(flightInfo); return resultSet; } private void loadNewQueue() { - Optional.ofNullable(flightStreamQueue).ifPresent(AutoCloseables::closeNoChecked); - flightStreamQueue = createNewQueue(connection.getExecutorService()); + Optional.ofNullable(flightEndpointDataQueue).ifPresent(AutoCloseables::closeNoChecked); + flightEndpointDataQueue = createNewQueue(connection.getExecutorService()); } private void loadNewFlightStream() throws SQLException { - if (currentFlightStream != null) { - AutoCloseables.closeNoChecked(currentFlightStream); + if (currentEndpointData != null) { + AutoCloseables.closeNoChecked(currentEndpointData); } - this.currentFlightStream = getNextFlightStream(true); + this.currentEndpointData = getNextEndpointStream(true); } @Override @@ -124,24 +125,24 @@ protected AvaticaResultSet execute() throws SQLException { if (flightInfo != null) { schema = flightInfo.getSchemaOptional().orElse(null); - execute(flightInfo); + populateData(flightInfo); } return this; } - private void execute(final FlightInfo flightInfo) throws SQLException { + private void populateData(final FlightInfo flightInfo) throws SQLException { loadNewQueue(); - flightStreamQueue.enqueue(connection.getClientHandler().getStreams(flightInfo)); + flightEndpointDataQueue.enqueue(connection.getClientHandler().getStreams(flightInfo)); loadNewFlightStream(); // Ownership of the root will be passed onto the cursor. - if (currentFlightStream != null) { - executeForCurrentFlightStream(); + if (currentEndpointData != null) { + populateDataForCurrentFlightStream(); } } - private void executeForCurrentFlightStream() throws SQLException { - final VectorSchemaRoot originalRoot = currentFlightStream.getRoot(); + private void populateDataForCurrentFlightStream() throws SQLException { + final VectorSchemaRoot originalRoot = currentEndpointData.getStream().getRoot(); if (transformer != null) { try { @@ -154,9 +155,9 @@ private void executeForCurrentFlightStream() throws SQLException { } if (schema != null) { - execute(currentVectorSchemaRoot, schema); + populateData(currentVectorSchemaRoot, schema); } else { - execute(currentVectorSchemaRoot); + populateData(currentVectorSchemaRoot); } } @@ -179,20 +180,23 @@ public boolean next() throws SQLException { return true; } - if (currentFlightStream != null) { - currentFlightStream.getRoot().clear(); - if (currentFlightStream.next()) { - executeForCurrentFlightStream(); + if (currentEndpointData != null) { + if (currentEndpointData.getStream().next()) { + populateDataForCurrentFlightStream(); continue; } - flightStreamQueue.enqueue(currentFlightStream); + try { + AutoCloseables.close(currentEndpointData); + } catch (Exception ex) { + throw new RuntimeException(ex); + } } - currentFlightStream = getNextFlightStream(false); + currentEndpointData = getNextEndpointStream(false); - if (currentFlightStream != null) { - executeForCurrentFlightStream(); + if (currentEndpointData != null) { + populateDataForCurrentFlightStream(); continue; } @@ -207,14 +211,19 @@ public boolean next() throws SQLException { @Override protected void cancel() { super.cancel(); - final FlightStream currentFlightStream = this.currentFlightStream; + final CloseableEndpointStreamPair currentFlightStream = this.currentEndpointData; if (currentFlightStream != null) { - currentFlightStream.cancel("Cancel", null); + currentFlightStream.getStream().cancel("Cancel", null); + try { + currentFlightStream.close(); + } catch (final Exception e) { + throw new RuntimeException(e); + } } - if (flightStreamQueue != null) { + if (flightEndpointDataQueue != null) { try { - flightStreamQueue.close(); + flightEndpointDataQueue.close(); } catch (final Exception e) { throw new RuntimeException(e); } @@ -224,12 +233,12 @@ protected void cancel() { @Override public synchronized void close() { try { - if (flightStreamQueue != null) { + if (flightEndpointDataQueue != null) { // flightStreamQueue should close currentFlightStream internally - flightStreamQueue.close(); - } else if (currentFlightStream != null) { + flightEndpointDataQueue.close(); + } else if (currentEndpointData != null) { // close is only called for currentFlightStream if there's no queue - currentFlightStream.close(); + currentEndpointData.close(); } } catch (final Exception e) { throw new RuntimeException(e); @@ -238,13 +247,13 @@ public synchronized void close() { } } - private FlightStream getNextFlightStream(final boolean isExecution) throws SQLException { - if (isExecution) { + private CloseableEndpointStreamPair getNextEndpointStream(final boolean canTimeout) throws SQLException { + if (canTimeout) { final int statementTimeout = statement != null ? statement.getQueryTimeout() : 0; return statementTimeout != 0 ? - flightStreamQueue.next(statementTimeout, TimeUnit.SECONDS) : flightStreamQueue.next(); + flightEndpointDataQueue.next(statementTimeout, TimeUnit.SECONDS) : flightEndpointDataQueue.next(); } else { - return flightStreamQueue.next(); + return flightEndpointDataQueue.next(); } } } diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java index 9e377e51decc9..20a2af6a84aa4 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java @@ -83,7 +83,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot( new ArrowFlightJdbcVectorSchemaRootResultSet(null, state, signature, resultSetMetaData, timeZone, null); - resultSet.execute(vectorSchemaRoot); + resultSet.populateData(vectorSchemaRoot); return resultSet; } @@ -92,7 +92,7 @@ protected AvaticaResultSet execute() throws SQLException { throw new RuntimeException("Can only execute with execute(VectorSchemaRoot)"); } - void execute(final VectorSchemaRoot vectorSchemaRoot) { + void populateData(final VectorSchemaRoot vectorSchemaRoot) { final List fields = vectorSchemaRoot.getSchema().getFields(); final List columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(fields); signature.columns.clear(); @@ -102,7 +102,7 @@ void execute(final VectorSchemaRoot vectorSchemaRoot) { execute2(new ArrowFlightJdbcCursor(vectorSchemaRoot), this.signature.columns); } - void execute(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) { + void populateData(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) { final List columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(schema.getFields()); signature.columns.clear(); signature.columns.addAll(columns); diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index bb1d524aca008..bc19820e2aa8a 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -20,12 +20,13 @@ import java.io.IOException; import java.security.GeneralSecurityException; import java.sql.SQLException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; +import java.util.concurrent.Callable; import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils; import org.apache.arrow.flight.CallOption; @@ -35,7 +36,6 @@ import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStatusCode; -import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.auth2.BearerCredentialWriter; import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; @@ -58,13 +58,18 @@ */ public final class ArrowFlightSqlClientHandler implements AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class); + private final FlightSqlClient sqlClient; private final Set options = new HashSet<>(); + private final Builder builder; ArrowFlightSqlClientHandler(final FlightSqlClient sqlClient, - final Collection options) { - this.options.addAll(options); + final Builder builder, + final Collection credentialOptions) { + this.options.addAll(builder.options); + this.options.addAll(credentialOptions); this.sqlClient = Preconditions.checkNotNull(sqlClient); + this.builder = builder; } /** @@ -75,8 +80,9 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable { * @return a new {@link ArrowFlightSqlClientHandler}. */ public static ArrowFlightSqlClientHandler createNewHandler(final FlightClient client, + final Builder builder, final Collection options) { - return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), options); + return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options); } /** @@ -95,11 +101,34 @@ private CallOption[] getOptions() { * @param flightInfo The {@link FlightInfo} instance from which to fetch results. * @return a {@code FlightStream} of results. */ - public List getStreams(final FlightInfo flightInfo) { - return flightInfo.getEndpoints().stream() - .map(FlightEndpoint::getTicket) - .map(ticket -> sqlClient.getStream(ticket, getOptions())) - .collect(Collectors.toList()); + public List> getStreams(final FlightInfo flightInfo) { + final ArrayList> lazyStreams = + new ArrayList<>(flightInfo.getEndpoints().size()); + for (FlightEndpoint endpoint : flightInfo.getEndpoints()) { + lazyStreams.add(() -> { + final CloseableEndpointStreamPair resultPair; + if (endpoint.getLocations().isEmpty()) { + // Create a stream using the current client only and do not close the client at the end. + resultPair = new CloseableEndpointStreamPair( + sqlClient.getStream(endpoint.getTicket(), getOptions()), null); + } else { + final ArrowFlightSqlClientHandler handler = ArrowFlightSqlClientHandler.this.builder.build(); + try { + resultPair = new CloseableEndpointStreamPair( + handler.sqlClient.getStream(endpoint.getTicket(), handler.getOptions()), handler.sqlClient); + } catch (Exception ex) { + AutoCloseables.close(handler); + throw ex; + } + } + if (resultPair.getStream().next()) { + return resultPair; + } + resultPair.close(); + return null; + }); + } + return lazyStreams; } /** @@ -535,18 +564,21 @@ public Builder withCallOptions(final Collection options) { * @throws SQLException on error. */ public ArrowFlightSqlClientHandler build() throws SQLException { + // Copy middlewares so that the build method doesn't change the state of the builder fields itself. + Set buildTimeMiddlewareFactories = new HashSet<>(this.middlewareFactories); FlightClient client = null; + try { ClientIncomingAuthHeaderMiddleware.Factory authFactory = null; // Token should take priority since some apps pass in a username/password even when a token is provided if (username != null && token == null) { authFactory = new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); - withMiddlewareFactories(authFactory); + buildTimeMiddlewareFactories.add(authFactory); } final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator); - withMiddlewareFactories(new ClientCookieMiddleware.Factory()); - middlewareFactories.forEach(clientBuilder::intercept); + buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory()); + buildTimeMiddlewareFactories.forEach(clientBuilder::intercept); Location location; if (useEncryption) { location = Location.forGrpcTls(host, port); @@ -571,17 +603,18 @@ public ArrowFlightSqlClientHandler build() throws SQLException { } client = clientBuilder.build(); + final ArrayList credentialOptions = new ArrayList<>(); if (authFactory != null) { - options.add( + credentialOptions.add( ClientAuthenticationUtils.getAuthenticate( client, username, password, authFactory, options.toArray(new CallOption[0]))); } else if (token != null) { - options.add( + credentialOptions.add( ClientAuthenticationUtils.getAuthenticate( client, new CredentialCallOption(new BearerCredentialWriter(token)), options.toArray( new CallOption[0]))); } - return ArrowFlightSqlClientHandler.createNewHandler(client, options); + return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions); } catch (final IllegalArgumentException | GeneralSecurityException | IOException | FlightRuntimeException e) { final SQLException originalException = new SQLException(e); diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/CloseableEndpointStreamPair.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/CloseableEndpointStreamPair.java new file mode 100644 index 0000000000000..6c37a5b0c626c --- /dev/null +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/CloseableEndpointStreamPair.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.client; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; + +/** + * Represents a connection to a {@link org.apache.arrow.flight.FlightEndpoint}. + */ +public class CloseableEndpointStreamPair implements AutoCloseable { + + private final FlightStream stream; + private final FlightSqlClient client; + + public CloseableEndpointStreamPair(FlightStream stream, FlightSqlClient client) { + this.stream = Preconditions.checkNotNull(stream); + this.client = client; + } + + public FlightStream getStream() { + return stream; + } + + @Override + public void close() throws Exception { + AutoCloseables.close(stream, client); + } +} diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueue.java similarity index 68% rename from java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java rename to java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueue.java index e1d770800e40c..db62b6cad0301 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueue.java @@ -27,7 +27,7 @@ import java.util.Collection; import java.util.HashSet; import java.util.Set; -import java.util.concurrent.CancellationException; +import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorCompletionService; @@ -36,9 +36,11 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.arrow.driver.jdbc.client.CloseableEndpointStreamPair; import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.util.AutoCloseables; import org.apache.calcite.avatica.AvaticaConnection; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -55,28 +57,28 @@ *
  • Repeat from (3) until next() returns null.
  • * */ -public class FlightStreamQueue implements AutoCloseable { - private static final Logger LOGGER = LoggerFactory.getLogger(FlightStreamQueue.class); - private final CompletionService completionService; - private final Set> futures = synchronizedSet(new HashSet<>()); - private final Set allStreams = synchronizedSet(new HashSet<>()); +public class FlightEndpointDataQueue implements AutoCloseable { + private static final Logger LOGGER = LoggerFactory.getLogger(FlightEndpointDataQueue.class); + private final CompletionService completionService; + private final Set> futures = synchronizedSet(new HashSet<>()); + private final Set endpointsToClose = synchronizedSet(new HashSet<>()); private final AtomicBoolean closed = new AtomicBoolean(); /** * Instantiate a new FlightStreamQueue. */ - protected FlightStreamQueue(final CompletionService executorService) { + protected FlightEndpointDataQueue(final CompletionService executorService) { completionService = checkNotNull(executorService); } /** - * Creates a new {@link FlightStreamQueue} from the provided {@link ExecutorService}. + * Creates a new {@link FlightEndpointDataQueue} from the provided {@link ExecutorService}. * * @param service the service from which to create a new queue. * @return a new queue. */ - public static FlightStreamQueue createNewQueue(final ExecutorService service) { - return new FlightStreamQueue(new ExecutorCompletionService<>(service)); + public static FlightEndpointDataQueue createNewQueue(final ExecutorService service) { + return new FlightEndpointDataQueue(new ExecutorCompletionService<>(service)); } /** @@ -93,20 +95,38 @@ public boolean isClosed() { */ @FunctionalInterface interface FlightStreamSupplier { - Future get() throws SQLException; + Future get() throws SQLException; } - private FlightStream next(final FlightStreamSupplier flightStreamSupplier) throws SQLException { + private CloseableEndpointStreamPair next(final FlightStreamSupplier flightStreamSupplier) throws SQLException { checkOpen(); while (!futures.isEmpty()) { - final Future future = flightStreamSupplier.get(); + final Future future = flightStreamSupplier.get(); futures.remove(future); try { - final FlightStream stream = future.get(); - if (stream.getRoot().getRowCount() > 0) { - return stream; + CloseableEndpointStreamPair endpoint = future.get(); + // Get the next FlightStream with content. + while (true) { + // The stream is non-empty. + if (endpoint != null && endpoint.getStream().getRoot().getRowCount() > 0) { + // Caller is now responsible for cleaning up this endpoint. + endpointsToClose.remove(endpoint); + return endpoint; + } + + if (endpoint == null) { + // Endpoint was entirely empty. + break; + } + + if (!endpoint.getStream().next()) { + // The endpoint was valid but had an empty result. Clean up. + AutoCloseables.close(endpoint); + endpointsToClose.remove(endpoint); + break; + } } - } catch (final ExecutionException | InterruptedException | CancellationException e) { + } catch (final Exception e) { throw AvaticaConnection.HELPER.wrap(e.getMessage(), e); } } @@ -120,11 +140,11 @@ private FlightStream next(final FlightStreamSupplier flightStreamSupplier) throw * @param timeoutUnit the timeoutValue time unit * @return a FlightStream that is ready to consume or null if all FlightStreams are ended. */ - public FlightStream next(final long timeoutValue, final TimeUnit timeoutUnit) + public CloseableEndpointStreamPair next(final long timeoutValue, final TimeUnit timeoutUnit) throws SQLException { return next(() -> { try { - final Future future = completionService.poll(timeoutValue, timeoutUnit); + final Future future = completionService.poll(timeoutValue, timeoutUnit); if (future != null) { return future; } @@ -142,7 +162,7 @@ public FlightStream next(final long timeoutValue, final TimeUnit timeoutUnit) * * @return a FlightStream that is ready to consume or null if all FlightStreams are ended. */ - public FlightStream next() throws SQLException { + public CloseableEndpointStreamPair next() throws SQLException { return next(() -> { try { return completionService.take(); @@ -162,21 +182,20 @@ public synchronized void checkOpen() { /** * Readily adds given {@link FlightStream}s to the queue. */ - public void enqueue(final Collection flightStreams) { - flightStreams.forEach(this::enqueue); + public void enqueue(final Collection> endpointRequests) { + endpointRequests.forEach(this::enqueue); } /** * Adds given {@link FlightStream} to the queue. */ - public synchronized void enqueue(final FlightStream flightStream) { - checkNotNull(flightStream); + public synchronized void enqueue(final Callable endpointRequest) { + checkNotNull(endpointRequest); checkOpen(); - allStreams.add(flightStream); futures.add(completionService.submit(() -> { - // `FlightStream#next` will block until new data can be read or stream is over. - flightStream.next(); - return flightStream; + CloseableEndpointStreamPair result = endpointRequest.call(); + endpointsToClose.add(result); + return result; })); } @@ -192,9 +211,9 @@ public synchronized void close() throws SQLException { return; } try { - for (final FlightStream flightStream : allStreams) { + for (final CloseableEndpointStreamPair endpointToClose : endpointsToClose) { try { - flightStream.cancel("Cancelling this FlightStream.", null); + endpointToClose.getStream().cancel("Cancelling this FlightStream.", null); } catch (final Exception e) { final String errorMsg = "Failed to cancel a FlightStream."; LOGGER.error(errorMsg, e); @@ -214,9 +233,9 @@ public synchronized void close() throws SQLException { } } }); - for (final FlightStream flightStream : allStreams) { + for (final CloseableEndpointStreamPair endpointToClose : endpointsToClose) { try { - flightStream.close(); + endpointToClose.close(); } catch (final Exception e) { final String errorMsg = "Failed to close a FlightStream."; LOGGER.error(errorMsg, e); @@ -224,7 +243,7 @@ public synchronized void close() throws SQLException { } } } finally { - allStreams.clear(); + endpointsToClose.clear(); futures.clear(); closed.set(true); } diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueueTest.java similarity index 85% rename from java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java rename to java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueueTest.java index b474da55a7f1f..a04b3ecb3723b 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueueTest.java @@ -21,9 +21,10 @@ import static org.hamcrest.CoreMatchers.nullValue; import static org.mockito.Mockito.mock; +import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; -import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.driver.jdbc.client.CloseableEndpointStreamPair; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -33,20 +34,20 @@ import org.mockito.junit.MockitoJUnitRunner; /** - * Tests for {@link FlightStreamQueue}. + * Tests for {@link FlightEndpointDataQueue}. */ @RunWith(MockitoJUnitRunner.class) -public class FlightStreamQueueTest { +public class FlightEndpointDataQueueTest { @Rule public final ErrorCollector collector = new ErrorCollector(); @Mock - private CompletionService mockedService; - private FlightStreamQueue queue; + private CompletionService mockedService; + private FlightEndpointDataQueue queue; @Before public void setUp() { - queue = new FlightStreamQueue(mockedService); + queue = new FlightEndpointDataQueue(mockedService); } @Test @@ -64,7 +65,7 @@ public void testNextShouldThrowExceptionUponClose() throws Exception { public void testEnqueueShouldThrowExceptionUponClose() throws Exception { queue.close(); ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, - () -> queue.enqueue(mock(FlightStream.class))); + () -> queue.enqueue(mock(Callable.class))); } @Test