diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index a1c5250ba66fa..57f4f3e030420 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -353,6 +353,11 @@ TEST(FlightTypes, LocationUnknownScheme) { ASSERT_OK(Location::Parse("https://example.com/foo")); } +TEST(FlightTypes, LocationFallback) { + EXPECT_EQ("arrow-flight-reuse-connection://?", Location::ReuseConnection().ToString()); + EXPECT_EQ("arrow-flight-reuse-connection", Location::ReuseConnection().scheme()); +} + TEST(FlightTypes, RoundtripStatus) { // Make sure status codes round trip through our conversions diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc index 6f3115cc5ab8a..92c088b7fae08 100644 --- a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc +++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc @@ -71,6 +71,10 @@ TEST(FlightIntegration, ExpirationTimeRenewFlightEndpoint) { ASSERT_OK(RunScenario("expiration_time:renew_flight_endpoint")); } +TEST(FlightIntegration, LocationReuseConnection) { + ASSERT_OK(RunScenario("location:reuse_connection")); +} + TEST(FlightIntegration, SessionOptions) { ASSERT_OK(RunScenario("session_options")); } TEST(FlightIntegration, PollFlightInfo) { ASSERT_OK(RunScenario("poll_flight_info")); } diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index d4e0a2cda5bd8..6ba5d9c352da1 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -2079,6 +2079,50 @@ class FlightSqlExtensionScenario : public FlightSqlScenario { return Status::OK(); } }; + +/// \brief The server for testing arrow-flight-reuse-connection://. +class ReuseConnectionServer : public FlightServerBase { + public: + Status GetFlightInfo(const ServerCallContext& context, + const FlightDescriptor& descriptor, + std::unique_ptr* info) override { + auto location = Location::ReuseConnection(); + auto endpoint = FlightEndpoint{{"reuse"}, {location}}; + ARROW_ASSIGN_OR_RAISE(auto info_data, FlightInfo::Make(arrow::Schema({}), descriptor, + {endpoint}, -1, -1)); + *info = std::make_unique(std::move(info_data)); + return Status::OK(); + } +}; + +/// \brief A scenario for testing arrow-flight-reuse-connection://. +class ReuseConnectionScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + *server = std::make_unique(); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } + + Status RunClient(std::unique_ptr client) override { + auto descriptor = FlightDescriptor::Command("reuse"); + ARROW_ASSIGN_OR_RAISE(auto info, client->GetFlightInfo(descriptor)); + if (info->endpoints().size() != 1) { + return Status::Invalid("Expected 1 endpoint, got ", info->endpoints().size()); + } + const auto& endpoint = info->endpoints().front(); + if (endpoint.locations.size() != 1) { + return Status::Invalid("Expected 1 location, got ", + info->endpoints().front().locations.size()); + } else if (endpoint.locations.front().ToString() != + "arrow-flight-reuse-connection://?") { + return Status::Invalid("Expected arrow-flight-reuse-connection://?, got ", + endpoint.locations.front().ToString()); + } + return Status::OK(); + } +}; } // namespace Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { @@ -2103,6 +2147,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr* } else if (scenario_name == "expiration_time:renew_flight_endpoint") { *out = std::make_shared(); return Status::OK(); + } else if (scenario_name == "location:reuse_connection") { + *out = std::make_shared(); + return Status::OK(); } else if (scenario_name == "session_options") { *out = std::make_shared(); return Status::OK(); diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 11b2baafad220..a1b799a3a069e 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -829,6 +829,12 @@ arrow::Result Location::Parse(const std::string& uri_string) { return location; } +const Location& Location::ReuseConnection() { + static Location kFallback = + Location::Parse("arrow-flight-reuse-connection://?").ValueOrDie(); + return kFallback; +} + arrow::Result Location::ForGrpcTcp(const std::string& host, const int port) { std::stringstream uri_string; uri_string << "grpc+tcp://" << host << ':' << port; diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 4b17149aa2d46..c96aa428b054e 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -424,6 +424,14 @@ struct ARROW_FLIGHT_EXPORT Location { /// \brief Initialize a location by parsing a URI string static arrow::Result Parse(const std::string& uri_string); + /// \brief Get the fallback URI. + /// + /// arrow-flight-reuse-connection:// means that a client may attempt to + /// reuse an existing connection to a Flight service to fetch data instead + /// of creating a new connection to one of the other locations listed in a + /// FlightEndpoint response. + static const Location& ReuseConnection(); + /// \brief Initialize a location for a non-TLS, gRPC-based Flight /// service from a host and port /// \param[in] host The hostname to connect to diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index e984468bc5052..7cdd5071b96c6 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -608,6 +608,11 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, "RenewFlightEndpoint are working as expected."), skip_testers={"JS", "C#", "Rust"}, ), + Scenario( + "location:reuse_connection", + description="Ensure arrow-flight-reuse-connection is accepted.", + skip_testers={"JS", "C#", "Rust"}, + ), Scenario( "session_options", description="Ensure Flight SQL Sessions work as expected.", diff --git a/docs/source/format/Flight.rst b/docs/source/format/Flight.rst index 73ca848b5e996..7ee84952b4350 100644 --- a/docs/source/format/Flight.rst +++ b/docs/source/format/Flight.rst @@ -121,6 +121,13 @@ A client that wishes to download the data would: connection to the original server to fetch data. Otherwise, the client must connect to one of the indicated locations. + The server may list "itself" as a location alongside other server + locations. Normally this requires the server to know its public + address, but it may also use the special URI string + ``arrow-flight-reuse-connection://?`` to tell clients that they may + reuse an existing connection to the same server, without having to + be able to name itself. See `Connection Reuse`_ below. + In this way, the locations inside an endpoint can also be thought of as performing look-aside load balancing or service discovery functions. And the endpoints can represent data that is partitioned @@ -307,29 +314,58 @@ well, in which case any `authentication method supported by gRPC .. _Mutual TLS (mTLS): https://grpc.io/docs/guides/auth/#supported-auth-mechanisms -Transport Implementations -========================= +Location URIs +============= Flight is primarily defined in terms of its Protobuf and gRPC specification below, but Arrow implementations may also support -alternative transports (see :ref:`status-flight-rpc`). In that case, -implementations should use the following URI schemes for the given -transport implementations: - -+----------------------------+----------------------------+ -| Transport | URI Scheme | -+============================+============================+ -| gRPC (plaintext) | grpc: or grpc+tcp: | -+----------------------------+----------------------------+ -| gRPC (TLS) | grpc+tls: | -+----------------------------+----------------------------+ -| gRPC (Unix domain socket) | grpc+unix: | -+----------------------------+----------------------------+ -| UCX_ (plaintext) | ucx: | -+----------------------------+----------------------------+ +alternative transports (see :ref:`status-flight-rpc`). Clients and +servers need to know which transport to use for a given URI in a +Location, so Flight implementations should use the following URI +schemes for the given transports: + ++----------------------------+--------------------------------+ +| Transport | URI Scheme | ++============================+================================+ +| gRPC (plaintext) | grpc: or grpc+tcp: | ++----------------------------+--------------------------------+ +| gRPC (TLS) | grpc+tls: | ++----------------------------+--------------------------------+ +| gRPC (Unix domain socket) | grpc+unix: | ++----------------------------+--------------------------------+ +| (reuse connection) | arrow-flight-reuse-connection: | ++----------------------------+--------------------------------+ +| UCX_ (plaintext) | ucx: | ++----------------------------+--------------------------------+ .. _UCX: https://openucx.org/ +Connection Reuse +---------------- + +"Reuse connection" above is not a particular transport. Instead, it +means that the client may try to execute DoGet against the same server +(and through the same connection) that it originally obtained the +FlightInfo from (i.e., that it called GetFlightInfo against). This is +interpreted the same way as when no specific ``Location`` are +returned. + +This allows the server to return "itself" as one possible location to +fetch data without having to know its own public address, which can be +useful in deployments where knowing this would be difficult or +impossible. For example, a developer may forward a remote service in +a cloud environment to their local machine; in this case, the remote +service would have no way to know the local hostname and port that it +is being accessed over. + +For compatibility reasons, the URI should always be +``arrow-flight-reuse-connection://?``, with the trailing empty query +string. Java's URI implementation does not accept ``scheme:`` or +``scheme://``, and C++'s implementation does not accept an empty +string, so the obvious candidates are not compatible. The chosen +representation can be parsed by both implementations, as well as Go's +``net/url`` and Python's ``urllib.parse``. + Error Handling ============== diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java index fe192aa0c3f9d..2eb3139c9dcdd 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java @@ -93,6 +93,19 @@ Flight.Location toProtocol() { return Flight.Location.newBuilder().setUri(uri.toString()).build(); } + /** + * Construct a special URI to indicate to clients that they may fetch data by reusing + * an existing connection to a Flight RPC server. + */ + public static Location reuseConnection() { + try { + return new Location(new URI(LocationSchemes.REUSE_CONNECTION, "", "", "", null)); + } catch (URISyntaxException e) { + // This should never happen. + throw new IllegalArgumentException(e); + } + } + /** * Construct a URI for a Flight+gRPC server without transport security. * diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java index 872e5b1c22deb..f1dbfb95f237e 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java @@ -25,6 +25,7 @@ public final class LocationSchemes { public static final String GRPC_INSECURE = "grpc+tcp"; public static final String GRPC_DOMAIN_SOCKET = "grpc+unix"; public static final String GRPC_TLS = "grpc+tls"; + public static final String REUSE_CONNECTION = "arrow-flight-reuse-connection"; private LocationSchemes() { throw new AssertionError("Do not instantiate this class."); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index ae520ee9b991b..bc34b5e6d6074 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -79,6 +79,12 @@ public void fastPathDefaults() { Assertions.assertFalse(ArrowMessage.ENABLE_ZERO_COPY_WRITE); } + @Test + public void fallbackLocation() { + Assertions.assertEquals("arrow-flight-reuse-connection://?", + Location.reuseConnection().getUri().toString()); + } + /** * ARROW-6017: we should be able to construct locations for unknown schemes. */ diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index 234820bd41823..a47cffc2fcf09 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -124,6 +124,13 @@ public List getStreams(final FlightInfo flightInfo) // It would also be good to identify when the reported location is the same as the original connection's // Location and skip creating a FlightClient in that scenario. final URI endpointUri = endpoint.getLocations().get(0).getUri(); + + if (endpointUri.getScheme().equals(LocationSchemes.REUSE_CONNECTION)) { + endpoints.add(new CloseableEndpointStreamPair( + sqlClient.getStream(endpoint.getTicket(), getOptions()), null)); + continue; + } + final Builder builderForEndpoint = new Builder(ArrowFlightSqlClientHandler.this.builder) .withHost(endpointUri.getHost()) .withPort(endpointUri.getPort()) diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java index 0e3e015a04636..ad01e8767b793 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java @@ -39,6 +39,7 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Random; @@ -46,6 +47,7 @@ import java.util.concurrent.CountDownLatch; import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.apache.arrow.driver.jdbc.utils.FallbackFlightSqlProducer; import org.apache.arrow.driver.jdbc.utils.PartitionedFlightSqlProducer; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightProducer; @@ -55,6 +57,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -63,6 +66,7 @@ import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; +import org.junit.jupiter.api.Assertions; import org.junit.rules.ErrorCollector; import com.google.common.collect.ImmutableSet; @@ -455,6 +459,69 @@ allocator, forGrpcInsecure("localhost", 0), rootProducer) } } + @Test + public void testFallbackFlightServer() throws Exception { + final Schema schema = new Schema( + Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType()))); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) { + resultData.setRowCount(1); + ((IntVector) resultData.getVector(0)).set(0, 1); + + try (final FallbackFlightSqlProducer rootProducer = new FallbackFlightSqlProducer(resultData); + FlightServer rootServer = FlightServer.builder( + allocator, forGrpcInsecure("localhost", 0), rootProducer) + .build() + .start(); + Connection newConnection = DriverManager.getConnection(String.format( + "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false", + rootServer.getLocation().getUri().getHost(), rootServer.getPort())); + Statement newStatement = newConnection.createStatement(); + ResultSet result = newStatement.executeQuery("fallback")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + } + + @Test + public void testFallbackSecondFlightServer() throws Exception { + final Schema schema = new Schema( + Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType()))); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) { + resultData.setRowCount(1); + ((IntVector) resultData.getVector(0)).set(0, 1); + + try (final FallbackFlightSqlProducer rootProducer = new FallbackFlightSqlProducer(resultData); + FlightServer rootServer = FlightServer.builder( + allocator, forGrpcInsecure("localhost", 0), rootProducer) + .build() + .start(); + Connection newConnection = DriverManager.getConnection(String.format( + "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false", + rootServer.getLocation().getUri().getHost(), rootServer.getPort())); + Statement newStatement = newConnection.createStatement()) { + + // TODO(https://github.com/apache/arrow/issues/38573) + // XXX: we could try to assert more structure but then we'd have to hardcode + // a particular exception chain which may be fragile + Assertions.assertThrows(SQLException.class, () -> { + try (ResultSet result = newStatement.executeQuery("fallback with error")) { + // Empty body + } + }); + + } + } + } + @Test public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception { try (Statement statement = connection.createStatement(); diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java new file mode 100644 index 0000000000000..2257220a4c845 --- /dev/null +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.BasicFlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; + +public class FallbackFlightSqlProducer extends BasicFlightSqlProducer { + private final VectorSchemaRoot data; + + public FallbackFlightSqlProducer(VectorSchemaRoot resultData) { + this.data = resultData; + } + + @Override + protected List determineEndpoints( + T request, FlightDescriptor flightDescriptor, Schema schema) { + return Collections.emptyList(); + } + + @Override + public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request, + CallContext context, StreamListener listener) { + final FlightSql.ActionCreatePreparedStatementResult.Builder resultBuilder = + FlightSql.ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(request.getQueryBytes()); + + final ByteString datasetSchemaBytes = ByteString.copyFrom(data.getSchema().serializeAsMessage()); + + resultBuilder.setDatasetSchema(datasetSchemaBytes); + listener.onNext(new Result(Any.pack(resultBuilder.build()).toByteArray())); + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoStatement( + FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) { + return getFlightInfo(descriptor, command.getQuery()); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + return getFlightInfo(descriptor, command.getPreparedStatementHandle().toStringUtf8()); + } + + @Override + public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context, + ServerStreamListener listener) { + listener.start(data); + listener.putNext(); + listener.completed(); + } + + @Override + public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request, + CallContext context, StreamListener listener) { + listener.onCompleted(); + } + + private FlightInfo getFlightInfo(FlightDescriptor descriptor, String query) { + final List endpoints; + final Ticket ticket = new Ticket( + Any.pack(FlightSql.TicketStatementQuery.getDefaultInstance()).toByteArray()); + if (query.equals("fallback")) { + endpoints = Collections.singletonList(FlightEndpoint.builder(ticket, Location.reuseConnection()).build()); + } else if (query.equals("fallback with error")) { + endpoints = Collections.singletonList( + FlightEndpoint.builder(ticket, + Location.forGrpcInsecure("localhost", 9999), + Location.reuseConnection()).build()); + } else { + throw CallStatus.UNIMPLEMENTED.withDescription(query).toRuntimeException(); + } + return FlightInfo.builder(data.getSchema(), descriptor, endpoints).build(); + } +}