diff --git a/raft/log.go b/raft/log.go index 93c80034901..6ece348b7f8 100644 --- a/raft/log.go +++ b/raft/log.go @@ -67,11 +67,7 @@ func (l *raftLog) maybeAppend(index, logTerm, committed uint64, ents ...pb.Entry default: l.append(ci-1, ents[ci-from:]...) } - tocommit := min(committed, lastnewi) - // if toCommit > commitIndex, set commitIndex = toCommit - if l.committed < tocommit { - l.committed = tocommit - } + l.commitTo(min(committed, lastnewi)) return lastnewi, true } return 0, false @@ -125,6 +121,16 @@ func (l *raftLog) nextEnts() (ents []pb.Entry) { return nil } +func (l *raftLog) commitTo(tocommit uint64) { + // never decrease commit + if l.committed < tocommit { + if l.lastIndex() < tocommit { + panic("committed out of range") + } + l.committed = tocommit + } +} + func (l *raftLog) appliedTo(i uint64) { if i == 0 { return @@ -179,7 +185,7 @@ func (l *raftLog) matchTerm(i, term uint64) bool { func (l *raftLog) maybeCommit(maxIndex, term uint64) bool { if maxIndex > l.committed && l.term(maxIndex) == term { - l.committed = maxIndex + l.commitTo(maxIndex) return true } return false diff --git a/raft/log_test.go b/raft/log_test.go index eb479b48db5..6ce4f5db952 100644 --- a/raft/log_test.go +++ b/raft/log_test.go @@ -386,6 +386,38 @@ func TestUnstableEnts(t *testing.T) { } } +func TestCommitTo(t *testing.T) { + previousEnts := []pb.Entry{{Term: 1, Index: 1}, {Term: 2, Index: 2}, {Term: 3, Index: 3}} + commit := uint64(2) + tests := []struct { + commit uint64 + wcommit uint64 + wpanic bool + }{ + {3, 3, false}, + {1, 2, false}, // never decrease + {4, 0, true}, // commit out of range -> panic + } + for i, tt := range tests { + func() { + defer func() { + if r := recover(); r != nil { + if tt.wpanic != true { + t.Errorf("%d: panic = %v, want %v", i, true, tt.wpanic) + } + } + }() + raftLog := newLog() + raftLog.append(0, previousEnts...) + raftLog.committed = commit + raftLog.commitTo(tt.commit) + if raftLog.committed != tt.wcommit { + t.Errorf("#%d: committed = %d, want %d", i, raftLog.committed, tt.wcommit) + } + }() + } +} + func TestStableTo(t *testing.T) { tests := []struct { stable uint64 diff --git a/raft/raft.go b/raft/raft.go index 67152968b6e..a11850fd63d 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -387,6 +387,10 @@ func (r *raft) handleAppendEntries(m pb.Message) { } } +func (r *raft) handleHeartbeat(m pb.Message) { + r.raftLog.commitTo(m.Commit) +} + func (r *raft) handleSnapshot(m pb.Message) { if r.restore(m.Snapshot) { r.send(pb.Message{To: m.From, Type: pb.MsgAppResp, Index: r.raftLog.lastIndex()}) @@ -482,7 +486,11 @@ func stepFollower(r *raft, m pb.Message) { case pb.MsgApp: r.elapsed = 0 r.lead = m.From - r.handleAppendEntries(m) + if m.LogTerm == 0 && m.Index == 0 && len(m.Entries) == 0 { + r.handleHeartbeat(m) + } else { + r.handleAppendEntries(m) + } case pb.MsgSnap: r.elapsed = 0 r.handleSnapshot(m) diff --git a/raft/raft_paper_test.go b/raft/raft_paper_test.go index 71e8a5cf56e..2ed44fabe6f 100644 --- a/raft/raft_paper_test.go +++ b/raft/raft_paper_test.go @@ -593,7 +593,6 @@ func TestFollowerCheckMsgApp(t *testing.T) { index uint64 wreject bool }{ - {ents[0].Term, ents[0].Index, false}, {ents[1].Term, ents[1].Index, false}, {ents[2].Term, ents[2].Index, false}, {ents[1].Term, ents[1].Index + 1, true}, diff --git a/raft/raft_test.go b/raft/raft_test.go index 8e1265e930d..c48519175e4 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -674,6 +674,35 @@ func TestHandleMsgApp(t *testing.T) { } } +// TestHandleHeartbeat ensures that the follower commits to the commit in the message. +func TestHandleHeartbeat(t *testing.T) { + commit := uint64(2) + tests := []struct { + m pb.Message + wCommit uint64 + }{ + {pb.Message{Type: pb.MsgApp, Term: 2, Commit: commit + 1}, commit + 1}, + {pb.Message{Type: pb.MsgApp, Term: 2, Commit: commit - 1}, commit}, // do not decrease commit + } + + for i, tt := range tests { + sm := &raft{ + state: StateFollower, + HardState: pb.HardState{Term: 2}, + raftLog: &raftLog{committed: 0, ents: []pb.Entry{{}, {Term: 1}, {Term: 2}, {Term: 3}}}, + } + sm.raftLog.commitTo(commit) + sm.handleHeartbeat(tt.m) + if sm.raftLog.committed != tt.wCommit { + t.Errorf("#%d: committed = %d, want %d", i, sm.raftLog.committed, tt.wCommit) + } + m := sm.readMessages() + if len(m) != 0 { + t.Fatalf("#%d: msg = nil, want 0", i) + } + } +} + func TestRecvMsgVote(t *testing.T) { tests := []struct { state StateType