From 31784d14aeef0f175e2467642e2c5a465256d676 Mon Sep 17 00:00:00 2001
From: David Li
Date: Sun, 13 Oct 2019 15:25:06 -0400
Subject: [PATCH] ARROW-5875: [FlightRPC] integration tests for Flight features
---
cpp/src/arrow/flight/CMakeLists.txt | 1 +
cpp/src/arrow/flight/test_integration.cc | 115 ++++++++++++++++
cpp/src/arrow/flight/test_integration.h | 49 +++++++
.../arrow/flight/test_integration_client.cc | 129 +++++++++++-------
.../arrow/flight/test_integration_server.cc | 39 +++++-
dev/archery/archery/integration/runner.py | 34 +++--
dev/archery/archery/integration/scenario.py | 29 ++++
dev/archery/archery/integration/tester.py | 4 +-
dev/archery/archery/integration/tester_cpp.py | 14 +-
.../archery/integration/tester_java.py | 16 ++-
.../flight/auth/ServerAuthInterceptor.java | 25 +++-
.../integration/AuthBasicProtoScenario.java | 97 +++++++++++++
.../integration/IntegrationAssertions.java | 74 ++++++++++
.../integration/IntegrationTestClient.java | 20 ++-
.../integration/IntegrationTestServer.java | 27 +++-
.../flight/example/integration/Scenario.java | 45 ++++++
.../flight/example/integration/Scenarios.java | 89 ++++++++++++
17 files changed, 719 insertions(+), 88 deletions(-)
create mode 100644 cpp/src/arrow/flight/test_integration.cc
create mode 100644 cpp/src/arrow/flight/test_integration.h
create mode 100644 dev/archery/archery/integration/scenario.py
create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java
create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java
create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java
create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index 9d460becb1dc2..95d05a64f9028 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -119,6 +119,7 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS OR ARROW_BUILD_INTEGRATION)
OUTPUTS
ARROW_FLIGHT_TESTING_LIBRARIES
SOURCES
+ test_integration.cc
test_util.cc
DEPENDENCIES
GTest::gtest
diff --git a/cpp/src/arrow/flight/test_integration.cc b/cpp/src/arrow/flight/test_integration.cc
new file mode 100644
index 0000000000000..1e2f4fa577823
--- /dev/null
+++ b/cpp/src/arrow/flight/test_integration.cc
@@ -0,0 +1,115 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/test_integration.h"
+#include "arrow/flight/client_middleware.h"
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/flight/types.h"
+#include "arrow/ipc/dictionary.h"
+
+#include
+#include
+#include
+#include
+#include
+
+namespace arrow {
+namespace flight {
+
+/// \brief The server for the basic auth integration test.
+class AuthBasicProtoServer : public FlightServerBase {
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr* result) override {
+ // Respond with the authenticated username.
+ auto buf = Buffer::FromString(context.peer_identity());
+ *result = std::unique_ptr(new SimpleResultStream({Result{buf}}));
+ return Status::OK();
+ }
+};
+
+/// Validate the result of a DoAction.
+Status CheckActionResults(FlightClient* client, const Action& action,
+ std::vector results) {
+ std::unique_ptr stream;
+ RETURN_NOT_OK(client->DoAction(action, &stream));
+ std::unique_ptr result;
+ for (const std::string& expected : results) {
+ RETURN_NOT_OK(stream->Next(&result));
+ if (!result) {
+ return Status::Invalid("Action result stream ended early");
+ }
+ const auto actual = result->body->ToString();
+ if (expected != actual) {
+ return Status::Invalid("Got wrong result; expected", expected, "but got", actual);
+ }
+ }
+ RETURN_NOT_OK(stream->Next(&result));
+ if (result) {
+ return Status::Invalid("Action result stream had too many entries");
+ }
+ return Status::OK();
+}
+
+// The expected username for the basic auth integration test.
+constexpr auto kAuthUsername = "arrow";
+// The expected password for the basic auth integration test.
+constexpr auto kAuthPassword = "flight";
+
+/// \brief A scenario testing the basic auth protobuf.
+class AuthBasicProtoScenario : public Scenario {
+ Status MakeServer(std::unique_ptr* server,
+ FlightServerOptions* options) override {
+ server->reset(new AuthBasicProtoServer());
+ options->auth_handler =
+ std::make_shared(kAuthUsername, kAuthPassword);
+ return Status::OK();
+ }
+
+ Status MakeClient(FlightClientOptions* options) override { return Status::OK(); }
+
+ Status RunClient(std::unique_ptr client) override {
+ Action action;
+ std::unique_ptr stream;
+ std::shared_ptr detail;
+ const auto& status = client->DoAction(action, &stream);
+ detail = FlightStatusDetail::UnwrapStatus(status);
+ // This client is unauthenticated and should fail.
+ if (detail == nullptr) {
+ return Status::Invalid("Expected UNAUTHENTICATED but got ", status.ToString());
+ }
+ if (detail->code() != FlightStatusCode::Unauthenticated) {
+ return Status::Invalid("Expected UNAUTHENTICATED but got ", detail->ToString());
+ }
+
+ auto client_handler = std::unique_ptr(
+ new TestClientBasicAuthHandler(kAuthUsername, kAuthPassword));
+ RETURN_NOT_OK(client->Authenticate({}, std::move(client_handler)));
+ return CheckActionResults(client.get(), action, {kAuthUsername});
+ }
+};
+
+Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) {
+ if (scenario_name == "auth:basic_proto") {
+ *out = std::make_shared();
+ return Status::OK();
+ }
+ return Status::KeyError("Scenario not found: ", scenario_name);
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/test_integration.h b/cpp/src/arrow/flight/test_integration.h
new file mode 100644
index 0000000000000..5d9bd7fd7bd74
--- /dev/null
+++ b/cpp/src/arrow/flight/test_integration.h
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Integration test scenarios for Arrow Flight.
+
+#include "arrow/flight/visibility.h"
+
+#include
+#include
+
+#include "arrow/flight/client.h"
+#include "arrow/flight/server.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace flight {
+
+/// \brief An integration test for Arrow Flight.
+class ARROW_FLIGHT_EXPORT Scenario {
+ public:
+ virtual ~Scenario() = default;
+ /// \brief Set up the server.
+ virtual Status MakeServer(std::unique_ptr* server,
+ FlightServerOptions* options) = 0;
+ /// \brief Set up the client.
+ virtual Status MakeClient(FlightClientOptions* options) = 0;
+ /// \brief Run the scenario as the client.
+ virtual Status RunClient(std::unique_ptr client) = 0;
+};
+
+/// \brief Get the implementation of an integration test scenario by name.
+Status GetScenario(const std::string& scenario_name, std::shared_ptr* out);
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/test_integration_client.cc b/cpp/src/arrow/flight/test_integration_client.cc
index ee6fa5e3dd970..f9bc1d3d19e58 100644
--- a/cpp/src/arrow/flight/test_integration_client.cc
+++ b/cpp/src/arrow/flight/test_integration_client.cc
@@ -39,11 +39,13 @@
#include "arrow/util/logging.h"
#include "arrow/flight/api.h"
+#include "arrow/flight/test_integration.h"
#include "arrow/flight/test_util.h"
DEFINE_string(host, "localhost", "Server port to connect to");
DEFINE_int32(port, 31337, "Server port to connect to");
DEFINE_string(path, "", "Resource path to request");
+DEFINE_string(scenario, "", "Integration test scenario to run");
namespace arrow {
namespace flight {
@@ -126,63 +128,73 @@ Status ConsumeFlightLocation(
return Status::OK();
}
-int RunIntegrationClient() {
- // Make sure the required extension types are registered.
- ExtensionTypeGuard uuid_ext_guard(uuid());
- ExtensionTypeGuard dict_ext_guard(dict_extension_type());
-
- std::unique_ptr client;
- Location location;
- ABORT_NOT_OK(Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location));
- ABORT_NOT_OK(FlightClient::Connect(location, &client));
-
- FlightDescriptor descr{FlightDescriptor::PATH, "", {FLAGS_path}};
-
- // 1. Put the data to the server.
- std::unique_ptr reader;
- std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl;
- auto in_file = *io::ReadableFile::Open(FLAGS_path);
- ABORT_NOT_OK(
- ipc::internal::json::JsonReader::Open(default_memory_pool(), in_file, &reader));
-
- std::shared_ptr original_schema = reader->schema();
- std::vector> original_data;
- ABORT_NOT_OK(ReadBatches(reader, &original_data));
-
- std::unique_ptr write_stream;
- std::unique_ptr metadata_reader;
- ABORT_NOT_OK(client->DoPut(descr, original_schema, &write_stream, &metadata_reader));
- ABORT_NOT_OK(UploadBatchesToFlight(original_data, *write_stream, *metadata_reader));
-
- // 2. Get the ticket for the data.
- std::unique_ptr info;
- ABORT_NOT_OK(client->GetFlightInfo(descr, &info));
-
- std::shared_ptr schema;
- ipc::DictionaryMemo dict_memo;
- ABORT_NOT_OK(info->GetSchema(&dict_memo, &schema));
-
- if (info->endpoints().size() == 0) {
- std::cerr << "No endpoints returned from Flight server." << std::endl;
- return -1;
+class IntegrationTestScenario : public flight::Scenario {
+ public:
+ Status MakeServer(std::unique_ptr* server,
+ FlightServerOptions* options) override {
+ ARROW_UNUSED(server);
+ ARROW_UNUSED(options);
+ return Status::NotImplemented("Not implemented, see test_integration_server.cc");
}
- for (const FlightEndpoint& endpoint : info->endpoints()) {
- const auto& ticket = endpoint.ticket;
+ Status MakeClient(FlightClientOptions* options) override {
+ ARROW_UNUSED(options);
+ return Status::OK();
+ }
+
+ Status RunClient(std::unique_ptr client) override {
+ // Make sure the required extension types are registered.
+ ExtensionTypeGuard uuid_ext_guard(uuid());
+ ExtensionTypeGuard dict_ext_guard(dict_extension_type());
+
+ FlightDescriptor descr{FlightDescriptor::PATH, "", {FLAGS_path}};
+
+ // 1. Put the data to the server.
+ std::unique_ptr reader;
+ std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl;
+ auto in_file = *io::ReadableFile::Open(FLAGS_path);
+ ABORT_NOT_OK(
+ ipc::internal::json::JsonReader::Open(default_memory_pool(), in_file, &reader));
+
+ std::shared_ptr original_schema = reader->schema();
+ std::vector> original_data;
+ ABORT_NOT_OK(ReadBatches(reader, &original_data));
- auto locations = endpoint.locations;
- if (locations.size() == 0) {
- locations = {location};
+ std::unique_ptr write_stream;
+ std::unique_ptr metadata_reader;
+ ABORT_NOT_OK(client->DoPut(descr, original_schema, &write_stream, &metadata_reader));
+ ABORT_NOT_OK(UploadBatchesToFlight(original_data, *write_stream, *metadata_reader));
+
+ // 2. Get the ticket for the data.
+ std::unique_ptr info;
+ ABORT_NOT_OK(client->GetFlightInfo(descr, &info));
+
+ std::shared_ptr schema;
+ ipc::DictionaryMemo dict_memo;
+ ABORT_NOT_OK(info->GetSchema(&dict_memo, &schema));
+
+ if (info->endpoints().size() == 0) {
+ std::cerr << "No endpoints returned from Flight server." << std::endl;
+ return Status::IOError("No endpoints returned from Flight server.");
}
- for (const auto& location : locations) {
- std::cout << "Verifying location " << location.ToString() << std::endl;
- // 3. Stream data from the server, comparing individual batches.
- ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, original_data));
+ for (const FlightEndpoint& endpoint : info->endpoints()) {
+ const auto& ticket = endpoint.ticket;
+
+ auto locations = endpoint.locations;
+ if (locations.size() == 0) {
+ return Status::IOError("No locations returned from Flight server.");
+ }
+
+ for (const auto& location : locations) {
+ std::cout << "Verifying location " << location.ToString() << std::endl;
+ // 3. Stream data from the server, comparing individual batches.
+ ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, original_data));
+ }
}
+ return Status::OK();
}
- return 0;
-}
+};
} // namespace flight
} // namespace arrow
@@ -190,5 +202,20 @@ int RunIntegrationClient() {
int main(int argc, char** argv) {
gflags::SetUsageMessage("Integration testing client for Flight.");
gflags::ParseCommandLineFlags(&argc, &argv, true);
- return arrow::flight::RunIntegrationClient();
+ std::shared_ptr scenario;
+ if (!FLAGS_scenario.empty()) {
+ ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario));
+ } else {
+ scenario = std::make_shared();
+ }
+
+ arrow::flight::FlightClientOptions options;
+ std::unique_ptr client;
+
+ ABORT_NOT_OK(scenario->MakeClient(&options));
+
+ arrow::flight::Location location;
+ ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location));
+ ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client));
+ return 0;
}
diff --git a/cpp/src/arrow/flight/test_integration_server.cc b/cpp/src/arrow/flight/test_integration_server.cc
index 9d4d91d8b75b3..9da42ae009207 100644
--- a/cpp/src/arrow/flight/test_integration_server.cc
+++ b/cpp/src/arrow/flight/test_integration_server.cc
@@ -15,7 +15,11 @@
// specific language governing permissions and limitations
// under the License.
-// Example server implementation for integration testing purposes
+// Server for integration testing.
+
+// Integration testing covers files and scenarios. The former
+// validates that Arrow data survives a round-trip through a Flight
+// service. The latter tests specific features of Arrow Flight.
#include
#include
@@ -33,9 +37,11 @@
#include "arrow/flight/internal.h"
#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
+#include "arrow/flight/test_integration.h"
#include "arrow/flight/test_util.h"
DEFINE_int32(port, 31337, "Server port to listen on");
+DEFINE_string(scenario, "", "Integration test senario to run");
namespace arrow {
namespace flight {
@@ -150,20 +156,47 @@ class FlightIntegrationTestServer : public FlightServerBase {
std::unordered_map uploaded_chunks;
};
+class IntegrationTestScenario : public Scenario {
+ public:
+ Status MakeServer(std::unique_ptr* server,
+ FlightServerOptions* options) override {
+ server->reset(new FlightIntegrationTestServer());
+ return Status::OK();
+ }
+
+ Status MakeClient(FlightClientOptions* options) override {
+ ARROW_UNUSED(options);
+ return Status::NotImplemented("Not implemented, see test_integration_client.cc");
+ }
+
+ Status RunClient(std::unique_ptr client) override {
+ ARROW_UNUSED(client);
+ return Status::NotImplemented("Not implemented, see test_integration_client.cc");
+ }
+};
+
} // namespace flight
} // namespace arrow
-std::unique_ptr g_server;
+std::unique_ptr g_server;
int main(int argc, char** argv) {
gflags::SetUsageMessage("Integration testing server for Flight.");
gflags::ParseCommandLineFlags(&argc, &argv, true);
- g_server.reset(new arrow::flight::FlightIntegrationTestServer);
+ std::shared_ptr scenario;
+
+ if (!FLAGS_scenario.empty()) {
+ ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario));
+ } else {
+ scenario = std::make_shared();
+ }
arrow::flight::Location location;
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
arrow::flight::FlightServerOptions options(location);
+ ARROW_CHECK_OK(scenario->MakeServer(&g_server, &options));
+
ARROW_CHECK_OK(g_server->Init(options));
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py
index f66df96992c04..b9bb94deba28a 100644
--- a/dev/archery/archery/integration/runner.py
+++ b/dev/archery/archery/integration/runner.py
@@ -26,6 +26,7 @@
import tempfile
import traceback
+from .scenario import Scenario
from .tester_cpp import CPPTester
from .tester_go import GoTester
from .tester_java import JavaTester
@@ -49,10 +50,11 @@ def __init__(self):
class IntegrationRunner(object):
- def __init__(self, json_files, testers, tempdir=None, debug=False,
- stop_on_error=True, gold_dirs=None, serial=False,
- match=None, **unused_kwargs):
+ def __init__(self, json_files, flight_scenarios, testers, tempdir=None,
+ debug=False, stop_on_error=True, gold_dirs=None,
+ serial=False, match=None, **unused_kwargs):
self.json_files = json_files
+ self.flight_scenarios = flight_scenarios
self.testers = testers
self.temp_dir = tempdir or tempfile.mkdtemp()
self.debug = debug
@@ -251,7 +253,8 @@ def _compare_flight_implementations(self, producer, consumer):
log('##########################################################')
case_runner = partial(self._run_flight_test_case, producer, consumer)
- self._run_test_cases(producer, consumer, case_runner, self.json_files)
+ self._run_test_cases(producer, consumer, case_runner,
+ self.json_files + self.flight_scenarios)
def _run_flight_test_case(self, producer, consumer, test_case):
"""
@@ -259,9 +262,8 @@ def _run_flight_test_case(self, producer, consumer, test_case):
"""
outcome = Outcome()
- json_path = test_case.path
log('=' * 58)
- log('Testing file {0}'.format(json_path))
+ log('Testing file {0}'.format(test_case.name))
log('=' * 58)
if producer.name in test_case.skip:
@@ -280,10 +282,17 @@ def _run_flight_test_case(self, producer, consumer, test_case):
else:
try:
- with producer.flight_server() as port:
+ if isinstance(test_case, Scenario):
+ server = producer.flight_server(test_case.name)
+ client_args = {'scenario_name': test_case.name}
+ else:
+ server = producer.flight_server()
+ client_args = {'json_path': test_case.path}
+
+ with server as port:
# Have the client upload the file, then download and
# compare
- consumer.flight_request(port, json_path)
+ consumer.flight_request(port, **client_args)
except Exception:
traceback.print_exc(file=printer.stdout)
outcome.failure = Failure(test_case, producer, consumer,
@@ -328,7 +337,14 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True,
)
json_files = static_json_files + generated_json_files
- runner = IntegrationRunner(json_files, testers, **kwargs)
+ # Additional integration test cases for Arrow Flight.
+ flight_scenarios = [
+ Scenario(
+ "auth:basic_proto",
+ description="Authenticate using the BasicAuth protobuf."),
+ ]
+
+ runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs)
runner.run()
if run_flight:
runner.run_flight()
diff --git a/dev/archery/archery/integration/scenario.py b/dev/archery/archery/integration/scenario.py
new file mode 100644
index 0000000000000..1fcbca64e6a1f
--- /dev/null
+++ b/dev/archery/archery/integration/scenario.py
@@ -0,0 +1,29 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+class Scenario:
+ """
+ An integration test scenario for Arrow Flight.
+
+ Does not correspond to a particular IPC JSON file.
+ """
+
+ def __init__(self, name, description, skip=None):
+ self.name = name
+ self.description = description
+ self.skip = skip or set()
diff --git a/dev/archery/archery/integration/tester.py b/dev/archery/archery/integration/tester.py
index 298298e4ef042..122e4f2e4a78b 100644
--- a/dev/archery/archery/integration/tester.py
+++ b/dev/archery/archery/integration/tester.py
@@ -50,7 +50,7 @@ def file_to_stream(self, file_path, stream_path):
def validate(self, json_path, arrow_path):
raise NotImplementedError
- def flight_server(self):
+ def flight_server(self, scenario_name=None):
"""Start the Flight server on a free port.
This should be a context manager that returns the port as the
@@ -58,5 +58,5 @@ def flight_server(self):
"""
raise NotImplementedError
- def flight_request(self, port, json_path):
+ def flight_request(self, port, json_path=None, scenario_name=None):
raise NotImplementedError
diff --git a/dev/archery/archery/integration/tester_cpp.py b/dev/archery/archery/integration/tester_cpp.py
index fed3f0904a3a7..d35c9550e58ea 100644
--- a/dev/archery/archery/integration/tester_cpp.py
+++ b/dev/archery/archery/integration/tester_cpp.py
@@ -76,8 +76,10 @@ def file_to_stream(self, file_path, stream_path):
self.run_shell_command(cmd)
@contextlib.contextmanager
- def flight_server(self):
+ def flight_server(self, scenario_name=None):
cmd = self.FLIGHT_SERVER_CMD + ['-port=0']
+ if scenario_name:
+ cmd = cmd + ["-scenario", scenario_name]
if self.debug:
log(' '.join(cmd))
server = subprocess.Popen(cmd,
@@ -98,11 +100,17 @@ def flight_server(self):
server.kill()
server.wait(5)
- def flight_request(self, port, json_path):
+ def flight_request(self, port, json_path=None, scenario_name=None):
cmd = self.FLIGHT_CLIENT_CMD + [
'-port=' + str(port),
- '-path=' + json_path,
]
+ if json_path:
+ cmd.extend(('-path', json_path))
+ elif scenario_name:
+ cmd.extend(('-scenario', scenario_name))
+ else:
+ raise TypeError("Must provide one of json_path or scenario_name")
+
if self.debug:
log(' '.join(cmd))
run_cmd(cmd)
diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py
index 1c8e960cd9843..3656a2cd1a329 100644
--- a/dev/archery/archery/integration/tester_java.py
+++ b/dev/archery/archery/integration/tester_java.py
@@ -96,19 +96,29 @@ def file_to_stream(self, file_path, stream_path):
log(' '.join(cmd))
run_cmd(cmd)
- def flight_request(self, port, json_path):
+ def flight_request(self, port, json_path=None, scenario_name=None):
cmd = ['java'] + self.JAVA_OPTS + \
['-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_CLIENT,
- '-port', str(port), '-j', json_path]
+ '-port', str(port)]
+
+ if json_path:
+ cmd.extend(('-j', json_path))
+ elif scenario_name:
+ cmd.extend(('-scenario', scenario_name))
+ else:
+ raise TypeError("Must provide one of json_path or scenario_name")
+
if self.debug:
log(' '.join(cmd))
run_cmd(cmd)
@contextlib.contextmanager
- def flight_server(self):
+ def flight_server(self, scenario_name=None):
cmd = ['java'] + self.JAVA_OPTS + \
['-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_SERVER,
'-port', '0']
+ if scenario_name:
+ cmd.extend(('-scenario', scenario_name))
if self.debug:
log(' '.join(cmd))
server = subprocess.Popen(cmd, stdout=subprocess.PIPE,
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
index 4ebd7424cb888..5bff3784e1730 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
@@ -19,6 +19,9 @@
import java.util.Optional;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.grpc.StatusUtils;
+
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
@@ -27,6 +30,7 @@
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
/**
* GRPC Interceptor for performing authentication.
@@ -43,10 +47,25 @@ public ServerAuthInterceptor(ServerAuthHandler authHandler) {
public Listener interceptCall(ServerCall call, Metadata headers,
ServerCallHandler next) {
if (!call.getMethodDescriptor().getFullMethodName().equals(AuthConstants.HANDSHAKE_DESCRIPTOR_NAME)) {
- final Optional peerIdentity = isValid(headers);
+ final Optional peerIdentity;
+
+ // Allow customizing the response code by throwing FlightRuntimeException
+ try {
+ peerIdentity = isValid(headers);
+ } catch (FlightRuntimeException e) {
+ final Status grpcStatus = StatusUtils.toGrpcStatus(e.status());
+ call.close(grpcStatus, new Metadata());
+ return new NoopServerCallListener<>();
+ } catch (StatusRuntimeException e) {
+ Metadata trailers = e.getTrailers();
+ call.close(e.getStatus(), trailers == null ? new Metadata() : trailers);
+ return new NoopServerCallListener<>();
+ }
+
if (!peerIdentity.isPresent()) {
- call.close(Status.UNAUTHENTICATED, new Metadata());
- // TODO: we should actually terminate here instead of causing an exception below.
+ // Send back a description along with the status code
+ call.close(Status.UNAUTHENTICATED
+ .withDescription("Unauthenticated (invalid or missing auth token)"), new Metadata());
return new NoopServerCallListener<>();
}
return Contexts.interceptCall(Context.current().withValue(AuthConstants.PEER_IDENTITY_KEY, peerIdentity.get()),
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java
new file mode 100644
index 0000000000000..3955d7d21bfcd
--- /dev/null
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Optional;
+
+import org.apache.arrow.flight.Action;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightStatusCode;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.Result;
+import org.apache.arrow.flight.auth.BasicClientAuthHandler;
+import org.apache.arrow.flight.auth.BasicServerAuthHandler;
+import org.apache.arrow.memory.BufferAllocator;
+
+/**
+ * A scenario testing the built-in basic authentication Protobuf.
+ */
+final class AuthBasicProtoScenario implements Scenario {
+
+ static final String USERNAME = "arrow";
+ static final String PASSWORD = "flight";
+
+ @Override
+ public FlightProducer producer(BufferAllocator allocator, Location location) {
+ return new NoOpFlightProducer() {
+ @Override
+ public void doAction(CallContext context, Action action, StreamListener listener) {
+ listener.onNext(new Result(context.peerIdentity().getBytes(StandardCharsets.UTF_8)));
+ listener.onCompleted();
+ }
+ };
+ }
+
+ @Override
+ public void buildServer(FlightServer.Builder builder) {
+ builder.authHandler(new BasicServerAuthHandler(new BasicServerAuthHandler.BasicAuthValidator() {
+ @Override
+ public byte[] getToken(String username, String password) throws Exception {
+ if (!USERNAME.equals(username) || !PASSWORD.equals(password)) {
+ throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException();
+ }
+ return ("valid:" + username).getBytes(StandardCharsets.UTF_8);
+ }
+
+ @Override
+ public Optional isValid(byte[] token) {
+ if (token != null) {
+ final String credential = new String(token, StandardCharsets.UTF_8);
+ if (credential.startsWith("valid:")) {
+ return Optional.of(credential.substring(6));
+ }
+ }
+ return Optional.empty();
+ }
+ }));
+ }
+
+ @Override
+ public void client(BufferAllocator allocator, Location location, FlightClient client) {
+ final FlightRuntimeException e = IntegrationAssertions.assertThrows(FlightRuntimeException.class, () -> {
+ client.listActions().forEach(act -> {
+ });
+ });
+ if (!FlightStatusCode.UNAUTHENTICATED.equals(e.status().code())) {
+ throw new AssertionError("Expected UNAUTHENTICATED but found " + e.status().code(), e);
+ }
+
+ client.authenticate(new BasicClientAuthHandler(USERNAME, PASSWORD));
+ final Result result = client.doAction(new Action("")).next();
+ if (!USERNAME.equals(new String(result.getBody(), StandardCharsets.UTF_8))) {
+ throw new AssertionError("Expected " + USERNAME + " but got " + Arrays.toString(result.getBody()));
+ }
+ }
+}
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java
new file mode 100644
index 0000000000000..576d1887f3905
--- /dev/null
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.util.Objects;
+
+/**
+ * Utility methods to implement integration tests without using JUnit assertions.
+ */
+final class IntegrationAssertions {
+
+ /**
+ * Assert that the given code throws the given exception or subclass thereof.
+ *
+ * @param clazz The exception type.
+ * @param body The code to run.
+ * @param The exception type.
+ * @return The thrown exception.
+ */
+ @SuppressWarnings("unchecked")
+ static T assertThrows(Class clazz, AssertThrows body) {
+ try {
+ body.run();
+ } catch (Throwable t) {
+ if (clazz.isInstance(t)) {
+ return (T) t;
+ }
+ throw new AssertionError("Expected exception of class " + clazz + " but got " + t.getClass(), t);
+ }
+ throw new AssertionError("Expected exception of class " + clazz + " but did not throw.");
+ }
+
+ /**
+ * Assert that the two (non-array) objects are equal.
+ */
+ static void assertEquals(Object expected, Object actual) {
+ if (!Objects.equals(expected, actual)) {
+ throw new AssertionError("Expected:\n" + expected + "\nbut got:\n" + actual);
+ }
+ }
+
+ /**
+ * Assert that the value is false, using the given message as an error otherwise.
+ */
+ static void assertFalse(String message, boolean value) {
+ if (value) {
+ throw new AssertionError("Expected false: " + message);
+ }
+ }
+
+ /**
+ * An interface used with {@link #assertThrows(Class, AssertThrows)}.
+ */
+ @FunctionalInterface
+ interface AssertThrows {
+
+ void run() throws Throwable;
+ }
+}
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
index fc0081369f984..27a545f84fd5b 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
@@ -22,7 +22,6 @@
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
-import java.util.Collections;
import java.util.List;
import org.apache.arrow.flight.AsyncPutListener;
@@ -50,7 +49,7 @@
import org.apache.commons.cli.ParseException;
/**
- * An Example Flight Server that provides access to the InMemoryStore.
+ * A Flight client for integration testing.
*/
class IntegrationTestClient {
private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestClient.class);
@@ -59,6 +58,7 @@ class IntegrationTestClient {
private IntegrationTestClient() {
options = new Options();
options.addOption("j", "json", true, "json file");
+ options.addOption("scenario", true, "The integration test scenario.");
options.addOption("host", true, "The host to connect to.");
options.addOption("port", true, "The port to connect to.");
}
@@ -70,6 +70,8 @@ public static void main(String[] args) {
fatalError("Invalid parameters", e);
} catch (IOException e) {
fatalError("Error accessing files", e);
+ } catch (Exception e) {
+ fatalError("Unknown error", e);
}
}
@@ -80,7 +82,7 @@ private static void fatalError(String message, Throwable e) {
System.exit(1);
}
- private void run(String[] args) throws ParseException, IOException {
+ private void run(String[] args) throws Exception {
final CommandLineParser parser = new DefaultParser();
final CommandLine cmd = parser.parse(options, args, false);
@@ -91,8 +93,12 @@ private void run(String[] args) throws ParseException, IOException {
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) {
- final String inputPath = cmd.getOptionValue("j");
- testStream(allocator, defaultLocation, client, inputPath);
+ if (cmd.hasOption("scenario")) {
+ Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client);
+ } else {
+ final String inputPath = cmd.getOptionValue("j");
+ testStream(allocator, defaultLocation, client, inputPath);
+ }
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
@@ -145,8 +151,8 @@ public void onNext(PutResult val) {
for (FlightEndpoint endpoint : info.getEndpoints()) {
// 3. Download the data from the server.
List locations = endpoint.getLocations();
- if (locations.size() == 0) {
- locations = Collections.singletonList(server);
+ if (locations.isEmpty()) {
+ throw new RuntimeException("No locations returned from Flight server.");
}
for (Location location : locations) {
System.out.println("Verifying location " + location.getUri());
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
index 2ee104f76b4c9..da336c5024aa2 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
@@ -17,8 +17,9 @@
package org.apache.arrow.flight.example.integration;
+import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.Location;
-import org.apache.arrow.flight.example.ExampleFlightServer;
+import org.apache.arrow.flight.example.InMemoryStore;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
@@ -38,30 +39,42 @@ class IntegrationTestServer {
private IntegrationTestServer() {
options = new Options();
options.addOption("port", true, "The port to serve on.");
+ options.addOption("scenario", true, "The integration test scenario.");
}
private void run(String[] args) throws Exception {
CommandLineParser parser = new DefaultParser();
CommandLine cmd = parser.parse(options, args, false);
final int port = Integer.parseInt(cmd.getOptionValue("port", "31337"));
+ final Location location = Location.forGrpcInsecure("localhost", port);
final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
- final ExampleFlightServer efs = new ExampleFlightServer(allocator, Location.forGrpcInsecure("localhost", port));
- efs.start();
- efs.getStore().setLocation(Location.forGrpcInsecure("localhost", efs.getPort()));
+ final FlightServer.Builder builder = FlightServer.builder().allocator(allocator).location(location);
+
+ final FlightServer server;
+ if (cmd.hasOption("scenario")) {
+ final Scenario scenario = Scenarios.getScenario(cmd.getOptionValue("scenario"));
+ scenario.buildServer(builder);
+ server = builder.producer(scenario.producer(allocator, location)).build();
+ server.start();
+ } else {
+ final InMemoryStore store = new InMemoryStore(allocator, location);
+ server = FlightServer.builder(allocator, location, store).build().start();
+ store.setLocation(Location.forGrpcInsecure("localhost", server.getPort()));
+ }
// Print out message for integration test script
- System.out.println("Server listening on localhost:" + efs.getPort());
+ System.out.println("Server listening on localhost:" + server.getPort());
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
System.out.println("\nExiting...");
- AutoCloseables.close(efs, allocator);
+ AutoCloseables.close(server, allocator);
} catch (Exception e) {
e.printStackTrace();
}
}));
- efs.awaitTermination();
+ server.awaitTermination();
}
public static void main(String[] args) {
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java
new file mode 100644
index 0000000000000..b3b962d2e734b
--- /dev/null
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+
+/**
+ * A particular scenario in integration testing.
+ */
+interface Scenario {
+
+ /**
+ * Construct the FlightProducer for a server in this scenario.
+ */
+ FlightProducer producer(BufferAllocator allocator, Location location) throws Exception;
+
+ /**
+ * Set any other server options.
+ */
+ void buildServer(FlightServer.Builder builder) throws Exception;
+
+ /**
+ * Run as the client in the scenario.
+ */
+ void client(BufferAllocator allocator, Location location, FlightClient client) throws Exception;
+}
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java
new file mode 100644
index 0000000000000..3cc65829f630c
--- /dev/null
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+
+/**
+ * Scenarios for integration testing.
+ */
+final class Scenarios {
+
+ private static Scenarios INSTANCE;
+
+ private final Map> scenarios;
+
+ private Scenarios() {
+ scenarios = new TreeMap<>();
+ scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new);
+ }
+
+ private static Scenarios getInstance() {
+ if (INSTANCE == null) {
+ INSTANCE = new Scenarios();
+ }
+ return INSTANCE;
+ }
+
+ static Scenario getScenario(String scenario) {
+ final Supplier ctor = getInstance().scenarios.get(scenario);
+ if (ctor == null) {
+ throw new IllegalArgumentException("Unknown integration test scenario: " + scenario);
+ }
+ return ctor.get();
+ }
+
+ // Utility methods for implementing tests.
+
+ public static void main(String[] args) {
+ // Run scenarios one after the other
+ final Location location = Location.forGrpcInsecure("localhost", 31337);
+ for (final Map.Entry> entry : getInstance().scenarios.entrySet()) {
+ System.out.println("Running test scenario: " + entry.getKey());
+ final Scenario scenario = entry.getValue().get();
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+ final FlightServer.Builder builder = FlightServer
+ .builder(allocator, location, scenario.producer(allocator, location));
+ scenario.buildServer(builder);
+ try (final FlightServer server = builder.build()) {
+ server.start();
+
+ try (final FlightClient client = FlightClient.builder(allocator, location).build()) {
+ scenario.client(allocator, location, client);
+ }
+
+ server.shutdown();
+ server.awaitTermination(1, TimeUnit.SECONDS);
+ System.out.println("Ran scenario " + entry.getKey());
+ }
+ } catch (Exception e) {
+ System.out.println("Exception while running scenario " + entry.getKey());
+ e.printStackTrace();
+ }
+ }
+ }
+}