Skip to content

Commit

Permalink
alts: add gRPC TSI socket (#4153)
Browse files Browse the repository at this point in the history
3rd PR for #3429, add transport socket implementation.

Risk Level: Low (not enabled in main)
Testing: bazel test //test/...
Docs Changes: N/A
Release Notes: N/A

Signed-off-by: Lizan Zhou <zlizan@google.com>
  • Loading branch information
lizan authored Aug 31, 2018
1 parent f0363ae commit 1212423
Show file tree
Hide file tree
Showing 8 changed files with 811 additions and 2 deletions.
23 changes: 22 additions & 1 deletion source/extensions/transport_sockets/alts/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ envoy_cc_library(
hdrs = [
"tsi_frame_protector.h",
],
repository = "@envoy",
deps = [
":grpc_tsi_wrapper",
"//source/common/buffer:buffer_lib",
Expand All @@ -54,6 +53,28 @@ envoy_cc_library(
],
)

envoy_cc_library(
name = "tsi_socket",
srcs = [
"tsi_socket.cc",
],
hdrs = [
"tsi_socket.h",
],
deps = [
":noop_transport_socket_callbacks_lib",
":tsi_frame_protector",
":tsi_handshaker",
"//include/envoy/network:transport_socket_interface",
"//source/common/buffer:buffer_lib",
"//source/common/common:cleanup_lib",
"//source/common/common:empty_string",
"//source/common/common:enum_to_int",
"//source/common/network:raw_buffer_socket_lib",
"//source/common/protobuf:utility_lib",
],
)

envoy_cc_library(
name = "noop_transport_socket_callbacks_lib",
hdrs = ["noop_transport_socket_callbacks.h"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class NoOpTransportSocketCallbacks : public Network::TransportSocketCallbacks {
Network::TransportSocketCallbacks& parent_;
};

typedef std::unique_ptr<NoOpTransportSocketCallbacks> NoOpTransportSocketCallbacksPtr;

} // namespace Alts
} // namespace TransportSockets
} // namespace Extensions
Expand Down
245 changes: 245 additions & 0 deletions source/extensions/transport_sockets/alts/tsi_socket.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
#include "extensions/transport_sockets/alts/tsi_socket.h"

#include "common/common/assert.h"
#include "common/common/cleanup.h"
#include "common/common/empty_string.h"
#include "common/common/enum_to_int.h"

namespace Envoy {
namespace Extensions {
namespace TransportSockets {
namespace Alts {

TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
Network::TransportSocketPtr&& raw_socket)
: handshaker_factory_(handshaker_factory), handshake_validator_(handshake_validator),
raw_buffer_socket_(std::move(raw_socket)) {}

TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator)
: TsiSocket(handshaker_factory, handshake_validator,
std::make_unique<Network::RawBufferSocket>()) {}

TsiSocket::~TsiSocket() { ASSERT(!handshaker_); }

void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) {
callbacks_ = &callbacks;

noop_callbacks_ = std::make_unique<NoOpTransportSocketCallbacks>(callbacks);
raw_buffer_socket_->setTransportSocketCallbacks(*noop_callbacks_);
}

std::string TsiSocket::protocol() const {
// TSI doesn't have a generic way to indicate application layer protocol.
// TODO(lizan): support application layer protocol from TSI for known TSIs.
return EMPTY_STRING;
}

Network::PostIoAction TsiSocket::doHandshake() {
ASSERT(!handshake_complete_);
ENVOY_CONN_LOG(debug, "TSI: doHandshake", callbacks_->connection());

if (!handshaker_) {
handshaker_ = handshaker_factory_(callbacks_->connection().dispatcher(),
callbacks_->connection().localAddress(),
callbacks_->connection().remoteAddress());
handshaker_->setHandshakerCallbacks(*this);
}

if (!handshaker_next_calling_) {
doHandshakeNext();
}
return Network::PostIoAction::KeepOpen;
}

void TsiSocket::doHandshakeNext() {
ENVOY_CONN_LOG(debug, "TSI: doHandshake next: received: {}", callbacks_->connection(),
raw_read_buffer_.length());
handshaker_next_calling_ = true;
Buffer::OwnedImpl handshaker_buffer;
handshaker_buffer.move(raw_read_buffer_);
handshaker_->next(handshaker_buffer);
}

Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result) {
ASSERT(next_result);

ENVOY_CONN_LOG(debug, "TSI: doHandshake next done: status: {} to_send: {}",
callbacks_->connection(), next_result->status_, next_result->to_send_->length());

tsi_result status = next_result->status_;
tsi_handshaker_result* handshaker_result = next_result->result_.get();

if (status != TSI_INCOMPLETE_DATA && status != TSI_OK) {
ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status);
return Network::PostIoAction::Close;
}

if (next_result->to_send_->length() > 0) {
raw_write_buffer_.move(*next_result->to_send_);
}

if (status == TSI_OK && handshaker_result != nullptr) {
tsi_peer peer;
// returns TSI_OK assuming there is no fatal error. Asserting OK.
status = tsi_handshaker_result_extract_peer(handshaker_result, &peer);
ASSERT(status == TSI_OK);
Cleanup peer_cleanup([&peer]() { tsi_peer_destruct(&peer); });
ENVOY_CONN_LOG(debug, "TSI: Handshake successful: peer properties: {}",
callbacks_->connection(), peer.property_count);
for (size_t i = 0; i < peer.property_count; ++i) {
ENVOY_CONN_LOG(debug, " {}: {}", callbacks_->connection(), peer.properties[i].name,
std::string(peer.properties[i].value.data, peer.properties[i].value.length));
}
if (handshake_validator_) {
std::string err;
const bool peer_validated = handshake_validator_(peer, err);
if (peer_validated) {
ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection());
} else {
ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(),
err);
return Network::PostIoAction::Close;
}
} else {
ENVOY_CONN_LOG(debug, "TSI: Handshake validation skipped.", callbacks_->connection());
}

