-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
3rd PR for #3429, add transport socket implementation. Risk Level: Low (not enabled in main) Testing: bazel test //test/... Docs Changes: N/A Release Notes: N/A Signed-off-by: Lizan Zhou <zlizan@google.com>
- Loading branch information
Showing
8 changed files
with
811 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
#include "extensions/transport_sockets/alts/tsi_socket.h" | ||
|
||
#include "common/common/assert.h" | ||
#include "common/common/cleanup.h" | ||
#include "common/common/empty_string.h" | ||
#include "common/common/enum_to_int.h" | ||
|
||
namespace Envoy { | ||
namespace Extensions { | ||
namespace TransportSockets { | ||
namespace Alts { | ||
|
||
TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator, | ||
Network::TransportSocketPtr&& raw_socket) | ||
: handshaker_factory_(handshaker_factory), handshake_validator_(handshake_validator), | ||
raw_buffer_socket_(std::move(raw_socket)) {} | ||
|
||
TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator) | ||
: TsiSocket(handshaker_factory, handshake_validator, | ||
std::make_unique<Network::RawBufferSocket>()) {} | ||
|
||
TsiSocket::~TsiSocket() { ASSERT(!handshaker_); } | ||
|
||
void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) { | ||
callbacks_ = &callbacks; | ||
|
||
noop_callbacks_ = std::make_unique<NoOpTransportSocketCallbacks>(callbacks); | ||
raw_buffer_socket_->setTransportSocketCallbacks(*noop_callbacks_); | ||
} | ||
|
||
std::string TsiSocket::protocol() const { | ||
// TSI doesn't have a generic way to indicate application layer protocol. | ||
// TODO(lizan): support application layer protocol from TSI for known TSIs. | ||
return EMPTY_STRING; | ||
} | ||
|
||
Network::PostIoAction TsiSocket::doHandshake() { | ||
ASSERT(!handshake_complete_); | ||
ENVOY_CONN_LOG(debug, "TSI: doHandshake", callbacks_->connection()); | ||
|
||
if (!handshaker_) { | ||
handshaker_ = handshaker_factory_(callbacks_->connection().dispatcher(), | ||
callbacks_->connection().localAddress(), | ||
callbacks_->connection().remoteAddress()); | ||
handshaker_->setHandshakerCallbacks(*this); | ||
} | ||
|
||
if (!handshaker_next_calling_) { | ||
doHandshakeNext(); | ||
} | ||
return Network::PostIoAction::KeepOpen; | ||
} | ||
|
||
void TsiSocket::doHandshakeNext() { | ||
ENVOY_CONN_LOG(debug, "TSI: doHandshake next: received: {}", callbacks_->connection(), | ||
raw_read_buffer_.length()); | ||
handshaker_next_calling_ = true; | ||
Buffer::OwnedImpl handshaker_buffer; | ||
handshaker_buffer.move(raw_read_buffer_); | ||
handshaker_->next(handshaker_buffer); | ||
} | ||
|
||
Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result) { | ||
ASSERT(next_result); | ||
|
||
ENVOY_CONN_LOG(debug, "TSI: doHandshake next done: status: {} to_send: {}", | ||
callbacks_->connection(), next_result->status_, next_result->to_send_->length()); | ||
|
||
tsi_result status = next_result->status_; | ||
tsi_handshaker_result* handshaker_result = next_result->result_.get(); | ||
|
||
if (status != TSI_INCOMPLETE_DATA && status != TSI_OK) { | ||
ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status); | ||
return Network::PostIoAction::Close; | ||
} | ||
|
||
if (next_result->to_send_->length() > 0) { | ||
raw_write_buffer_.move(*next_result->to_send_); | ||
} | ||
|
||
if (status == TSI_OK && handshaker_result != nullptr) { | ||
tsi_peer peer; | ||
// returns TSI_OK assuming there is no fatal error. Asserting OK. | ||
status = tsi_handshaker_result_extract_peer(handshaker_result, &peer); | ||
ASSERT(status == TSI_OK); | ||
Cleanup peer_cleanup([&peer]() { tsi_peer_destruct(&peer); }); | ||
ENVOY_CONN_LOG(debug, "TSI: Handshake successful: peer properties: {}", | ||
callbacks_->connection(), peer.property_count); | ||
for (size_t i = 0; i < peer.property_count; ++i) { | ||
ENVOY_CONN_LOG(debug, " {}: {}", callbacks_->connection(), peer.properties[i].name, | ||
std::string(peer.properties[i].value.data, peer.properties[i].value.length)); | ||
} | ||
if (handshake_validator_) { | ||
std::string err; | ||
const bool peer_validated = handshake_validator_(peer, err); | ||
if (peer_validated) { | ||
ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection()); | ||
} else { | ||
ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(), | ||
err); | ||
return Network::PostIoAction::Close; | ||
} | ||
} else { | ||
ENVOY_CONN_LOG(debug, "TSI: Handshake validation skipped.", callbacks_->connection()); | ||
} | ||
|
||
const unsigned char* unused_bytes; | ||
size_t unused_byte_size; | ||
|
||
// returns TSI_OK assuming there is no fatal error. Asserting OK. | ||
status = | ||
tsi_handshaker_result_get_unused_bytes(handshaker_result, &unused_bytes, &unused_byte_size); | ||
ASSERT(status == TSI_OK); | ||
if (unused_byte_size > 0) { | ||
raw_read_buffer_.prepend( | ||
absl::string_view{reinterpret_cast<const char*>(unused_bytes), unused_byte_size}); | ||
} | ||
ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(), | ||
unused_byte_size); | ||
|
||
// returns TSI_OK assuming there is no fatal error. Asserting OK. | ||
tsi_frame_protector* frame_protector; | ||
status = | ||
tsi_handshaker_result_create_frame_protector(handshaker_result, NULL, &frame_protector); | ||
ASSERT(status == TSI_OK); | ||
frame_protector_ = std::make_unique<TsiFrameProtector>(frame_protector); | ||
|
||
handshake_complete_ = true; | ||
callbacks_->raiseEvent(Network::ConnectionEvent::Connected); | ||
} | ||
|
||
if (read_error_ || (!handshake_complete_ && end_stream_read_)) { | ||
ENVOY_CONN_LOG(debug, "TSI: Handshake failed: end of stream without enough data", | ||
callbacks_->connection()); | ||
return Network::PostIoAction::Close; | ||
} | ||
|
||
if (raw_read_buffer_.length() > 0) { | ||
callbacks_->setReadBufferReady(); | ||
} | ||
|
||
// Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack. | ||
if (raw_write_buffer_.length() > 0) { | ||
return raw_buffer_socket_->doWrite(raw_write_buffer_, false).action_; | ||
} | ||
|
||
return Network::PostIoAction::KeepOpen; | ||
} | ||
|
||
Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) { | ||
Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false}; | ||
if (!end_stream_read_ && !read_error_) { | ||
result = raw_buffer_socket_->doRead(raw_read_buffer_); | ||
ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}", | ||
callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_, | ||
result.end_stream_read_); | ||
if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) { | ||
return result; | ||
} | ||
|
||
end_stream_read_ = result.end_stream_read_; | ||
read_error_ = result.action_ == Network::PostIoAction::Close; | ||
} | ||
|
||
if (!handshake_complete_) { | ||
Network::PostIoAction action = doHandshake(); | ||
if (action == Network::PostIoAction::Close || !handshake_complete_) { | ||
return {action, 0, false}; | ||
} | ||
} | ||
|
||
if (handshake_complete_) { | ||
ASSERT(frame_protector_); | ||
|
||
uint64_t read_size = raw_read_buffer_.length(); | ||
ENVOY_CONN_LOG(debug, "TSI: unprotecting buffer size: {}", callbacks_->connection(), | ||
raw_read_buffer_.length()); | ||
tsi_result status = frame_protector_->unprotect(raw_read_buffer_, buffer); | ||
ENVOY_CONN_LOG(debug, "TSI: unprotected buffer left: {} result: {}", callbacks_->connection(), | ||
raw_read_buffer_.length(), tsi_result_to_string(status)); | ||
result.bytes_processed_ = read_size - raw_read_buffer_.length(); | ||
} | ||
|
||
ENVOY_CONN_LOG(debug, "TSI: do read result action {} bytes {} end_stream {}", | ||
callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_, | ||
result.end_stream_read_); | ||
return result; | ||
} | ||
|
||
Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) { | ||
if (!handshake_complete_) { | ||
Network::PostIoAction action = doHandshake(); | ||
ASSERT(action == Network::PostIoAction::KeepOpen); | ||
// TODO(lizan): Handle synchronous handshake when TsiHandshaker supports it. | ||
} | ||
|
||
if (handshake_complete_) { | ||
ASSERT(frame_protector_); | ||
ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(), | ||
buffer.length()); | ||
tsi_result status = frame_protector_->protect(buffer, raw_write_buffer_); | ||
ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(), | ||
buffer.length(), tsi_result_to_string(status)); | ||
} | ||
|
||
if (raw_write_buffer_.length() > 0) { | ||
ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(), | ||
raw_write_buffer_.length(), end_stream); | ||
return raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0)); | ||
} | ||
return {Network::PostIoAction::KeepOpen, 0, false}; | ||
} | ||
|
||
void TsiSocket::closeSocket(Network::ConnectionEvent) { | ||
if (handshaker_) { | ||
handshaker_.release()->deferredDelete(); | ||
} | ||
} | ||
|
||
void TsiSocket::onConnected() { ASSERT(!handshake_complete_); } | ||
|
||
void TsiSocket::onNextDone(NextResultPtr&& result) { | ||
handshaker_next_calling_ = false; | ||
|
||
Network::PostIoAction action = doHandshakeNextDone(std::move(result)); | ||
if (action == Network::PostIoAction::Close) { | ||
callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); | ||
} | ||
} | ||
|
||
TsiSocketFactory::TsiSocketFactory(HandshakerFactory handshaker_factory, | ||
HandshakeValidator handshake_validator) | ||
: handshaker_factory_(std::move(handshaker_factory)), | ||
handshake_validator_(std::move(handshake_validator)) {} | ||
|
||
bool TsiSocketFactory::implementsSecureTransport() const { return true; } | ||
|
||
Network::TransportSocketPtr TsiSocketFactory::createTransportSocket() const { | ||
return std::make_unique<TsiSocket>(handshaker_factory_, handshake_validator_); | ||
} | ||
|
||
} // namespace Alts | ||
} // namespace TransportSockets | ||
} // namespace Extensions | ||
} // namespace Envoy |
Oops, something went wrong.