Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-10487 [FlightRPC][C++] Header-based auth in clients #8724

Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c2ba2c7
[1] Initial commit to support client header authentication in C++
lyndonbauto Nov 18, 2020
ab7d6a3
[1] Added integration test for client header authentication in C++ an…
lyndonbauto Nov 19, 2020
ecb533f
Merge branch 'lyndon/flight-auth-redesign-cpp' of https://github.com/…
lyndonbauto Nov 19, 2020
73be3e7
[1] Updates for linting
lyndonbauto Nov 19, 2020
74ef8ea
[1] Adding missed file.
lyndonbauto Nov 19, 2020
29a3192
[1] Adding fix for Java lint errors.
lyndonbauto Nov 19, 2020
fac0bd0
[1] Added a couple comments
lyndonbauto Nov 19, 2020
7fd1279
[1] Minor comment fixes
lyndonbauto Nov 19, 2020
0b5a08e
[1] Addressed pull request comments, still need to address comments a…
lyndonbauto Nov 23, 2020
6861cf0
[1] Added unit test.
lyndonbauto Nov 23, 2020
01a134b
[1] Fixed linting issues
lyndonbauto Nov 23, 2020
de78c6b
[1] Fixing linting issues
lyndonbauto Nov 23, 2020
c44698f
[1] Removed some extra spaces at the end of some lines.
lyndonbauto Nov 23, 2020
e975fd8
[1] Correcting linting issues.
lyndonbauto Nov 24, 2020
37889fa
[1] Minor cmake fix
lyndonbauto Nov 24, 2020
e7ac27c
[1] Trying different cmake spacing.
lyndonbauto Nov 24, 2020
ba7cb9f
[1] Trying different cmake
lyndonbauto Nov 24, 2020
1426252
[1] Addressed code review comments.
lyndonbauto Nov 24, 2020
516d993
[1] Removing integration test and reverting some cmake changes.
lyndonbauto Nov 24, 2020
3000ecb
[1] Removed some no longer used functionality.
lyndonbauto Nov 24, 2020
1de10fa
[1] Added improved testing and fixed linting.
lyndonbauto Nov 24, 2020
065af4a
[1] Fixed lint issue
lyndonbauto Nov 24, 2020
d4da03b
Merge branch 'master' of https://github.com/apache/arrow into jduo/ly…
lyndonbauto Nov 25, 2020
911fcc7
[1] Updating submodule
lyndonbauto Nov 25, 2020
f41edce
[1] Minor documentation fixes.
lyndonbauto Nov 25, 2020
6b6fbbe
[1] Fixed casting issue on some builds
lyndonbauto Nov 25, 2020
199b655
[1] Added missing parameter for documentation
lyndonbauto Nov 25, 2020
47aa581
[1] Fixing cast.
lyndonbauto Nov 25, 2020
477d865
[1] Moving std:: from toupper call because it causes break in some bu…
lyndonbauto Nov 25, 2020
1cc3fdb
[1] Adding missed std remove
lyndonbauto Nov 25, 2020
d27465d
[1] Fixed linting issue.
lyndonbauto Nov 25, 2020
d21006f
[1] Updated test return error properly and to check for error
lyndonbauto Nov 25, 2020
6cd8a45
[1] Fixed linting issue.
lyndonbauto Nov 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions cpp/src/arrow/flight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ set(ARROW_FLIGHT_SRCS
serialization_internal.cc
server.cc
server_auth.cc
client_header_auth_middleware.cc
types.cc)

add_arrow_lib(arrow_flight
Expand Down Expand Up @@ -213,10 +214,16 @@ if(ARROW_BUILD_INTEGRATION)
target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS}
${GFLAGS_LIBRARIES} GTest::gtest)

add_executable(flight-test-integration-client-header-auth test_integration_client_header_auth.cc)
target_link_libraries(flight-test-integration-client-header-auth ${ARROW_FLIGHT_TEST_LINK_LIBS}
${GFLAGS_LIBRARIES} GTest::gtest)

add_dependencies(arrow_flight flight-test-integration-client
flight-test-integration-server)
flight-test-integration-server
flight-test-integration-client-header-auth)
add_dependencies(arrow-integration flight-test-integration-client
flight-test-integration-server)
flight-test-integration-server
flight-test-integration-client-header-auth)
endif()

