Skip to content

Commit

Permalink
feature: add defensive check to MemStore
Browse files Browse the repository at this point in the history
  • Loading branch information
drmingdrmer committed Aug 28, 2021
1 parent fb602ac commit 420cdd7
Show file tree
Hide file tree
Showing 4 changed files with 768 additions and 29 deletions.
9 changes: 8 additions & 1 deletion async-raft/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ where S: AsyncRead + AsyncSeek + Send + Unpin + 'static
///
/// This model derives serde's traits for easily (de)serializing this
/// model for storage & retrieval.
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Default)]
pub struct HardState {
/// The last recorded term observed by this system.
pub current_term: u64,
Expand Down Expand Up @@ -112,6 +112,13 @@ where
/// For all other methods of this trait, returning an error will cause Raft to shutdown.
type ShutdownError: Error + Send + Sync + 'static;

/// Set if to turn on defensive check to unexpected input.
/// E.g. discontinuous log appending.
/// The default impl returns `false` to indicate it does impl any defensive check.
async fn defensive(&self, _d: bool) -> bool {
false
}

/// Get the latest membership config found in the log.
///
/// This must always be implemented as a reverse search through the log to find the most
Expand Down
1 change: 1 addition & 0 deletions memstore/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ readme = "README.md"
[dependencies]
anyhow = "1.0.32"
async-raft = { version="0.6", path="../async-raft" }
async-trait = "0.1.36"
serde = { version="1.0.114", features=["derive"] }
serde_json = "1.0.57"
thiserror = "1.0.20"
Expand Down
208 changes: 204 additions & 4 deletions memstore/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ pub struct MemStoreStateMachine {

/// An in-memory storage system implementing the `async_raft::RaftStorage` trait.
pub struct MemStore {
/// Turn on defensive check for inputs.
defensive: RwLock<bool>,

/// The ID of the Raft node for which this memory storage instances is configured.
id: NodeId,
/// The Raft log.
Expand All @@ -107,6 +110,7 @@ impl MemStore {
let hs = RwLock::new(None);
let current_snapshot = RwLock::new(None);
Self {
defensive: RwLock::new(false),
id,
log,
sm,
Expand All @@ -130,6 +134,7 @@ impl MemStore {
let hs = RwLock::new(hs);
let current_snapshot = RwLock::new(current_snapshot);
Self {
defensive: RwLock::new(false),
id,
log,
sm,
Expand All @@ -140,6 +145,181 @@ impl MemStore {
}
}

// TODO(xp): elaborate errors
impl MemStore {
/// Ensure that logs that have greater index than last_applied should have greater log_id.
/// Invariant must hold: `log.log_id.index > last_applied.index` implies `log.log_id > last_applied`.
pub async fn defensive_no_dirty_log(&self) -> anyhow::Result<()> {
if !*self.defensive.read().await {
return Ok(());
}

let log = self.log.read().await;
let sm = self.sm.read().await;
let last_log_id = log.iter().last().map(|(_index, entry)| entry.log_id).unwrap_or_default();
let last_applied = sm.last_applied_log;

if last_log_id.index > last_applied.index && last_log_id < last_applied {
return Err(anyhow::anyhow!("greater index log is smaller than last_applied"));
}

Ok(())
}

/// Ensure that current_term must increment for every update, and for every term there could be only one value for
/// voted_for.
pub async fn defensive_incremental_hard_state(&self, hs: &HardState) -> anyhow::Result<()> {
if !*self.defensive.read().await {
return Ok(());
}

let h = self.hs.write().await;
let curr = h.clone().unwrap_or_default();
if hs.current_term < curr.current_term {
return Err(anyhow::anyhow!("smaller term is now allowed"));
}

if hs.current_term == curr.current_term && hs.voted_for != curr.voted_for {
return Err(anyhow::anyhow!("voted_for can not change in one term"));
}

Ok(())
}

pub async fn defensive_consecutive_input<D: AppData>(&self, entries: &[&Entry<D>]) -> anyhow::Result<()> {
if !*self.defensive.read().await {
return Ok(());
}

if entries.is_empty() {
return Ok(());
}

let mut prev_log_id = entries[0].log_id;

for e in entries.iter().skip(1) {
if e.log_id.index != prev_log_id.index + 1 {
return Err(anyhow::anyhow!(
"nonconsecutive input log index: {}, {}",
prev_log_id,
e.log_id
));
}

prev_log_id = e.log_id;
}

Ok(())
}

pub async fn defensive_nonempty_input<D: AppData>(&self, entries: &[&Entry<D>]) -> anyhow::Result<()> {
if !*self.defensive.read().await {
return Ok(());
}

if entries.is_empty() {
return Err(anyhow::anyhow!("append empty entries"));
}

Ok(())
}

pub async fn defensive_append_log_index_is_last_plus_one<D: AppData>(
&self,
entries: &[&Entry<D>],
) -> anyhow::Result<()> {
if !*self.defensive.read().await {
return Ok(());
}

let last_id = self.last_log_id().await;

let first_id = entries[0].log_id;
if last_id.index + 1 != first_id.index {
return Err(anyhow::anyhow!(
"first input log index({}) is not last({}) + 1",
first_id.index,
last_id.index,
));
}

Ok(())
}

pub async fn defensive_append_log_id_gt_last<D: AppData>(&self, entries: &[&Entry<D>]) -> anyhow::Result<()> {
if !*self.defensive.read().await {
return Ok(());
}

let last_id = self.last_log_id().await;

let first_id = entries[0].log_id;
if first_id < last_id {
return Err(anyhow::anyhow!(
"first input log id({}) is not > last id({})",
first_id,
last_id,
));
}

Ok(())
}

/// Find the last known log id from log or state machine
/// If no log id found, the default one `0,0` is returned.
pub async fn last_log_id(&self) -> LogId {
let log_last_id = {
let log_last = self.log.read().await;
log_last.iter().last().map(|(_k, v)| v.log_id).unwrap_or_default()
};

let sm_last_id = self.sm.read().await.last_applied_log;

std::cmp::max(log_last_id, sm_last_id)
}

pub async fn defensive_apply_index_is_last_applied_plus_one<D: AppData>(
&self,
entries: &[&Entry<D>],
) -> anyhow::Result<()> {
if !*self.defensive.read().await {
return Ok(());
}

let last_id = self.sm.read().await.last_applied_log;

let first_id = entries[0].log_id;
if last_id.index + 1 != first_id.index {
return Err(anyhow::anyhow!(
"first input log index({}) is not last({}) + 1",
first_id.index,
last_id.index,
));
}

Ok(())
}

pub async fn defensive_apply_log_id_gt_last<D: AppData>(&self, entries: &[&Entry<D>]) -> anyhow::Result<()> {
if !*self.defensive.read().await {
return Ok(());
}

let last_id = self.sm.read().await.last_applied_log;

let first_id = entries[0].log_id;
if first_id < last_id {
return Err(anyhow::anyhow!(
"first input log id({}) is not > last id({})",
first_id,
last_id,
));
}

Ok(())
}
}

#[async_trait]
impl RaftStorageDebug<MemStoreStateMachine> for MemStore {
/// Get a handle to the state machine for testing purposes.
Expand Down Expand Up @@ -168,6 +348,8 @@ impl MemStore {
/// Go backwards through the log to find the most recent membership config <= `upto_index`.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn get_membership_from_log(&self, upto_index: Option<u64>) -> Result<MembershipConfig> {
self.defensive_no_dirty_log().await?;

let membership = {
let log = self.log.read().await;

Expand Down Expand Up @@ -213,13 +395,21 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {
type SnapshotData = Cursor<Vec<u8>>;
type ShutdownError = ShutdownError;

async fn defensive(&self, d: bool) -> bool {
let mut defensive_flag = self.defensive.write().await;
*defensive_flag = d;
d
}

#[tracing::instrument(level = "trace", skip(self))]
async fn get_membership_config(&self) -> Result<MembershipConfig> {
self.get_membership_from_log(None).await
}

#[tracing::instrument(level = "trace", skip(self))]
async fn get_initial_state(&self) -> Result<InitialState> {
self.defensive_no_dirty_log().await?;

let membership = self.get_membership_config().await?;
let mut hs = self.hs.write().await;
let log = self.log.read().await;
Expand Down Expand Up @@ -257,7 +447,11 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {

#[tracing::instrument(level = "trace", skip(self))]
async fn save_hard_state(&self, hs: &HardState) -> Result<()> {
*self.hs.write().await = Some(hs.clone());
self.defensive_incremental_hard_state(hs).await?;

let mut h = self.hs.write().await;

*h = Some(hs.clone());
Ok(())
}

Expand Down Expand Up @@ -295,6 +489,11 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {

#[tracing::instrument(level = "trace", skip(self, entries))]
async fn append_to_log(&self, entries: &[&Entry<ClientRequest>]) -> Result<()> {
self.defensive_nonempty_input(entries).await?;
self.defensive_consecutive_input(entries).await?;
self.defensive_append_log_index_is_last_plus_one(entries).await?;
self.defensive_append_log_id_gt_last(entries).await?;

let mut log = self.log.write().await;
for entry in entries {
log.insert(entry.log_id.index, (*entry).clone());
Expand All @@ -304,15 +503,16 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {

#[tracing::instrument(level = "trace", skip(self, entries))]
async fn apply_to_state_machine(&self, entries: &[&Entry<ClientRequest>]) -> Result<Vec<ClientResponse>> {
self.defensive_nonempty_input(entries).await?;
self.defensive_apply_index_is_last_applied_plus_one(entries).await?;
self.defensive_apply_log_id_gt_last(entries).await?;

let mut sm = self.sm.write().await;
let mut res = Vec::with_capacity(entries.len());

for entry in entries {
tracing::debug!("id:{} replicate to sm index:{}", self.id, entry.log_id.index);

// TODO(xp) return error if there is out of order apply
assert_eq!(sm.last_applied_log.index + 1, entry.log_id.index);

sm.last_applied_log = entry.log_id;

match entry.payload {
Expand Down
Loading

0 comments on commit 420cdd7

Please sign in to comment.