const unsigned char* unused_bytes;
size_t unused_byte_size;

// returns TSI_OK assuming there is no fatal error. Asserting OK.
status =
tsi_handshaker_result_get_unused_bytes(handshaker_result, &unused_bytes, &unused_byte_size);
ASSERT(status == TSI_OK);
if (unused_byte_size > 0) {
raw_read_buffer_.prepend(
absl::string_view{reinterpret_cast<const char*>(unused_bytes), unused_byte_size});
}
ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(),
unused_byte_size);

// returns TSI_OK assuming there is no fatal error. Asserting OK.
tsi_frame_protector* frame_protector;
status =
tsi_handshaker_result_create_frame_protector(handshaker_result, NULL, &frame_protector);
ASSERT(status == TSI_OK);
frame_protector_ = std::make_unique<TsiFrameProtector>(frame_protector);

handshake_complete_ = true;
callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
}

if (read_error_ || (!handshake_complete_ && end_stream_read_)) {
ENVOY_CONN_LOG(debug, "TSI: Handshake failed: end of stream without enough data",
callbacks_->connection());
return Network::PostIoAction::Close;
}

if (raw_read_buffer_.length() > 0) {
callbacks_->setReadBufferReady();
}

// Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack.
if (raw_write_buffer_.length() > 0) {
return raw_buffer_socket_->doWrite(raw_write_buffer_, false).action_;
}

return Network::PostIoAction::KeepOpen;
}

Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) {
Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};
if (!end_stream_read_ && !read_error_) {
result = raw_buffer_socket_->doRead(raw_read_buffer_);
ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}",
callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
result.end_stream_read_);
if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) {
return result;
}

end_stream_read_ = result.end_stream_read_;
read_error_ = result.action_ == Network::PostIoAction::Close;
}

if (!handshake_complete_) {
Network::PostIoAction action = doHandshake();
if (action == Network::PostIoAction::Close || !handshake_complete_) {
return {action, 0, false};
}
}

if (handshake_complete_) {
ASSERT(frame_protector_);

uint64_t read_size = raw_read_buffer_.length();
ENVOY_CONN_LOG(debug, "TSI: unprotecting buffer size: {}", callbacks_->connection(),
raw_read_buffer_.length());
tsi_result status = frame_protector_->unprotect(raw_read_buffer_, buffer);
ENVOY_CONN_LOG(debug, "TSI: unprotected buffer left: {} result: {}", callbacks_->connection(),
raw_read_buffer_.length(), tsi_result_to_string(status));
result.bytes_processed_ = read_size - raw_read_buffer_.length();
}

ENVOY_CONN_LOG(debug, "TSI: do read result action {} bytes {} end_stream {}",
callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
result.end_stream_read_);
return result;
}

Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) {
if (!handshake_complete_) {
Network::PostIoAction action = doHandshake();
ASSERT(action == Network::PostIoAction::KeepOpen);
// TODO(lizan): Handle synchronous handshake when TsiHandshaker supports it.
}

if (handshake_complete_) {
ASSERT(frame_protector_);
ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(),
buffer.length());
tsi_result status = frame_protector_->protect(buffer, raw_write_buffer_);
ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(),
buffer.length(), tsi_result_to_string(status));
}

if (raw_write_buffer_.length() > 0) {
ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
raw_write_buffer_.length(), end_stream);
return raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
}
return {Network::PostIoAction::KeepOpen, 0, false};
}

void TsiSocket::closeSocket(Network::ConnectionEvent) {
if (handshaker_) {
handshaker_.release()->deferredDelete();
}
}

void TsiSocket::onConnected() { ASSERT(!handshake_complete_); }

void TsiSocket::onNextDone(NextResultPtr&& result) {
handshaker_next_calling_ = false;

Network::PostIoAction action = doHandshakeNextDone(std::move(result));
if (action == Network::PostIoAction::Close) {
callbacks_->connection().close(Network::ConnectionCloseType::NoFlush);
}
}

TsiSocketFactory::TsiSocketFactory(HandshakerFactory handshaker_factory,
HandshakeValidator handshake_validator)
: handshaker_factory_(std::move(handshaker_factory)),
handshake_validator_(std::move(handshake_validator)) {}

bool TsiSocketFactory::implementsSecureTransport() const { return true; }

Network::TransportSocketPtr TsiSocketFactory::createTransportSocket() const {
return std::make_unique<TsiSocket>(handshaker_factory_, handshake_validator_);
}

} // namespace Alts
} // namespace TransportSockets
} // namespace Extensions
} // namespace Envoy
Loading

0 comments on commit 1212423

Please sign in to comment.