if(ARROW_BUILD_BENCHMARKS)
Expand Down
42 changes: 39 additions & 3 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

#include "arrow/flight/client_auth.h"
#include "arrow/flight/client_middleware.h"
#include "arrow/flight/client_header_auth_middleware.h"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ordering alphabetical? If so this should be up one...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, was waiting for feedback on design before proper formatting. I will correct this.

#include "arrow/flight/internal.h"
#include "arrow/flight/middleware.h"
#include "arrow/flight/middleware_internal.h"
Expand Down Expand Up @@ -104,6 +105,9 @@ struct ClientRpc {
std::chrono::system_clock::now() + options.timeout);
context.set_deadline(deadline);
}
for (auto metadata : options.metadata) {
context.AddMetadata(metadata.first, metadata.second);
}
}

/// \brief Add an auth token via an auth handler
Expand Down Expand Up @@ -328,7 +332,7 @@ class GrpcClientInterceptorAdapterFactory
: public grpc::experimental::ClientInterceptorFactoryInterface {
public:
GrpcClientInterceptorAdapterFactory(
std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware)
std::vector<std::shared_ptr<ClientMiddlewareFactory>>& middleware)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be const?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should either be const reference or a pointer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I'd rather we store a reference to this interceptor factory instead of a reference to the middleware with a mutable pointer back, i.e. I'd rather have this class own the middleware as it currently does, and have FlightClient call a method of this class to add more middleware at runtime.

: middleware_(middleware) {}

grpc::experimental::Interceptor* CreateClientInterceptor(
Expand Down Expand Up @@ -371,7 +375,7 @@ class GrpcClientInterceptorAdapterFactory
}

private:
std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware_;
std::vector<std::shared_ptr<ClientMiddlewareFactory>>& middleware_;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intention is to keep the reference and not a copy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a pointer if not a copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was to keep the reference, but you got me thinking more and I think I should just take a copy and then make a function to allow adding to the middleware on the fly, since the issue is that I need to be able to add and remove middleware from the interceptor. I will rejig this.

};

