Skip to content

Commit

Permalink
change: pass all logs to apply_entry_to_state_machine(), not just Nor…
Browse files Browse the repository at this point in the history
…mal logs.

Pass `Entry<D>` to `apply_entry_to_state_machine()`, not just the only
`EntryPayload::Normal(normal_log)`.

Thus the state machine is able to save the membership changes if it
prefers to.

Why:

In practice, a snapshot contains info about all applied logs, including
the membership config log.
Before this change, the state machine does not receive any membership
log thus when making a snapshot, one needs to walk through all applied
logs to get the last membership that is included in state machine.

By letting the state machine remember the membership log applied,
the snapshto creation becomes more convinient and intuitive: it does not
need to scan the applied logs any more.
  • Loading branch information
drmingdrmer committed Aug 16, 2021
1 parent 8e0cca5 commit adc24f5
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 71 deletions.
15 changes: 3 additions & 12 deletions async-raft/src/core/append_entries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,7 @@ impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> Ra
.filter_map(|idx| {
if let Some(entry) = self.entries_cache.remove(&idx) {
last_entry_seen = Some(entry.log_id);
match entry.payload {
EntryPayload::Normal(inner) => Some((entry.log_id, inner.data)),
_ => None,
}
Some(entry)
} else {
None
}
Expand All @@ -251,7 +248,7 @@ impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> Ra
let handle = tokio::spawn(async move {
// Create a new vector of references to the entries data ... might have to change this
// interface a bit before 1.0.
let entries_refs: Vec<_> = entries.iter().map(|(k, v)| (k, v)).collect();
let entries_refs: Vec<_> = entries.iter().collect();
storage.replicate_to_state_machine(&entries_refs).await?;
Ok(last_entry_seen)
});
Expand Down Expand Up @@ -280,13 +277,7 @@ impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> Ra
if let Some(entry) = entries.last() {
new_last_applied = Some(entry.log_id);
}
let data_entries: Vec<_> = entries
.iter()
.filter_map(|entry| match &entry.payload {
EntryPayload::Normal(inner) => Some((&entry.log_id, &inner.data)),
_ => None,
})
.collect();
let data_entries: Vec<_> = entries.iter().collect();
if data_entries.is_empty() {
return Ok(new_last_applied);
}
Expand Down
49 changes: 25 additions & 24 deletions async-raft/src/core/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,22 +307,22 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
/// Handle the post-commit logic for a client request.
#[tracing::instrument(level = "trace", skip(self, req))]
pub(super) async fn client_request_post_commit(&mut self, req: ClientRequestEntry<D, R>) {
let entry = &req.entry;

match req.tx {
ClientOrInternalResponseTx::Client(tx) => {
match &req.entry.payload {
EntryPayload::Normal(inner) => {
match self.apply_entry_to_state_machine(&req.entry.log_id, &inner.data).await {
Ok(data) => {
let _ = tx.send(Ok(ClientWriteResponse {
index: req.entry.log_id.index,
data,
}));
}
Err(err) => {
let _ = tx.send(Err(ClientWriteError::RaftError(err)));
}
match &entry.payload {
EntryPayload::Normal(_) => match self.apply_entry_to_state_machine(&entry).await {
Ok(data) => {
let _ = tx.send(Ok(ClientWriteResponse {
index: req.entry.log_id.index,
data,
}));
}
}
Err(err) => {
let _ = tx.send(Err(ClientWriteError::RaftError(err)));
}
},
_ => {
// Why is this a bug, and why are we shutting down? This is because we can not easily
// encode these constraints in the type system, and client requests should be the only
Expand All @@ -334,9 +334,15 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
}
}
ClientOrInternalResponseTx::Internal(tx) => {
self.core.last_applied = req.entry.log_id;
// TODO(xp): copied from above, need refactor.
let res = match self.apply_entry_to_state_machine(&entry).await {
Ok(_data) => Ok(entry.log_id.index),
Err(err) => Err(err),
};

self.core.last_applied = entry.log_id;
self.leader_report_metrics();
let _ = tx.send(Ok(req.entry.log_id.index));
let _ = tx.send(res);
}
}

Expand All @@ -346,13 +352,14 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>

/// Apply the given log entry to the state machine.
#[tracing::instrument(level = "trace", skip(self, entry))]
pub(super) async fn apply_entry_to_state_machine(&mut self, log_id: &LogId, entry: &D) -> RaftResult<R> {
pub(super) async fn apply_entry_to_state_machine(&mut self, entry: &Entry<D>) -> RaftResult<R> {
// First, we just ensure that we apply any outstanding up to, but not including, the index
// of the given entry. We need to be able to return the data response from applying this
// entry to the state machine.
//
// Note that this would only ever happen if a node had unapplied logs from before becoming leader.

let log_id = &entry.log_id;
let index = log_id.index;

let expected_next_index = self.core.last_applied.index + 1;
Expand All @@ -368,13 +375,7 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
self.core.last_applied = entry.log_id;
}

let data_entries: Vec<_> = entries
.iter()
.filter_map(|entry| match &entry.payload {
EntryPayload::Normal(inner) => Some((&entry.log_id, &inner.data)),
_ => None,
})
.collect();
let data_entries: Vec<_> = entries.iter().collect();
if !data_entries.is_empty() {
self.core
.storage
Expand All @@ -393,7 +394,7 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
}
}
// Apply this entry to the state machine and return its data response.
let res = self.core.storage.apply_entry_to_state_machine(&log_id, entry).await.map_err(|err| {
let res = self.core.storage.apply_entry_to_state_machine(entry).await.map_err(|err| {
if err.downcast_ref::<S::ShutdownError>().is_some() {
// If this is an instance of the storage impl's shutdown error, then trigger shutdown.
self.core.map_fatal_storage_error(err)
Expand Down
9 changes: 7 additions & 2 deletions async-raft/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ where
/// specific transaction is being started, or perhaps committed. This may be where a key/value
/// is being stored. This may be where an entry is being appended to an immutable log.
///
/// An impl should do:
/// - Deal with the EntryPayload::Normal() log, which is business logic log.
/// - Optionally, deal with EntryPayload::ConfigChange or EntryPayload::SnapshotPointer log if they are concerned.
/// E.g. when an impl need to track the membership changing.
///
/// Error handling for this method is note worthy. If an error is returned from a call to this
/// method, the error will be inspected, and if the error is an instance of
/// `RaftStorage::ShutdownError`, then Raft will go into shutdown in order to preserve the
Expand All @@ -186,15 +191,15 @@ where
///
/// It is important to note that even in cases where an application specific error is returned,
/// implementations should still record that the entry has been applied to the state machine.
async fn apply_entry_to_state_machine(&self, index: &LogId, data: &D) -> Result<R>;
async fn apply_entry_to_state_machine(&self, data: &Entry<D>) -> Result<R>;

/// Apply the given payload of entries to the state machine, as part of replication.
///
/// The Raft protocol guarantees that only logs which have been _committed_, that is, logs which
/// have been replicated to a majority of the cluster, will be applied to the state machine.
///
/// Errors returned from this method will cause Raft to go into shutdown.
async fn replicate_to_state_machine(&self, entries: &[(&LogId, &D)]) -> Result<()>;
async fn replicate_to_state_machine(&self, entries: &[&Entry<D>]) -> Result<()>;

/// Perform log compaction, returning a handle to the generated snapshot.
///
Expand Down
94 changes: 94 additions & 0 deletions async-raft/tests/state_machien_apply_membership.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
mod fixtures;

use std::sync::Arc;

use anyhow::Result;
use async_raft::raft::MembershipConfig;
use async_raft::Config;
use async_raft::State;
use fixtures::RaftRouter;
use futures::stream::StreamExt;
use maplit::hashset;

/// All log should be applied to state machine.
///
/// What does this test do?
///
/// - bring a cluster with 3 voter and 2 non-voter.
/// - check last_membership in state machine.
///
/// RUST_LOG=async_raft,memstore,state_machine_apply_membership=trace cargo test -p async-raft --test
/// state_machine_apply_membership
#[tokio::test(flavor = "multi_thread", worker_threads = 6)]
async fn state_machine_apply_membership() -> Result<()> {
fixtures::init_tracing();

// Setup test dependencies.
let config = Arc::new(Config::build("test".into()).validate().expect("failed to build Raft config"));
let router = Arc::new(RaftRouter::new(config.clone()));
router.new_raft_node(0).await;

let mut want = 0;

// Assert all nodes are in non-voter state & have no entries.
router.wait_for_log(&hashset![0], want, None, "empty").await?;
router.wait_for_state(&hashset![0], State::NonVoter, None, "empty").await?;
router.assert_pristine_cluster().await;

// Initialize the cluster, then assert that a stable cluster was formed & held.
tracing::info!("--- initializing cluster");
router.initialize_from_single_node(0).await?;
want += 1;

router.wait_for_log(&hashset![0], want, None, "init").await?;
router.assert_stable_cluster(Some(1), Some(want)).await;

for i in 0..=0 {
let sto = router.get_storage_handle(&i).await?;
let sm = sto.get_state_machine().await;
assert_eq!(
Some(MembershipConfig {
members: hashset![0],
members_after_consensus: None
}),
sm.last_membership
);
}

// Sync some new nodes.
router.new_raft_node(1).await;
router.new_raft_node(2).await;
router.new_raft_node(3).await;
router.new_raft_node(4).await;

tracing::info!("--- adding new nodes to cluster");
let mut new_nodes = futures::stream::FuturesUnordered::new();
new_nodes.push(router.add_non_voter(0, 1));
new_nodes.push(router.add_non_voter(0, 2));
new_nodes.push(router.add_non_voter(0, 3));
new_nodes.push(router.add_non_voter(0, 4));
while let Some(inner) = new_nodes.next().await {
inner?;
}

tracing::info!("--- changing cluster config");
router.change_membership(0, hashset![0, 1, 2]).await?;
want += 2;

router.wait_for_log(&hashset![0, 1, 2, 3, 4], want, None, "cluster of 5 candidates").await?;

tracing::info!("--- check applied membership config");
for i in 0..5 {
let sto = router.get_storage_handle(&i).await?;
let sm = sto.get_state_machine().await;
assert_eq!(
Some(MembershipConfig {
members: hashset![0, 1, 2],
members_after_consensus: None
}),
sm.last_membership
);
}

Ok(())
}
2 changes: 1 addition & 1 deletion async-raft/tests/stepdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async fn stepdown() -> Result<()> {
assert!(metrics.current_term >= 2, "term incr when leader changes");
router.assert_stable_cluster(Some(metrics.current_term), Some(want)).await;
router
.assert_storage_state(metrics.current_term, want, None, LogId { term: 0, index: 0 }, None)
.assert_storage_state(metrics.current_term, want, None, LogId { term: 2, index: 4 }, None)
.await;
// ----------------------------------- ^^^ this is `0` instead of `4` because blank payloads from new leaders
// and config change entries are never applied to the state machine.
Expand Down
72 changes: 49 additions & 23 deletions memstore/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ pub struct MemStoreSnapshot {
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct MemStoreStateMachine {
pub last_applied_log: LogId,

pub last_membership: Option<MembershipConfig>,

/// A mapping of client IDs to their state info.
pub client_serial_responses: HashMap<String, (u64, Option<String>)>,
/// The current status of a client by ID.
Expand Down Expand Up @@ -233,7 +236,7 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {
async fn get_log_entries(&self, start: u64, stop: u64) -> Result<Vec<Entry<ClientRequest>>> {
// Invalid request, return empty vec.
if start > stop {
tracing::error!("invalid request, start > stop");
tracing::error!("get_log_entries: invalid request, start({}) > stop({})", start, stop);
return Ok(vec![]);
}
let log = self.log.read().await;
Expand All @@ -243,7 +246,7 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {
#[tracing::instrument(level = "trace", skip(self))]
async fn delete_logs_from(&self, start: u64, stop: Option<u64>) -> Result<()> {
if stop.as_ref().map(|stop| &start > stop).unwrap_or(false) {
tracing::error!("invalid request, start > stop");
tracing::error!("delete_logs_from: invalid request, start({}) > stop({:?})", start, stop);
return Ok(());
}
let mut log = self.log.write().await;
Expand Down Expand Up @@ -276,50 +279,73 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {
Ok(())
}

#[tracing::instrument(level = "trace", skip(self, data))]
async fn apply_entry_to_state_machine(&self, index: &LogId, data: &ClientRequest) -> Result<ClientResponse> {
#[tracing::instrument(level = "trace", skip(self, entry))]
async fn apply_entry_to_state_machine(&self, entry: &Entry<ClientRequest>) -> Result<ClientResponse> {
let mut sm = self.sm.write().await;
sm.last_applied_log = *index;
if let Some((serial, res)) = sm.client_serial_responses.get(&data.client) {
if serial == &data.serial {
return Ok(ClientResponse(res.clone()));
sm.last_applied_log = entry.log_id;

return match entry.payload {
EntryPayload::Blank => return Ok(ClientResponse(None)),
EntryPayload::SnapshotPointer(_) => return Ok(ClientResponse(None)),
EntryPayload::Normal(ref norm) => {
let data = &norm.data;
if let Some((serial, res)) = sm.client_serial_responses.get(&data.client) {
if serial == &data.serial {
return Ok(ClientResponse(res.clone()));
}
}
let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
Ok(ClientResponse(previous))
}
}
let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
Ok(ClientResponse(previous))
EntryPayload::ConfigChange(ref mem) => {
sm.last_membership = Some(mem.membership.clone());
return Ok(ClientResponse(None));
}
};
}

#[tracing::instrument(level = "trace", skip(self, entries))]
async fn replicate_to_state_machine(&self, entries: &[(&LogId, &ClientRequest)]) -> Result<()> {
async fn replicate_to_state_machine(&self, entries: &[&Entry<ClientRequest>]) -> Result<()> {
let mut sm = self.sm.write().await;
for (index, data) in entries {
sm.last_applied_log = **index;
if let Some((serial, _)) = sm.client_serial_responses.get(&data.client) {
if serial == &data.serial {
continue;
for entry in entries {
sm.last_applied_log = entry.log_id;

match entry.payload {
EntryPayload::Blank => {}
EntryPayload::SnapshotPointer(_) => {}
EntryPayload::Normal(ref norm) => {
let data = &norm.data;
if let Some((serial, _)) = sm.client_serial_responses.get(&data.client) {
if serial == &data.serial {
continue;
}
}
let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
}
}
let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
EntryPayload::ConfigChange(ref mem) => {
sm.last_membership = Some(mem.membership.clone());
}
};
}
Ok(())
}

#[tracing::instrument(level = "trace", skip(self))]
async fn do_log_compaction(&self) -> Result<CurrentSnapshotData<Self::Snapshot>> {
let (data, last_applied_log);
let membership_config;
{
// Serialize the data of the state machine.
let sm = self.sm.read().await;
data = serde_json::to_vec(&*sm)?;
last_applied_log = sm.last_applied_log;
membership_config = sm.last_membership.clone().unwrap_or_else(|| MembershipConfig::new_initial(self.id));
} // Release state machine read lock.

let snapshot_size = data.len();

let membership_config = self.get_membership_from_log(Some(last_applied_log.index)).await?;

let snapshot_idx = {
let mut l = self.snapshot_idx.lock().unwrap();
*l += 1;
Expand Down
Loading

0 comments on commit adc24f5

Please sign in to comment.