Skip to content

Commit

Permalink
1. Fix bug in trace validation spec
Browse files Browse the repository at this point in the history
2. Add random fault tests
3. Update readme to include instructions on validation.sh
  • Loading branch information
joshuazh-x committed Jan 17, 2024
1 parent 02e685a commit 4ab49b4
Show file tree
Hide file tree
Showing 19 changed files with 5,223 additions and 2,923 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.5.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.26.0 // indirect
google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
Expand Down
2 changes: 2 additions & 0 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ func setupNode(c *Config, peers []Peer) *node {
// Peers must not be zero length; call RestartNode in that case.
func StartNode(c *Config, peers []Peer) Node {
n := setupNode(c, peers)
traceInitState(n.rn.raft)
go n.run()
return n
}
Expand All @@ -284,6 +285,7 @@ func RestartNode(c *Config) Node {
panic(err)
}
n := newNode(rn)
traceInitState(n.rn.raft)
go n.run()
return &n
}
Expand Down
5 changes: 2 additions & 3 deletions raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ func newRaft(c *Config) *raft {
disableConfChangeValidation: c.DisableConfChangeValidation,
stepDownOnRemoval: c.StepDownOnRemoval,
traceLogger: c.TraceLogger,
initStateTraced: false,
}

cfg, prs, err := confchange.Restore(confchange.Changer{
Expand Down Expand Up @@ -1069,7 +1070,6 @@ func (r *raft) poll(id uint64, t pb.MessageType, v bool) (granted int, rejected
}

func (r *raft) Step(m pb.Message) error {
traceInitStateOnce(r)
traceReceiveMessage(r, &m)

// Handle the message term, which may result in our stepping down to a follower.
Expand Down Expand Up @@ -1296,8 +1296,6 @@ func stepLeader(r *raft, m pb.Message) error {
cc = ccc
}
if cc != nil {
traceChangeConfEvent(cc, r)

alreadyPending := r.pendingConfIndex > r.raftLog.applied
alreadyJoint := len(r.prs.Config.Voters[1]) > 0
wantsLeaveJoint := len(cc.AsV2().Changes) == 0
Expand All @@ -1316,6 +1314,7 @@ func stepLeader(r *raft, m pb.Message) error {
m.Entries[i] = pb.Entry{Type: pb.EntryNormal}
} else {
r.pendingConfIndex = r.raftLog.lastIndex() + uint64(i) + 1
traceChangeConfEvent(cc, r)
}
}
}
Expand Down
260 changes: 260 additions & 0 deletions rafttest/cluster.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
package rafttest

import (
"context"
"fmt"
"time"

"go.etcd.io/raft/v3"
"go.etcd.io/raft/v3/raftpb"
)

type clusterConfig struct {
size int
traceLogger raft.TraceLogger
tickInterval time.Duration
}

type endpoint struct {
index int
node *node
}

type getEndpoint struct {
i int
c chan endpoint
}

type cluster struct {
nodes map[uint64]*node
network *raftNetwork

stopc chan struct{}
removec chan uint64
addc chan uint64
getc chan getEndpoint
faultc chan func(*cluster)

traceLogger raft.TraceLogger
tickInterval time.Duration
}

func newCluster(c clusterConfig) *cluster {
tickInterval := c.tickInterval
if tickInterval == 0 {
tickInterval = 100 * time.Millisecond
}
ids := make([]uint64, c.size)
peers := make([]raft.Peer, c.size)
for i := 0; i < c.size; i++ {
peers[i].ID = uint64(i + 1)
ids[i] = peers[i].ID
}
network := newRaftNetwork(ids...)
nodes := make(map[uint64]*node, c.size)
for i := 0; i < c.size; i++ {
nodes[uint64(i+1)] = startNodeWithConfig(nodeConfig{
id: uint64(i + 1),
peers: peers,
iface: network.nodeNetwork(uint64(i + 1)),
traceLogger: c.traceLogger,
tickInterval: tickInterval,
})
}

cl := &cluster{
nodes: nodes,
network: network,
stopc: make(chan struct{}),
removec: make(chan uint64),
addc: make(chan uint64),
getc: make(chan getEndpoint),
faultc: make(chan func(*cluster)),
traceLogger: c.traceLogger,
tickInterval: tickInterval,
}

cl.waitLeader()

go cl.mgmtLoop()

return cl
}

func (cl *cluster) mgmtLoop() {
peers := []raft.Peer{}
for _, n := range cl.nodes {
peers = append(peers, raft.Peer{ID: n.id, Context: []byte{}})
}
for {
select {
case id := <-cl.addc:
cl.network.changeFace(id, true)
peers = append(peers, raft.Peer{ID: id, Context: []byte{}})
node := startNodeWithConfig(nodeConfig{
id: id,
peers: nil,
iface: cl.network.nodeNetwork(id),
traceLogger: cl.traceLogger,
tickInterval: cl.tickInterval,
})
cl.nodes[id] = node
case id := <-cl.removec:
cl.network.changeFace(id, false)
cl.nodes[id].stop()
delete(cl.nodes, id)
for i, p := range peers {
if p.ID == id {
peers = append(peers[:i], peers[i+1:]...)
break
}
}
case gn := <-cl.getc:
i := gn.i % len(peers)
nid := peers[i].ID
gn.c <- endpoint{index: i, node: cl.nodes[nid]}
case <-cl.stopc:
for _, n := range cl.nodes {
n.stop()
}
close(cl.stopc)
return
case f := <-cl.faultc:
cl.network.clearFault()
for _, n := range cl.nodes {
if n.stopped {
n.restart()
}
}
f(cl)
}
}
}

func (cl *cluster) stop() {
cl.stopc <- struct{}{}
<-cl.stopc
}

func (cl *cluster) removeNode(id uint64) {
cl.removec <- id
}

func (cl *cluster) addNode(id uint64) {
cl.addc <- id
}

func (cl *cluster) newClient() *client {
return &client{cluster: cl, epc: make(chan endpoint)}
}

func (cl *cluster) waitLeader() uint64 {
var l map[uint64]struct{}
var lindex uint64

for {
l = make(map[uint64]struct{})

for i, n := range cl.nodes {
lead := n.Status().SoftState.Lead
if lead != 0 {
l[lead] = struct{}{}
if n.id == lead {
lindex = i
}
}
}

if len(l) == 1 {
return lindex
}
}
}

type client struct {
cluster *cluster
epi int
epc chan endpoint
}

func (cl *client) propose(ctx context.Context, data []byte) error {
ep := cl.getEndpoint()
return ep.Propose(ctx, data)
}

func (cl *client) addNode(ctx context.Context, n uint64) error {
change := raftpb.ConfChangeSingle{
Type: raftpb.ConfChangeAddNode,
NodeID: n,
}
cc := raftpb.ConfChangeV2{
Transition: 0,
Changes: []raftpb.ConfChangeSingle{change},
Context: []byte{},
}

ep := cl.getEndpoint()
if err := ep.ProposeConfChange(ctx, cc); err != nil {
return err
}

toc := time.After(cl.cluster.tickInterval * 50)
for {
select {
case <-toc:
return fmt.Errorf("addNode timeout")
default:
}
st := ep.Status()
if _, exist := st.Config.Voters[0][n]; exist {
break
}
}
cl.cluster.addNode(n)
return nil
}

func (cl *client) removeNode(ctx context.Context, n uint64) error {
change := raftpb.ConfChangeSingle{
Type: raftpb.ConfChangeRemoveNode,
NodeID: n,
}
cc := raftpb.ConfChangeV2{
Transition: 0,
Changes: []raftpb.ConfChangeSingle{change},
Context: []byte{},
}

ep := cl.getEndpoint()
if err := ep.ProposeConfChange(ctx, cc); err != nil {
return err
}

toc := time.After(cl.cluster.tickInterval * 50)
for {
select {
case <-toc:
return fmt.Errorf("removeNode timeout")
default:
}
st := ep.Status()
if _, exist := st.Config.Voters[0][n]; !exist {
break
}
}

cl.cluster.removeNode(n)
return nil
}

func (cl *client) getEndpoint() *node {
// round robin
ge := getEndpoint{
i: cl.epi + 1,
c: cl.epc,
}
cl.cluster.getc <- ge
ep := <-ge.c
cl.epi = ep.index
return ep.node
}
Loading

0 comments on commit 4ab49b4

Please sign in to comment.