From 2811a54a77d42822723df1ff5155378969e43cc9 Mon Sep 17 00:00:00 2001
From: David Li
Date: Sun, 13 Oct 2019 15:25:06 -0400
Subject: [PATCH] ARROW-5875: [FlightRPC] integration tests for Flight features
---
cpp/src/arrow/flight/CMakeLists.txt | 1 +
cpp/src/arrow/flight/test_integration.cc | 116 ++++++++++++++++
cpp/src/arrow/flight/test_integration.h | 49 +++++++
.../arrow/flight/test_integration_client.cc | 126 +++++++++++-------
.../arrow/flight/test_integration_server.cc | 43 +++++-
cpp/src/arrow/flight/test_util.h | 2 +
dev/archery/archery/integration/runner.py | 34 +++--
dev/archery/archery/integration/scenario.py | 29 ++++
dev/archery/archery/integration/tester.py | 4 +-
dev/archery/archery/integration/tester_cpp.py | 14 +-
.../archery/integration/tester_java.py | 16 ++-
.../flight/auth/ServerAuthInterceptor.java | 28 +++-
.../integration/AuthBasicProtoScenario.java | 96 +++++++++++++
.../integration/IntegrationAssertions.java | 74 ++++++++++
.../integration/IntegrationTestClient.java | 20 ++-
.../integration/IntegrationTestServer.java | 23 +++-
.../flight/example/integration/Scenario.java | 45 +++++++
.../flight/example/integration/Scenarios.java | 89 +++++++++++++
18 files changed, 727 insertions(+), 82 deletions(-)
create mode 100644 cpp/src/arrow/flight/test_integration.cc
create mode 100644 cpp/src/arrow/flight/test_integration.h
create mode 100644 dev/archery/archery/integration/scenario.py
create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java
create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java
create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java
create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index d65bbef7179f7..8642bb4253823 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -119,6 +119,7 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS OR ARROW_BUILD_INTEGRATION)
OUTPUTS
ARROW_FLIGHT_TESTING_LIBRARIES
SOURCES
+ test_integration.cc
test_util.cc
DEPENDENCIES
GTest::gtest
diff --git a/cpp/src/arrow/flight/test_integration.cc b/cpp/src/arrow/flight/test_integration.cc
new file mode 100644
index 0000000000000..34eb4d6bcc931
--- /dev/null
+++ b/cpp/src/arrow/flight/test_integration.cc
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/test_integration.h"
+#include "arrow/flight/client_middleware.h"
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/flight/types.h"
+#include "arrow/ipc/dictionary.h"
+
+#include
+#include
+#include
+#include
+#include
+
+namespace arrow {
+namespace flight {
+
+/// \brief The server for the basic auth integration test.
+class AuthBasicProtoServer : public FlightServerBase {
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr* result) override {
+ // Respond with the authenticated username.
+ std::shared_ptr buf;
+ RETURN_NOT_OK(Buffer::FromString(context.peer_identity(), &buf));
+ *result = std::unique_ptr(new SimpleResultStream({Result{buf}}));
+ return Status::OK();
+ }
+};
+
+/// Validate the result of a DoAction.
+Status CheckActionResults(FlightClient* client, const Action& action,
+ std::vector results) {
+ std::unique_ptr stream;
+ RETURN_NOT_OK(client->DoAction(action, &stream));
+ std::unique_ptr result;
+ for (const std::string& expected : results) {
+ RETURN_NOT_OK(stream->Next(&result));
+ if (!result) {
+ return Status::Invalid("Action result stream ended early");
+ }
+ const auto actual = result->body->ToString();
+ if (expected != actual) {
+ return Status::Invalid("Got wrong result; expected", expected, "but got", actual);
+ }
+ }
+ RETURN_NOT_OK(stream->Next(&result));
+ if (result) {
+ return Status::Invalid("Action result stream had too many entries");
+ }
+ return Status::OK();
+}
+
+// The expected username for the basic auth integration test.
+constexpr auto kAuthUsername = "arrow";
+// The expected password for the basic auth integration test.
+constexpr auto kAuthPassword = "flight";
+
+/// \brief A scenario testing the basic auth protobuf.
+class AuthBasicProtoScenario : public Scenario {
+ Status MakeServer(std::unique_ptr* server,
+ FlightServerOptions* options) override {
+ server->reset(new AuthBasicProtoServer());
+ options->auth_handler =
+ std::make_shared(kAuthUsername, kAuthPassword);
+ return Status::OK();
+ }
+
+ Status MakeClient(FlightClientOptions* options) override { return Status::OK(); }
+
+ Status RunClient(std::unique_ptr client) override {
+ Action action;
+ std::unique_ptr stream;
+ std::shared_ptr detail;
+ const auto& status = client->DoAction(action, &stream);
+ detail = FlightStatusDetail::UnwrapStatus(status);
+ // This client is unauthenticated and should fail.
+ if (detail == nullptr) {
+ return Status::Invalid("Expected UNAUTHENTICATED but got ", status.ToString());
+ }
+ if (detail->code() != FlightStatusCode::Unauthenticated) {
+ return Status::Invalid("Expected UNAUTHENTICATED but got ", detail->ToString());
+ }
+
+ auto client_handler = std::unique_ptr(
+ new TestClientBasicAuthHandler(kAuthUsername, kAuthPassword));
+ RETURN_NOT_OK(client->Authenticate({}, std::move(client_handler)));
+ return CheckActionResults(client.get(), action, {kAuthUsername});
+ }
+};
+
+Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) {
+ if (scenario_name == "auth:basic_proto") {
+ *out = std::make_shared();
+ return Status::OK();
+ }
+ return Status::KeyError("Scenario not found: ", scenario_name);
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/test_integration.h b/cpp/src/arrow/flight/test_integration.h
new file mode 100644
index 0000000000000..5d9bd7fd7bd74
--- /dev/null
+++ b/cpp/src/arrow/flight/test_integration.h
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Integration test scenarios for Arrow Flight.
+
+#include "arrow/flight/visibility.h"
+
+#include
+#include
+
+#include "arrow/flight/client.h"
+#include "arrow/flight/server.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace flight {
+
+/// \brief An integration test for Arrow Flight.
+class ARROW_FLIGHT_EXPORT Scenario {
+ public:
+ virtual ~Scenario() = default;
+ /// \brief Set up the server.
+ virtual Status MakeServer(std::unique_ptr* server,
+ FlightServerOptions* options) = 0;
+ /// \brief Set up the client.
+ virtual Status MakeClient(FlightClientOptions* options) = 0;
+ /// \brief Run the scenario as the client.
+ virtual Status RunClient(std::unique_ptr client) = 0;
+};
+
+/// \brief Get the implementation of an integration test scenario by name.
+Status GetScenario(const std::string& scenario_name, std::shared_ptr* out);
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/test_integration_client.cc b/cpp/src/arrow/flight/test_integration_client.cc
index a8847303abb5f..bd9022bf761cb 100644
--- a/cpp/src/arrow/flight/test_integration_client.cc
+++ b/cpp/src/arrow/flight/test_integration_client.cc
@@ -38,11 +38,13 @@
#include "arrow/util/logging.h"
#include "arrow/flight/api.h"
+#include "arrow/flight/test_integration.h"
#include "arrow/flight/test_util.h"
DEFINE_string(host, "localhost", "Server port to connect to");
DEFINE_int32(port, 31337, "Server port to connect to");
DEFINE_string(path, "", "Resource path to request");
+DEFINE_string(scenario, "", "Integration test scenario to run");
/// \brief Helper to read all batches from a JsonReader
arrow::Status ReadBatches(std::unique_ptr& reader,
@@ -125,60 +127,92 @@ arrow::Status ConsumeFlightLocation(
return arrow::Status::OK();
}
+class IntegrationTestScenario : public arrow::flight::Scenario {
+ public:
+ arrow::Status MakeServer(std::unique_ptr* server,
+ arrow::flight::FlightServerOptions* options) override {
+ ARROW_UNUSED(server);
+ ARROW_UNUSED(options);
+ return arrow::Status::NotImplemented(
+ "Not implemented, see test_integration_server.cc");
+ }
+
+ arrow::Status MakeClient(arrow::flight::FlightClientOptions* options) override {
+ ARROW_UNUSED(options);
+ return arrow::Status::OK();
+ }
+
+ arrow::Status RunClient(std::unique_ptr client) override {
+ arrow::flight::FlightDescriptor descr{
+ arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}};
+
+ // 1. Put the data to the server.
+ std::unique_ptr reader;
+ std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl;
+ auto in_file = *arrow::io::ReadableFile::Open(FLAGS_path);
+ ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(
+ arrow::default_memory_pool(), in_file, &reader));
+
+ std::shared_ptr original_schema = reader->schema();
+ std::vector> original_data;
+ ABORT_NOT_OK(ReadBatches(reader, &original_data));
+
+ std::unique_ptr write_stream;
+ std::unique_ptr metadata_reader;
+ ABORT_NOT_OK(client->DoPut(descr, original_schema, &write_stream, &metadata_reader));
+ ABORT_NOT_OK(UploadBatchesToFlight(original_data, *write_stream, *metadata_reader));
+
+ // 2. Get the ticket for the data.
+ std::unique_ptr info;
+ ABORT_NOT_OK(client->GetFlightInfo(descr, &info));
+
+ std::shared_ptr schema;
+ arrow::ipc::DictionaryMemo dict_memo;
+ ABORT_NOT_OK(info->GetSchema(&dict_memo, &schema));
+
+ if (info->endpoints().size() == 0) {
+ std::cerr << "No endpoints returned from Flight server." << std::endl;
+ return arrow::Status::IOError("No endpoints returned from Flight server.");
+ }
+
+ for (const arrow::flight::FlightEndpoint& endpoint : info->endpoints()) {
+ const auto& ticket = endpoint.ticket;
+
+ auto locations = endpoint.locations;
+ if (locations.size() == 0) {
+ return arrow::Status::IOError("No locations returned from Flight server.");
+ }
+
+ for (const auto location : locations) {
+ std::cout << "Verifying location " << location.ToString() << std::endl;
+ ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, original_data));
+ }
+ }
+ return arrow::Status::OK();
+ }
+};
+
int main(int argc, char** argv) {
gflags::SetUsageMessage("Integration testing client for Flight.");
gflags::ParseCommandLineFlags(&argc, &argv, true);
+ std::shared_ptr scenario;
+ if (!FLAGS_scenario.empty()) {
+ ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario));
+ } else {
+ scenario = std::make_shared();
+ }
+
+ arrow::flight::FlightClientOptions options;
std::unique_ptr client;
+
+ ABORT_NOT_OK(scenario->MakeClient(&options));
+
arrow::flight::Location location;
ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location));
- ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, &client));
-
- arrow::flight::FlightDescriptor descr{
- arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}};
-
- // 1. Put the data to the server.
- std::unique_ptr reader;
- std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl;
- auto in_file = *arrow::io::ReadableFile::Open(FLAGS_path);
- ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(arrow::default_memory_pool(),
- in_file, &reader));
-
- std::shared_ptr original_schema = reader->schema();
- std::vector> original_data;
- ABORT_NOT_OK(ReadBatches(reader, &original_data));
-
- std::unique_ptr write_stream;
- std::unique_ptr metadata_reader;
- ABORT_NOT_OK(client->DoPut(descr, original_schema, &write_stream, &metadata_reader));
- ABORT_NOT_OK(UploadBatchesToFlight(original_data, *write_stream, *metadata_reader));
-
- // 2. Get the ticket for the data.
- std::unique_ptr info;
- ABORT_NOT_OK(client->GetFlightInfo(descr, &info));
-
- std::shared_ptr schema;
- arrow::ipc::DictionaryMemo dict_memo;
- ABORT_NOT_OK(info->GetSchema(&dict_memo, &schema));
-
- if (info->endpoints().size() == 0) {
- std::cerr << "No endpoints returned from Flight server." << std::endl;
- return -1;
- }
-
- for (const arrow::flight::FlightEndpoint& endpoint : info->endpoints()) {
- const auto& ticket = endpoint.ticket;
+ ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client));
- auto locations = endpoint.locations;
- if (locations.size() == 0) {
- locations = {location};
- }
+ ABORT_NOT_OK(scenario->RunClient(std::move(client)));
- for (const auto location : locations) {
- std::cout << "Verifying location " << location.ToString() << std::endl;
- // 3. Stream data from the server, comparing individual batches.
- ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, original_data));
- }
- }
return 0;
}
diff --git a/cpp/src/arrow/flight/test_integration_server.cc b/cpp/src/arrow/flight/test_integration_server.cc
index 108c38473ce45..4dced8a691c59 100644
--- a/cpp/src/arrow/flight/test_integration_server.cc
+++ b/cpp/src/arrow/flight/test_integration_server.cc
@@ -15,7 +15,11 @@
// specific language governing permissions and limitations
// under the License.
-// Example server implementation for integration testing purposes
+// Server for integration testing.
+
+// Integration testing covers files and scenarios. The former
+// validates that Arrow data survives a round-trip through a Flight
+// service. The latter tests specific features of Arrow Flight.
#include
#include
@@ -33,9 +37,11 @@
#include "arrow/flight/internal.h"
#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
+#include "arrow/flight/test_integration.h"
#include "arrow/flight/test_util.h"
DEFINE_int32(port, 31337, "Server port to listen on");
+DEFINE_string(scenario, "", "Integration test senario to run");
namespace arrow {
namespace flight {
@@ -81,7 +87,9 @@ class FlightIntegrationTestServer : public FlightServerBase {
}
auto flight = data->second;
- FlightEndpoint endpoint1({{request.path[0]}, {}});
+ Location server_location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("127.0.0.1", port(), &server_location));
+ FlightEndpoint endpoint1({{request.path[0]}, {server_location}});
FlightInfo::Data flight_data;
RETURN_NOT_OK(internal::SchemaToString(*flight.schema, &flight_data.schema));
@@ -148,20 +156,47 @@ class FlightIntegrationTestServer : public FlightServerBase {
std::unordered_map uploaded_chunks;
};
+class IntegrationTestScenario : public Scenario {
+ public:
+ Status MakeServer(std::unique_ptr* server,
+ FlightServerOptions* options) override {
+ server->reset(new FlightIntegrationTestServer());
+ return Status::OK();
+ }
+
+ Status MakeClient(FlightClientOptions* options) override {
+ ARROW_UNUSED(options);
+ return Status::NotImplemented("Not implemented, see test_integration_client.cc");
+ }
+
+ Status RunClient(std::unique_ptr client) override {
+ ARROW_UNUSED(client);
+ return Status::NotImplemented("Not implemented, see test_integration_client.cc");
+ }
+};
+
} // namespace flight
} // namespace arrow
-std::unique_ptr g_server;
+std::unique_ptr g_server;
int main(int argc, char** argv) {
gflags::SetUsageMessage("Integration testing server for Flight.");
gflags::ParseCommandLineFlags(&argc, &argv, true);
- g_server.reset(new arrow::flight::FlightIntegrationTestServer);
+ std::shared_ptr scenario;
+
+ if (!FLAGS_scenario.empty()) {
+ ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario));
+ } else {
+ scenario = std::make_shared();
+ }
arrow::flight::Location location;
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
arrow::flight::FlightServerOptions options(location);
+ ARROW_CHECK_OK(scenario->MakeServer(&g_server, &options));
+
ARROW_CHECK_OK(g_server->Init(options));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h
index b78573139763e..e012c23c4e71d 100644
--- a/cpp/src/arrow/flight/test_util.h
+++ b/cpp/src/arrow/flight/test_util.h
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+#pragma once
+
#include
#include
#include
diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py
index e6d3c42f6acf6..9b506a2416eb1 100644
--- a/dev/archery/archery/integration/runner.py
+++ b/dev/archery/archery/integration/runner.py
@@ -26,6 +26,7 @@
import tempfile
import traceback
+from .scenario import Scenario
from .tester_cpp import CPPTester
from .tester_go import GoTester
from .tester_java import JavaTester
@@ -49,10 +50,11 @@ def __init__(self):
class IntegrationRunner(object):
- def __init__(self, json_files, testers, tempdir=None, debug=False,
- stop_on_error=True, gold_dirs=None, serial=False,
- match=None, **unused_kwargs):
+ def __init__(self, json_files, flight_scenarios, testers, tempdir=None,
+ debug=False, stop_on_error=True, gold_dirs=None,
+ serial=False, match=None, **unused_kwargs):
self.json_files = json_files
+ self.flight_scenarios = flight_scenarios
self.testers = testers
self.temp_dir = tempdir or tempfile.mkdtemp()
self.debug = debug
@@ -251,7 +253,8 @@ def _compare_flight_implementations(self, producer, consumer):
log('##########################################################')
case_runner = partial(self._run_flight_test_case, producer, consumer)
- self._run_test_cases(producer, consumer, case_runner, self.json_files)
+ self._run_test_cases(producer, consumer, case_runner,
+ self.json_files + self.flight_scenarios)
def _run_flight_test_case(self, producer, consumer, test_case):
"""
@@ -259,9 +262,8 @@ def _run_flight_test_case(self, producer, consumer, test_case):
"""
outcome = Outcome()
- json_path = test_case.path
log('=' * 58)
- log('Testing file {0}'.format(json_path))
+ log('Testing file {0}'.format(test_case.name))
log('=' * 58)
if producer.name in test_case.skip:
@@ -281,10 +283,17 @@ def _run_flight_test_case(self, producer, consumer, test_case):
else:
try:
port = find_unused_port()
- with producer.flight_server(port):
+ if isinstance(test_case, Scenario):
+ server = producer.flight_server(port, test_case.name)
+ client_args = {'scenario_name': test_case.name}
+ else:
+ server = producer.flight_server(port)
+ client_args = {'json_path': test_case.path}
+
+ with server:
# Have the client upload the file, then download and
# compare
- consumer.flight_request(port, json_path)
+ consumer.flight_request(port, **client_args)
except Exception:
traceback.print_exc(file=printer.stdout)
outcome.failure = Failure(test_case, producer, consumer,
@@ -329,7 +338,14 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True,
)
json_files = static_json_files + generated_json_files
- runner = IntegrationRunner(json_files, testers, **kwargs)
+ # Additional integration test cases for Arrow Flight.
+ flight_scenarios = [
+ Scenario(
+ "auth:basic_proto",
+ description="Authenticate using the BasicAuth protobuf."),
+ ]
+
+ runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs)
runner.run()
if run_flight:
runner.run_flight()
diff --git a/dev/archery/archery/integration/scenario.py b/dev/archery/archery/integration/scenario.py
new file mode 100644
index 0000000000000..1fcbca64e6a1f
--- /dev/null
+++ b/dev/archery/archery/integration/scenario.py
@@ -0,0 +1,29 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+class Scenario:
+ """
+ An integration test scenario for Arrow Flight.
+
+ Does not correspond to a particular IPC JSON file.
+ """
+
+ def __init__(self, name, description, skip=None):
+ self.name = name
+ self.description = description
+ self.skip = skip or set()
diff --git a/dev/archery/archery/integration/tester.py b/dev/archery/archery/integration/tester.py
index 85ed2214d4bb4..97cfe09206e3b 100644
--- a/dev/archery/archery/integration/tester.py
+++ b/dev/archery/archery/integration/tester.py
@@ -50,8 +50,8 @@ def file_to_stream(self, file_path, stream_path):
def validate(self, json_path, arrow_path):
raise NotImplementedError
- def flight_server(self, port):
+ def flight_server(self, port, scenario_name=None):
raise NotImplementedError
- def flight_request(self, port, json_path):
+ def flight_request(self, port, json_path=None, scenario_name=None):
raise NotImplementedError
diff --git a/dev/archery/archery/integration/tester_cpp.py b/dev/archery/archery/integration/tester_cpp.py
index f0fd3fc286b77..ad836e574f404 100644
--- a/dev/archery/archery/integration/tester_cpp.py
+++ b/dev/archery/archery/integration/tester_cpp.py
@@ -76,8 +76,10 @@ def file_to_stream(self, file_path, stream_path):
self.run_shell_command(cmd)
@contextlib.contextmanager
- def flight_server(self, port):
+ def flight_server(self, port, scenario_name=None):
cmd = self.FLIGHT_SERVER_CMD + ['-port=' + str(port)]
+ if scenario_name:
+ cmd = cmd + ["-scenario", scenario_name]
if self.debug:
log(' '.join(cmd))
server = subprocess.Popen(cmd,
@@ -97,11 +99,17 @@ def flight_server(self, port):
server.kill()
server.wait(5)
- def flight_request(self, port, json_path):
+ def flight_request(self, port, json_path=None, scenario_name=None):
cmd = self.FLIGHT_CLIENT_CMD + [
'-port=' + str(port),
- '-path=' + json_path,
]
+ if json_path:
+ cmd.extend(('-path', json_path))
+ elif scenario_name:
+ cmd.extend(('-scenario', scenario_name))
+ else:
+ raise TypeError("Must provide one of json_path or scenario_name")
+
if self.debug:
log(' '.join(cmd))
run_cmd(cmd)
diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py
index bdfe3429ce2b2..ddc110c4f7d30 100644
--- a/dev/archery/archery/integration/tester_java.py
+++ b/dev/archery/archery/integration/tester_java.py
@@ -96,19 +96,29 @@ def file_to_stream(self, file_path, stream_path):
log(' '.join(cmd))
run_cmd(cmd)
- def flight_request(self, port, json_path):
+ def flight_request(self, port, json_path=None, scenario_name=None):
cmd = ['java'] + self.JAVA_OPTS + \
['-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_CLIENT,
- '-port', str(port), '-j', json_path]
+ '-port', str(port)]
+
+ if json_path:
+ cmd.extend(('-j', json_path))
+ elif scenario_name:
+ cmd.extend(('-scenario', scenario_name))
+ else:
+ raise TypeError("Must provide one of json_path or scenario_name")
+
if self.debug:
log(' '.join(cmd))
run_cmd(cmd)
@contextlib.contextmanager
- def flight_server(self, port):
+ def flight_server(self, port, scenario_name=None):
cmd = ['java'] + self.JAVA_OPTS + \
['-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_SERVER,
'-port', str(port)]
+ if scenario_name:
+ cmd.extend(('-scenario', scenario_name))
if self.debug:
log(' '.join(cmd))
server = subprocess.Popen(cmd, stdout=subprocess.PIPE,
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
index 4ebd7424cb888..32ca5c5e2033b 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
@@ -19,6 +19,9 @@
import java.util.Optional;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.grpc.StatusUtils;
+
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
@@ -27,6 +30,7 @@
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
/**
* GRPC Interceptor for performing authentication.
@@ -43,10 +47,28 @@ public ServerAuthInterceptor(ServerAuthHandler authHandler) {
public Listener interceptCall(ServerCall call, Metadata headers,
ServerCallHandler next) {
if (!call.getMethodDescriptor().getFullMethodName().equals(AuthConstants.HANDSHAKE_DESCRIPTOR_NAME)) {
- final Optional peerIdentity = isValid(headers);
+ final Optional peerIdentity;
+
+ // Allow customizing the response code by throwing FlightRuntimeException
+ try {
+ peerIdentity = isValid(headers);
+ } catch (FlightRuntimeException e) {
+ final Status grpcStatus = StatusUtils.toGrpcStatus(e.status());
+ call.close(grpcStatus, new Metadata());
+ return new NoopServerCallListener<>();
+ } catch (StatusRuntimeException e) {
+ Metadata trailers = e.getTrailers();
+ call.close(e.getStatus(), trailers == null ? new Metadata() : trailers);
+ return new NoopServerCallListener<>();
+ } catch (RuntimeException e) {
+ call.close(Status.UNAUTHENTICATED.withDescription("Unauthenticated: " + e), new Metadata());
+ return new NoopServerCallListener<>();
+ }
+
if (!peerIdentity.isPresent()) {
- call.close(Status.UNAUTHENTICATED, new Metadata());
- // TODO: we should actually terminate here instead of causing an exception below.
+ // Send back a description along with the status code
+ call.close(Status.UNAUTHENTICATED
+ .withDescription("Unauthenticated (invalid or missing auth token)"), new Metadata());
return new NoopServerCallListener<>();
}
return Contexts.interceptCall(Context.current().withValue(AuthConstants.PEER_IDENTITY_KEY, peerIdentity.get()),
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java
new file mode 100644
index 0000000000000..6a46206796841
--- /dev/null
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Optional;
+
+import org.apache.arrow.flight.Action;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightStatusCode;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.Result;
+import org.apache.arrow.flight.auth.BasicClientAuthHandler;
+import org.apache.arrow.flight.auth.BasicServerAuthHandler;
+import org.apache.arrow.memory.BufferAllocator;
+
+/**
+ * A scenario testing the built-in basic authentication Protobuf.
+ */
+final class AuthBasicProtoScenario implements Scenario {
+
+ static final String USERNAME = "arrow";
+ static final String PASSWORD = "flight";
+
+ @Override
+ public FlightProducer producer(BufferAllocator allocator, Location location) {
+ return new NoOpFlightProducer() {
+ @Override
+ public void doAction(CallContext context, Action action, StreamListener listener) {
+ listener.onNext(new Result(context.peerIdentity().getBytes(StandardCharsets.UTF_8)));
+ listener.onCompleted();
+ }
+ };
+ }
+
+ @Override
+ public void buildServer(FlightServer.Builder builder) {
+ builder.authHandler(new BasicServerAuthHandler(new BasicServerAuthHandler.BasicAuthValidator() {
+ @Override
+ public byte[] getToken(String username, String password) throws Exception {
+ if (!USERNAME.equals(username) || !PASSWORD.equals(password)) {
+ throw new IllegalArgumentException("Username or password is invalid.");
+ }
+ return ("valid:" + username).getBytes(StandardCharsets.UTF_8);
+ }
+
+ @Override
+ public Optional isValid(byte[] token) {
+ if (token != null) {
+ final String credential = new String(token, StandardCharsets.UTF_8);
+ if (credential.startsWith("valid:")) {
+ return Optional.of(credential.substring(6));
+ }
+ }
+ return Optional.empty();
+ }
+ }));
+ }
+
+ @Override
+ public void client(BufferAllocator allocator, Location location, FlightClient client) {
+ final FlightRuntimeException e = IntegrationAssertions.assertThrows(FlightRuntimeException.class, () -> {
+ client.listActions().forEach(act -> {
+ });
+ });
+ if (!FlightStatusCode.UNAUTHENTICATED.equals(e.status().code())) {
+ throw new AssertionError("Expected UNAUTHENTICATED but found " + e.status().code(), e);
+ }
+
+ client.authenticate(new BasicClientAuthHandler(USERNAME, PASSWORD));
+ final Result result = client.doAction(new Action("")).next();
+ if (!USERNAME.equals(new String(result.getBody(), StandardCharsets.UTF_8))) {
+ throw new AssertionError("Expected " + USERNAME + " but got " + Arrays.toString(result.getBody()));
+ }
+ }
+}
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java
new file mode 100644
index 0000000000000..576d1887f3905
--- /dev/null
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.util.Objects;
+
+/**
+ * Utility methods to implement integration tests without using JUnit assertions.
+ */
+final class IntegrationAssertions {
+
+ /**
+ * Assert that the given code throws the given exception or subclass thereof.
+ *
+ * @param clazz The exception type.
+ * @param body The code to run.
+ * @param The exception type.
+ * @return The thrown exception.
+ */
+ @SuppressWarnings("unchecked")
+ static T assertThrows(Class clazz, AssertThrows body) {
+ try {
+ body.run();
+ } catch (Throwable t) {
+ if (clazz.isInstance(t)) {
+ return (T) t;
+ }
+ throw new AssertionError("Expected exception of class " + clazz + " but got " + t.getClass(), t);
+ }
+ throw new AssertionError("Expected exception of class " + clazz + " but did not throw.");
+ }
+
+ /**
+ * Assert that the two (non-array) objects are equal.
+ */
+ static void assertEquals(Object expected, Object actual) {
+ if (!Objects.equals(expected, actual)) {
+ throw new AssertionError("Expected:\n" + expected + "\nbut got:\n" + actual);
+ }
+ }
+
+ /**
+ * Assert that the value is false, using the given message as an error otherwise.
+ */
+ static void assertFalse(String message, boolean value) {
+ if (value) {
+ throw new AssertionError("Expected false: " + message);
+ }
+ }
+
+ /**
+ * An interface used with {@link #assertThrows(Class, AssertThrows)}.
+ */
+ @FunctionalInterface
+ interface AssertThrows {
+
+ void run() throws Throwable;
+ }
+}
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
index 1b28b2d86a300..f96309434aaab 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
@@ -22,7 +22,6 @@
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
-import java.util.Collections;
import java.util.List;
import org.apache.arrow.flight.AsyncPutListener;
@@ -51,7 +50,7 @@
import io.netty.buffer.ArrowBuf;
/**
- * An Example Flight Server that provides access to the InMemoryStore.
+ * A Flight client for integration testing.
*/
class IntegrationTestClient {
private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestClient.class);
@@ -60,6 +59,7 @@ class IntegrationTestClient {
private IntegrationTestClient() {
options = new Options();
options.addOption("j", "json", true, "json file");
+ options.addOption("scenario", true, "The integration test scenario.");
options.addOption("host", true, "The host to connect to.");
options.addOption("port", true, "The port to connect to.");
}
@@ -71,6 +71,8 @@ public static void main(String[] args) {
fatalError("Invalid parameters", e);
} catch (IOException e) {
fatalError("Error accessing files", e);
+ } catch (Exception e) {
+ fatalError("Unknown error", e);
}
}
@@ -81,7 +83,7 @@ private static void fatalError(String message, Throwable e) {
System.exit(1);
}
- private void run(String[] args) throws ParseException, IOException {
+ private void run(String[] args) throws Exception {
final CommandLineParser parser = new DefaultParser();
final CommandLine cmd = parser.parse(options, args, false);
@@ -92,8 +94,12 @@ private void run(String[] args) throws ParseException, IOException {
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) {
- final String inputPath = cmd.getOptionValue("j");
- testStream(allocator, defaultLocation, client, inputPath);
+ if (cmd.hasOption("scenario")) {
+ Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client);
+ } else {
+ final String inputPath = cmd.getOptionValue("j");
+ testStream(allocator, defaultLocation, client, inputPath);
+ }
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
@@ -146,8 +152,8 @@ public void onNext(PutResult val) {
for (FlightEndpoint endpoint : info.getEndpoints()) {
// 3. Download the data from the server.
List locations = endpoint.getLocations();
- if (locations.size() == 0) {
- locations = Collections.singletonList(server);
+ if (locations.isEmpty()) {
+ throw new RuntimeException("No locations returned from Flight server.");
}
for (Location location : locations) {
System.out.println("Verifying location " + location.getUri());
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
index 5db2957784dae..5a9245d463df3 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
@@ -17,8 +17,9 @@
package org.apache.arrow.flight.example.integration;
+import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.Location;
-import org.apache.arrow.flight.example.ExampleFlightServer;
+import org.apache.arrow.flight.example.InMemoryStore;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
@@ -38,29 +39,41 @@ class IntegrationTestServer {
private IntegrationTestServer() {
options = new Options();
options.addOption("port", true, "The port to serve on.");
+ options.addOption("scenario", true, "The integration test scenario.");
}
private void run(String[] args) throws Exception {
CommandLineParser parser = new DefaultParser();
CommandLine cmd = parser.parse(options, args, false);
final int port = Integer.parseInt(cmd.getOptionValue("port", "31337"));
+ final Location location = Location.forGrpcInsecure("localhost", port);
final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
- final ExampleFlightServer efs = new ExampleFlightServer(allocator, Location.forGrpcInsecure("localhost", port));
- efs.start();
+ final FlightServer.Builder builder = FlightServer.builder().allocator(allocator).location(location);
+
+ final FlightServer server;
+ if (cmd.hasOption("scenario")) {
+ final Scenario scenario = Scenarios.getScenario(cmd.getOptionValue("scenario"));
+ scenario.buildServer(builder);
+ server = builder.producer(scenario.producer(allocator, location)).build();
+ server.start();
+ } else {
+ server = FlightServer.builder(allocator, location, new InMemoryStore(allocator, location)).build().start();
+ }
+
// Print out message for integration test script
System.out.println("Server listening on localhost:" + port);
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
System.out.println("\nExiting...");
- AutoCloseables.close(efs, allocator);
+ AutoCloseables.close(server, allocator);
} catch (Exception e) {
e.printStackTrace();
}
}));
- efs.awaitTermination();
+ server.awaitTermination();
}
public static void main(String[] args) {
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java
new file mode 100644
index 0000000000000..b3b962d2e734b
--- /dev/null
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+
+/**
+ * A particular scenario in integration testing.
+ */
+interface Scenario {
+
+ /**
+ * Construct the FlightProducer for a server in this scenario.
+ */
+ FlightProducer producer(BufferAllocator allocator, Location location) throws Exception;
+
+ /**
+ * Set any other server options.
+ */
+ void buildServer(FlightServer.Builder builder) throws Exception;
+
+ /**
+ * Run as the client in the scenario.
+ */
+ void client(BufferAllocator allocator, Location location, FlightClient client) throws Exception;
+}
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java
new file mode 100644
index 0000000000000..3cc65829f630c
--- /dev/null
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+
+/**
+ * Scenarios for integration testing.
+ */
+final class Scenarios {
+
+ private static Scenarios INSTANCE;
+
+ private final Map> scenarios;
+
+ private Scenarios() {
+ scenarios = new TreeMap<>();
+ scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new);
+ }
+
+ private static Scenarios getInstance() {
+ if (INSTANCE == null) {
+ INSTANCE = new Scenarios();
+ }
+ return INSTANCE;
+ }
+
+ static Scenario getScenario(String scenario) {
+ final Supplier ctor = getInstance().scenarios.get(scenario);
+ if (ctor == null) {
+ throw new IllegalArgumentException("Unknown integration test scenario: " + scenario);
+ }
+ return ctor.get();
+ }
+
+ // Utility methods for implementing tests.
+
+ public static void main(String[] args) {
+ // Run scenarios one after the other
+ final Location location = Location.forGrpcInsecure("localhost", 31337);
+ for (final Map.Entry> entry : getInstance().scenarios.entrySet()) {
+ System.out.println("Running test scenario: " + entry.getKey());
+ final Scenario scenario = entry.getValue().get();
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+ final FlightServer.Builder builder = FlightServer
+ .builder(allocator, location, scenario.producer(allocator, location));
+ scenario.buildServer(builder);
+ try (final FlightServer server = builder.build()) {
+ server.start();
+
+ try (final FlightClient client = FlightClient.builder(allocator, location).build()) {
+ scenario.client(allocator, location, client);
+ }
+
+ server.shutdown();
+ server.awaitTermination(1, TimeUnit.SECONDS);
+ System.out.println("Ran scenario " + entry.getKey());
+ }
+ } catch (Exception e) {
+ System.out.println("Exception while running scenario " + entry.getKey());
+ e.printStackTrace();
+ }
+ }
+ }
+}