diff --git a/async-raft/src/core/client.rs b/async-raft/src/core/client.rs index d60dc78c7..0e4d066a5 100644 --- a/async-raft/src/core/client.rs +++ b/async-raft/src/core/client.rs @@ -445,7 +445,7 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork, S: RaftStorage } } // Apply this entry to the state machine and return its data response. - let res = self.core.storage.apply_entry_to_state_machine(entry).await.map_err(|err| { + let res = self.core.storage.replicate_to_state_machine(&[entry]).await.map_err(|err| { if err.downcast_ref::().is_some() { // If this is an instance of the storage impl's shutdown error, then trigger shutdown. self.core.map_fatal_storage_error(err) @@ -454,8 +454,13 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork, S: RaftStorage RaftError::RaftStorage(err) } }); + self.core.last_applied = *log_id; self.leader_report_metrics(); - res + let res = res?; + + // TODO(xp) merge this function to replication_to_state_machine? + + Ok(res.into_iter().next().unwrap()) } } diff --git a/async-raft/src/storage.rs b/async-raft/src/storage.rs index 44ef3c005..d0d6365b9 100644 --- a/async-raft/src/storage.rs +++ b/async-raft/src/storage.rs @@ -166,7 +166,7 @@ where /// Errors returned from this method will cause Raft to go into shutdown. async fn append_to_log(&self, entries: &[&Entry]) -> Result<()>; - /// Apply the given log entry to the state machine. + /// Apply the given payload of entries to the state machine. /// /// The Raft protocol guarantees that only logs which have been _committed_, that is, logs which /// have been replicated to a majority of the cluster, will be applied to the state machine. @@ -174,30 +174,22 @@ where /// This is where the business logic of interacting with your application's state machine /// should live. This is 100% application specific. Perhaps this is where an application /// specific transaction is being started, or perhaps committed. This may be where a key/value - /// is being stored. This may be where an entry is being appended to an immutable log. + /// is being stored. /// /// An impl should do: /// - Deal with the EntryPayload::Normal() log, which is business logic log. - /// - Optionally, deal with EntryPayload::ConfigChange or EntryPayload::SnapshotPointer log if they are concerned. - /// E.g. when an impl need to track the membership changing. + /// - Deal with EntryPayload::ConfigChange + /// - A EntryPayload::SnapshotPointer log should never be seen. /// + /// TODO(xp): choose one of the following policy: /// Error handling for this method is note worthy. If an error is returned from a call to this /// method, the error will be inspected, and if the error is an instance of /// `RaftStorage::ShutdownError`, then Raft will go into shutdown in order to preserve the /// safety of the data and avoid corruption. Any other errors will be propagated back up to the /// `Raft.client_write` call point. /// - /// It is important to note that even in cases where an application specific error is returned, - /// implementations should still record that the entry has been applied to the state machine. - async fn apply_entry_to_state_machine(&self, data: &Entry) -> Result; - - /// Apply the given payload of entries to the state machine, as part of replication. - /// - /// The Raft protocol guarantees that only logs which have been _committed_, that is, logs which - /// have been replicated to a majority of the cluster, will be applied to the state machine. - /// /// Errors returned from this method will cause Raft to go into shutdown. - async fn replicate_to_state_machine(&self, entries: &[&Entry]) -> Result<()>; + async fn replicate_to_state_machine(&self, entries: &[&Entry]) -> Result>; /// Perform log compaction, returning a handle to the generated snapshot. /// diff --git a/memstore/src/lib.rs b/memstore/src/lib.rs index 5721e2e5f..039f58d58 100644 --- a/memstore/src/lib.rs +++ b/memstore/src/lib.rs @@ -274,39 +274,11 @@ impl RaftStorage for MemStore { Ok(()) } - #[tracing::instrument(level = "trace", skip(self, entry))] - async fn apply_entry_to_state_machine(&self, entry: &Entry) -> Result { - let mut sm = self.sm.write().await; - - tracing::debug!("id:{} apply to sm index:{}", self.id, entry.log_id.index); - assert_eq!(sm.last_applied_log.index + 1, entry.log_id.index); - - sm.last_applied_log = entry.log_id; - - return match entry.payload { - EntryPayload::Blank => return Ok(ClientResponse(None)), - EntryPayload::SnapshotPointer(_) => return Ok(ClientResponse(None)), - EntryPayload::Normal(ref norm) => { - let data = &norm.data; - if let Some((serial, res)) = sm.client_serial_responses.get(&data.client) { - if serial == &data.serial { - return Ok(ClientResponse(res.clone())); - } - } - let previous = sm.client_status.insert(data.client.clone(), data.status.clone()); - sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone())); - Ok(ClientResponse(previous)) - } - EntryPayload::ConfigChange(ref mem) => { - sm.last_membership = Some(mem.membership.clone()); - return Ok(ClientResponse(None)); - } - }; - } - #[tracing::instrument(level = "trace", skip(self, entries))] - async fn replicate_to_state_machine(&self, entries: &[&Entry]) -> Result<()> { + async fn replicate_to_state_machine(&self, entries: &[&Entry]) -> Result> { let mut sm = self.sm.write().await; + let mut res = Vec::with_capacity(entries.len()); + for entry in entries { tracing::debug!("id:{} replicate to sm index:{}", self.id, entry.log_id.index); @@ -316,24 +288,27 @@ impl RaftStorage for MemStore { sm.last_applied_log = entry.log_id; match entry.payload { - EntryPayload::Blank => {} - EntryPayload::SnapshotPointer(_) => {} + EntryPayload::Blank => res.push(ClientResponse(None)), + EntryPayload::SnapshotPointer(_) => res.push(ClientResponse(None)), EntryPayload::Normal(ref norm) => { let data = &norm.data; - if let Some((serial, _)) = sm.client_serial_responses.get(&data.client) { + if let Some((serial, r)) = sm.client_serial_responses.get(&data.client) { if serial == &data.serial { + res.push(ClientResponse(r.clone())); continue; } } let previous = sm.client_status.insert(data.client.clone(), data.status.clone()); sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone())); + res.push(ClientResponse(previous)); } EntryPayload::ConfigChange(ref mem) => { sm.last_membership = Some(mem.membership.clone()); + res.push(ClientResponse(None)) } }; } - Ok(()) + Ok(res) } #[tracing::instrument(level = "trace", skip(self))] diff --git a/memstore/src/test.rs b/memstore/src/test.rs index 5b6e2ef49..2a0cf6370 100644 --- a/memstore/src/test.rs +++ b/memstore/src/test.rs @@ -270,7 +270,7 @@ async fn test_apply_entry_to_state_machine() -> Result<()> { }, }), }; - store.apply_entry_to_state_machine(&entry).await?; + store.replicate_to_state_machine(&[&entry]).await?; let sm = store.get_state_machine().await; assert_eq!(