-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ARROW-10487 [FlightRPC][C++] Header-based auth in clients #8724
Changes from 8 commits
c2ba2c7
ab7d6a3
ecb533f
73be3e7
74ef8ea
29a3192
fac0bd0
7fd1279
0b5a08e
6861cf0
01a134b
de78c6b
c44698f
e975fd8
37889fa
e7ac27c
ba7cb9f
1426252
516d993
3000ecb
1de10fa
065af4a
d4da03b
911fcc7
f41edce
6b6fbbe
199b655
47aa581
477d865
1cc3fdb
d27465d
d21006f
6cd8a45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,7 @@ | |
|
||
#include "arrow/flight/client_auth.h" | ||
#include "arrow/flight/client_middleware.h" | ||
#include "arrow/flight/client_header_auth_middleware.h" | ||
#include "arrow/flight/internal.h" | ||
#include "arrow/flight/middleware.h" | ||
#include "arrow/flight/middleware_internal.h" | ||
|
@@ -104,6 +105,9 @@ struct ClientRpc { | |
std::chrono::system_clock::now() + options.timeout); | ||
context.set_deadline(deadline); | ||
} | ||
for (auto metadata : options.metadata) { | ||
context.AddMetadata(metadata.first, metadata.second); | ||
} | ||
} | ||
|
||
/// \brief Add an auth token via an auth handler | ||
|
@@ -328,7 +332,7 @@ class GrpcClientInterceptorAdapterFactory | |
: public grpc::experimental::ClientInterceptorFactoryInterface { | ||
public: | ||
GrpcClientInterceptorAdapterFactory( | ||
std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware) | ||
std::vector<std::shared_ptr<ClientMiddlewareFactory>>& middleware) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be const? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should either be const reference or a pointer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, I'd rather we store a reference to this interceptor factory instead of a reference to the middleware with a mutable pointer back, i.e. I'd rather have this class own the middleware as it currently does, and have FlightClient call a method of this class to add more middleware at runtime. |
||
: middleware_(middleware) {} | ||
|
||
grpc::experimental::Interceptor* CreateClientInterceptor( | ||
|
@@ -371,7 +375,7 @@ class GrpcClientInterceptorAdapterFactory | |
} | ||
|
||
private: | ||
std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware_; | ||
std::vector<std::shared_ptr<ClientMiddlewareFactory>>& middleware_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Intention is to keep the reference and not a copy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be a pointer if not a copy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was to keep the reference, but you got me thinking more and I think I should just take a copy and then make a function to allow adding to the middleware on the fly, since the issue is that I need to be able to add and remove middleware from the interceptor. I will rejig this. |
||
}; | ||
|
||
class GrpcClientAuthSender : public ClientAuthSender { | ||
|
@@ -963,8 +967,9 @@ class FlightClient::FlightClientImpl { | |
|
||
std::vector<std::unique_ptr<grpc::experimental::ClientInterceptorFactoryInterface>> | ||
interceptors; | ||
middleware = std::move(options.middleware); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Goal was so that middleware was retained in this class so I could push and pop it. I am going to adjust the implementation though. |
||
interceptors.emplace_back( | ||
new GrpcClientInterceptorAdapterFactory(std::move(options.middleware))); | ||
new GrpcClientInterceptorAdapterFactory(middleware)); | ||
|
||
stub_ = pb::FlightService::NewStub( | ||
grpc::experimental::CreateCustomChannelWithInterceptors( | ||
|
@@ -993,6 +998,30 @@ class FlightClient::FlightClientImpl { | |
return Status::OK(); | ||
} | ||
|
||
Status AuthenticateBasicToken(std::string username, std::string password, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should username/password be passed by const ref? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
std::pair<std::string, std::string>* bearer_token) { | ||
// Add bearer token factory to middleware so it can intercept the bearer token. | ||
middleware.push_back(std::make_shared<ClientBearerTokenFactory>(bearer_token)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks odd to create the shared pointer after you've passed in the raw pointer....it seems like the method itself should take a shared pointer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The raw pointer is unpopulated, so it's passed to the BearerTokenFactory's constructor, which stores it and populated it when it receives the bearer token. I could make the client pass the whole factory in with the bearer token already inside it, but it's more work and requires they understand what's going on more than they otherwise would need to. |
||
ClientRpc rpc({}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may actually want to allow passing call options so things like timeouts can be set. |
||
AddBasicAuthHeaders(&rpc.context, username, password); | ||
std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>> | ||
stream = stub_->Handshake(&rpc.context); | ||
|
||
GrpcClientAuthSender outgoing{stream}; | ||
GrpcClientAuthReader incoming{stream}; | ||
// Explicitly close our side of the connection | ||
bool finished_writes = stream->WritesDone(); | ||
middleware.pop_back(); | ||
RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); | ||
if (!finished_writes) { | ||
return MakeFlightError(FlightStatusCode::Internal, | ||
"Could not finish writing before closing"); | ||
} | ||
return Status::OK(); | ||
} | ||
|
||
|
||
|
||
Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, | ||
std::unique_ptr<FlightListing>* listing) { | ||
pb::Criteria pb_criteria; | ||
|
@@ -1174,6 +1203,7 @@ class FlightClient::FlightClientImpl { | |
GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig> | ||
noop_auth_check_; | ||
#endif | ||
std::vector<std::shared_ptr<arrow::flight::ClientMiddlewareFactory>> middleware; | ||
int64_t write_size_limit_bytes_; | ||
}; | ||
|
||
|
@@ -1197,6 +1227,12 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, | |
return impl_->Authenticate(options, std::move(auth_handler)); | ||
} | ||
|
||
Status FlightClient::AuthenticateBasicToken( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we make this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, I will make this change. |
||
std::string username, std::string password, | ||
std::pair<std::string, std::string>* bearer_token) { | ||
return impl_->AuthenticateBasicToken(username, password, bearer_token); | ||
} | ||
|
||
Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action, | ||
std::unique_ptr<ResultStream>* results) { | ||
return impl_->DoAction(options, action, results); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,6 +65,9 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions { | |
|
||
/// \brief IPC writer options, if applicable for the call. | ||
ipc::IpcWriteOptions write_options; | ||
|
||
/// \brief Metadata for client to add to context. | ||
std::vector<std::pair<std::string, std::string>> metadata; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we call it headers elsewhere, so this should stay consistent with that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will correct this. |
||
}; | ||
|
||
/// \brief Indicate that the client attempted to write a message | ||
|
@@ -191,6 +194,14 @@ class ARROW_FLIGHT_EXPORT FlightClient { | |
Status Authenticate(const FlightCallOptions& options, | ||
std::unique_ptr<ClientAuthHandler> auth_handler); | ||
|
||
/// \brief Authenticate to the server using the given handler. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no handler in play here. |
||
/// \param[in] username Username to use | ||
/// \param[in] password Password to use | ||
/// \param[in] bearer_token Bearer token retreived if applicable | ||
/// \return Status OK if the client authenticated successfully | ||
Status AuthenticateBasicToken(std::string username, std::string password, | ||
std::pair<std::string, std::string>* bearer_token); | ||
|
||
/// \brief Perform the indicated action, returning an iterator to the stream | ||
/// of results, if any | ||
/// \param[in] options Per-RPC options | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
// Interfaces for defining middleware for Flight clients. Currently | ||
// experimental. | ||
|
||
#include "client_header_auth_middleware.h" | ||
#include "client_middleware.h" | ||
#include "client_auth.h" | ||
#include "client.h" | ||
|
||
namespace arrow { | ||
namespace flight { | ||
|
||
std::string base64_encode(const std::string& input); | ||
|
||
ClientBearerTokenMiddleware::ClientBearerTokenMiddleware( | ||
std::pair<std::string, std::string>* bearer_token_) | ||
: bearer_token(bearer_token_) { } | ||
|
||
void ClientBearerTokenMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { } | ||
|
||
void ClientBearerTokenMiddleware::ReceivedHeaders( | ||
const CallHeaders& incoming_headers) { | ||
// Grab the auth token if one exists. | ||
auto bearer_iter = incoming_headers.find(AUTH_HEADER); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const? |
||
if (bearer_iter == incoming_headers.end()) { | ||
return; | ||
} | ||
|
||
// Check if the value of the auth token starts with the bearer prefix, latch the token. | ||
std::string bearer_val = bearer_iter->second.to_string(); | ||
if (bearer_val.size() > BEARER_PREFIX.size()) { | ||
bool hasPrefix = std::equal(bearer_val.begin(), bearer_val.begin() + BEARER_PREFIX.size(), BEARER_PREFIX.begin(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like they use snake_case instead of camelCase. |
||
[] (const char& char1, const char& char2) { | ||
return (std::toupper(char1) == std::toupper(char2)); | ||
} | ||
); | ||
if (hasPrefix) { | ||
*bearer_token = std::make_pair(AUTH_HEADER, bearer_val); | ||
} | ||
} | ||
} | ||
|
||
void ClientBearerTokenMiddleware::CallCompleted(const Status& status) { } | ||
|
||
void ClientBearerTokenFactory::StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better to pass a reference instead of a pointer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't this is a method of the base class, I think it's done this way to allow you to assign a new unique pointer to it without exposing the other middlewares they are already holding in their vector. |
||
*middleware = std::unique_ptr<ClientBearerTokenMiddleware>(new ClientBearerTokenMiddleware(bearer_token)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. std::make_unique? |
||
} | ||
|
||
void ClientBearerTokenFactory::Reset() { | ||
*bearer_token = std::make_pair("", ""); | ||
} | ||
|
||
template<typename ... Args> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need an entire templated function to concatenate two strings. |
||
std::string string_format(const std::string& format, const Args... args) { | ||
// Check size requirement for new string and increment by 1 for null terminator. | ||
size_t size = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; | ||
if(size <= 0){ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: spacing between if (, ){ |
||
throw std::runtime_error("Error during string formatting. Format: '" + format + "'."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And Arrow disallows exceptions. |
||
} | ||
|
||
// Create buffer for new string and write string in. | ||
std::unique_ptr<char[]> buf(new char[size]); | ||
std::snprintf(buf.get(), size, format.c_str(), args...); | ||
|
||
// Convert to std::string, subtracting size by 1 to trim null terminator. | ||
return std::string(buf.get(), buf.get() + size - 1); | ||
} | ||
|
||
void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password) { | ||
const std::string formatted_credentials = string_format("%s:%s", username.c_str(), password.c_str()); | ||
context->AddMetadata(AUTH_HEADER, BASIC_PREFIX + base64_encode(formatted_credentials)); | ||
} | ||
|
||
std::string base64_encode(const std::string& input) { | ||
static const std::string base64_chars = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't exist in the codebase already? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I found it. Will remove this. |
||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; | ||
auto get_encoded_length = [] (const std::string& in) { | ||
return 4 * ((in.size() + 2) / 3); | ||
}; | ||
auto get_overwrite_count = [] (const std::string& in) { | ||
const std::string::size_type remainder = in.length() % 3; | ||
return (remainder > 0) ? (3 - (remainder % 3)) : 0; | ||
}; | ||
|
||
// Generate string with required length for encoding. | ||
std::string encoded; | ||
encoded.reserve(get_encoded_length(input)); | ||
|
||
// Loop through input writing base64 characters to string. | ||
for (int i = 0; i < input.length();) { | ||
uint32_t octet_1 = i < input.length() ? (unsigned char)input[i++] : 0; | ||
uint32_t octet_2 = i < input.length() ? (unsigned char)input[i++] : 0; | ||
uint32_t octet_3 = i < input.length() ? (unsigned char)input[i++] : 0; | ||
uint32_t octriple = (octet_1 << 0x10) + (octet_2 << 0x08) + octet_3; | ||
for (int j = 3; j >= 0; j--) { | ||
encoded.push_back(base64_chars[(octriple >> j * 6) & 0x3F]); | ||
} | ||
} | ||
|
||
// Round up to nearest multiple of 3 and replace characters at end based on rounding. | ||
int overwrite_count = get_overwrite_count(input); | ||
encoded.replace(encoded.length() - overwrite_count, | ||
encoded.length(), | ||
overwrite_count, '='); | ||
return encoded; | ||
} | ||
} // namespace flight | ||
} // namespace arrow |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably be an _internal.h header since I don't think any of this is intended to be directly used outside of the implementation here. |
||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
// Interfaces for defining middleware for Flight clients. Currently | ||
// experimental. | ||
|
||
#pragma once | ||
|
||
#include "arrow/flight/client_middleware.h" | ||
#include "arrow/flight/client_auth.h" | ||
#include "arrow/flight/client.h" | ||
|
||
#ifdef GRPCPP_PP_INCLUDE | ||
#include <grpcpp/grpcpp.h> | ||
#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) | ||
#include <grpcpp/security/tls_credentials_options.h> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this include? |
||
#endif | ||
#else | ||
#include <grpc++/grpc++.h> | ||
#endif | ||
|
||
#include <algorithm> | ||
#include <iostream> | ||
#include <cctype> | ||
#include <string> | ||
|
||
const std::string AUTH_HEADER = "authorization"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these need to be in the header, or should they be only in the CC. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These should be named kAuthHeader etc. and should go in the .cc file unless we actually want to expose these to users. Or alternatively, the other constants are currently in one of the _internal.h headers. |
||
const std::string BEARER_PREFIX = "Bearer "; | ||
const std::string BASIC_PREFIX = "Basic "; | ||
|
||
namespace arrow { | ||
namespace flight { | ||
|
||
// TODO: Need to add documentation in this file. | ||
void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is internal - it shouldn't be in a public header. (Ditto for the grpc include.) |
||
const std::string& username, | ||
const std::string& password); | ||
|
||
class ARROW_FLIGHT_EXPORT ClientBearerTokenMiddleware : public ClientMiddleware { | ||
public: | ||
explicit ClientBearerTokenMiddleware( | ||
std::pair<std::string, std::string>* bearer_token_); | ||
|
||
void SendingHeaders(AddCallHeaders* outgoing_headers); | ||
void ReceivedHeaders(const CallHeaders& incoming_headers); | ||
void CallCompleted(const Status& status); | ||
|
||
private: | ||
std::pair<std::string, std::string>* bearer_token; | ||
}; | ||
|
||
class ARROW_FLIGHT_EXPORT ClientBearerTokenFactory : public ClientMiddlewareFactory { | ||
public: | ||
explicit ClientBearerTokenFactory(std::pair<std::string, std::string>* bearer_token_) | ||
: bearer_token(bearer_token_) {} | ||
|
||
void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware); | ||
void Reset(); | ||
|
||
private: | ||
std::pair<std::string, std::string>* bearer_token; | ||
}; | ||
} // namespace flight | ||
} // namespace arrow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this ordering alphabetical? If so this should be up one...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, was waiting for feedback on design before proper formatting. I will correct this.