class GrpcClientAuthSender : public ClientAuthSender {
Expand Down Expand Up @@ -963,8 +967,9 @@ class FlightClient::FlightClientImpl {

std::vector<std::unique_ptr<grpc::experimental::ClientInterceptorFactoryInterface>>
interceptors;
middleware = std::move(options.middleware);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Goal was so that middleware was retained in this class so I could push and pop it. I am going to adjust the implementation though.

interceptors.emplace_back(
new GrpcClientInterceptorAdapterFactory(std::move(options.middleware)));
new GrpcClientInterceptorAdapterFactory(middleware));

stub_ = pb::FlightService::NewStub(
grpc::experimental::CreateCustomChannelWithInterceptors(
Expand Down Expand Up @@ -993,6 +998,30 @@ class FlightClient::FlightClientImpl {
return Status::OK();
}

Status AuthenticateBasicToken(std::string username, std::string password,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should username/password be passed by const ref?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

std::pair<std::string, std::string>* bearer_token) {
// Add bearer token factory to middleware so it can intercept the bearer token.
middleware.push_back(std::make_shared<ClientBearerTokenFactory>(bearer_token));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks odd to create the shared pointer after you've passed in the raw pointer....it seems like the method itself should take a shared pointer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The raw pointer is unpopulated, so it's passed to the BearerTokenFactory's constructor, which stores it and populated it when it receives the bearer token. I could make the client pass the whole factory in with the bearer token already inside it, but it's more work and requires they understand what's going on more than they otherwise would need to.

ClientRpc rpc({});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may actually want to allow passing call options so things like timeouts can be set.

AddBasicAuthHeaders(&rpc.context, username, password);
std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
stream = stub_->Handshake(&rpc.context);

GrpcClientAuthSender outgoing{stream};
GrpcClientAuthReader incoming{stream};
// Explicitly close our side of the connection
bool finished_writes = stream->WritesDone();
middleware.pop_back();
RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context));
if (!finished_writes) {
return MakeFlightError(FlightStatusCode::Internal,
"Could not finish writing before closing");
}
return Status::OK();
}



Status ListFlights(const FlightCallOptions& options, const Criteria& criteria,
std::unique_ptr<FlightListing>* listing) {
pb::Criteria pb_criteria;
Expand Down Expand Up @@ -1174,6 +1203,7 @@ class FlightClient::FlightClientImpl {
GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig>
noop_auth_check_;
#endif
std::vector<std::shared_ptr<arrow::flight::ClientMiddlewareFactory>> middleware;
int64_t write_size_limit_bytes_;
};

Expand All @@ -1197,6 +1227,12 @@ Status FlightClient::Authenticate(const FlightCallOptions& options,
return impl_->Authenticate(options, std::move(auth_handler));
}

Status FlightClient::AuthenticateBasicToken(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this arrow::Result<std::pair<>>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I will make this change.

std::string username, std::string password,
std::pair<std::string, std::string>* bearer_token) {
return impl_->AuthenticateBasicToken(username, password, bearer_token);
}

Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action,
std::unique_ptr<ResultStream>* results) {
return impl_->DoAction(options, action, results);
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions {

/// \brief IPC writer options, if applicable for the call.
ipc::IpcWriteOptions write_options;

/// \brief Metadata for client to add to context.
std::vector<std::pair<std::string, std::string>> metadata;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we call it headers elsewhere, so this should stay consistent with that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will correct this.

};

/// \brief Indicate that the client attempted to write a message
Expand Down Expand Up @@ -191,6 +194,14 @@ class ARROW_FLIGHT_EXPORT FlightClient {
Status Authenticate(const FlightCallOptions& options,
std::unique_ptr<ClientAuthHandler> auth_handler);

/// \brief Authenticate to the server using the given handler.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no handler in play here.

/// \param[in] username Username to use
/// \param[in] password Password to use
/// \param[in] bearer_token Bearer token retreived if applicable
/// \return Status OK if the client authenticated successfully
Status AuthenticateBasicToken(std::string username, std::string password,
std::pair<std::string, std::string>* bearer_token);

/// \brief Perform the indicated action, returning an iterator to the stream
/// of results, if any
/// \param[in] options Per-RPC options
Expand Down
124 changes: 124 additions & 0 deletions cpp/src/arrow/flight/client_header_auth_middleware.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Interfaces for defining middleware for Flight clients. Currently
// experimental.

#include "client_header_auth_middleware.h"
#include "client_middleware.h"
#include "client_auth.h"
#include "client.h"

namespace arrow {
namespace flight {

std::string base64_encode(const std::string& input);

ClientBearerTokenMiddleware::ClientBearerTokenMiddleware(
std::pair<std::string, std::string>* bearer_token_)
: bearer_token(bearer_token_) { }

void ClientBearerTokenMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { }

void ClientBearerTokenMiddleware::ReceivedHeaders(
const CallHeaders& incoming_headers) {
// Grab the auth token if one exists.
auto bearer_iter = incoming_headers.find(AUTH_HEADER);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const?

if (bearer_iter == incoming_headers.end()) {
return;
}

// Check if the value of the auth token starts with the bearer prefix, latch the token.
std::string bearer_val = bearer_iter->second.to_string();
if (bearer_val.size() > BEARER_PREFIX.size()) {
bool hasPrefix = std::equal(bearer_val.begin(), bearer_val.begin() + BEARER_PREFIX.size(), BEARER_PREFIX.begin(),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like they use snake_case instead of camelCase.

[] (const char& char1, const char& char2) {
return (std::toupper(char1) == std::toupper(char2));
}
);
if (hasPrefix) {
*bearer_token = std::make_pair(AUTH_HEADER, bearer_val);
}
}
}

void ClientBearerTokenMiddleware::CallCompleted(const Status& status) { }

void ClientBearerTokenFactory::StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to pass a reference instead of a pointer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this is a method of the base class, I think it's done this way to allow you to assign a new unique pointer to it without exposing the other middlewares they are already holding in their vector.

*middleware = std::unique_ptr<ClientBearerTokenMiddleware>(new ClientBearerTokenMiddleware(bearer_token));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::make_unique?

}

void ClientBearerTokenFactory::Reset() {
*bearer_token = std::make_pair("", "");
}

template<typename ... Args>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need an entire templated function to concatenate two strings.

std::string string_format(const std::string& format, const Args... args) {
// Check size requirement for new string and increment by 1 for null terminator.
size_t size = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1;
if(size <= 0){

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: spacing between if (, ){

throw std::runtime_error("Error during string formatting. Format: '" + format + "'.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And Arrow disallows exceptions.

}

// Create buffer for new string and write string in.
std::unique_ptr<char[]> buf(new char[size]);
std::snprintf(buf.get(), size, format.c_str(), args...);

// Convert to std::string, subtracting size by 1 to trim null terminator.
return std::string(buf.get(), buf.get() + size - 1);
}

void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password) {
const std::string formatted_credentials = string_format("%s:%s", username.c_str(), password.c_str());
context->AddMetadata(AUTH_HEADER, BASIC_PREFIX + base64_encode(formatted_credentials));
}

std::string base64_encode(const std::string& input) {
static const std::string base64_chars =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't exist in the codebase already?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I found it. Will remove this.

"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
auto get_encoded_length = [] (const std::string& in) {
return 4 * ((in.size() + 2) / 3);
};
auto get_overwrite_count = [] (const std::string& in) {
const std::string::size_type remainder = in.length() % 3;
return (remainder > 0) ? (3 - (remainder % 3)) : 0;
};

// Generate string with required length for encoding.
std::string encoded;
encoded.reserve(get_encoded_length(input));

// Loop through input writing base64 characters to string.
for (int i = 0; i < input.length();) {
uint32_t octet_1 = i < input.length() ? (unsigned char)input[i++] : 0;
uint32_t octet_2 = i < input.length() ? (unsigned char)input[i++] : 0;
uint32_t octet_3 = i < input.length() ? (unsigned char)input[i++] : 0;
uint32_t octriple = (octet_1 << 0x10) + (octet_2 << 0x08) + octet_3;
for (int j = 3; j >= 0; j--) {
encoded.push_back(base64_chars[(octriple >> j * 6) & 0x3F]);
}
}

// Round up to nearest multiple of 3 and replace characters at end based on rounding.
int overwrite_count = get_overwrite_count(input);
encoded.replace(encoded.length() - overwrite_count,
encoded.length(),
overwrite_count, '=');
return encoded;
}
} // namespace flight
} // namespace arrow
78 changes: 78 additions & 0 deletions cpp/src/arrow/flight/client_header_auth_middleware.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be an _internal.h header since I don't think any of this is intended to be directly used outside of the implementation here.

// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Interfaces for defining middleware for Flight clients. Currently
// experimental.

#pragma once

#include "arrow/flight/client_middleware.h"
#include "arrow/flight/client_auth.h"
#include "arrow/flight/client.h"

#ifdef GRPCPP_PP_INCLUDE
#include <grpcpp/grpcpp.h>
#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
#include <grpcpp/security/tls_credentials_options.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this include?

#endif
#else
#include <grpc++/grpc++.h>
#endif

#include <algorithm>
#include <iostream>
#include <cctype>
#include <string>

const std::string AUTH_HEADER = "authorization";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be in the header, or should they be only in the CC.
If in the header, I think they're created anew for each compilation unit, and thus should be defined as extern with the actual value defined in the CC to avoid multiple instantiations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be named kAuthHeader etc. and should go in the .cc file unless we actually want to expose these to users.

Or alternatively, the other constants are currently in one of the _internal.h headers.

const std::string BEARER_PREFIX = "Bearer ";
const std::string BASIC_PREFIX = "Basic ";

namespace arrow {
namespace flight {

// TODO: Need to add documentation in this file.
void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is internal - it shouldn't be in a public header. (Ditto for the grpc include.)

const std::string& username,
const std::string& password);

class ARROW_FLIGHT_EXPORT ClientBearerTokenMiddleware : public ClientMiddleware {
public:
explicit ClientBearerTokenMiddleware(
std::pair<std::string, std::string>* bearer_token_);

void SendingHeaders(AddCallHeaders* outgoing_headers);
void ReceivedHeaders(const CallHeaders& incoming_headers);
void CallCompleted(const Status& status);

private:
std::pair<std::string, std::string>* bearer_token;
};

class ARROW_FLIGHT_EXPORT ClientBearerTokenFactory : public ClientMiddlewareFactory {
public:
explicit ClientBearerTokenFactory(std::pair<std::string, std::string>* bearer_token_)
: bearer_token(bearer_token_) {}

void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware);
void Reset();

private:
std::pair<std::string, std::string>* bearer_token;
};
} // namespace flight
} // namespace arrow
Loading