Skip to content

Commit

Permalink
Improve connection termination
Browse files Browse the repository at this point in the history
  • Loading branch information
AhmedSoliman committed Feb 3, 2025
1 parent c708a94 commit 6b320a4
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 79 deletions.
11 changes: 8 additions & 3 deletions crates/core/src/network/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::time::Instant;
use enum_map::{enum_map, EnumMap};
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tracing::{debug, trace};

use restate_types::net::codec::Targeted;
use restate_types::net::codec::{serialize_message, WireEncode};
Expand Down Expand Up @@ -146,11 +147,16 @@ impl OwnedConnection {

/// Best-effort delivery of signals on the connection.
pub fn send_control_frame(&self, control: message::ConnectionControl) {
let signal = control.signal();
let msg = Message {
header: None,
body: Some(control.into()),
};
let _ = self.sender.try_send(msg);

debug!(?msg, "Sending control frame to peer");
if self.sender.try_send(msg).is_ok() {
trace!(?signal, "Control frame was written to connection");
}
}

/// A handle that sends messages through that connection. This hides the
Expand Down Expand Up @@ -709,9 +715,8 @@ pub mod test_util {
break;
};

// if it's a control signal, handle it, otherwise, route with message router.
// If it's a control signal, we terminate the connection
if let message::Body::ConnectionControl(ctrl_msg) = &body {
// do something
info!(
"Terminating connection based on signal from peer: {:?} {}",
ctrl_msg.signal(),
Expand Down
134 changes: 73 additions & 61 deletions crates/core/src/network/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::sync::{Arc, Weak};
use std::time::Instant;

use enum_map::EnumMap;
use futures::{Stream, StreamExt};
use futures::{FutureExt, Stream, StreamExt};
use opentelemetry::global;
use parking_lot::Mutex;
use rand::seq::SliceRandom;
Expand Down Expand Up @@ -500,13 +500,22 @@ where
let mut cancellation = std::pin::pin!(current_task.cancellation_token().cancelled());
let mut seen_versions = MetadataVersions::default();

let mut needs_drain = false;
// Receive loop
loop {
// read a message from the stream
let msg = tokio::select! {
biased;
_ = &mut cancellation => {
connection.send_control_frame(ConnectionControl::shutdown());
if TaskCenter::is_shutdown_requested() {
// We want to make the distinction between whether we are terminating the
// connection, or whether the node is shutting down.
connection.send_control_frame(ConnectionControl::shutdown());
} else {
connection.send_control_frame(ConnectionControl::connection_reset());
}
// we only drain the connection if we were the initiators of the termination
needs_drain = true;
break;
},
msg = incoming.next() => {
Expand All @@ -518,6 +527,7 @@ where
break;
}
None => {
// peer terminated the connection
// stream has terminated cleanly.
break;
}
Expand All @@ -528,6 +538,32 @@ where
MESSAGE_RECEIVED.increment(1);
let processing_started = Instant::now();

// body are not allowed to be empty.
let Some(body) = msg.body else {
connection
.send_control_frame(ConnectionControl::codec_error("Body is missing on message"));
break;
};

// Welcome and hello are not allowed after handshake
if body.is_welcome() || body.is_hello() {
connection.send_control_frame(ConnectionControl::codec_error(
"Hello/Welcome are not allowed after handshake",
));
break;
};

// if it's a control signal, handle it, otherwise, route with message router.
if let message::Body::ConnectionControl(ctrl_msg) = &body {
// do something
info!(
"Terminating connection based on signal from peer: {:?} {}",
ctrl_msg.signal(),
ctrl_msg.message
);
break;
}

// header is required on all messages
let Some(header) = msg.header else {
connection.send_control_frame(ConnectionControl::codec_error(
Expand Down Expand Up @@ -555,32 +591,6 @@ where
}
});

// body are not allowed to be empty.
let Some(body) = msg.body else {
connection
.send_control_frame(ConnectionControl::codec_error("Body is missing on message"));
break;
};

// Welcome and hello are not allowed after handshake
if body.is_welcome() || body.is_hello() {
connection.send_control_frame(ConnectionControl::codec_error(
"Hello/Welcome are not allowed after handshake",
));
break;
};

// if it's a control signal, handle it, otherwise, route with message router.
if let message::Body::ConnectionControl(ctrl_msg) = &body {
// do something
info!(
"Terminating connection based on signal from peer: {:?} {}",
ctrl_msg.signal(),
ctrl_msg.message
);
break;
}

match body.try_as_binary_body(connection.protocol_version) {
Ok(msg) => {
trace!(
Expand Down Expand Up @@ -632,41 +642,43 @@ where
drop(connection);

let drain_start = std::time::Instant::now();
trace!("Draining connection");
debug!("Draining connection");
let mut drain_counter = 0;
// Draining of incoming queue
while let Some(Ok(msg)) = incoming.next().await {
// ignore malformed messages
let Some(header) = msg.header else {
continue;
};
if let Some(body) = msg.body {
// we ignore non-deserializable messages (serde errors, or control signals in drain)
if let Ok(msg) = body.try_as_binary_body(protocol_version) {
drain_counter += 1;
let parent_context = header.span_context.as_ref().map(|span_ctx| {
global::get_text_map_propagator(|propagator| propagator.extract(span_ctx))
});

if let Err(e) = router
.call(
Incoming::from_parts(
msg,
// This is a dying connection, don't pass it down.
WeakConnection::new_closed(peer_node_id),
header.msg_id,
header.in_response_to,
PeerMetadataVersion::from(header),
if needs_drain {
// Draining of incoming queue
while let Some(Some(Ok(msg))) = incoming.next().now_or_never() {
// ignore malformed messages
let Some(header) = msg.header else {
continue;
};
if let Some(body) = msg.body {
// we ignore non-deserializable messages (serde errors, or control signals in drain)
if let Ok(msg) = body.try_as_binary_body(protocol_version) {
drain_counter += 1;
let parent_context = header.span_context.as_ref().map(|span_ctx| {
global::get_text_map_propagator(|propagator| propagator.extract(span_ctx))
});

if let Err(e) = router
.call(
Incoming::from_parts(
msg,
// This is a dying connection, don't pass it down.
WeakConnection::new_closed(peer_node_id),
header.msg_id,
header.in_response_to,
PeerMetadataVersion::from(header),
)
.with_parent_context(parent_context),
protocol_version,
)
.with_parent_context(parent_context),
protocol_version,
)
.await
{
debug!(
"Error processing message while draining connection: {:?}",
e
);
.await
{
debug!(
"Error processing message while draining connection: {:?}",
e
);
}
}
}
}
Expand Down
25 changes: 12 additions & 13 deletions crates/core/src/network/net_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use tokio::io;
use tokio::net::{TcpListener, UnixListener, UnixStream};
use tokio_util::net::Listener;
use tonic::transport::{Channel, Endpoint};
use tracing::{debug, info, instrument, trace, Instrument, Span};
use tracing::{debug, error_span, info, instrument, trace, Instrument, Span};

use restate_types::config::{Configuration, MetadataStoreClientOptions, NetworkingOptions};
use restate_types::errors::GenericError;
Expand Down Expand Up @@ -189,12 +189,12 @@ where
tokio::select! {
biased;
_ = &mut shutdown => {
debug!("Shutdown requested, will stop listening to new connection");
debug!("Shutdown requested, will stop listening to new connections");
drop(listener);
break;
}
incoming_connection = listener.accept() => {
let (stream, remote_addr) = incoming_connection?;
let (stream, peer_addr) = incoming_connection?;
let io = TokioIo::new(stream);

let network_options = &configuration.live_load().networking;
Expand All @@ -206,15 +206,13 @@ where
.keep_alive_interval(Some(network_options.http2_keep_alive_interval.into()))
.keep_alive_timeout(network_options.http2_keep_alive_timeout.into());


let socket_span = error_span!("SocketHandler", ?peer_addr);
let connection = graceful_shutdown.watch(builder
.serve_connection(io, service.clone()).into_owned())
.in_current_span();

// TaskCenter will wait for the parent task, we don't need individual connection
// handlers to be managed tasks. We just need to make sure that we actually try and
// shutdown connections, that's why H2Stream tasks are managed.
TaskCenter::spawn_unmanaged(TaskKind::SocketHandler, server_name, async move {
trace!("Connection accepted from {remote_addr:?}");
.serve_connection(io, service.clone()).into_owned());

TaskCenter::spawn(TaskKind::SocketHandler, server_name, async move {
debug!("New connection accepted");
if let Err(e) = connection.await {
if let Some(hyper_error) = e.downcast_ref::<hyper::Error>() {
if hyper_error.is_incomplete_message() {
Expand All @@ -226,7 +224,8 @@ where
} else {
trace!("Connection completed cleanly");
}
})?;
Ok(())
}.instrument(socket_span))?;
}
}
}
Expand All @@ -238,7 +237,7 @@ where

},
_ = tokio::time::sleep(Duration::from_secs(5)) => {
debug!("Some connections are taking longer to drain, dropping them");
info!("Some connections are taking longer to drain, dropping them");
}
}

Expand Down
7 changes: 7 additions & 0 deletions crates/core/src/task_center.rs
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,13 @@ impl TaskCenterInner {
} else {
info!(%reason, "** Shutdown requested");
}
self.cancel_tasks(Some(TaskKind::ClusterController), None)
.await;
tokio::join!(
self.cancel_tasks(Some(TaskKind::RpcServer), None),
self.cancel_tasks(Some(TaskKind::ConnectionReactor), None),
self.cancel_tasks(Some(TaskKind::SocketHandler), None)
);
self.initiate_managed_runtimes_shutdown();
self.cancel_tasks(None, None).await;
self.shutdown_managed_runtimes();
Expand Down
5 changes: 3 additions & 2 deletions crates/types/protobuf/restate/node.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

syntax = "proto3";

import "restate/common.proto";

package restate.node;

import "restate/common.proto";

//
// # Wire Protocol Of Streaming Connections
// -------------------------------------
Expand Down Expand Up @@ -56,6 +56,7 @@ message Welcome {
message Message {
enum Signal {
Signal_UNKNOWN = 0;
// Node is shutting down
SHUTDOWN = 1;
// Connection will be dropped
DRAIN_CONNECTION = 2;
Expand Down

0 comments on commit 6b320a4

Please sign in to comment.