From 31784d14aeef0f175e2467642e2c5a465256d676 Mon Sep 17 00:00:00 2001 From: David Li Date: Sun, 13 Oct 2019 15:25:06 -0400 Subject: [PATCH] ARROW-5875: [FlightRPC] integration tests for Flight features --- cpp/src/arrow/flight/CMakeLists.txt | 1 + cpp/src/arrow/flight/test_integration.cc | 115 ++++++++++++++++ cpp/src/arrow/flight/test_integration.h | 49 +++++++ .../arrow/flight/test_integration_client.cc | 129 +++++++++++------- .../arrow/flight/test_integration_server.cc | 39 +++++- dev/archery/archery/integration/runner.py | 34 +++-- dev/archery/archery/integration/scenario.py | 29 ++++ dev/archery/archery/integration/tester.py | 4 +- dev/archery/archery/integration/tester_cpp.py | 14 +- .../archery/integration/tester_java.py | 16 ++- .../flight/auth/ServerAuthInterceptor.java | 25 +++- .../integration/AuthBasicProtoScenario.java | 97 +++++++++++++ .../integration/IntegrationAssertions.java | 74 ++++++++++ .../integration/IntegrationTestClient.java | 20 ++- .../integration/IntegrationTestServer.java | 27 +++- .../flight/example/integration/Scenario.java | 45 ++++++ .../flight/example/integration/Scenarios.java | 89 ++++++++++++ 17 files changed, 719 insertions(+), 88 deletions(-) create mode 100644 cpp/src/arrow/flight/test_integration.cc create mode 100644 cpp/src/arrow/flight/test_integration.h create mode 100644 dev/archery/archery/integration/scenario.py create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 9d460becb1dc2..95d05a64f9028 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -119,6 +119,7 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS OR ARROW_BUILD_INTEGRATION) OUTPUTS ARROW_FLIGHT_TESTING_LIBRARIES SOURCES + test_integration.cc test_util.cc DEPENDENCIES GTest::gtest diff --git a/cpp/src/arrow/flight/test_integration.cc b/cpp/src/arrow/flight/test_integration.cc new file mode 100644 index 0000000000000..1e2f4fa577823 --- /dev/null +++ b/cpp/src/arrow/flight/test_integration.cc @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/test_integration.h" +#include "arrow/flight/client_middleware.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/types.h" +#include "arrow/ipc/dictionary.h" + +#include +#include +#include +#include +#include + +namespace arrow { +namespace flight { + +/// \brief The server for the basic auth integration test. +class AuthBasicProtoServer : public FlightServerBase { + Status DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* result) override { + // Respond with the authenticated username. + auto buf = Buffer::FromString(context.peer_identity()); + *result = std::unique_ptr(new SimpleResultStream({Result{buf}})); + return Status::OK(); + } +}; + +/// Validate the result of a DoAction. +Status CheckActionResults(FlightClient* client, const Action& action, + std::vector results) { + std::unique_ptr stream; + RETURN_NOT_OK(client->DoAction(action, &stream)); + std::unique_ptr result; + for (const std::string& expected : results) { + RETURN_NOT_OK(stream->Next(&result)); + if (!result) { + return Status::Invalid("Action result stream ended early"); + } + const auto actual = result->body->ToString(); + if (expected != actual) { + return Status::Invalid("Got wrong result; expected", expected, "but got", actual); + } + } + RETURN_NOT_OK(stream->Next(&result)); + if (result) { + return Status::Invalid("Action result stream had too many entries"); + } + return Status::OK(); +} + +// The expected username for the basic auth integration test. +constexpr auto kAuthUsername = "arrow"; +// The expected password for the basic auth integration test. +constexpr auto kAuthPassword = "flight"; + +/// \brief A scenario testing the basic auth protobuf. +class AuthBasicProtoScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + server->reset(new AuthBasicProtoServer()); + options->auth_handler = + std::make_shared(kAuthUsername, kAuthPassword); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } + + Status RunClient(std::unique_ptr client) override { + Action action; + std::unique_ptr stream; + std::shared_ptr detail; + const auto& status = client->DoAction(action, &stream); + detail = FlightStatusDetail::UnwrapStatus(status); + // This client is unauthenticated and should fail. + if (detail == nullptr) { + return Status::Invalid("Expected UNAUTHENTICATED but got ", status.ToString()); + } + if (detail->code() != FlightStatusCode::Unauthenticated) { + return Status::Invalid("Expected UNAUTHENTICATED but got ", detail->ToString()); + } + + auto client_handler = std::unique_ptr( + new TestClientBasicAuthHandler(kAuthUsername, kAuthPassword)); + RETURN_NOT_OK(client->Authenticate({}, std::move(client_handler))); + return CheckActionResults(client.get(), action, {kAuthUsername}); + } +}; + +Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { + if (scenario_name == "auth:basic_proto") { + *out = std::make_shared(); + return Status::OK(); + } + return Status::KeyError("Scenario not found: ", scenario_name); +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration.h b/cpp/src/arrow/flight/test_integration.h new file mode 100644 index 0000000000000..5d9bd7fd7bd74 --- /dev/null +++ b/cpp/src/arrow/flight/test_integration.h @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Integration test scenarios for Arrow Flight. + +#include "arrow/flight/visibility.h" + +#include +#include + +#include "arrow/flight/client.h" +#include "arrow/flight/server.h" +#include "arrow/status.h" + +namespace arrow { +namespace flight { + +/// \brief An integration test for Arrow Flight. +class ARROW_FLIGHT_EXPORT Scenario { + public: + virtual ~Scenario() = default; + /// \brief Set up the server. + virtual Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) = 0; + /// \brief Set up the client. + virtual Status MakeClient(FlightClientOptions* options) = 0; + /// \brief Run the scenario as the client. + virtual Status RunClient(std::unique_ptr client) = 0; +}; + +/// \brief Get the implementation of an integration test scenario by name. +Status GetScenario(const std::string& scenario_name, std::shared_ptr* out); + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration_client.cc b/cpp/src/arrow/flight/test_integration_client.cc index ee6fa5e3dd970..f9bc1d3d19e58 100644 --- a/cpp/src/arrow/flight/test_integration_client.cc +++ b/cpp/src/arrow/flight/test_integration_client.cc @@ -39,11 +39,13 @@ #include "arrow/util/logging.h" #include "arrow/flight/api.h" +#include "arrow/flight/test_integration.h" #include "arrow/flight/test_util.h" DEFINE_string(host, "localhost", "Server port to connect to"); DEFINE_int32(port, 31337, "Server port to connect to"); DEFINE_string(path, "", "Resource path to request"); +DEFINE_string(scenario, "", "Integration test scenario to run"); namespace arrow { namespace flight { @@ -126,63 +128,73 @@ Status ConsumeFlightLocation( return Status::OK(); } -int RunIntegrationClient() { - // Make sure the required extension types are registered. - ExtensionTypeGuard uuid_ext_guard(uuid()); - ExtensionTypeGuard dict_ext_guard(dict_extension_type()); - - std::unique_ptr client; - Location location; - ABORT_NOT_OK(Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location)); - ABORT_NOT_OK(FlightClient::Connect(location, &client)); - - FlightDescriptor descr{FlightDescriptor::PATH, "", {FLAGS_path}}; - - // 1. Put the data to the server. - std::unique_ptr reader; - std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl; - auto in_file = *io::ReadableFile::Open(FLAGS_path); - ABORT_NOT_OK( - ipc::internal::json::JsonReader::Open(default_memory_pool(), in_file, &reader)); - - std::shared_ptr original_schema = reader->schema(); - std::vector> original_data; - ABORT_NOT_OK(ReadBatches(reader, &original_data)); - - std::unique_ptr write_stream; - std::unique_ptr metadata_reader; - ABORT_NOT_OK(client->DoPut(descr, original_schema, &write_stream, &metadata_reader)); - ABORT_NOT_OK(UploadBatchesToFlight(original_data, *write_stream, *metadata_reader)); - - // 2. Get the ticket for the data. - std::unique_ptr info; - ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); - - std::shared_ptr schema; - ipc::DictionaryMemo dict_memo; - ABORT_NOT_OK(info->GetSchema(&dict_memo, &schema)); - - if (info->endpoints().size() == 0) { - std::cerr << "No endpoints returned from Flight server." << std::endl; - return -1; +class IntegrationTestScenario : public flight::Scenario { + public: + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + ARROW_UNUSED(server); + ARROW_UNUSED(options); + return Status::NotImplemented("Not implemented, see test_integration_server.cc"); } - for (const FlightEndpoint& endpoint : info->endpoints()) { - const auto& ticket = endpoint.ticket; + Status MakeClient(FlightClientOptions* options) override { + ARROW_UNUSED(options); + return Status::OK(); + } + + Status RunClient(std::unique_ptr client) override { + // Make sure the required extension types are registered. + ExtensionTypeGuard uuid_ext_guard(uuid()); + ExtensionTypeGuard dict_ext_guard(dict_extension_type()); + + FlightDescriptor descr{FlightDescriptor::PATH, "", {FLAGS_path}}; + + // 1. Put the data to the server. + std::unique_ptr reader; + std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl; + auto in_file = *io::ReadableFile::Open(FLAGS_path); + ABORT_NOT_OK( + ipc::internal::json::JsonReader::Open(default_memory_pool(), in_file, &reader)); + + std::shared_ptr original_schema = reader->schema(); + std::vector> original_data; + ABORT_NOT_OK(ReadBatches(reader, &original_data)); - auto locations = endpoint.locations; - if (locations.size() == 0) { - locations = {location}; + std::unique_ptr write_stream; + std::unique_ptr metadata_reader; + ABORT_NOT_OK(client->DoPut(descr, original_schema, &write_stream, &metadata_reader)); + ABORT_NOT_OK(UploadBatchesToFlight(original_data, *write_stream, *metadata_reader)); + + // 2. Get the ticket for the data. + std::unique_ptr info; + ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); + + std::shared_ptr schema; + ipc::DictionaryMemo dict_memo; + ABORT_NOT_OK(info->GetSchema(&dict_memo, &schema)); + + if (info->endpoints().size() == 0) { + std::cerr << "No endpoints returned from Flight server." << std::endl; + return Status::IOError("No endpoints returned from Flight server."); } - for (const auto& location : locations) { - std::cout << "Verifying location " << location.ToString() << std::endl; - // 3. Stream data from the server, comparing individual batches. - ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, original_data)); + for (const FlightEndpoint& endpoint : info->endpoints()) { + const auto& ticket = endpoint.ticket; + + auto locations = endpoint.locations; + if (locations.size() == 0) { + return Status::IOError("No locations returned from Flight server."); + } + + for (const auto& location : locations) { + std::cout << "Verifying location " << location.ToString() << std::endl; + // 3. Stream data from the server, comparing individual batches. + ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, original_data)); + } } + return Status::OK(); } - return 0; -} +}; } // namespace flight } // namespace arrow @@ -190,5 +202,20 @@ int RunIntegrationClient() { int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); - return arrow::flight::RunIntegrationClient(); + std::shared_ptr scenario; + if (!FLAGS_scenario.empty()) { + ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario)); + } else { + scenario = std::make_shared(); + } + + arrow::flight::FlightClientOptions options; + std::unique_ptr client; + + ABORT_NOT_OK(scenario->MakeClient(&options)); + + arrow::flight::Location location; + ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location)); + ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client)); + return 0; } diff --git a/cpp/src/arrow/flight/test_integration_server.cc b/cpp/src/arrow/flight/test_integration_server.cc index 9d4d91d8b75b3..9da42ae009207 100644 --- a/cpp/src/arrow/flight/test_integration_server.cc +++ b/cpp/src/arrow/flight/test_integration_server.cc @@ -15,7 +15,11 @@ // specific language governing permissions and limitations // under the License. -// Example server implementation for integration testing purposes +// Server for integration testing. + +// Integration testing covers files and scenarios. The former +// validates that Arrow data survives a round-trip through a Flight +// service. The latter tests specific features of Arrow Flight. #include #include @@ -33,9 +37,11 @@ #include "arrow/flight/internal.h" #include "arrow/flight/server.h" #include "arrow/flight/server_auth.h" +#include "arrow/flight/test_integration.h" #include "arrow/flight/test_util.h" DEFINE_int32(port, 31337, "Server port to listen on"); +DEFINE_string(scenario, "", "Integration test senario to run"); namespace arrow { namespace flight { @@ -150,20 +156,47 @@ class FlightIntegrationTestServer : public FlightServerBase { std::unordered_map uploaded_chunks; }; +class IntegrationTestScenario : public Scenario { + public: + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + server->reset(new FlightIntegrationTestServer()); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { + ARROW_UNUSED(options); + return Status::NotImplemented("Not implemented, see test_integration_client.cc"); + } + + Status RunClient(std::unique_ptr client) override { + ARROW_UNUSED(client); + return Status::NotImplemented("Not implemented, see test_integration_client.cc"); + } +}; + } // namespace flight } // namespace arrow -std::unique_ptr g_server; +std::unique_ptr g_server; int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing server for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); - g_server.reset(new arrow::flight::FlightIntegrationTestServer); + std::shared_ptr scenario; + + if (!FLAGS_scenario.empty()) { + ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario)); + } else { + scenario = std::make_shared(); + } arrow::flight::Location location; ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); arrow::flight::FlightServerOptions options(location); + ARROW_CHECK_OK(scenario->MakeServer(&g_server, &options)); + ARROW_CHECK_OK(g_server->Init(options)); // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index f66df96992c04..b9bb94deba28a 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -26,6 +26,7 @@ import tempfile import traceback +from .scenario import Scenario from .tester_cpp import CPPTester from .tester_go import GoTester from .tester_java import JavaTester @@ -49,10 +50,11 @@ def __init__(self): class IntegrationRunner(object): - def __init__(self, json_files, testers, tempdir=None, debug=False, - stop_on_error=True, gold_dirs=None, serial=False, - match=None, **unused_kwargs): + def __init__(self, json_files, flight_scenarios, testers, tempdir=None, + debug=False, stop_on_error=True, gold_dirs=None, + serial=False, match=None, **unused_kwargs): self.json_files = json_files + self.flight_scenarios = flight_scenarios self.testers = testers self.temp_dir = tempdir or tempfile.mkdtemp() self.debug = debug @@ -251,7 +253,8 @@ def _compare_flight_implementations(self, producer, consumer): log('##########################################################') case_runner = partial(self._run_flight_test_case, producer, consumer) - self._run_test_cases(producer, consumer, case_runner, self.json_files) + self._run_test_cases(producer, consumer, case_runner, + self.json_files + self.flight_scenarios) def _run_flight_test_case(self, producer, consumer, test_case): """ @@ -259,9 +262,8 @@ def _run_flight_test_case(self, producer, consumer, test_case): """ outcome = Outcome() - json_path = test_case.path log('=' * 58) - log('Testing file {0}'.format(json_path)) + log('Testing file {0}'.format(test_case.name)) log('=' * 58) if producer.name in test_case.skip: @@ -280,10 +282,17 @@ def _run_flight_test_case(self, producer, consumer, test_case): else: try: - with producer.flight_server() as port: + if isinstance(test_case, Scenario): + server = producer.flight_server(test_case.name) + client_args = {'scenario_name': test_case.name} + else: + server = producer.flight_server() + client_args = {'json_path': test_case.path} + + with server as port: # Have the client upload the file, then download and # compare - consumer.flight_request(port, json_path) + consumer.flight_request(port, **client_args) except Exception: traceback.print_exc(file=printer.stdout) outcome.failure = Failure(test_case, producer, consumer, @@ -328,7 +337,14 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, ) json_files = static_json_files + generated_json_files - runner = IntegrationRunner(json_files, testers, **kwargs) + # Additional integration test cases for Arrow Flight. + flight_scenarios = [ + Scenario( + "auth:basic_proto", + description="Authenticate using the BasicAuth protobuf."), + ] + + runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) runner.run() if run_flight: runner.run_flight() diff --git a/dev/archery/archery/integration/scenario.py b/dev/archery/archery/integration/scenario.py new file mode 100644 index 0000000000000..1fcbca64e6a1f --- /dev/null +++ b/dev/archery/archery/integration/scenario.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +class Scenario: + """ + An integration test scenario for Arrow Flight. + + Does not correspond to a particular IPC JSON file. + """ + + def __init__(self, name, description, skip=None): + self.name = name + self.description = description + self.skip = skip or set() diff --git a/dev/archery/archery/integration/tester.py b/dev/archery/archery/integration/tester.py index 298298e4ef042..122e4f2e4a78b 100644 --- a/dev/archery/archery/integration/tester.py +++ b/dev/archery/archery/integration/tester.py @@ -50,7 +50,7 @@ def file_to_stream(self, file_path, stream_path): def validate(self, json_path, arrow_path): raise NotImplementedError - def flight_server(self): + def flight_server(self, scenario_name=None): """Start the Flight server on a free port. This should be a context manager that returns the port as the @@ -58,5 +58,5 @@ def flight_server(self): """ raise NotImplementedError - def flight_request(self, port, json_path): + def flight_request(self, port, json_path=None, scenario_name=None): raise NotImplementedError diff --git a/dev/archery/archery/integration/tester_cpp.py b/dev/archery/archery/integration/tester_cpp.py index fed3f0904a3a7..d35c9550e58ea 100644 --- a/dev/archery/archery/integration/tester_cpp.py +++ b/dev/archery/archery/integration/tester_cpp.py @@ -76,8 +76,10 @@ def file_to_stream(self, file_path, stream_path): self.run_shell_command(cmd) @contextlib.contextmanager - def flight_server(self): + def flight_server(self, scenario_name=None): cmd = self.FLIGHT_SERVER_CMD + ['-port=0'] + if scenario_name: + cmd = cmd + ["-scenario", scenario_name] if self.debug: log(' '.join(cmd)) server = subprocess.Popen(cmd, @@ -98,11 +100,17 @@ def flight_server(self): server.kill() server.wait(5) - def flight_request(self, port, json_path): + def flight_request(self, port, json_path=None, scenario_name=None): cmd = self.FLIGHT_CLIENT_CMD + [ '-port=' + str(port), - '-path=' + json_path, ] + if json_path: + cmd.extend(('-path', json_path)) + elif scenario_name: + cmd.extend(('-scenario', scenario_name)) + else: + raise TypeError("Must provide one of json_path or scenario_name") + if self.debug: log(' '.join(cmd)) run_cmd(cmd) diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py index 1c8e960cd9843..3656a2cd1a329 100644 --- a/dev/archery/archery/integration/tester_java.py +++ b/dev/archery/archery/integration/tester_java.py @@ -96,19 +96,29 @@ def file_to_stream(self, file_path, stream_path): log(' '.join(cmd)) run_cmd(cmd) - def flight_request(self, port, json_path): + def flight_request(self, port, json_path=None, scenario_name=None): cmd = ['java'] + self.JAVA_OPTS + \ ['-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_CLIENT, - '-port', str(port), '-j', json_path] + '-port', str(port)] + + if json_path: + cmd.extend(('-j', json_path)) + elif scenario_name: + cmd.extend(('-scenario', scenario_name)) + else: + raise TypeError("Must provide one of json_path or scenario_name") + if self.debug: log(' '.join(cmd)) run_cmd(cmd) @contextlib.contextmanager - def flight_server(self): + def flight_server(self, scenario_name=None): cmd = ['java'] + self.JAVA_OPTS + \ ['-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_SERVER, '-port', '0'] + if scenario_name: + cmd.extend(('-scenario', scenario_name)) if self.debug: log(' '.join(cmd)) server = subprocess.Popen(cmd, stdout=subprocess.PIPE, diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java index 4ebd7424cb888..5bff3784e1730 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java @@ -19,6 +19,9 @@ import java.util.Optional; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.grpc.StatusUtils; + import io.grpc.Context; import io.grpc.Contexts; import io.grpc.Metadata; @@ -27,6 +30,7 @@ import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; +import io.grpc.StatusRuntimeException; /** * GRPC Interceptor for performing authentication. @@ -43,10 +47,25 @@ public ServerAuthInterceptor(ServerAuthHandler authHandler) { public Listener interceptCall(ServerCall call, Metadata headers, ServerCallHandler next) { if (!call.getMethodDescriptor().getFullMethodName().equals(AuthConstants.HANDSHAKE_DESCRIPTOR_NAME)) { - final Optional peerIdentity = isValid(headers); + final Optional peerIdentity; + + // Allow customizing the response code by throwing FlightRuntimeException + try { + peerIdentity = isValid(headers); + } catch (FlightRuntimeException e) { + final Status grpcStatus = StatusUtils.toGrpcStatus(e.status()); + call.close(grpcStatus, new Metadata()); + return new NoopServerCallListener<>(); + } catch (StatusRuntimeException e) { + Metadata trailers = e.getTrailers(); + call.close(e.getStatus(), trailers == null ? new Metadata() : trailers); + return new NoopServerCallListener<>(); + } + if (!peerIdentity.isPresent()) { - call.close(Status.UNAUTHENTICATED, new Metadata()); - // TODO: we should actually terminate here instead of causing an exception below. + // Send back a description along with the status code + call.close(Status.UNAUTHENTICATED + .withDescription("Unauthenticated (invalid or missing auth token)"), new Metadata()); return new NoopServerCallListener<>(); } return Contexts.interceptCall(Context.current().withValue(AuthConstants.PEER_IDENTITY_KEY, peerIdentity.get()), diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java new file mode 100644 index 0000000000000..3955d7d21bfcd --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Optional; + +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStatusCode; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.auth.BasicClientAuthHandler; +import org.apache.arrow.flight.auth.BasicServerAuthHandler; +import org.apache.arrow.memory.BufferAllocator; + +/** + * A scenario testing the built-in basic authentication Protobuf. + */ +final class AuthBasicProtoScenario implements Scenario { + + static final String USERNAME = "arrow"; + static final String PASSWORD = "flight"; + + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) { + return new NoOpFlightProducer() { + @Override + public void doAction(CallContext context, Action action, StreamListener listener) { + listener.onNext(new Result(context.peerIdentity().getBytes(StandardCharsets.UTF_8))); + listener.onCompleted(); + } + }; + } + + @Override + public void buildServer(FlightServer.Builder builder) { + builder.authHandler(new BasicServerAuthHandler(new BasicServerAuthHandler.BasicAuthValidator() { + @Override + public byte[] getToken(String username, String password) throws Exception { + if (!USERNAME.equals(username) || !PASSWORD.equals(password)) { + throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException(); + } + return ("valid:" + username).getBytes(StandardCharsets.UTF_8); + } + + @Override + public Optional isValid(byte[] token) { + if (token != null) { + final String credential = new String(token, StandardCharsets.UTF_8); + if (credential.startsWith("valid:")) { + return Optional.of(credential.substring(6)); + } + } + return Optional.empty(); + } + })); + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) { + final FlightRuntimeException e = IntegrationAssertions.assertThrows(FlightRuntimeException.class, () -> { + client.listActions().forEach(act -> { + }); + }); + if (!FlightStatusCode.UNAUTHENTICATED.equals(e.status().code())) { + throw new AssertionError("Expected UNAUTHENTICATED but found " + e.status().code(), e); + } + + client.authenticate(new BasicClientAuthHandler(USERNAME, PASSWORD)); + final Result result = client.doAction(new Action("")).next(); + if (!USERNAME.equals(new String(result.getBody(), StandardCharsets.UTF_8))) { + throw new AssertionError("Expected " + USERNAME + " but got " + Arrays.toString(result.getBody())); + } + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java new file mode 100644 index 0000000000000..576d1887f3905 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import java.util.Objects; + +/** + * Utility methods to implement integration tests without using JUnit assertions. + */ +final class IntegrationAssertions { + + /** + * Assert that the given code throws the given exception or subclass thereof. + * + * @param clazz The exception type. + * @param body The code to run. + * @param The exception type. + * @return The thrown exception. + */ + @SuppressWarnings("unchecked") + static T assertThrows(Class clazz, AssertThrows body) { + try { + body.run(); + } catch (Throwable t) { + if (clazz.isInstance(t)) { + return (T) t; + } + throw new AssertionError("Expected exception of class " + clazz + " but got " + t.getClass(), t); + } + throw new AssertionError("Expected exception of class " + clazz + " but did not throw."); + } + + /** + * Assert that the two (non-array) objects are equal. + */ + static void assertEquals(Object expected, Object actual) { + if (!Objects.equals(expected, actual)) { + throw new AssertionError("Expected:\n" + expected + "\nbut got:\n" + actual); + } + } + + /** + * Assert that the value is false, using the given message as an error otherwise. + */ + static void assertFalse(String message, boolean value) { + if (value) { + throw new AssertionError("Expected false: " + message); + } + } + + /** + * An interface used with {@link #assertThrows(Class, AssertThrows)}. + */ + @FunctionalInterface + interface AssertThrows { + + void run() throws Throwable; + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index fc0081369f984..27a545f84fd5b 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -22,7 +22,6 @@ import java.io.File; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.Collections; import java.util.List; import org.apache.arrow.flight.AsyncPutListener; @@ -50,7 +49,7 @@ import org.apache.commons.cli.ParseException; /** - * An Example Flight Server that provides access to the InMemoryStore. + * A Flight client for integration testing. */ class IntegrationTestClient { private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestClient.class); @@ -59,6 +58,7 @@ class IntegrationTestClient { private IntegrationTestClient() { options = new Options(); options.addOption("j", "json", true, "json file"); + options.addOption("scenario", true, "The integration test scenario."); options.addOption("host", true, "The host to connect to."); options.addOption("port", true, "The port to connect to."); } @@ -70,6 +70,8 @@ public static void main(String[] args) { fatalError("Invalid parameters", e); } catch (IOException e) { fatalError("Error accessing files", e); + } catch (Exception e) { + fatalError("Unknown error", e); } } @@ -80,7 +82,7 @@ private static void fatalError(String message, Throwable e) { System.exit(1); } - private void run(String[] args) throws ParseException, IOException { + private void run(String[] args) throws Exception { final CommandLineParser parser = new DefaultParser(); final CommandLine cmd = parser.parse(options, args, false); @@ -91,8 +93,12 @@ private void run(String[] args) throws ParseException, IOException { try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) { - final String inputPath = cmd.getOptionValue("j"); - testStream(allocator, defaultLocation, client, inputPath); + if (cmd.hasOption("scenario")) { + Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client); + } else { + final String inputPath = cmd.getOptionValue("j"); + testStream(allocator, defaultLocation, client, inputPath); + } } catch (InterruptedException e) { throw new RuntimeException(e); } @@ -145,8 +151,8 @@ public void onNext(PutResult val) { for (FlightEndpoint endpoint : info.getEndpoints()) { // 3. Download the data from the server. List locations = endpoint.getLocations(); - if (locations.size() == 0) { - locations = Collections.singletonList(server); + if (locations.isEmpty()) { + throw new RuntimeException("No locations returned from Flight server."); } for (Location location : locations) { System.out.println("Verifying location " + location.getUri()); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java index 2ee104f76b4c9..da336c5024aa2 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java @@ -17,8 +17,9 @@ package org.apache.arrow.flight.example.integration; +import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.example.ExampleFlightServer; +import org.apache.arrow.flight.example.InMemoryStore; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; @@ -38,30 +39,42 @@ class IntegrationTestServer { private IntegrationTestServer() { options = new Options(); options.addOption("port", true, "The port to serve on."); + options.addOption("scenario", true, "The integration test scenario."); } private void run(String[] args) throws Exception { CommandLineParser parser = new DefaultParser(); CommandLine cmd = parser.parse(options, args, false); final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); + final Location location = Location.forGrpcInsecure("localhost", port); final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - final ExampleFlightServer efs = new ExampleFlightServer(allocator, Location.forGrpcInsecure("localhost", port)); - efs.start(); - efs.getStore().setLocation(Location.forGrpcInsecure("localhost", efs.getPort())); + final FlightServer.Builder builder = FlightServer.builder().allocator(allocator).location(location); + + final FlightServer server; + if (cmd.hasOption("scenario")) { + final Scenario scenario = Scenarios.getScenario(cmd.getOptionValue("scenario")); + scenario.buildServer(builder); + server = builder.producer(scenario.producer(allocator, location)).build(); + server.start(); + } else { + final InMemoryStore store = new InMemoryStore(allocator, location); + server = FlightServer.builder(allocator, location, store).build().start(); + store.setLocation(Location.forGrpcInsecure("localhost", server.getPort())); + } // Print out message for integration test script - System.out.println("Server listening on localhost:" + efs.getPort()); + System.out.println("Server listening on localhost:" + server.getPort()); Runtime.getRuntime().addShutdownHook(new Thread(() -> { try { System.out.println("\nExiting..."); - AutoCloseables.close(efs, allocator); + AutoCloseables.close(server, allocator); } catch (Exception e) { e.printStackTrace(); } })); - efs.awaitTermination(); + server.awaitTermination(); } public static void main(String[] args) { diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java new file mode 100644 index 0000000000000..b3b962d2e734b --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; + +/** + * A particular scenario in integration testing. + */ +interface Scenario { + + /** + * Construct the FlightProducer for a server in this scenario. + */ + FlightProducer producer(BufferAllocator allocator, Location location) throws Exception; + + /** + * Set any other server options. + */ + void buildServer(FlightServer.Builder builder) throws Exception; + + /** + * Run as the client in the scenario. + */ + void client(BufferAllocator allocator, Location location, FlightClient client) throws Exception; +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java new file mode 100644 index 0000000000000..3cc65829f630c --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; + +/** + * Scenarios for integration testing. + */ +final class Scenarios { + + private static Scenarios INSTANCE; + + private final Map> scenarios; + + private Scenarios() { + scenarios = new TreeMap<>(); + scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new); + } + + private static Scenarios getInstance() { + if (INSTANCE == null) { + INSTANCE = new Scenarios(); + } + return INSTANCE; + } + + static Scenario getScenario(String scenario) { + final Supplier ctor = getInstance().scenarios.get(scenario); + if (ctor == null) { + throw new IllegalArgumentException("Unknown integration test scenario: " + scenario); + } + return ctor.get(); + } + + // Utility methods for implementing tests. + + public static void main(String[] args) { + // Run scenarios one after the other + final Location location = Location.forGrpcInsecure("localhost", 31337); + for (final Map.Entry> entry : getInstance().scenarios.entrySet()) { + System.out.println("Running test scenario: " + entry.getKey()); + final Scenario scenario = entry.getValue().get(); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) { + final FlightServer.Builder builder = FlightServer + .builder(allocator, location, scenario.producer(allocator, location)); + scenario.buildServer(builder); + try (final FlightServer server = builder.build()) { + server.start(); + + try (final FlightClient client = FlightClient.builder(allocator, location).build()) { + scenario.client(allocator, location, client); + } + + server.shutdown(); + server.awaitTermination(1, TimeUnit.SECONDS); + System.out.println("Ran scenario " + entry.getKey()); + } + } catch (Exception e) { + System.out.println("Exception while running scenario " + entry.getKey()); + e.printStackTrace(); + } + } + } +}