diff --git a/cmd/cmd.go b/cmd/cmd.go index 25123cce1..1209576fe 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -361,34 +361,14 @@ func CreateNode(numValidators int, chain genesis.ChainType, workingDir string, func StartNode(workingDir string, passwordFetcher func(*wallet.Wallet) (string, bool)) ( *node.Node, *wallet.Wallet, error, ) { - gen, err := genesis.LoadFromFile(PactusGenesisPath(workingDir)) - if err != nil { - return nil, nil, err - } - - if !gen.ChainType().IsMainnet() { - crypto.AddressHRP = "tpc" - crypto.PublicKeyHRP = "tpublic" - crypto.PrivateKeyHRP = "tsecret" - crypto.XPublicKeyHRP = "txpublic" - crypto.XPrivateKeyHRP = "txsecret" - } - - walletsDir := PactusWalletDir(workingDir) - confPath := PactusConfigPath(workingDir) - - conf, err := MakeConfig(gen, confPath, walletsDir) - if err != nil { - return nil, nil, err - } - - err = conf.BasicCheck() + conf, gen, err := MakeConfig(workingDir) if err != nil { return nil, nil, err } defaultWalletPath := PactusDefaultWalletPath(workingDir) - walletInstance, err := wallet.Open(defaultWalletPath, true) + walletInstance, err := wallet.Open(defaultWalletPath, true, + wallet.WithCustomServers([]string{conf.GRPC.Listen})) if err != nil { return nil, nil, err } @@ -451,13 +431,30 @@ func makeLocalGenesis(w wallet.Wallet) *genesis.Genesis { return gen } -// MakeConfig opens the given config file and creates the appropriate configuration per chain type. -// The chain type is determined from the genesis document. -// It also updates some private configurations, like "wallets directory". -// TODO: write test for me. -func MakeConfig(genDoc *genesis.Genesis, confPath, walletsDir string) (*config.Config, error) { +// MakeConfig attempts to load the configuration file and +// returns an instance of the configuration along with the genesis document. +// The genesis document is required to determine the chain type, which influences the configuration settings. +// The function sets various private configurations, such as the "wallets directory" and chain-specific HRP values. +// If the configuration file cannot be loaded, it tries to recover or restore the configuration. +func MakeConfig(workingDir string) (*config.Config, *genesis.Genesis, error) { + gen, err := genesis.LoadFromFile(PactusGenesisPath(workingDir)) + if err != nil { + return nil, nil, err + } + + if !gen.ChainType().IsMainnet() { + crypto.AddressHRP = "tpc" + crypto.PublicKeyHRP = "tpublic" + crypto.PrivateKeyHRP = "tsecret" + crypto.XPublicKeyHRP = "txpublic" + crypto.XPrivateKeyHRP = "txsecret" + } + + walletsDir := PactusWalletDir(workingDir) + confPath := PactusConfigPath(workingDir) + var defConf *config.Config - chainType := genDoc.ChainType() + chainType := gen.ChainType() switch chainType { case genesis.Mainnet: @@ -475,12 +472,12 @@ func MakeConfig(genDoc *genesis.Genesis, confPath, walletsDir string) (*config.C conf, err = RecoverConfig(confPath, defConf, chainType) if err != nil { - return nil, err + return nil, nil, err } } // Now we can update the private filed, if any - genParams := genDoc.Params() + genParams := gen.Params() conf.Store.TxCacheSize = genParams.TransactionToLiveInterval conf.Store.SortitionCacheSize = genParams.SortitionInterval @@ -493,7 +490,11 @@ func MakeConfig(genDoc *genesis.Genesis, confPath, walletsDir string) (*config.C conf.WalletManager.ChainType = chainType conf.WalletManager.WalletsDir = walletsDir - return conf, nil + if err := conf.BasicCheck(); err != nil { + return nil, nil, err + } + + return conf, gen, nil } func RecoverConfig(confPath string, defConf *config.Config, chainType genesis.ChainType) (*config.Config, error) { diff --git a/cmd/daemon/init.go b/cmd/daemon/init.go index 004811fe7..ad4f0e631 100644 --- a/cmd/daemon/init.go +++ b/cmd/daemon/init.go @@ -17,8 +17,8 @@ func buildInitCmd(parentCmd *cobra.Command) { Short: "initialize the Pactus Blockchain node", } parentCmd.AddCommand(initCmd) - workingDirOpt := initCmd.Flags().StringP("working-dir", "w", cmd.PactusDefaultHomeDir(), - "a path to the working directory to save the wallet and node files") + + workingDirOpt := addWorkingDirOption(initCmd) testnetOpt := initCmd.Flags().Bool("testnet", false, "initialize working directory for joining the testnet") diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index 4cedc1ff4..f877ef29d 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -23,9 +23,15 @@ func main() { buildVersionCmd(rootCmd) buildInitCmd(rootCmd) buildStartCmd(rootCmd) + buildPruneCmd(rootCmd) err := rootCmd.Execute() if err != nil { cmd.PrintErrorMsgf("%s", err) } } + +func addWorkingDirOption(c *cobra.Command) *string { + return c.Flags().StringP("working-dir", "w", cmd.PactusDefaultHomeDir(), + "the path to the working directory that keeps the wallets and node files") +} diff --git a/cmd/daemon/prune.go b/cmd/daemon/prune.go new file mode 100644 index 000000000..3ccaa0c9d --- /dev/null +++ b/cmd/daemon/prune.go @@ -0,0 +1,115 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + + "github.com/gofrs/flock" + "github.com/pactus-project/pactus/cmd" + "github.com/pactus-project/pactus/store" + "github.com/spf13/cobra" +) + +func buildPruneCmd(parentCmd *cobra.Command) { + pruneCmd := &cobra.Command{ + Use: "prune", + Short: "prune old blocks and transactions from client", + Long: "The prune command optimizes blockchain storage by removing outdated blocks and transactions, " + + "freeing up disk space and enhancing client performance.", + } + parentCmd.AddCommand(pruneCmd) + + workingDirOpt := addWorkingDirOption(pruneCmd) + + pruneCmd.Run = func(_ *cobra.Command, _ []string) { + workingDir, _ := filepath.Abs(*workingDirOpt) + // change working directory + err := os.Chdir(workingDir) + cmd.FatalErrorCheck(err) + + // Define the lock file path + lockFilePath := filepath.Join(workingDir, ".pactus.lock") + fileLock := flock.New(lockFilePath) + + locked, err := fileLock.TryLock() + if err != nil { + // handle unable to attempt to acquire lock + cmd.FatalErrorCheck(err) + } + + if !locked { + cmd.PrintWarnMsgf("Could not lock '%s', another instance is running?", lockFilePath) + + return + } + + conf, _, err := cmd.MakeConfig(workingDir) + cmd.FatalErrorCheck(err) + + cmd.PrintLine() + cmd.PrintWarnMsgf("This command removes all the blocks and transactions up to %d days ago "+ + "and converts the node to prune mode.", conf.Store.RetentionDays) + cmd.PrintLine() + confirmed := cmd.PromptConfirm("Do you want to continue") + if !confirmed { + return + } + cmd.PrintLine() + + str, err := store.NewStore(conf.Store) + cmd.FatalErrorCheck(err) + + prunedCount := uint32(0) + skippedCount := uint32(0) + totalCount := uint32(0) + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) + + go func() { + <-interrupt + str.Close() + _ = fileLock.Unlock() + }() + + err = str.Prune(func(pruned, skipped, pruningHeight uint32) { + prunedCount += pruned + skippedCount += skipped + + if totalCount == 0 { + totalCount = pruningHeight + } + + pruningProgressBar(prunedCount, skippedCount, totalCount) + }) + cmd.PrintLine() + cmd.FatalErrorCheck(err) + + str.Close() + _ = fileLock.Unlock() + + cmd.PrintLine() + cmd.PrintInfoMsgf("βœ… Your node successfully pruned and changed to prune mode.") + cmd.PrintLine() + cmd.PrintInfoMsgf("You can start the node by running this command:") + cmd.PrintInfoMsgf("./pactus-daemon start -w %v", workingDir) + } +} + +func pruningProgressBar(prunedCount, skippedCount, totalCount uint32) { + percentage := float64(prunedCount+skippedCount) / float64(totalCount) * 100 + if percentage > 100 { + percentage = 100 + } + + barLength := 40 + filledLength := int(float64(barLength) * percentage / 100) + + bar := strings.Repeat("=", filledLength) + strings.Repeat(" ", barLength-filledLength) + fmt.Printf("\r [%s] %.0f%% Pruned: %d | Skipped: %d", //nolint + bar, percentage, prunedCount, skippedCount) +} diff --git a/cmd/daemon/start.go b/cmd/daemon/start.go index 84a430a80..14b705571 100644 --- a/cmd/daemon/start.go +++ b/cmd/daemon/start.go @@ -22,8 +22,7 @@ func buildStartCmd(parentCmd *cobra.Command) { parentCmd.AddCommand(startCmd) - workingDirOpt := startCmd.Flags().StringP("working-dir", "w", cmd.PactusDefaultHomeDir(), - "the path to the working directory to load the wallet and node files") + workingDirOpt := addWorkingDirOption(startCmd) passwordOpt := startCmd.Flags().StringP("password", "p", "", "the wallet password") diff --git a/cmd/gtk/main.go b/cmd/gtk/main.go index 216ab874f..0f522c207 100644 --- a/cmd/gtk/main.go +++ b/cmd/gtk/main.go @@ -170,8 +170,6 @@ func run(n *node.Node, wlt *wallet.Wallet, app *gtk.Application) { grpcAddr := n.GRPC().Address() cmd.PrintInfoMsgf("connect wallet to grpc server: %s\n", grpcAddr) - wlt.SetServerAddr(grpcAddr) - nodeModel := newNodeModel(n) walletModel := newWalletModel(wlt, n) diff --git a/cmd/wallet/main.go b/cmd/wallet/main.go index 1282cc6ef..6da00a9fe 100644 --- a/cmd/wallet/main.go +++ b/cmd/wallet/main.go @@ -1,15 +1,18 @@ package main import ( + "time" + "github.com/pactus-project/pactus/cmd" "github.com/pactus-project/pactus/wallet" "github.com/spf13/cobra" ) var ( - pathOpt *string - offlineOpt *bool - serverAddrOpt *string + pathOpt *string + offlineOpt *bool + serverAddrsOpt *[]string + timeoutOpt *int ) func addPasswordOption(c *cobra.Command) *string { @@ -18,13 +21,19 @@ func addPasswordOption(c *cobra.Command) *string { } func openWallet() (*wallet.Wallet, error) { - wlt, err := wallet.Open(*pathOpt, *offlineOpt) - if err != nil { - return nil, err + opts := make([]wallet.Option, 0) + + if *serverAddrsOpt != nil { + opts = append(opts, wallet.WithCustomServers(*serverAddrsOpt)) } - if *serverAddrOpt != "" { - wlt.SetServerAddr(*serverAddrOpt) + if *timeoutOpt > 0 { + opts = append(opts, wallet.WithTimeout(time.Duration(*timeoutOpt)*time.Second)) + } + + wlt, err := wallet.Open(*pathOpt, *offlineOpt, opts...) + if err != nil { + return nil, err } return wlt, err @@ -43,7 +52,9 @@ func main() { pathOpt = rootCmd.PersistentFlags().String("path", cmd.PactusDefaultWalletPath(cmd.PactusDefaultHomeDir()), "the path to the wallet file") offlineOpt = rootCmd.PersistentFlags().Bool("offline", false, "offline mode") - serverAddrOpt = rootCmd.PersistentFlags().String("server", "", "server gRPC address") + serverAddrsOpt = rootCmd.PersistentFlags().StringSlice("servers", []string{}, "servers gRPC address") + timeoutOpt = rootCmd.PersistentFlags().Int("timeout", 1, + "specifies the timeout duration for the client connection in seconds") buildCreateCmd(rootCmd) buildRecoverCmd(rootCmd) diff --git a/config/example_config.toml b/config/example_config.toml index ef1ae39ac..f34fbaef3 100644 --- a/config/example_config.toml +++ b/config/example_config.toml @@ -16,6 +16,11 @@ # Default is `"data"`. path = "data" + # `retention_days` this parameter indicates the number of days for which the node should keep or retain the blocks + # before pruning them. It is only applicable if the node is in Prune Mode. + # Default is `10`. + retention_days = 10 + # `network` contains configuration options for the network module, which manages communication between nodes. [network] diff --git a/consensus/mock.go b/consensus/mock.go index 49b414fdb..25aa5ded6 100644 --- a/consensus/mock.go +++ b/consensus/mock.go @@ -1,10 +1,9 @@ package consensus import ( - "sync" - "github.com/pactus-project/pactus/crypto/bls" "github.com/pactus-project/pactus/crypto/hash" + "github.com/pactus-project/pactus/state" "github.com/pactus-project/pactus/types/proposal" "github.com/pactus-project/pactus/types/vote" "github.com/pactus-project/pactus/util/testsuite" @@ -13,10 +12,9 @@ import ( var _ Consensus = &MockConsensus{} type MockConsensus struct { - // This locks prevents the Data Race in tests - lk sync.RWMutex ts *testsuite.TestSuite + State *state.MockState ValKey *bls.ValidatorKey Votes []*vote.Vote CurProposal *proposal.Proposal @@ -26,11 +24,13 @@ type MockConsensus struct { Round int16 } -func MockingManager(ts *testsuite.TestSuite, valKeys []*bls.ValidatorKey) (Manager, []*MockConsensus) { +func MockingManager(ts *testsuite.TestSuite, st *state.MockState, + valKeys []*bls.ValidatorKey, +) (Manager, []*MockConsensus) { mocks := make([]*MockConsensus, len(valKeys)) instances := make([]Consensus, len(valKeys)) - for i, s := range valKeys { - cons := MockingConsensus(ts, s) + for i, key := range valKeys { + cons := MockingConsensus(ts, st, key) mocks[i] = cons instances[i] = cons } @@ -42,9 +42,10 @@ func MockingManager(ts *testsuite.TestSuite, valKeys []*bls.ValidatorKey) (Manag }, mocks } -func MockingConsensus(ts *testsuite.TestSuite, valKey *bls.ValidatorKey) *MockConsensus { +func MockingConsensus(ts *testsuite.TestSuite, st *state.MockState, valKey *bls.ValidatorKey) *MockConsensus { return &MockConsensus{ ts: ts, + State: st, ValKey: valKey, } } @@ -54,39 +55,24 @@ func (m *MockConsensus) ConsensusKey() *bls.PublicKey { } func (m *MockConsensus) MoveToNewHeight() { - m.lk.Lock() - defer m.lk.Unlock() - - m.Height++ + m.Height = m.State.LastBlockHeight() + 1 } func (*MockConsensus) Start() {} func (m *MockConsensus) AddVote(v *vote.Vote) { - m.lk.Lock() - defer m.lk.Unlock() - m.Votes = append(m.Votes, v) } func (m *MockConsensus) AllVotes() []*vote.Vote { - m.lk.Lock() - defer m.lk.Unlock() - return m.Votes } func (m *MockConsensus) SetProposal(p *proposal.Proposal) { - m.lk.Lock() - defer m.lk.Unlock() - m.CurProposal = p } func (m *MockConsensus) HasVote(h hash.Hash) bool { - m.lk.Lock() - defer m.lk.Unlock() - for _, v := range m.Votes { if v.Hash() == h { return true @@ -97,16 +83,10 @@ func (m *MockConsensus) HasVote(h hash.Hash) bool { } func (m *MockConsensus) Proposal() *proposal.Proposal { - m.lk.Lock() - defer m.lk.Unlock() - return m.CurProposal } func (m *MockConsensus) HeightRound() (uint32, int16) { - m.lk.Lock() - defer m.lk.Unlock() - return m.Height, m.Round } @@ -115,9 +95,6 @@ func (*MockConsensus) String() string { } func (m *MockConsensus) PickRandomVote(_ int16) *vote.Vote { - m.lk.Lock() - defer m.lk.Unlock() - if len(m.Votes) == 0 { return nil } @@ -127,22 +104,13 @@ func (m *MockConsensus) PickRandomVote(_ int16) *vote.Vote { } func (m *MockConsensus) IsActive() bool { - m.lk.Lock() - defer m.lk.Unlock() - return m.Active } func (m *MockConsensus) IsProposer() bool { - m.lk.Lock() - defer m.lk.Unlock() - return m.Proposer } func (m *MockConsensus) SetActive(active bool) { - m.lk.Lock() - defer m.lk.Unlock() - m.Active = active } diff --git a/genesis/genesis.go b/genesis/genesis.go index 0bbff1640..464e1ed50 100644 --- a/genesis/genesis.go +++ b/genesis/genesis.go @@ -2,7 +2,6 @@ package genesis import ( "encoding/json" - "fmt" "os" "time" @@ -66,7 +65,7 @@ type genesisData struct { func (gen *Genesis) Hash() hash.Hash { bs, err := cbor.Marshal(gen.data) if err != nil { - panic(fmt.Errorf("could not create hash of Genesis: %w", err)) + return hash.UndefHash } return hash.CalcHash(bs) diff --git a/go.mod b/go.mod index 7f2539fae..c625a42ec 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( go.nanomsg.org/mangos/v3 v3.4.2 golang.org/x/crypto v0.24.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 - google.golang.org/grpc v1.64.0 + google.golang.org/grpc v1.64.1 google.golang.org/protobuf v1.34.2 gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) diff --git a/go.sum b/go.sum index ea8db6d6a..87f67f0e2 100644 --- a/go.sum +++ b/go.sum @@ -778,8 +778,8 @@ google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQ google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= -google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= +google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/network/gossip.go b/network/gossip.go index 6c6d4f279..6f94bfb01 100644 --- a/network/gossip.go +++ b/network/gossip.go @@ -75,6 +75,9 @@ func (g *gossipService) Broadcast(msg []byte, topicID TopicID) error { g.logger.Debug("publishing new message", "topic", topicID) switch topicID { + case TopicIDUnspecified: + return InvalidTopicError{TopicID: topicID} + case TopicIDBlock: if g.topicBlock == nil { return NotSubscribedError{TopicID: topicID} @@ -114,6 +117,9 @@ func (g *gossipService) publish(msg []byte, topic *lp2pps.Topic) error { // JoinTopic joins to the topic with the given name and subscribes to receive topic messages. func (g *gossipService) JoinTopic(topicID TopicID, sp ShouldPropagate) error { switch topicID { + case TopicIDUnspecified: + return InvalidTopicError{TopicID: topicID} + case TopicIDBlock: if g.topicBlock != nil { g.logger.Warn("already subscribed to block topic") @@ -247,7 +253,7 @@ func (g *gossipService) onReceiveMessage(m *lp2pps.Message) { return } - g.logger.Debug("receiving new gossip message", "source", m.GetFrom()) + g.logger.Debug("receiving new gossip message", "from", m.ReceivedFrom) event := &GossipMessage{ From: m.ReceivedFrom, Data: m.Data, diff --git a/network/gossip_test.go b/network/gossip_test.go index 72717dab9..867648c6a 100644 --- a/network/gossip_test.go +++ b/network/gossip_test.go @@ -3,7 +3,7 @@ package network import ( "testing" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) func TestJoinConsensusTopic(t *testing.T) { @@ -11,12 +11,26 @@ func TestJoinConsensusTopic(t *testing.T) { msg := []byte("test-consensus-topic") - require.ErrorIs(t, net.gossip.Broadcast(msg, TopicIDConsensus), + assert.ErrorIs(t, net.gossip.Broadcast(msg, TopicIDConsensus), NotSubscribedError{ TopicID: TopicIDConsensus, }) - require.NoError(t, net.JoinTopic(TopicIDConsensus, alwaysPropagate)) - require.NoError(t, net.gossip.Broadcast(msg, TopicIDConsensus)) + assert.NoError(t, net.JoinTopic(TopicIDConsensus, alwaysPropagate)) + assert.NoError(t, net.gossip.Broadcast(msg, TopicIDConsensus)) +} + +func TestJoinInvalidTopic(t *testing.T) { + net := makeTestNetwork(t, testConfig(), nil) + + assert.ErrorIs(t, net.JoinTopic(TopicIDUnspecified, alwaysPropagate), + InvalidTopicError{ + TopicID: TopicIDUnspecified, + }) + + assert.ErrorIs(t, net.JoinTopic(TopicID(-1), alwaysPropagate), + InvalidTopicError{ + TopicID: TopicID(-1), + }) } func TestInvalidTopic(t *testing.T) { @@ -24,7 +38,12 @@ func TestInvalidTopic(t *testing.T) { msg := []byte("test-invalid-topic") - require.ErrorIs(t, net.gossip.Broadcast(msg, -1), + assert.ErrorIs(t, net.gossip.Broadcast(msg, TopicIDUnspecified), + InvalidTopicError{ + TopicID: TopicIDUnspecified, + }) + + assert.ErrorIs(t, net.gossip.Broadcast(msg, -1), InvalidTopicError{ TopicID: TopicID(-1), }) diff --git a/network/interface.go b/network/interface.go index f56702942..ce2a02650 100644 --- a/network/interface.go +++ b/network/interface.go @@ -9,6 +9,7 @@ import ( type TopicID int const ( + TopicIDUnspecified TopicID = 0 TopicIDBlock TopicID = 1 TopicIDTransaction TopicID = 2 TopicIDConsensus TopicID = 3 @@ -16,6 +17,9 @@ const ( func (t TopicID) String() string { switch t { + case TopicIDUnspecified: + return "unspecified" + case TopicIDBlock: return "block" diff --git a/network/network_test.go b/network/network_test.go index 6048fe2c2..3cdc04c11 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -57,7 +57,7 @@ func testConfig() *Config { func shouldReceiveEvent(t *testing.T, net *network, eventType EventType) Event { t.Helper() - timeout := time.NewTimer(8 * time.Second) + timeout := time.NewTimer(10 * time.Second) for { select { @@ -167,7 +167,7 @@ func TestNetwork(t *testing.T) { confM.EnableRelay = true confM.BootstrapAddrStrings = bootstrapAddresses confM.ListenAddrStrings = []string{ - "/ip4/127.0.0.1/tcp/9987", + "/ip4/127.0.0.1/tcp/0", } fmt.Println("Starting Private node M") networkM := makeTestNetwork(t, confM, []lp2p.Option{ @@ -179,7 +179,7 @@ func TestNetwork(t *testing.T) { confN.EnableRelay = true confN.BootstrapAddrStrings = bootstrapAddresses confN.ListenAddrStrings = []string{ - "/ip4/127.0.0.1/tcp/5678", + "/ip4/127.0.0.1/tcp/0", } fmt.Println("Starting Private node N") networkN := makeTestNetwork(t, confN, []lp2p.Option{ @@ -232,6 +232,41 @@ func TestNetwork(t *testing.T) { assert.NotContains(t, protos, lp2pproto.ProtoIDv2Stop) assert.Contains(t, protos, lp2pproto.ProtoIDv2Hop) }, time.Second, 100*time.Millisecond) + + require.EventuallyWithT(t, func(_ *assert.CollectT) { + protos := networkX.Protocols() + assert.NotContains(t, protos, lp2pproto.ProtoIDv2Stop) + assert.NotContains(t, protos, lp2pproto.ProtoIDv2Hop) + }, time.Second, 100*time.Millisecond) + }) + + t.Run("Reachability", func(t *testing.T) { + fmt.Printf("Running %s\n", t.Name()) + + require.EventuallyWithT(t, func(_ *assert.CollectT) { + reachability := networkB.ReachabilityStatus() + assert.Equal(t, "Public", reachability) + }, time.Second, 100*time.Millisecond) + + require.EventuallyWithT(t, func(_ *assert.CollectT) { + reachability := networkM.ReachabilityStatus() + assert.Equal(t, "Private", reachability) + }, time.Second, 100*time.Millisecond) + + require.EventuallyWithT(t, func(_ *assert.CollectT) { + reachability := networkN.ReachabilityStatus() + assert.Equal(t, "Private", reachability) + }, time.Second, 100*time.Millisecond) + + require.EventuallyWithT(t, func(_ *assert.CollectT) { + reachability := networkP.ReachabilityStatus() + assert.Equal(t, "Public", reachability) + }, time.Second, 100*time.Millisecond) + + require.EventuallyWithT(t, func(_ *assert.CollectT) { + reachability := networkP.ReachabilityStatus() + assert.Equal(t, "Public", reachability) + }, time.Second, 100*time.Millisecond) }) t.Run("all nodes have at least one connection to the bootstrap node B", func(t *testing.T) { @@ -355,7 +390,7 @@ func TestNetwork(t *testing.T) { // TODO: How to test this? // t.Run("nodes M and N (private, connected via relay) can communicate using the relay node R", func(t *testing.T) { // msgM := ts.RandBytes(64) - // require.NoError(t, networkM.SendTo(msgM, networkN.SelfID())) + // networkM.SendTo(msgM, networkN.SelfID()) // eM := shouldReceiveEvent(t, networkN, EventTypeStream).(*StreamMessage) // assert.Equal(t, readData(t, eM.Reader, len(msgM)), msgM) // }) diff --git a/scripts/snapshot.py b/scripts/snapshot.py new file mode 100644 index 000000000..96ea4fd08 --- /dev/null +++ b/scripts/snapshot.py @@ -0,0 +1,316 @@ +# Pactus Blockchain Snapshot tool +# +# This script first stops the Pactus node if it is running, then creates a backup of the blockchain data by copying +# or compressing it based on the specified options. The backup is stored in a timestamped snapshot directory along +# with a `metadata.json` file that contains detailed information about the snapshot, including file paths and +# checksums. Finally, the script manages the retention of snapshots, ensuring only a specified number of recent +# backups are kept. +# +# Arguments +# +# - `--service_path`: This argument specifies the path to the `pactus` service file to manage systemctl service. +# - `--data_path`: This argument specifies the path to the Pactus data folder to create snapshots. +# - Windows: `C:\Users\{user}\pactus\data` +# - Linux or Mac: `/home/{user}/pactus/data` +# - `--compress`: This argument specifies the compression method based on your choice ['none', 'zip', 'tar'], +# with 'none' being without compression. +# - `--retention`: This argument sets the number of snapshots to keep. +# - `--snapshot_path`: This argument sets a custom path for snapshots, with the default being the current +# working directory of the script. +# +# How to run? +# +# For create snapshots just run this command: +# +# sudo python3 snapshot.py --service_path /etc/systemd/system/pactus.service --data_path /home/{user}/pactus/data +# --compress zip --retention 3 + + +import argparse +import os +import shutil +import subprocess +import hashlib +import json +import logging +import zipfile +from datetime import datetime + + +def setup_logging(): + logging.basicConfig( + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d-%H:%M', + level=logging.INFO + ) + + +def get_timestamp_str(): + return datetime.now().strftime("%Y%m%d%H%M%S") + + +def get_current_time_iso(): + return datetime.now().isoformat() + + +class Metadata: + @staticmethod + def sha256(file_path): + hash_sha = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha.update(chunk) + return hash_sha.hexdigest() + + @staticmethod + def update_metadata_file(snapshot_path, snapshot_metadata): + metadata_file = os.path.join(snapshot_path, 'snapshots', 'metadata.json') + if os.path.isfile(metadata_file): + logging.info(f"Updating existing metadata file '{metadata_file}'") + with open(metadata_file, 'r') as f: + metadata = json.load(f) + else: + logging.info(f"Creating new metadata file '{metadata_file}'") + metadata = [] + + formatted_metadata = { + "name": snapshot_metadata["name"], + "created_at": snapshot_metadata["created_at"], + "compress": snapshot_metadata["compress"], + "data": snapshot_metadata["data"] + } + + metadata.append(formatted_metadata) + + with open(metadata_file, 'w') as f: + json.dump(metadata, f, indent=4) + + @staticmethod + def update_metadata_after_removal(snapshots_dir, removed_snapshots): + metadata_file = os.path.join(snapshots_dir, 'metadata.json') + if not os.path.isfile(metadata_file): + return + + logging.info(f"Updating metadata file '{metadata_file}' after snapshot removal") + with open(metadata_file, 'r') as f: + metadata = json.load(f) + + updated_metadata = [entry for entry in metadata if entry["name"] not in removed_snapshots] + + with open(metadata_file, 'w') as f: + json.dump(updated_metadata, f, indent=4) + + @staticmethod + def create_snapshot_json(data_dir, snapshot_subdir): + files = [] + for root, _, filenames in os.walk(data_dir): + for filename in filenames: + file_path = os.path.join(root, filename) + rel_path = os.path.relpath(file_path, data_dir) + snapshot_rel_path = os.path.join(snapshot_subdir, rel_path).replace('\\', '/') + file_info = { + "name": filename, + "path": snapshot_rel_path, + "sha": Metadata.sha256(file_path) + } + files.append(file_info) + + return {"data": files} + + @staticmethod + def create_compressed_snapshot_json(compressed_file, rel_path): + file_info = { + "name": os.path.basename(compressed_file), + "path": rel_path, + "sha": Metadata.sha256(compressed_file) + } + + return {"data": file_info} + + +def run_command(command): + logging.info(f"Running command: {' '.join(command)}") + try: + result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + logging.info(f"Command output: {result.stdout.strip()}") + if result.stderr.strip(): + logging.error(f"Command error: {result.stderr.strip()}") + return result.stdout.strip() + except subprocess.CalledProcessError as e: + logging.error(f"Command failed with error: {e.stderr.strip()}") + return f"Error: {e.stderr.strip()}" + + +def get_service_name(service_path): + base_name = os.path.basename(service_path) + service_name = os.path.splitext(base_name)[0] + return service_name + + +class DaemonManager: + @staticmethod + def start_service(service_path): + sv = get_service_name(service_path) + logging.info(f"Starting '{sv}' service") + return run_command(['sudo', 'systemctl', 'start', sv]) + + @staticmethod + def stop_service(service_path): + sv = get_service_name(service_path) + logging.info(f"Stopping '{sv}' service") + return run_command(['sudo', 'systemctl', 'stop', sv]) + + +class SnapshotManager: + def __init__(self, args): + self.args = args + + def manage_snapshots(self): + snapshots_dir = os.path.join(self.args.snapshot_path, 'snapshots') + logging.info(f"Managing snapshots in '{snapshots_dir}'") + + if not os.path.exists(snapshots_dir): + logging.info(f"Snapshots directory '{snapshots_dir}' does not exist. Creating it.") + os.makedirs(snapshots_dir) + + snapshots = sorted([s for s in os.listdir(snapshots_dir) if s != 'metadata.json']) + + logging.info(f"Found snapshots: {snapshots}") + logging.info(f"Retention policy is to keep {self.args.retention} snapshots") + + if len(snapshots) >= self.args.retention: + num_to_remove = len(snapshots) - self.args.retention + 1 + to_remove = snapshots[:num_to_remove] + logging.info(f"Snapshots to remove: {to_remove}") + for snapshot in to_remove: + snapshot_path = os.path.join(snapshots_dir, snapshot) + logging.info(f"Removing old snapshot '{snapshot_path}'") + shutil.rmtree(snapshot_path) + + Metadata.update_metadata_after_removal(snapshots_dir, to_remove) + + def create_snapshot(self): + timestamp_str = get_timestamp_str() + snapshot_dir = os.path.join(self.args.snapshot_path, 'snapshots', timestamp_str) + logging.info(f"Creating snapshot directory '{snapshot_dir}'") + os.makedirs(snapshot_dir, exist_ok=True) + + data_dir = os.path.join(snapshot_dir, 'data') + if self.args.compress == 'none': + logging.info(f"Copying data from '{self.args.data_path}' to '{data_dir}'") + shutil.copytree(self.args.data_path, data_dir) + snapshot_metadata = Metadata.create_snapshot_json(data_dir, timestamp_str) + elif self.args.compress == 'zip': + zip_file = os.path.join(snapshot_dir, 'data.zip') + rel = os.path.relpath(zip_file, snapshot_dir) + meta_path = os.path.join(timestamp_str, rel) + logging.info(f"Creating ZIP archive '{zip_file}'") + with zipfile.ZipFile(zip_file, 'w', zipfile.ZIP_DEFLATED) as zipf: + for root, _, files in os.walk(self.args.data_path): + for file in files: + full_path = os.path.join(root, file) + rel_path = os.path.relpath(full_path, self.args.data_path) + zipf.write(full_path, os.path.join('data', rel_path)) + snapshot_metadata = Metadata.create_compressed_snapshot_json(zip_file, meta_path) + elif self.args.compress == 'tar': + tar_file = os.path.join(snapshot_dir, 'data.tar.gz') + rel = os.path.relpath(tar_file, snapshot_dir) + meta_path = os.path.join(timestamp_str, rel) + logging.info(f"Creating TAR.GZ archive '{tar_file}'") + subprocess.run(['tar', '-czvf', tar_file, '-C', self.args.data_path, '.']) + snapshot_metadata = Metadata.create_compressed_snapshot_json(tar_file, meta_path) + + snapshot_metadata["name"] = timestamp_str + snapshot_metadata["created_at"] = get_current_time_iso() + snapshot_metadata["compress"] = self.args.compress + + Metadata.update_metadata_file(self.args.snapshot_path, snapshot_metadata) + + +class Validation: + @staticmethod + def validate_args(args): + logging.info('Validating arguments') + + if not os.path.isfile(args.service_path): + raise ValueError(f"Service file '{args.service_path}' does not exist.") + logging.info(f"Service file '{args.service_path}' exists") + + if not os.path.isdir(args.data_path): + raise ValueError(f"Data path '{args.data_path}' does not exist.") + logging.info(f"Data path '{args.data_path}' exists") + + if not os.access(args.data_path, os.W_OK): + raise PermissionError(f"No permission to access data path '{args.data_path}'.") + logging.info(f"Permission to access data path '{args.data_path}' confirmed") + + if args.compress == 'zip' and not shutil.which('zip'): + raise EnvironmentError("The 'zip' command is not available.") + elif args.compress == 'zip': + logging.info("The 'zip' command is available") + + if args.compress == 'tar' and not shutil.which('tar'): + raise EnvironmentError("The 'tar' command is not available.") + elif args.compress == 'tar': + logging.info("The 'tar' command is available") + + if args.retention <= 0: + raise ValueError("Retention value must be greater than 0.") + logging.info(f"Retention value is set to {args.retention}") + + if not os.access(args.snapshot_path, os.W_OK): + raise PermissionError(f"No permission to access snapshot path '{args.snapshot_path}'.") + logging.info(f"Permission to access snapshot path '{args.snapshot_path}' confirmed") + + snapshots_dir = os.path.join(args.snapshot_path, 'snapshots') + if not os.path.isdir(snapshots_dir): + logging.info("Snapshots directory does not exist, creating it") + os.makedirs(snapshots_dir) + else: + logging.info("Snapshots directory exists") + + @staticmethod + def validate(): + if os.name == "nt": + raise EnvironmentError("Windows not supported.") + if os.geteuid() != 0: + raise PermissionError("This script requires sudo/root access. Please run with sudo.") + + +class ProcessBackup: + def __init__(self, args): + self.args = args + + def run(self): + Validation.validate() + Validation.validate_args(self.args) + DaemonManager.stop_service(self.args.service_path) + snapshot_manager = SnapshotManager(self.args) + snapshot_manager.manage_snapshots() + snapshot_manager.create_snapshot() + DaemonManager.start_service(self.args.service_path) + + +def parse_args(): + user_home = os.path.expanduser("~") + default_data_path = os.path.join(user_home, 'pactus') + + parser = argparse.ArgumentParser(description='Pactus Blockchain Snapshot Tool') + parser.add_argument('--service_path', required=True, help='Path to pactus systemctl service') + parser.add_argument('--data_path', default=default_data_path, help='Path to data directory') + parser.add_argument('--compress', choices=['none', 'zip', 'tar'], default='none', help='Compression type') + parser.add_argument('--retention', type=int, default=3, help='Number of snapshots to retain') + parser.add_argument('--snapshot_path', default=os.getcwd(), help='Path to store snapshots') + + return parser.parse_args() + + +def main(): + setup_logging() + args = parse_args() + process_backup = ProcessBackup(args) + process_backup.run() + + +if __name__ == "__main__": + main() diff --git a/state/mock.go b/state/mock.go index e0f7c55e7..6cc5c08d5 100644 --- a/state/mock.go +++ b/state/mock.go @@ -42,7 +42,7 @@ type MockState struct { func MockingState(ts *testsuite.TestSuite) *MockState { cmt, valKeys := ts.GenerateTestCommittee(21) - genDoc := genesis.TestnetGenesis() + genDoc := genesis.MainnetGenesis() return &MockState{ ts: ts, diff --git a/state/state.go b/state/state.go index 86cfbb39b..1da7dad13 100644 --- a/state/state.go +++ b/state/state.go @@ -138,25 +138,6 @@ func (st *state) concreteSandbox() sandbox.Sandbox { } func (st *state) tryLoadLastInfo() error { - // Make sure the genesis doc is the same as before. - // - // This check is not strictly necessary, since the genesis state is already committed. - // However, it is good to perform this check to ensure that the genesis document has not been modified. - genStateRoot := st.calculateGenesisStateRootFromGenesisDoc() - committedBlockOne, err := st.store.Block(1) - if err != nil { - return err - } - - blockOne, err := committedBlockOne.ToBlock() - if err != nil { - return err - } - - if genStateRoot != blockOne.Header().StateRoot() { - return fmt.Errorf("invalid genesis doc") - } - logger.Debug("try to restore the last state") committeeInstance, err := st.lastInfo.RestoreLastInfo(st.store, st.params.CommitteeSize) if err != nil { @@ -244,27 +225,6 @@ func (st *state) stateRoot() hash.Hash { return *stateRoot } -func (st *state) calculateGenesisStateRootFromGenesisDoc() hash.Hash { - accs := st.genDoc.Accounts() - vals := st.genDoc.Validators() - - accHashes := make([]hash.Hash, len(accs)) - valHashes := make([]hash.Hash, len(vals)) - for _, acc := range accs { - accHashes[acc.Number()] = acc.Hash() - } - for _, val := range vals { - valHashes[val.Number()] = val.Hash() - } - - accTree := simplemerkle.NewTreeFromHashes(accHashes) - valTree := simplemerkle.NewTreeFromHashes(valHashes) - accRootHash := accTree.Root() - valRootHash := valTree.Root() - - return *simplemerkle.HashMerkleBranches(&accRootHash, &valRootHash) -} - func (st *state) Close() { st.lk.RLock() defer st.lk.RUnlock() @@ -509,7 +469,7 @@ func (st *state) CommitBlock(blk *block.Block, cert *certificate.BlockCertificat } // ----------------------------------- - // Publishing the events to the zmq + // Publishing the events to the nano message. st.publishEvents(height, blk) return nil diff --git a/state/state_test.go b/state/state_test.go index 96f758c52..1e73c3f06 100644 --- a/state/state_test.go +++ b/state/state_test.go @@ -561,29 +561,6 @@ func TestLoadState(t *testing.T) { require.NoError(t, newState.CommitBlock(blk6, cert6)) } -func TestLoadStateAfterChangingGenesis(t *testing.T) { - td := setup(t) - - _, err := LoadOrNewState(td.state.genDoc, td.state.valKeys, - td.state.store, txpool.MockingTxPool(), nil) - require.NoError(t, err) - - pub, _ := td.RandBLSKeyPair() - val := validator.NewValidator(pub, 4) - newVals := append(td.state.genDoc.Validators(), val) - - genDoc := genesis.MakeGenesis( - td.state.genDoc.GenesisTime(), - td.state.genDoc.Accounts(), - newVals, - td.state.genDoc.Params()) - - // Load last state info after modifying genesis - _, err = LoadOrNewState(genDoc, td.state.valKeys, - td.state.store, txpool.MockingTxPool(), nil) - require.Error(t, err) -} - func TestIsValidator(t *testing.T) { td := setup(t) @@ -643,9 +620,6 @@ func TestCommittedBlock(t *testing.T) { assert.NoError(t, err) assert.Nil(t, blockOne.PrevCertificate()) assert.Equal(t, hash.UndefHash, blockOne.Header().PrevBlockHash()) - - r := td.state.calculateGenesisStateRootFromGenesisDoc() - assert.Equal(t, blockOne.Header().StateRoot(), r) }) t.Run("Last block", func(t *testing.T) { diff --git a/store/block.go b/store/block.go index 689b316fe..5edf6251e 100644 --- a/store/block.go +++ b/store/block.go @@ -11,7 +11,6 @@ import ( "github.com/pactus-project/pactus/types/block" "github.com/pactus-project/pactus/util" "github.com/pactus-project/pactus/util/encoding" - "github.com/pactus-project/pactus/util/logger" "github.com/pactus-project/pactus/util/pairslice" "github.com/syndtr/goleveldb/leveldb" ) @@ -47,15 +46,6 @@ func newBlockStore(db *leveldb.DB, sortitionCacheSize uint32, publicKeyCacheSize } func (bs *blockStore) saveBlock(batch *leveldb.Batch, height uint32, blk *block.Block) []blockRegion { - if height > 1 { - if !bs.hasBlock(height - 1) { - logger.Panic("previous block not found", "height", height) - } - } - if bs.hasBlock(height) { - logger.Panic("duplicated block", "height", height) - } - blockHash := blk.Hash() regs := make([]blockRegion, blk.Transactions().Len()) w := bytes.NewBuffer(make([]byte, 0, blk.SerializeSize()+hash.HashSize)) diff --git a/store/block_test.go b/store/block_test.go index e69d75032..39e956a0b 100644 --- a/store/block_test.go +++ b/store/block_test.go @@ -13,16 +13,6 @@ func TestBlockStore(t *testing.T) { lastCert := td.store.LastCertificate() lastHeight := lastCert.Height() nextBlk, nextCert := td.GenerateTestBlock(lastHeight + 1) - nextNextBlk, nextNextCert := td.GenerateTestBlock(lastHeight + 2) - - t.Run("Missed block, Should panic ", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - td.store.SaveBlock(nextNextBlk, nextNextCert) - }) t.Run("Add block, don't batch write", func(t *testing.T) { td.store.SaveBlock(nextBlk, nextCert) @@ -46,15 +36,6 @@ func TestBlockStore(t *testing.T) { assert.NoError(t, err) assert.Equal(t, cert.Hash(), nextCert.Hash()) }) - - t.Run("Duplicated block, Should panic ", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - td.store.SaveBlock(nextBlk, nextCert) - }) } func TestSortitionSeed(t *testing.T) { diff --git a/store/config.go b/store/config.go index f8358e97c..f79e87d74 100644 --- a/store/config.go +++ b/store/config.go @@ -8,7 +8,8 @@ import ( ) type Config struct { - Path string `toml:"path"` + Path string `toml:"path"` + RetentionDays uint32 `toml:"retention_days"` // Private configs TxCacheSize uint32 `toml:"-"` @@ -21,6 +22,7 @@ type Config struct { func DefaultConfig() *Config { return &Config{ Path: "data", + RetentionDays: 10, TxCacheSize: 1024, SortitionCacheSize: 1024, AccountCacheSize: 1024, @@ -54,5 +56,15 @@ func (conf *Config) BasicCheck() error { } } + if conf.RetentionDays < 10 { + return ConfigError{ + Reason: "Retention days can't be less than 10 days", + } + } + return nil } + +func (conf *Config) RetentionBlocks() uint32 { + return conf.RetentionDays * 8640 +} diff --git a/store/interface.go b/store/interface.go index 5c51ecacf..147d43477 100644 --- a/store/interface.go +++ b/store/interface.go @@ -98,6 +98,7 @@ type Reader interface { TotalValidators() int32 LastCertificate() *certificate.BlockCertificate IsBanned(addr crypto.Address) bool + IsPruned() bool } type Store interface { @@ -106,6 +107,7 @@ type Store interface { UpdateAccount(addr crypto.Address, acc *account.Account) UpdateValidator(val *validator.Validator) SaveBlock(blk *block.Block, cert *certificate.BlockCertificate) + Prune(resultFunc func(pruned, skipped, pruningHeight uint32)) error WriteBatch() error Close() } diff --git a/store/mock.go b/store/mock.go index 117e42011..05b54a9cb 100644 --- a/store/mock.go +++ b/store/mock.go @@ -289,3 +289,11 @@ func (m *MockStore) RandomTestVal() *validator.Validator { func (*MockStore) IsBanned(_ crypto.Address) bool { return false } + +func (*MockStore) Prune(_ func(pruned, skipped, pruningHeight uint32)) error { + return nil +} + +func (*MockStore) IsPruned() bool { + return false +} diff --git a/store/store.go b/store/store.go index 0621b85ed..eba891d76 100644 --- a/store/store.go +++ b/store/store.go @@ -73,6 +73,7 @@ type store struct { txStore *txStore accountStore *accountStore validatorStore *validatorStore + isPruned bool } func NewStore(conf *Config) (Store, error) { @@ -93,13 +94,20 @@ func NewStore(conf *Config) (Store, error) { txStore: newTxStore(db, conf.TxCacheSize), accountStore: newAccountStore(db, conf.AccountCacheSize), validatorStore: newValidatorStore(db), + isPruned: false, } - lc := s.LastCertificate() + lc := s.lastCertificate() if lc == nil { return s, nil } + // Check if the node is pruned by checking genesis block. + blockOne, _ := s.block(1) + if blockOne == nil { + s.isPruned = true + } + currentHeight := lc.Height() startHeight := uint32(1) if currentHeight > conf.TxCacheSize { @@ -107,7 +115,7 @@ func NewStore(conf *Config) (Store, error) { } for i := startHeight; i < currentHeight+1; i++ { - committedBlock, err := s.Block(i) + committedBlock, err := s.block(i) if err != nil { return nil, err } @@ -147,6 +155,22 @@ func (s *store) SaveBlock(blk *block.Block, cert *certificate.BlockCertificate) s.txStore.saveTxs(s.batch, blk.Transactions(), regs) s.txStore.pruneCache(height) + // Removing old block from prune node store. + if s.isPruned && height > s.config.RetentionBlocks() { + pruneHeight := height - s.config.RetentionBlocks() + deleted, err := s.pruneBlock(pruneHeight) + if err != nil { + panic(err) + } + + if deleted { + // TODO: Let's use state logger in store[?]. + logger.Debug("old block is pruned", "height", pruneHeight) + } else { + logger.Warn("unable to prune the old block", "height", pruneHeight, "error", err) + } + } + // Save last certificate: [version: 4 bytes]+[certificate: variant] w := bytes.NewBuffer(make([]byte, 0, 4+cert.SerializeSize())) err := encoding.WriteElements(w, lastStoreVersion) @@ -165,6 +189,10 @@ func (s *store) Block(height uint32) (*CommittedBlock, error) { s.lk.Lock() defer s.lk.Unlock() + return s.block(height) +} + +func (s *store) block(height uint32) (*CommittedBlock, error) { data, err := s.blockStore.block(height) if err != nil { return nil, err @@ -341,6 +369,10 @@ func (s *store) LastCertificate() *certificate.BlockCertificate { s.lk.Lock() defer s.lk.Unlock() + return s.lastCertificate() +} + +func (s *store) lastCertificate() *certificate.BlockCertificate { data, _ := tryGet(s.db, lastInfoKey) if data == nil { // Genesis block @@ -378,3 +410,66 @@ func (s *store) WriteBatch() error { func (s *store) IsBanned(addr crypto.Address) bool { return s.config.BannedAddrs[addr] } + +func (s *store) IsPruned() bool { + return s.isPruned +} + +func (s *store) Prune(resultFunc func(pruned, skipped, pruningHeight uint32)) error { + cert := s.lastCertificate() + + // at genesis block + if cert == nil { + return nil + } + + retentionBlocks := s.config.RetentionBlocks() + + if cert.Height() < retentionBlocks { + return nil + } + + pruningHeight := cert.Height() - retentionBlocks + + for i := pruningHeight; i >= 1; i-- { + deleted, err := s.pruneBlock(i) + if err != nil { + return err + } + + if err := s.WriteBatch(); err != nil { + return err + } + + if deleted { + resultFunc(1, 0, i) + + continue + } + + resultFunc(0, 1, i) + } + + return nil +} + +func (s *store) pruneBlock(blockHeight uint32) (bool, error) { + if !s.blockStore.hasBlock(blockHeight) { + return false, nil + } + + cBlock, _ := s.block(blockHeight) + blk, err := block.FromBytes(cBlock.Data) + if err != nil { + return false, err + } + + s.batch.Delete(blockHashKey(blk.Hash())) + s.batch.Delete(blockKey(blockHeight)) + + for _, t := range blk.Transactions() { + s.batch.Delete(t.ID().Bytes()) + } + + return true, nil +} diff --git a/sync/bundle/message/block_announce.go b/sync/bundle/message/block_announce.go index e7696cd21..3beb1d984 100644 --- a/sync/bundle/message/block_announce.go +++ b/sync/bundle/message/block_announce.go @@ -3,6 +3,7 @@ package message import ( "fmt" + "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/types/block" "github.com/pactus-project/pactus/types/certificate" ) @@ -35,6 +36,14 @@ func (*BlockAnnounceMessage) Type() Type { return TypeBlockAnnounce } +func (*BlockAnnounceMessage) TopicID() network.TopicID { + return network.TopicIDBlock +} + +func (*BlockAnnounceMessage) ShouldBroadcast() bool { + return true +} + func (m *BlockAnnounceMessage) String() string { return fmt.Sprintf("{⌘ %d %v}", m.Certificate.Height(), diff --git a/sync/bundle/message/blocks_request.go b/sync/bundle/message/blocks_request.go index cd2dff286..dcb2ce923 100644 --- a/sync/bundle/message/blocks_request.go +++ b/sync/bundle/message/blocks_request.go @@ -3,6 +3,7 @@ package message import ( "fmt" + "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/util/errors" ) @@ -39,6 +40,14 @@ func (*BlocksRequestMessage) Type() Type { return TypeBlocksRequest } +func (*BlocksRequestMessage) TopicID() network.TopicID { + return network.TopicIDUnspecified +} + +func (*BlocksRequestMessage) ShouldBroadcast() bool { + return false +} + func (m *BlocksRequestMessage) String() string { return fmt.Sprintf("{βš“ %d %v:%v}", m.SessionID, m.From, m.To()) } diff --git a/sync/bundle/message/blocks_response.go b/sync/bundle/message/blocks_response.go index fab79148c..f4e94f3e5 100644 --- a/sync/bundle/message/blocks_response.go +++ b/sync/bundle/message/blocks_response.go @@ -3,6 +3,7 @@ package message import ( "fmt" + "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/types/certificate" ) @@ -42,6 +43,14 @@ func (*BlocksResponseMessage) Type() Type { return TypeBlocksResponse } +func (*BlocksResponseMessage) TopicID() network.TopicID { + return network.TopicIDUnspecified +} + +func (*BlocksResponseMessage) ShouldBroadcast() bool { + return false +} + func (m *BlocksResponseMessage) Count() uint32 { return uint32(len(m.CommittedBlocksData)) } diff --git a/sync/bundle/message/hello.go b/sync/bundle/message/hello.go index 3e75fa3f1..081b3c2e8 100644 --- a/sync/bundle/message/hello.go +++ b/sync/bundle/message/hello.go @@ -7,6 +7,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/pactus-project/pactus/crypto/bls" "github.com/pactus-project/pactus/crypto/hash" + "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/sync/peerset/peer/service" "github.com/pactus-project/pactus/util/errors" "github.com/pactus-project/pactus/version" @@ -64,6 +65,14 @@ func (*HelloMessage) Type() Type { return TypeHello } +func (*HelloMessage) TopicID() network.TopicID { + return network.TopicIDUnspecified +} + +func (*HelloMessage) ShouldBroadcast() bool { + return false +} + func (m *HelloMessage) String() string { return fmt.Sprintf("{%s %d %s}", m.Moniker, m.Height, m.Services) } diff --git a/sync/bundle/message/hello_ack.go b/sync/bundle/message/hello_ack.go index 843196e7a..65f9d85f4 100644 --- a/sync/bundle/message/hello_ack.go +++ b/sync/bundle/message/hello_ack.go @@ -2,6 +2,8 @@ package message import ( "fmt" + + "github.com/pactus-project/pactus/network" ) type HelloAckMessage struct { @@ -26,6 +28,14 @@ func (*HelloAckMessage) Type() Type { return TypeHelloAck } +func (*HelloAckMessage) TopicID() network.TopicID { + return network.TopicIDUnspecified +} + +func (*HelloAckMessage) ShouldBroadcast() bool { + return false +} + func (m *HelloAckMessage) String() string { return fmt.Sprintf("{%s: %s %v}", m.ResponseCode, m.Reason, m.Height) } diff --git a/sync/bundle/message/message.go b/sync/bundle/message/message.go index a6dc7803e..9a617682c 100644 --- a/sync/bundle/message/message.go +++ b/sync/bundle/message/message.go @@ -52,36 +52,6 @@ const ( TypeBlocksResponse = Type(10) ) -func (t Type) TopicID() network.TopicID { - switch t { - case TypeBlockAnnounce: - - return network.TopicIDBlock - - case TypeTransaction: - - return network.TopicIDTransaction - - case TypeQueryProposal, - TypeProposal, - TypeQueryVote, - TypeVote: - - return network.TopicIDConsensus - - case TypeHello, - TypeHelloAck, - TypeBlocksRequest, - TypeBlocksResponse: - - return -1 // topic id for direct message - - default: - - return -2 // topic id for unknown message - } -} - func (t Type) String() string { switch t { case TypeHello: @@ -91,7 +61,7 @@ func (t Type) String() string { return "hello-ack" case TypeTransaction: - return "txs" + return "transaction" case TypeQueryProposal: return "query-proposal" @@ -100,7 +70,7 @@ func (t Type) String() string { return "proposal" case TypeQueryVote: - return "query-votes" + return "query-vote" case TypeVote: return "vote" @@ -109,10 +79,10 @@ func (t Type) String() string { return "block-announce" case TypeBlocksRequest: - return "blocks-req" + return "blocks-request" case TypeBlocksResponse: - return "blocks-res" + return "blocks-response" default: return fmt.Sprintf("%d", t) @@ -159,5 +129,7 @@ func MakeMessage(t Type) Message { type Message interface { BasicCheck() error Type() Type + TopicID() network.TopicID + ShouldBroadcast() bool String() string } diff --git a/sync/bundle/message/message_test.go b/sync/bundle/message/message_test.go new file mode 100644 index 000000000..4e4b18f70 --- /dev/null +++ b/sync/bundle/message/message_test.go @@ -0,0 +1,35 @@ +package message + +import ( + "testing" + + "github.com/pactus-project/pactus/network" + "github.com/stretchr/testify/assert" +) + +func TestMessage(t *testing.T) { + testCases := []struct { + msgType Type + typeName string + topicID network.TopicID + shouldBroadcast bool + }{ + {TypeHello, "hello", network.TopicIDUnspecified, false}, + {TypeHelloAck, "hello-ack", network.TopicIDUnspecified, false}, + {TypeTransaction, "transaction", network.TopicIDTransaction, true}, + {TypeQueryProposal, "query-proposal", network.TopicIDConsensus, true}, + {TypeProposal, "proposal", network.TopicIDConsensus, true}, + {TypeQueryVote, "query-vote", network.TopicIDConsensus, true}, + {TypeVote, "vote", network.TopicIDConsensus, true}, + {TypeBlockAnnounce, "block-announce", network.TopicIDBlock, true}, + {TypeBlocksRequest, "blocks-request", network.TopicIDUnspecified, false}, + {TypeBlocksResponse, "blocks-response", network.TopicIDUnspecified, false}, + } + + for _, tc := range testCases { + msg := MakeMessage(tc.msgType) + assert.Equal(t, tc.typeName, msg.Type().String()) + assert.Equal(t, tc.topicID, msg.TopicID()) + assert.Equal(t, tc.shouldBroadcast, msg.ShouldBroadcast()) + } +} diff --git a/sync/bundle/message/proposal.go b/sync/bundle/message/proposal.go index 51d836b86..b3a406b39 100644 --- a/sync/bundle/message/proposal.go +++ b/sync/bundle/message/proposal.go @@ -1,6 +1,7 @@ package message import ( + "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/types/proposal" ) @@ -24,6 +25,14 @@ func (*ProposalMessage) Type() Type { return TypeProposal } +func (*ProposalMessage) TopicID() network.TopicID { + return network.TopicIDConsensus +} + +func (*ProposalMessage) ShouldBroadcast() bool { + return true +} + func (m *ProposalMessage) String() string { return m.Proposal.String() } diff --git a/sync/bundle/message/query_proposal.go b/sync/bundle/message/query_proposal.go index 16d076bc6..8e0cc945f 100644 --- a/sync/bundle/message/query_proposal.go +++ b/sync/bundle/message/query_proposal.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/pactus-project/pactus/crypto" + "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/util/errors" ) @@ -33,6 +34,14 @@ func (*QueryProposalMessage) Type() Type { return TypeQueryProposal } +func (*QueryProposalMessage) TopicID() network.TopicID { + return network.TopicIDConsensus +} + +func (*QueryProposalMessage) ShouldBroadcast() bool { + return true +} + func (m *QueryProposalMessage) String() string { return fmt.Sprintf("{%v %s}", m.Height, m.Querier.ShortString()) } diff --git a/sync/bundle/message/query_votes.go b/sync/bundle/message/query_votes.go index 502f4d49b..4339a9465 100644 --- a/sync/bundle/message/query_votes.go +++ b/sync/bundle/message/query_votes.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/pactus-project/pactus/crypto" + "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/util/errors" ) @@ -33,6 +34,14 @@ func (*QueryVotesMessage) Type() Type { return TypeQueryVote } +func (*QueryVotesMessage) TopicID() network.TopicID { + return network.TopicIDConsensus +} + +func (*QueryVotesMessage) ShouldBroadcast() bool { + return true +} + func (m *QueryVotesMessage) String() string { return fmt.Sprintf("{%d/%d %s}", m.Height, m.Round, m.Querier.ShortString()) } diff --git a/sync/bundle/message/transactions.go b/sync/bundle/message/transactions.go index e4d1b573a..2d5a3ed26 100644 --- a/sync/bundle/message/transactions.go +++ b/sync/bundle/message/transactions.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/types/tx" "github.com/pactus-project/pactus/util/errors" ) @@ -35,6 +36,14 @@ func (*TransactionsMessage) Type() Type { return TypeTransaction } +func (*TransactionsMessage) TopicID() network.TopicID { + return network.TopicIDTransaction +} + +func (*TransactionsMessage) ShouldBroadcast() bool { + return true +} + func (m *TransactionsMessage) String() string { var builder strings.Builder diff --git a/sync/bundle/message/vote.go b/sync/bundle/message/vote.go index 2a3695af0..34a1ed861 100644 --- a/sync/bundle/message/vote.go +++ b/sync/bundle/message/vote.go @@ -1,6 +1,7 @@ package message import ( + "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/types/vote" ) @@ -22,6 +23,14 @@ func (*VoteMessage) Type() Type { return TypeVote } +func (*VoteMessage) TopicID() network.TopicID { + return network.TopicIDConsensus +} + +func (*VoteMessage) ShouldBroadcast() bool { + return true +} + func (m *VoteMessage) String() string { return m.Vote.String() } diff --git a/sync/firewall/errors.go b/sync/firewall/errors.go new file mode 100644 index 000000000..da3bf78e3 --- /dev/null +++ b/sync/firewall/errors.go @@ -0,0 +1,24 @@ +package firewall + +import ( + "errors" + "fmt" + + lp2pcore "github.com/libp2p/go-libp2p/core" +) + +// PeerBannedError is returned when a message received from a banned peer-id or banned address. +type PeerBannedError struct { + PeerID lp2pcore.PeerID + Address string +} + +func (e PeerBannedError) Error() string { + return fmt.Sprintf("peer is banned, peer-id: %s, remote-address: %s", e.PeerID, e.Address) +} + +// ErrGossipMessage is returned when a stream message sends as gossip message. +var ErrGossipMessage = errors.New("receive stream message as gossip message") + +// ErrStreamMessage is returned when a gossip message sends as stream message. +var ErrStreamMessage = errors.New("receive gossip message as stream message") diff --git a/sync/firewall/firewall.go b/sync/firewall/firewall.go index d2d1398d3..eed0a68f8 100644 --- a/sync/firewall/firewall.go +++ b/sync/firewall/firewall.go @@ -12,6 +12,7 @@ import ( "github.com/pactus-project/pactus/sync/bundle" "github.com/pactus-project/pactus/sync/peerset" "github.com/pactus-project/pactus/sync/peerset/peer" + "github.com/pactus-project/pactus/sync/peerset/peer/status" "github.com/pactus-project/pactus/util/errors" "github.com/pactus-project/pactus/util/ipblocker" "github.com/pactus-project/pactus/util/logger" @@ -56,26 +57,29 @@ func NewFirewall(conf *Config, net network.Network, peerSet *peerset.PeerSet, st }, nil } -func (f *Firewall) OpenGossipBundle(data []byte, from peer.ID) *bundle.Bundle { +func (f *Firewall) OpenGossipBundle(data []byte, from peer.ID) (*bundle.Bundle, error) { bdl, err := f.openBundle(bytes.NewReader(data), from) if err != nil { - f.logger.Debug("firewall: unable to open a gossip bundle", + return nil, err + } + + if !bdl.Message.ShouldBroadcast() { + f.logger.Warn("firewall: receive stream message as gossip message", "error", err, "bundle", bdl, "from", from) - return nil - } + f.closeConnection(from) - // TODO: check if gossip flag is set - // TODO: check if bundle is a gossip bundle + return nil, ErrGossipMessage + } - return bdl + return bdl, nil } // IsBannedAddress checks if the remote IP address is banned. func (f *Firewall) IsBannedAddress(remoteAddr string) bool { ip, err := f.getIPFromMultiAddress(remoteAddr) if err != nil { - f.logger.Warn("firewall: unable to parse remote address", "err", err, "addr", remoteAddr) + f.logger.Warn("firewall: unable to parse remote address", "error", err, "addr", remoteAddr) return false } @@ -83,29 +87,49 @@ func (f *Firewall) IsBannedAddress(remoteAddr string) bool { return f.ipBlocker.IsBanned(ip) } -func (f *Firewall) OpenStreamBundle(r io.Reader, from peer.ID) *bundle.Bundle { +func (f *Firewall) OpenStreamBundle(r io.Reader, from peer.ID) (*bundle.Bundle, error) { bdl, err := f.openBundle(r, from) if err != nil { f.logger.Debug("firewall: unable to open a stream bundle", "error", err, "bundle", bdl, "from", from) - return nil + return nil, err } - // TODO: check if gossip flag is NOT set - // TODO: check if bundle is a stream bundle + if bdl.Message.ShouldBroadcast() { + f.logger.Warn("firewall: receive gossip message as stream message", + "error", err, "bundle", bdl, "from", from) + + f.closeConnection(from) - return bdl + return nil, ErrStreamMessage + } + + return bdl, nil } func (f *Firewall) openBundle(r io.Reader, from peer.ID) (*bundle.Bundle, error) { f.peerSet.UpdateLastReceived(from) f.peerSet.IncreaseReceivedBundlesCounter(from) - if f.isPeerBanned(from) { + p := f.peerSet.GetPeer(from) + if p.Status.IsBanned() { f.closeConnection(from) - return nil, errors.Errorf(errors.ErrInvalidMessage, "peer is banned: %s", from) + return nil, PeerBannedError{ + PeerID: p.PeerID, + Address: p.Address, + } + } + + if f.IsBannedAddress(p.Address) { + f.closeConnection(from) + f.peerSet.UpdateStatus(from, status.StatusBanned) + + return nil, PeerBannedError{ + PeerID: p.PeerID, + Address: p.Address, + } } bdl, err := f.decodeBundle(r, from) @@ -163,12 +187,6 @@ func (f *Firewall) checkBundle(bdl *bundle.Bundle) error { return nil } -func (f *Firewall) isPeerBanned(pid peer.ID) bool { - p := f.peerSet.GetPeer(pid) - - return p.Status.IsBanned() -} - func (f *Firewall) closeConnection(pid peer.ID) { f.network.CloseConnection(pid) } diff --git a/sync/firewall/firewall_test.go b/sync/firewall/firewall_test.go index 2c3aa8c29..067c831dc 100644 --- a/sync/firewall/firewall_test.go +++ b/sync/firewall/firewall_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/pactus-project/pactus/genesis" "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/state" "github.com/pactus-project/pactus/sync/bundle" @@ -23,7 +24,7 @@ type testData struct { *testsuite.TestSuite firewall *Firewall - badPeerID peer.ID + bannedPeerID peer.ID goodPeerID peer.ID unknownPeerID peer.ID network *network.MockNetwork @@ -50,117 +51,211 @@ func setup(t *testing.T, conf *Config) *testData { } assert.NotNil(t, firewall) - badPeerID := ts.RandPeerID() + bannedPeerID := ts.RandPeerID() goodPeerID := ts.RandPeerID() unknownPeerID := ts.RandPeerID() net.AddAnotherNetwork(network.MockingNetwork(ts, goodPeerID)) net.AddAnotherNetwork(network.MockingNetwork(ts, unknownPeerID)) - net.AddAnotherNetwork(network.MockingNetwork(ts, badPeerID)) + net.AddAnotherNetwork(network.MockingNetwork(ts, bannedPeerID)) firewall.peerSet.UpdateStatus(goodPeerID, status.StatusKnown) - firewall.peerSet.UpdateStatus(badPeerID, status.StatusBanned) + firewall.peerSet.UpdateStatus(bannedPeerID, status.StatusBanned) return &testData{ TestSuite: ts, firewall: firewall, network: net, state: st, - badPeerID: badPeerID, + bannedPeerID: bannedPeerID, goodPeerID: goodPeerID, unknownPeerID: unknownPeerID, } } -func TestInvalidBundlesCounter(t *testing.T) { - td := setup(t, nil) +func (td *testData) testGossipBundle() []byte { + bdl := bundle.NewBundle(message.NewQueryVotesMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) + bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkMainnet) + d, _ := bdl.Encode() - assert.Nil(t, td.firewall.OpenGossipBundle([]byte("bad"), td.unknownPeerID)) - assert.Nil(t, td.firewall.OpenGossipBundle(nil, td.unknownPeerID)) + return d +} - bdl := bundle.NewBundle(message.NewQueryVotesMessage(td.RandHeight(), -1, td.RandValAddress())) - bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) +func (td *testData) testStreamBundle() []byte { + bdl := bundle.NewBundle(message.NewBlocksRequestMessage(td.RandInt(100), 1, 100)) + bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkMainnet) d, _ := bdl.Encode() - assert.Nil(t, td.firewall.OpenGossipBundle(d, td.unknownPeerID)) + + return d +} + +func TestDecodeBundles(t *testing.T) { + td := setup(t, nil) + + testCases := []struct { + name string + data string + peerID string + wantErr bool + }{ + { + name: "invalid data", + data: "bad0", + wantErr: true, + }, + { + name: "nil data", + data: "", + wantErr: true, + }, + { + name: "invalid bundle (round is -1)", + data: "a4" + // Map with 4 key-value pairs + "01" + "01" + // Key 1 (Flags), Value: 1 (Mainnet) + "02" + "06" + // Key 2 (Message Type), Value: 6 (QueryVote) + "03" + "581d" + // Key 2 (Message), Value: 30 Bytes + "" + "a3" + // Map with 3 key-value pairs + "" + "01" + "1864" + // Key 1 (Height), Value: 100 + "" + "02" + "20" + // Key 2 (Round), Value: -1 + "" + "03" + "5501aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + // Key 3 (Querier), Value: 21 Bytes + "04" + "05", // Key 4 (Sequence number), Value: 5 + wantErr: true, + }, + + { + name: "valid bundle (invalid network, Testnet)", + data: "a4" + // Map with 4 key-value pairs + "01" + "02" + // Key 1 (Flags), Value: 1 (Testnet) + "02" + "06" + // Key 2 (Message Type), Value: 6 (QueryVote) + "03" + "581d" + // Key 2 (Message), Value: 30 Bytes + "" + "a3" + // Map with 3 key-value pairs + "" + "01" + "1864" + // Key 1 (Height), Value: 100 + "" + "02" + "00" + // Key 2 (Round), Value: 0 + "" + "03" + "5501aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + // Key 3 (Querier), Value: 21 Bytes + "04" + "05", // Key 4 (Sequence number), Value: 5 + wantErr: true, + }, + { + name: "valid bundle", + data: "a4" + // Map with 4 key-value pairs + "01" + "01" + // Key 1 (Flags), Value: 1 (Mainnet) + "02" + "06" + // Key 2 (Message Type), Value: 6 (QueryVote) + "03" + "581d" + // Key 2 (Message), Value: 30 Bytes + "" + "a3" + // Map with 3 key-value pairs + "" + "01" + "1864" + // Key 1 (Height), Value: 100 + "" + "02" + "00" + // Key 2 (Round), Value: 0 + "" + "03" + "5501aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + // Key 3 (Querier), Value: 21 Bytes + "04" + "05", // Key 4 (Sequence number), Value: 5 + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + bs := td.DecodingHex(tc.data) + _, err := td.firewall.OpenGossipBundle(bs, td.unknownPeerID) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } p := td.firewall.peerSet.GetPeer(td.unknownPeerID) - assert.Equal(t, p.InvalidBundles, 3) + assert.Equal(t, 5, p.ReceivedBundles) + assert.Equal(t, 4, p.InvalidBundles) } func TestGossipMessage(t *testing.T) { - t.Run("Message from: unknown => should NOT close the connection", func(t *testing.T) { + t.Run("Message is nil", func(t *testing.T) { td := setup(t, nil) - bdl := bundle.NewBundle(message.NewQueryVotesMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) - bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) - d, _ := bdl.Encode() - - assert.False(t, td.network.IsClosed(td.unknownPeerID)) - assert.NotNil(t, td.firewall.OpenGossipBundle(d, td.unknownPeerID)) + _, err := td.firewall.OpenGossipBundle(nil, td.unknownPeerID) + require.Error(t, err) assert.False(t, td.network.IsClosed(td.unknownPeerID)) }) - t.Run("Message from: bad => should close the connection", func(t *testing.T) { + t.Run("Message from banned peer", func(t *testing.T) { td := setup(t, nil) - bdl := bundle.NewBundle(message.NewQueryVotesMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) - bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) - d, _ := bdl.Encode() + data := td.testGossipBundle() - assert.False(t, td.network.IsClosed(td.badPeerID)) - assert.Nil(t, td.firewall.OpenGossipBundle(d, td.badPeerID)) - assert.True(t, td.network.IsClosed(td.badPeerID)) + assert.False(t, td.network.IsClosed(td.bannedPeerID)) + _, err := td.firewall.OpenGossipBundle(data, td.bannedPeerID) + require.ErrorIs(t, err, PeerBannedError{ + PeerID: td.bannedPeerID, + Address: "", + }) + assert.True(t, td.network.IsClosed(td.bannedPeerID)) }) - t.Run("Message is nil => should close the connection", func(t *testing.T) { + t.Run("Stream message as gossip message", func(t *testing.T) { td := setup(t, nil) - assert.Nil(t, td.firewall.OpenGossipBundle(nil, td.unknownPeerID)) + data := td.testStreamBundle() + + assert.False(t, td.network.IsClosed(td.unknownPeerID)) + _, err := td.firewall.OpenGossipBundle(data, td.unknownPeerID) + require.ErrorIs(t, err, ErrGossipMessage) + assert.True(t, td.network.IsClosed(td.unknownPeerID)) }) - t.Run("Ok => should NOT close the connection", func(t *testing.T) { + t.Run("Ok", func(t *testing.T) { td := setup(t, nil) - bdl := bundle.NewBundle(message.NewQueryVotesMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) - bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) - d, _ := bdl.Encode() + data := td.testGossipBundle() assert.False(t, td.network.IsClosed(td.goodPeerID)) - assert.NotNil(t, td.firewall.OpenGossipBundle(d, td.goodPeerID)) + _, err := td.firewall.OpenGossipBundle(data, td.goodPeerID) + require.NoError(t, err) assert.False(t, td.network.IsClosed(td.goodPeerID)) }) } func TestStreamMessage(t *testing.T) { - t.Run("Message is nil => should close the connection", func(t *testing.T) { + t.Run("Message is nil", func(t *testing.T) { td := setup(t, nil) - assert.False(t, td.network.IsClosed(td.badPeerID)) - assert.Nil(t, td.firewall.OpenStreamBundle(bytes.NewReader(nil), td.badPeerID)) - assert.True(t, td.network.IsClosed(td.badPeerID)) + assert.False(t, td.network.IsClosed(td.unknownPeerID)) + _, err := td.firewall.OpenStreamBundle(bytes.NewReader(nil), td.unknownPeerID) + assert.Error(t, err) + }) + + t.Run("Message from banned peer", func(t *testing.T) { + td := setup(t, nil) + + data := td.testStreamBundle() + + assert.False(t, td.network.IsClosed(td.bannedPeerID)) + _, err := td.firewall.OpenStreamBundle(bytes.NewReader(data), td.bannedPeerID) + assert.ErrorIs(t, err, PeerBannedError{ + PeerID: td.bannedPeerID, + Address: "", + }) + + assert.True(t, td.network.IsClosed(td.bannedPeerID)) }) - t.Run("Message from: bad => should close the connection", func(t *testing.T) { + t.Run("Gossip message as direct message", func(t *testing.T) { td := setup(t, nil) - bdl := bundle.NewBundle(message.NewBlocksRequestMessage(td.RandInt(100), 1, 100)) - bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) - d, _ := bdl.Encode() + data := td.testGossipBundle() - assert.False(t, td.network.IsClosed(td.badPeerID)) - assert.Nil(t, td.firewall.OpenStreamBundle(bytes.NewReader(d), td.badPeerID)) - assert.True(t, td.network.IsClosed(td.badPeerID)) + assert.False(t, td.network.IsClosed(td.unknownPeerID)) + _, err := td.firewall.OpenStreamBundle(bytes.NewReader(data), td.unknownPeerID) + require.ErrorIs(t, err, ErrStreamMessage) + assert.True(t, td.network.IsClosed(td.unknownPeerID)) }) - t.Run("Ok => should NOT close the connection", func(t *testing.T) { + t.Run("Ok", func(t *testing.T) { td := setup(t, nil) - bdl := bundle.NewBundle(message.NewBlocksRequestMessage(td.RandInt(100), 1, 100)) - bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) - d, _ := bdl.Encode() + data := td.testStreamBundle() assert.False(t, td.network.IsClosed(td.goodPeerID)) - assert.NotNil(t, td.firewall.OpenStreamBundle(bytes.NewReader(d), td.goodPeerID)) + _, err := td.firewall.OpenStreamBundle(bytes.NewReader(data), td.goodPeerID) + require.NoError(t, err) assert.False(t, td.network.IsClosed(td.goodPeerID)) }) } @@ -168,11 +263,10 @@ func TestStreamMessage(t *testing.T) { func TestUpdateLastReceived(t *testing.T) { td := setup(t, nil) - bdl := bundle.NewBundle(message.NewQueryVotesMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) - bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) - d, _ := bdl.Encode() + data := td.testGossipBundle() now := time.Now().UnixNano() - assert.NotNil(t, td.firewall.OpenGossipBundle(d, td.goodPeerID)) + _, err := td.firewall.OpenGossipBundle(data, td.goodPeerID) + require.NoError(t, err) peerGood := td.firewall.peerSet.GetPeer(td.goodPeerID) assert.GreaterOrEqual(t, peerGood.LastReceived.UnixNano(), now) @@ -210,22 +304,43 @@ func TestBannedAddress(t *testing.T) { } for i, tc := range testCases { - banned := td.firewall.IsBannedAddress(tc.addr) + peerID := td.RandPeerID() + td.firewall.peerSet.UpdateAddress(peerID, tc.addr, "inbound") + data := td.testGossipBundle() + _, err := td.firewall.OpenGossipBundle(data, peerID) if tc.banned { - assert.True(t, banned, + expectedErr := PeerBannedError{ + PeerID: peerID, + Address: tc.addr, + } + assert.ErrorIs(t, err, expectedErr, "test %v failed, addr %v should be banned", i, tc.addr) } else { - assert.False(t, banned, + assert.NoError(t, err, "test %v failed, addr %v should not be banned", i, tc.addr) } } } -func TestNetworkFlags(t *testing.T) { +func TestNetworkFlagsMainnet(t *testing.T) { td := setup(t, nil) - // TODO: add tests for Mainnet and Testnet flags + bdl := bundle.NewBundle(message.NewQueryVotesMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) + bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkMainnet) + assert.NoError(t, td.firewall.checkBundle(bdl)) + + bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) + assert.Error(t, td.firewall.checkBundle(bdl)) + + bdl.Flags = 0 + assert.Error(t, td.firewall.checkBundle(bdl)) +} + +func TestNetworkFlagsTestnet(t *testing.T) { + td := setup(t, nil) + td.state.TestGenesis = genesis.TestnetGenesis() + bdl := bundle.NewBundle(message.NewQueryVotesMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) assert.NoError(t, td.firewall.checkBundle(bdl)) @@ -235,10 +350,21 @@ func TestNetworkFlags(t *testing.T) { bdl.Flags = 0 assert.Error(t, td.firewall.checkBundle(bdl)) +} +func TestNetworkFlagsLocalnet(t *testing.T) { + td := setup(t, nil) td.state.TestParams.BlockVersion = 0x3f // changing genesis hash - bdl.Flags = 1 + + bdl := bundle.NewBundle(message.NewQueryVotesMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) + bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) assert.Error(t, td.firewall.checkBundle(bdl)) + + bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkMainnet) + assert.Error(t, td.firewall.checkBundle(bdl)) + + bdl.Flags = 0 + assert.NoError(t, td.firewall.checkBundle(bdl)) } func TestParseP2PAddr(t *testing.T) { diff --git a/sync/handler.go b/sync/handler.go index 9dea72bcc..1e78a1776 100644 --- a/sync/handler.go +++ b/sync/handler.go @@ -7,6 +7,6 @@ import ( ) type messageHandler interface { - ParseMessage(message.Message, peer.ID) error + ParseMessage(message.Message, peer.ID) PrepareBundle(message.Message) *bundle.Bundle } diff --git a/sync/handler_block_announce.go b/sync/handler_block_announce.go index 8e004bfd9..34685e910 100644 --- a/sync/handler_block_announce.go +++ b/sync/handler_block_announce.go @@ -16,29 +16,23 @@ func newBlockAnnounceHandler(sync *synchronizer) messageHandler { } } -func (handler *blockAnnounceHandler) ParseMessage(m message.Message, pid peer.ID) error { +func (handler *blockAnnounceHandler) ParseMessage(m message.Message, pid peer.ID) { msg := m.(*message.BlockAnnounceMessage) handler.logger.Trace("parsing BlockAnnounce message", "msg", msg) if handler.cache.HasBlockInCache(msg.Height()) { // We have processed this block before. - return nil + return } + handler.peerSet.UpdateHeight(pid, msg.Height(), msg.Block.Hash()) handler.cache.AddCertificate(msg.Certificate) handler.cache.AddBlock(msg.Block) - err := handler.tryCommitBlocks() - if err != nil { - return err - } + handler.tryCommitBlocks() handler.moveConsensusToNewHeight() - - handler.peerSet.UpdateHeight(pid, msg.Height(), msg.Block.Hash()) handler.updateBlockchain() - - return nil } func (*blockAnnounceHandler) PrepareBundle(m message.Message) *bundle.Bundle { diff --git a/sync/handler_block_announce_test.go b/sync/handler_block_announce_test.go index 5ef568065..8f929b66b 100644 --- a/sync/handler_block_announce_test.go +++ b/sync/handler_block_announce_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/pactus-project/pactus/sync/bundle/message" - "github.com/pactus-project/pactus/types/certificate" "github.com/stretchr/testify/assert" ) @@ -22,31 +21,21 @@ func TestParsingBlockAnnounceMessages(t *testing.T) { msg2 := message.NewBlockAnnounceMessage(blk2, cert2) t.Run("Receiving new block announce message, without committing previous block", func(t *testing.T) { - assert.NoError(t, td.receivingNewMessage(td.sync, msg2, pid)) + td.receivingNewMessage(td.sync, msg2, pid) - assert.Equal(t, td.sync.state.LastBlockHeight(), lastHeight) + stateHeight := td.sync.state.LastBlockHeight() + consHeight, _ := td.consMgr.HeightRound() + assert.Equal(t, lastHeight, stateHeight) + assert.Equal(t, lastHeight+1, consHeight) }) t.Run("Receiving missed block, should commit both blocks", func(t *testing.T) { - assert.NoError(t, td.receivingNewMessage(td.sync, msg1, pid)) + td.receivingNewMessage(td.sync, msg1, pid) assert.Equal(t, td.sync.state.LastBlockHeight(), lastHeight+2) }) } -func TestInvalidBlockAnnounce(t *testing.T) { - td := setup(t, nil) - - pid := td.RandPeerID() - height := td.state.LastBlockHeight() + 1 - blk, _ := td.GenerateTestBlock(height) - invCert := certificate.NewBlockCertificate(height, 0) - msg := message.NewBlockAnnounceMessage(blk, invCert) - - err := td.receivingNewMessage(td.sync, msg, pid) - assert.Error(t, err) -} - func TestBroadcastingBlockAnnounceMessages(t *testing.T) { td := setup(t, nil) @@ -57,3 +46,21 @@ func TestBroadcastingBlockAnnounceMessages(t *testing.T) { msg1 := td.shouldPublishMessageWithThisType(t, message.TypeBlockAnnounce) assert.Equal(t, msg1.Message.(*message.BlockAnnounceMessage).Certificate.Height(), msg.Certificate.Height()) } + +func TestCacheAnnouncedBlock(t *testing.T) { + td := setup(t, nil) + + height := td.RandHeight() + blk1, cert1 := td.GenerateTestBlock(height) + blk2, cert2 := td.GenerateTestBlock(height) + msg1 := message.NewBlockAnnounceMessage(blk1, cert1) + msg2 := message.NewBlockAnnounceMessage(blk2, cert2) + + td.receivingNewMessage(td.sync, msg1, td.RandPeerID()) + td.receivingNewMessage(td.sync, msg2, td.RandPeerID()) + + cachedBlock := td.sync.cache.GetBlock(height) + cachedCert := td.sync.cache.GetCertificate(height) + assert.Equal(t, cachedBlock, blk1) + assert.Equal(t, cachedCert, cert1) +} diff --git a/sync/handler_blocks_request.go b/sync/handler_blocks_request.go index c1ccd8ef8..b49196791 100644 --- a/sync/handler_blocks_request.go +++ b/sync/handler_blocks_request.go @@ -19,7 +19,7 @@ func newBlocksRequestHandler(sync *synchronizer) messageHandler { } } -func (handler *blocksRequestHandler) ParseMessage(m message.Message, pid peer.ID) error { +func (handler *blocksRequestHandler) ParseMessage(m message.Message, pid peer.ID) { msg := m.(*message.BlocksRequestMessage) handler.logger.Trace("parsing BlocksRequest message", "msg", msg) @@ -30,7 +30,7 @@ func (handler *blocksRequestHandler) ParseMessage(m message.Message, pid peer.ID handler.respond(response, pid) - return nil + return } if !status.IsKnown() { @@ -39,7 +39,7 @@ func (handler *blocksRequestHandler) ParseMessage(m message.Message, pid peer.ID handler.respond(response, pid) - return nil + return } ourHeight := handler.state.LastBlockHeight() @@ -50,7 +50,7 @@ func (handler *blocksRequestHandler) ParseMessage(m message.Message, pid peer.ID handler.respond(response, pid) - return nil + return } } @@ -60,7 +60,7 @@ func (handler *blocksRequestHandler) ParseMessage(m message.Message, pid peer.ID handler.respond(response, pid) - return nil + return } if msg.Count > handler.config.LatestBlockInterval { @@ -69,7 +69,7 @@ func (handler *blocksRequestHandler) ParseMessage(m message.Message, pid peer.ID handler.respond(response, pid) - return nil + return } // Help this peer to sync up @@ -100,15 +100,13 @@ func (handler *blocksRequestHandler) ParseMessage(m message.Message, pid peer.ID handler.respond(response, pid) - return nil + return } response := message.NewBlocksResponseMessage(message.ResponseCodeNoMoreBlocks, message.ResponseCodeNoMoreBlocks.String(), msg.SessionID, 0, nil, nil) handler.respond(response, pid) - - return nil } func (*blocksRequestHandler) PrepareBundle(m message.Message) *bundle.Bundle { diff --git a/sync/handler_blocks_request_test.go b/sync/handler_blocks_request_test.go index 662f03d1c..22a4ccc3b 100644 --- a/sync/handler_blocks_request_test.go +++ b/sync/handler_blocks_request_test.go @@ -24,7 +24,7 @@ func TestBlocksRequestMessages(t *testing.T) { t.Run("Reject request from unknown peers", func(t *testing.T) { pid := td.RandPeerID() msg := message.NewBlocksRequestMessage(sid, curHeight-1, 1) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl := td.shouldPublishMessageWithThisType(t, message.TypeBlocksResponse) res := bdl.Message.(*message.BlocksResponseMessage) @@ -37,7 +37,7 @@ func TestBlocksRequestMessages(t *testing.T) { t.Run("Reject request from peers without handshaking", func(t *testing.T) { pid := td.addPeer(t, status.StatusConnected, service.New(service.None)) msg := message.NewBlocksRequestMessage(sid, curHeight-1, 1) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl := td.shouldPublishMessageWithThisType(t, message.TypeBlocksResponse) res := bdl.Message.(*message.BlocksResponseMessage) @@ -49,7 +49,7 @@ func TestBlocksRequestMessages(t *testing.T) { t.Run("Peer requested blocks that we don't have", func(t *testing.T) { msg := message.NewBlocksRequestMessage(sid, curHeight+1, 1) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl := td.shouldPublishMessageWithThisType(t, message.TypeBlocksResponse) res := bdl.Message.(*message.BlocksResponseMessage) @@ -59,7 +59,7 @@ func TestBlocksRequestMessages(t *testing.T) { t.Run("Reject requests not within `LatestBlockInterval`", func(t *testing.T) { msg := message.NewBlocksRequestMessage(sid, curHeight-config.LatestBlockInterval-1, 1) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl := td.shouldPublishMessageWithThisType(t, message.TypeBlocksResponse) res := bdl.Message.(*message.BlocksResponseMessage) @@ -69,7 +69,7 @@ func TestBlocksRequestMessages(t *testing.T) { t.Run("Request blocks more than `LatestBlockInterval`", func(t *testing.T) { msg := message.NewBlocksRequestMessage(sid, 10, config.LatestBlockInterval+1) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl := td.shouldPublishMessageWithThisType(t, message.TypeBlocksResponse) res := bdl.Message.(*message.BlocksResponseMessage) @@ -80,7 +80,7 @@ func TestBlocksRequestMessages(t *testing.T) { t.Run("Accept request within `LatestBlockInterval`", func(t *testing.T) { t.Run("Peer needs more block", func(t *testing.T) { msg := message.NewBlocksRequestMessage(sid, curHeight-config.BlockPerMessage, config.BlockPerMessage) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl1 := td.shouldPublishMessageWithThisType(t, message.TypeBlocksResponse) res1 := bdl1.Message.(*message.BlocksResponseMessage) @@ -99,7 +99,7 @@ func TestBlocksRequestMessages(t *testing.T) { t.Run("Peer synced", func(t *testing.T) { msg := message.NewBlocksRequestMessage(sid, curHeight-config.BlockPerMessage+1, config.BlockPerMessage) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl1 := td.shouldPublishMessageWithThisType(t, message.TypeBlocksResponse) res1 := bdl1.Message.(*message.BlocksResponseMessage) @@ -124,7 +124,7 @@ func TestBlocksRequestMessages(t *testing.T) { t.Run("Requesting one block", func(t *testing.T) { msg := message.NewBlocksRequestMessage(sid, 1, 2) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) msg1 := td.shouldPublishMessageWithThisType(t, message.TypeBlocksResponse) assert.Equal(t, msg1.Message.(*message.BlocksResponseMessage).ResponseCode, message.ResponseCodeMoreBlocks) diff --git a/sync/handler_blocks_response.go b/sync/handler_blocks_response.go index 42b9a3c7e..968766bbd 100644 --- a/sync/handler_blocks_response.go +++ b/sync/handler_blocks_response.go @@ -17,15 +17,16 @@ func newBlocksResponseHandler(sync *synchronizer) messageHandler { } } -func (handler *blocksResponseHandler) ParseMessage(m message.Message, pid peer.ID) error { +func (handler *blocksResponseHandler) ParseMessage(m message.Message, pid peer.ID) { msg := m.(*message.BlocksResponseMessage) handler.logger.Trace("parsing BlocksResponse message", "msg", msg) if msg.IsRequestRejected() { - handler.logger.Warn("blocks request is rejected", "pid", pid, "reason", msg.Reason, "sid", msg.SessionID) + handler.logger.Warn("blocks request is rejected", "pid", pid, + "reason", msg.Reason, "sid", msg.SessionID) } else { handler.logger.Info("blocks received", "from", msg.From, "count", msg.Count(), - "pid", pid, "reason", msg.Reason, "sid", msg.SessionID) + "pid", pid, "sid", msg.SessionID) // TODO: // It is good to check the latest height before adding blocks to the cache. @@ -34,20 +35,17 @@ func (handler *blocksResponseHandler) ParseMessage(m message.Message, pid peer.I for _, data := range msg.CommittedBlocksData { blk, err := block.FromBytes(data) if err != nil { - return err + handler.logger.Warn("unable to decode block data", + "from", msg.From, "pid", pid, "error", err) + } else { + handler.cache.AddBlock(blk) } - handler.cache.AddBlock(blk) } handler.cache.AddCertificate(msg.LastCertificate) - err := handler.tryCommitBlocks() - if err != nil { - return err - } + handler.tryCommitBlocks() } handler.updateSession(msg.SessionID, msg.ResponseCode) - - return nil } func (*blocksResponseHandler) PrepareBundle(m message.Message) *bundle.Bundle { diff --git a/sync/handler_blocks_response_test.go b/sync/handler_blocks_response_test.go index 6c162e58b..42093b6fd 100644 --- a/sync/handler_blocks_response_test.go +++ b/sync/handler_blocks_response_test.go @@ -2,7 +2,6 @@ package sync import ( "fmt" - "io" "testing" "time" @@ -10,7 +9,6 @@ import ( "github.com/pactus-project/pactus/crypto/bls" "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/state" - "github.com/pactus-project/pactus/store" "github.com/pactus-project/pactus/sync/bundle/message" "github.com/pactus-project/pactus/sync/peerset/peer/service" "github.com/pactus-project/pactus/sync/peerset/peer/status" @@ -26,27 +24,14 @@ func TestInvalidBlockData(t *testing.T) { td := setup(t, nil) td.state.CommitTestBlocks(10) - lastHeight := td.state.LastBlockHeight() - prevCert := td.GenerateTestBlockCertificate(lastHeight) - cert := td.GenerateTestBlockCertificate(lastHeight + 1) - blk := block.MakeBlock(1, time.Now(), nil, td.RandHash(), td.RandHash(), - prevCert, td.RandSeed(), td.RandValAddress()) + blk, cert := td.GenerateTestBlock(lastHeight+1, testsuite.BlockWithPrevCert(nil)) data, _ := blk.Bytes() tests := []struct { data []byte - err error }{ - { - td.RandBytes(16), - io.ErrUnexpectedEOF, - }, - { - data, - block.BasicCheckError{ - Reason: "no subsidy transaction", - }, - }, + {data: td.RandBytes(16)}, + {data: data}, } for _, test := range tests { @@ -56,8 +41,8 @@ func TestInvalidBlockData(t *testing.T) { message.ResponseCodeMoreBlocks.String(), sid, lastHeight+1, [][]byte{test.data}, cert) - err := td.receivingNewMessage(td.sync, msg, pid) - assert.ErrorIs(t, err, test.err) + td.receivingNewMessage(td.sync, msg, pid) + assert.Nil(t, td.sync.cache.GetBlock(msg.From)) } } @@ -74,7 +59,7 @@ func TestOneBlockShorter(t *testing.T) { sid := td.RandInt(1000) msg := message.NewBlocksResponseMessage(message.ResponseCodeSynced, t.Name(), sid, lastHeight+1, [][]byte{d1}, cert1) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) assert.Equal(t, td.state.LastBlockHeight(), lastHeight+1) } @@ -87,59 +72,56 @@ func TestStrippedPublicKey(t *testing.T) { lastHeight := td.state.LastBlockHeight() // Add a new block and keep the signer key - indexedPub, indexedPrv := td.RandBLSKeyPair() - trx0 := tx.NewTransferTx(lastHeight, indexedPub.AccountAddress(), td.RandAccAddress(), 1, 1, "") - td.HelperSignTransaction(indexedPrv, trx0) + _, indexedPrv := td.RandBLSKeyPair() + trx0 := td.GenerateTestTransferTx(testsuite.TransactionWithSigner(indexedPrv)) trxs0 := []*tx.Tx{trx0} - blk0 := block.MakeBlock(1, time.Now(), trxs0, td.RandHash(), td.RandHash(), - td.state.LastCertificate(), td.RandSeed(), td.RandValAddress()) - cert0 := td.GenerateTestBlockCertificate(lastHeight + 1) + blk0, cert0 := td.GenerateTestBlock(lastHeight+1, testsuite.BlockWithTransactions(trxs0)) err := td.state.CommitBlock(blk0, cert0) require.NoError(t, err) lastHeight++ // ----- - rndPub, rndPrv := td.RandBLSKeyPair() - trx1 := tx.NewTransferTx(lastHeight, rndPub.AccountAddress(), td.RandAccAddress(), 1, 1, "") - td.HelperSignTransaction(rndPrv, trx1) + _, rndPrv := td.RandBLSKeyPair() + trx1 := td.GenerateTestTransferTx(testsuite.TransactionWithSigner(rndPrv)) trx1.StripPublicKey() trxs1 := []*tx.Tx{trx1} - blk1 := block.MakeBlock(1, time.Now(), trxs1, td.RandHash(), td.RandHash(), - cert0, td.RandSeed(), td.RandValAddress()) + blk1, _ := td.GenerateTestBlock(lastHeight+1, testsuite.BlockWithTransactions(trxs1)) - trx2 := tx.NewTransferTx(lastHeight, indexedPub.AccountAddress(), td.RandAccAddress(), 1, 1, "") - td.HelperSignTransaction(indexedPrv, trx2) + trx2 := td.GenerateTestTransferTx(testsuite.TransactionWithSigner(indexedPrv)) trx2.StripPublicKey() trxs2 := []*tx.Tx{trx2} - blk2 := block.MakeBlock(1, time.Now(), trxs2, td.RandHash(), td.RandHash(), - cert0, td.RandSeed(), td.RandValAddress()) + blk2, _ := td.GenerateTestBlock(lastHeight+1, testsuite.BlockWithTransactions(trxs2)) tests := []struct { - blk *block.Block - err error + receivedBlock *block.Block + shouldFail bool }{ { - blk1, - store.ErrNotFound, + receivedBlock: blk1, + shouldFail: true, }, { - blk2, - nil, + receivedBlock: blk2, + shouldFail: false, }, } // Add a peer pid := td.addPeer(t, status.StatusKnown, service.New(service.None)) - for _, test := range tests { - blkData, _ := test.blk.Bytes() + for _, tc := range tests { + blkData, _ := tc.receivedBlock.Bytes() sid := td.RandInt(1000) cert := td.GenerateTestBlockCertificate(lastHeight + 1) msg := message.NewBlocksResponseMessage(message.ResponseCodeMoreBlocks, message.ResponseCodeMoreBlocks.String(), sid, lastHeight+1, [][]byte{blkData}, cert) - err := td.receivingNewMessage(td.sync, msg, pid) + td.receivingNewMessage(td.sync, msg, pid) - assert.ErrorIs(t, err, test.err) + if tc.shouldFail { + assert.Nil(t, td.sync.cache.GetBlock(msg.From)) + } else { + assert.NotNil(t, td.sync.cache.GetBlock(msg.From)) + } } } @@ -186,8 +168,8 @@ func makeAliceAndBobNetworks(t *testing.T) *networkAliceBob { valKeyBob := []*bls.ValidatorKey{ts.RandValKey()} stateAlice := state.MockingState(ts) stateBob := state.MockingState(ts) - consMgrAlice, _ := consensus.MockingManager(ts, valKeyAlice) - consMgrBob, _ := consensus.MockingManager(ts, valKeyBob) + consMgrAlice, _ := consensus.MockingManager(ts, stateAlice, valKeyAlice) + consMgrBob, _ := consensus.MockingManager(ts, stateBob, valKeyBob) internalMessageCh := make(chan message.Message, 1000) networkAlice := network.MockingNetwork(ts, ts.RandPeerID()) networkBob := network.MockingNetwork(ts, ts.RandPeerID()) @@ -244,7 +226,7 @@ func makeAliceAndBobNetworks(t *testing.T) *networkAliceBob { require.Eventually(t, func() bool { return syncAlice.PeerSet().Len() == 1 && syncBob.PeerSet().Len() == 1 - }, time.Second, 100*time.Millisecond) + }, 2*time.Second, 100*time.Millisecond) require.Equal(t, status.StatusKnown, syncAlice.PeerSet().GetPeerStatus(syncBob.SelfID())) require.Equal(t, status.StatusKnown, syncBob.PeerSet().GetPeerStatus(syncAlice.SelfID())) diff --git a/sync/handler_hello.go b/sync/handler_hello.go index 6dd64fc23..11568cb37 100644 --- a/sync/handler_hello.go +++ b/sync/handler_hello.go @@ -23,7 +23,7 @@ func newHelloHandler(sync *synchronizer) messageHandler { } } -func (handler *helloHandler) ParseMessage(m message.Message, pid peer.ID) error { +func (handler *helloHandler) ParseMessage(m message.Message, pid peer.ID) { msg := m.(*message.HelloMessage) handler.logger.Trace("parsing Hello message", "msg", msg) @@ -45,7 +45,7 @@ func (handler *helloHandler) ParseMessage(m message.Message, pid peer.ID) error handler.acknowledge(response, pid) - return nil + return } if msg.GenesisHash != handler.state.Genesis().Hash() { @@ -55,7 +55,7 @@ func (handler *helloHandler) ParseMessage(m message.Message, pid peer.ID) error handler.acknowledge(response, pid) - return nil + return } if math.Abs(time.Since(msg.MyTime()).Seconds()) > 10 { @@ -64,7 +64,7 @@ func (handler *helloHandler) ParseMessage(m message.Message, pid peer.ID) error handler.acknowledge(response, pid) - return nil + return } agent, _ := version.ParseAgent(msg.Agent) @@ -74,7 +74,7 @@ func (handler *helloHandler) ParseMessage(m message.Message, pid peer.ID) error handler.acknowledge(response, pid) - return nil + return } handler.peerSet.UpdateHeight(pid, msg.Height, msg.BlockHash) @@ -82,8 +82,6 @@ func (handler *helloHandler) ParseMessage(m message.Message, pid peer.ID) error response := message.NewHelloAckMessage(message.ResponseCodeOK, "Ok", handler.state.LastBlockHeight()) handler.acknowledge(response, pid) - - return nil } func (*helloHandler) PrepareBundle(m message.Message) *bundle.Bundle { @@ -100,7 +98,6 @@ func (handler *helloHandler) acknowledge(msg *message.HelloAckMessage, to peer.I handler.sendTo(msg, to) handler.peerSet.UpdateStatus(to, status.StatusBanned) - handler.network.CloseConnection(to) } else { handler.logger.Info("acknowledging hello message", "msg", msg, "to", to) diff --git a/sync/handler_hello_ack.go b/sync/handler_hello_ack.go index 593fcf2e5..325a8aa6f 100644 --- a/sync/handler_hello_ack.go +++ b/sync/handler_hello_ack.go @@ -18,7 +18,7 @@ func newHelloAckHandler(sync *synchronizer) messageHandler { } } -func (handler *helloAckHandler) ParseMessage(m message.Message, pid peer.ID) error { +func (handler *helloAckHandler) ParseMessage(m message.Message, pid peer.ID) { msg := m.(*message.HelloAckMessage) handler.logger.Trace("parsing HelloAck message", "msg", msg) @@ -28,7 +28,7 @@ func (handler *helloAckHandler) ParseMessage(m message.Message, pid peer.ID) err handler.network.CloseConnection(pid) - return nil + return } handler.peerSet.UpdateStatus(pid, status.StatusKnown) @@ -37,8 +37,6 @@ func (handler *helloAckHandler) ParseMessage(m message.Message, pid peer.ID) err if msg.Height > handler.state.LastBlockHeight() { handler.updateBlockchain() } - - return nil } func (*helloAckHandler) PrepareBundle(m message.Message) *bundle.Bundle { diff --git a/sync/handler_hello_ack_test.go b/sync/handler_hello_ack_test.go index 8c067bae9..a55174777 100644 --- a/sync/handler_hello_ack_test.go +++ b/sync/handler_hello_ack_test.go @@ -5,7 +5,6 @@ import ( "github.com/pactus-project/pactus/sync/bundle/message" "github.com/pactus-project/pactus/sync/peerset/peer/status" - "github.com/stretchr/testify/assert" ) func TestParsingHelloAckMessages(t *testing.T) { @@ -16,7 +15,8 @@ func TestParsingHelloAckMessages(t *testing.T) { pid := td.RandPeerID() msg := message.NewHelloAckMessage(message.ResponseCodeRejected, "rejected", td.RandHeight()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) + td.checkPeerStatus(t, pid, status.StatusUnknown) }) t.Run("Receiving HelloAck message: OK hello", @@ -24,7 +24,7 @@ func TestParsingHelloAckMessages(t *testing.T) { pid := td.RandPeerID() msg := message.NewHelloAckMessage(message.ResponseCodeOK, "ok", td.RandHeight()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.checkPeerStatus(t, pid, status.StatusKnown) }) } diff --git a/sync/handler_hello_test.go b/sync/handler_hello_test.go index 4faf34dd7..f97b9cb8a 100644 --- a/sync/handler_hello_test.go +++ b/sync/handler_hello_test.go @@ -28,7 +28,7 @@ func TestParsingHelloMessages(t *testing.T) { msg.Sign([]*bls.ValidatorKey{valKey}) from := td.RandPeerID() - assert.NoError(t, td.receivingNewMessage(td.sync, msg, from)) + td.receivingNewMessage(td.sync, msg, from) bdl := td.shouldPublishMessageWithThisType(t, message.TypeHelloAck) assert.Equal(t, bdl.Message.(*message.HelloAckMessage).ResponseCode, message.ResponseCodeRejected) }) @@ -42,7 +42,7 @@ func TestParsingHelloMessages(t *testing.T) { td.state.LastBlockHash(), invGenHash) msg.Sign([]*bls.ValidatorKey{valKey}) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.checkPeerStatus(t, pid, status.StatusBanned) bdl := td.shouldPublishMessageWithThisType(t, message.TypeHelloAck) assert.Equal(t, bdl.Message.(*message.HelloAckMessage).ResponseCode, message.ResponseCodeRejected) @@ -58,7 +58,7 @@ func TestParsingHelloMessages(t *testing.T) { msg.Sign([]*bls.ValidatorKey{valKey}) msg.MyTimeUnixMilli = msg.MyTime().Add(-10 * time.Second).UnixMilli() - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.checkPeerStatus(t, pid, status.StatusBanned) bdl := td.shouldPublishMessageWithThisType(t, message.TypeHelloAck) assert.Equal(t, bdl.Message.(*message.HelloAckMessage).ResponseCode, message.ResponseCodeRejected) @@ -74,7 +74,7 @@ func TestParsingHelloMessages(t *testing.T) { msg.Sign([]*bls.ValidatorKey{valKey}) msg.MyTimeUnixMilli = msg.MyTime().Add(20 * time.Second).UnixMilli() - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.checkPeerStatus(t, pid, status.StatusBanned) bdl := td.shouldPublishMessageWithThisType(t, message.TypeHelloAck) assert.Equal(t, bdl.Message.(*message.HelloAckMessage).ResponseCode, message.ResponseCodeRejected) @@ -96,7 +96,7 @@ func TestParsingHelloMessages(t *testing.T) { msg.Agent = nodeAgent.String() msg.Sign([]*bls.ValidatorKey{valKey}) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.checkPeerStatus(t, pid, status.StatusBanned) bdl := td.shouldPublishMessageWithThisType(t, message.TypeHelloAck) assert.Equal(t, bdl.Message.(*message.HelloAckMessage).ResponseCode, message.ResponseCodeRejected) @@ -112,7 +112,7 @@ func TestParsingHelloMessages(t *testing.T) { msg.Agent = "invalid-agent" msg.Sign([]*bls.ValidatorKey{valKey}) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.checkPeerStatus(t, pid, status.StatusBanned) bdl := td.shouldPublishMessageWithThisType(t, message.TypeHelloAck) assert.Equal(t, bdl.Message.(*message.HelloAckMessage).ResponseCode, message.ResponseCodeRejected) @@ -127,7 +127,7 @@ func TestParsingHelloMessages(t *testing.T) { td.state.LastBlockHash(), td.state.Genesis().Hash()) msg.Sign([]*bls.ValidatorKey{valKey}) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl := td.shouldPublishMessageWithThisType(t, message.TypeHelloAck) assert.Equal(t, bdl.Message.(*message.HelloAckMessage).ResponseCode, message.ResponseCodeOK) diff --git a/sync/handler_proposal.go b/sync/handler_proposal.go index 831de281c..7b16f571c 100644 --- a/sync/handler_proposal.go +++ b/sync/handler_proposal.go @@ -16,13 +16,11 @@ func newProposalHandler(sync *synchronizer) messageHandler { } } -func (handler *proposalHandler) ParseMessage(m message.Message, _ peer.ID) error { +func (handler *proposalHandler) ParseMessage(m message.Message, _ peer.ID) { msg := m.(*message.ProposalMessage) handler.logger.Trace("parsing Proposal message", "msg", msg) handler.consMgr.SetProposal(msg.Proposal) - - return nil } func (*proposalHandler) PrepareBundle(m message.Message) *bundle.Bundle { diff --git a/sync/handler_proposal_test.go b/sync/handler_proposal_test.go index 1c3e5dd74..791d8b9bf 100644 --- a/sync/handler_proposal_test.go +++ b/sync/handler_proposal_test.go @@ -16,7 +16,7 @@ func TestParsingProposalMessages(t *testing.T) { msg := message.NewProposalMessage(prop) pid := td.RandPeerID() - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) assert.NotNil(t, td.consMgr.Proposal()) }) } diff --git a/sync/handler_query_proposal.go b/sync/handler_query_proposal.go index ec29effd7..cf02a996d 100644 --- a/sync/handler_query_proposal.go +++ b/sync/handler_query_proposal.go @@ -16,20 +16,20 @@ func newQueryProposalHandler(sync *synchronizer) messageHandler { } } -func (handler *queryProposalHandler) ParseMessage(m message.Message, _ peer.ID) error { +func (handler *queryProposalHandler) ParseMessage(m message.Message, _ peer.ID) { msg := m.(*message.QueryProposalMessage) handler.logger.Trace("parsing QueryProposal message", "msg", msg) if !handler.consMgr.HasActiveInstance() { handler.logger.Debug("ignoring QueryProposal, not active", "msg", msg) - return nil + return } if !handler.consMgr.HasProposer() { handler.logger.Debug("ignoring QueryProposal, not proposer", "msg", msg) - return nil + return } height, round := handler.consMgr.HeightRound() @@ -37,7 +37,7 @@ func (handler *queryProposalHandler) ParseMessage(m message.Message, _ peer.ID) handler.logger.Debug("ignoring QueryProposal, not same height/round", "msg", msg, "height", height, "round", round) - return nil + return } prop := handler.consMgr.Proposal() @@ -45,8 +45,6 @@ func (handler *queryProposalHandler) ParseMessage(m message.Message, _ peer.ID) response := message.NewProposalMessage(prop) handler.broadcast(response) } - - return nil } func (*queryProposalHandler) PrepareBundle(m message.Message) *bundle.Bundle { diff --git a/sync/handler_query_proposal_test.go b/sync/handler_query_proposal_test.go index 065aa2ff0..1a16dc632 100644 --- a/sync/handler_query_proposal_test.go +++ b/sync/handler_query_proposal_test.go @@ -17,7 +17,7 @@ func TestParsingQueryProposalMessages(t *testing.T) { t.Run("doesn't have active validator", func(t *testing.T) { msg := message.NewQueryProposalMessage(consHeight, consRound, td.RandValAddress()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.shouldNotPublishMessageWithThisType(t, message.TypeProposal) }) @@ -26,7 +26,7 @@ func TestParsingQueryProposalMessages(t *testing.T) { t.Run("not the proposer", func(t *testing.T) { msg := message.NewQueryProposalMessage(consHeight, consRound, td.RandValAddress()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.shouldNotPublishMessageWithThisType(t, message.TypeProposal) }) @@ -35,21 +35,21 @@ func TestParsingQueryProposalMessages(t *testing.T) { t.Run("not the same height", func(t *testing.T) { msg := message.NewQueryProposalMessage(consHeight+1, consRound, td.RandValAddress()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.shouldNotPublishMessageWithThisType(t, message.TypeProposal) }) t.Run("not the same round", func(t *testing.T) { msg := message.NewQueryProposalMessage(consHeight, consRound+1, td.RandValAddress()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.shouldNotPublishMessageWithThisType(t, message.TypeProposal) }) t.Run("should respond to the query proposal message", func(t *testing.T) { msg := message.NewQueryProposalMessage(consHeight, consRound, td.RandValAddress()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl := td.shouldPublishMessageWithThisType(t, message.TypeProposal) assert.Equal(t, bdl.Message.(*message.ProposalMessage).Proposal.Hash(), prop.Hash()) @@ -58,7 +58,7 @@ func TestParsingQueryProposalMessages(t *testing.T) { t.Run("doesn't have the proposal", func(t *testing.T) { td.consMocks[0].CurProposal = nil msg := message.NewQueryProposalMessage(consHeight, consRound, td.RandValAddress()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.shouldNotPublishMessageWithThisType(t, message.TypeProposal) }) diff --git a/sync/handler_query_votes.go b/sync/handler_query_votes.go index d71de440e..b01508072 100644 --- a/sync/handler_query_votes.go +++ b/sync/handler_query_votes.go @@ -16,14 +16,14 @@ func newQueryVotesHandler(sync *synchronizer) messageHandler { } } -func (handler *queryVotesHandler) ParseMessage(m message.Message, _ peer.ID) error { +func (handler *queryVotesHandler) ParseMessage(m message.Message, _ peer.ID) { msg := m.(*message.QueryVotesMessage) handler.logger.Trace("parsing QueryVotes message", "msg", msg) if !handler.consMgr.HasActiveInstance() { handler.logger.Debug("ignoring QueryVotes, not active", "msg", msg) - return nil + return } height, _ := handler.consMgr.HeightRound() @@ -31,7 +31,7 @@ func (handler *queryVotesHandler) ParseMessage(m message.Message, _ peer.ID) err handler.logger.Debug("ignoring QueryVotes, not same height", "msg", msg, "height", height) - return nil + return } v := handler.consMgr.PickRandomVote(msg.Round) @@ -39,8 +39,6 @@ func (handler *queryVotesHandler) ParseMessage(m message.Message, _ peer.ID) err response := message.NewVoteMessage(v) handler.broadcast(response) } - - return nil } func (*queryVotesHandler) PrepareBundle(m message.Message) *bundle.Bundle { diff --git a/sync/handler_query_votes_test.go b/sync/handler_query_votes_test.go index 741208c23..508a50c8f 100644 --- a/sync/handler_query_votes_test.go +++ b/sync/handler_query_votes_test.go @@ -17,7 +17,7 @@ func TestParsingQueryVotesMessages(t *testing.T) { t.Run("doesn't have active validator", func(t *testing.T) { msg := message.NewQueryVotesMessage(consensusHeight, 1, td.RandValAddress()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.shouldNotPublishMessageWithThisType(t, message.TypeVote) }) @@ -26,7 +26,7 @@ func TestParsingQueryVotesMessages(t *testing.T) { t.Run("should respond to the query votes message", func(t *testing.T) { msg := message.NewQueryVotesMessage(consensusHeight, 1, td.RandValAddress()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) bdl := td.shouldPublishMessageWithThisType(t, message.TypeVote) assert.Equal(t, bdl.Message.(*message.VoteMessage).Vote.Hash(), v1.Hash()) @@ -34,7 +34,7 @@ func TestParsingQueryVotesMessages(t *testing.T) { t.Run("doesn't have any votes", func(t *testing.T) { msg := message.NewQueryVotesMessage(consensusHeight+1, 1, td.RandValAddress()) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) td.shouldNotPublishMessageWithThisType(t, message.TypeVote) }) diff --git a/sync/handler_transactions.go b/sync/handler_transactions.go index 150d5bf8a..5edc2199f 100644 --- a/sync/handler_transactions.go +++ b/sync/handler_transactions.go @@ -16,7 +16,7 @@ func newTransactionsHandler(sync *synchronizer) messageHandler { } } -func (handler *transactionsHandler) ParseMessage(m message.Message, _ peer.ID) error { +func (handler *transactionsHandler) ParseMessage(m message.Message, _ peer.ID) { msg := m.(*message.TransactionsMessage) handler.logger.Trace("parsing Transactions message", "msg", msg) @@ -25,8 +25,6 @@ func (handler *transactionsHandler) ParseMessage(m message.Message, _ peer.ID) e handler.logger.Debug("cannot append transaction", "tx", trx, "error", err) } } - - return nil } func (*transactionsHandler) PrepareBundle(m message.Message) *bundle.Bundle { diff --git a/sync/handler_transactions_test.go b/sync/handler_transactions_test.go index 4d69acaa8..f682d5b00 100644 --- a/sync/handler_transactions_test.go +++ b/sync/handler_transactions_test.go @@ -16,8 +16,18 @@ func TestParsingTransactionsMessages(t *testing.T) { msg := message.NewTransactionsMessage([]*tx.Tx{trx1}) pid := td.RandPeerID() - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) assert.NotNil(t, td.sync.state.PendingTx(trx1.ID())) }) } + +func TestBroadcastingTransactionVotesMessages(t *testing.T) { + td := setup(t, nil) + + trx1 := td.GenerateTestBondTx() + msg := message.NewTransactionsMessage([]*tx.Tx{trx1}) + td.sync.broadcast(msg) + + td.shouldPublishMessageWithThisType(t, message.TypeTransaction) +} diff --git a/sync/handler_vote.go b/sync/handler_vote.go index dddcb92aa..23c520de5 100644 --- a/sync/handler_vote.go +++ b/sync/handler_vote.go @@ -16,13 +16,11 @@ func newVoteHandler(sync *synchronizer) messageHandler { } } -func (handler *voteHandler) ParseMessage(m message.Message, _ peer.ID) error { +func (handler *voteHandler) ParseMessage(m message.Message, _ peer.ID) { msg := m.(*message.VoteMessage) handler.logger.Trace("parsing Vote message", "msg", msg) handler.consMgr.AddVote(msg.Vote) - - return nil } func (*voteHandler) PrepareBundle(m message.Message) *bundle.Bundle { diff --git a/sync/handler_vote_test.go b/sync/handler_vote_test.go index ce461c2b7..63211add3 100644 --- a/sync/handler_vote_test.go +++ b/sync/handler_vote_test.go @@ -15,7 +15,7 @@ func TestParsingVoteMessages(t *testing.T) { msg := message.NewVoteMessage(v) pid := td.RandPeerID() - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) assert.Equal(t, td.consMgr.PickRandomVote(0).Hash(), v.Hash()) }) } diff --git a/sync/sync.go b/sync/sync.go index 97a5c8c69..35f7ca260 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -214,7 +214,7 @@ func (sync *synchronizer) broadcast(msg message.Message) { bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagBroadcasted) data, _ := bdl.Encode() - sync.network.Broadcast(data, msg.Type().TopicID()) + sync.network.Broadcast(data, msg.TopicID()) sync.peerSet.IncreaseSentCounters(msg.Type(), int64(len(data)), nil) sync.logger.Debug("bundle broadcasted", "bundle", bdl) @@ -304,43 +304,39 @@ func (sync *synchronizer) receiveLoop() { func (sync *synchronizer) processGossipMessage(msg *network.GossipMessage) { sync.logger.Debug("processing gossip message", "pid", msg.From) - bdl := sync.firewall.OpenGossipBundle(msg.Data, msg.From) - err := sync.processIncomingBundle(bdl, msg.From) + bdl, err := sync.firewall.OpenGossipBundle(msg.Data, msg.From) if err != nil { sync.logger.Debug("error on parsing a Gossip bundle", "from", msg.From, "bundle", bdl, "error", err) + + return } + sync.processIncomingBundle(bdl, msg.From) } func (sync *synchronizer) processStreamMessage(msg *network.StreamMessage) { sync.logger.Debug("processing stream message", "pid", msg.From) - bdl := sync.firewall.OpenStreamBundle(msg.Reader, msg.From) + bdl, err := sync.firewall.OpenStreamBundle(msg.Reader, msg.From) + if err != nil { + sync.logger.Debug("error on parsing a Stream bundle", + "from", msg.From, "bundle", bdl, "error", err) + + return + } if err := msg.Reader.Close(); err != nil { // TODO: write test for me sync.logger.Debug("error on closing stream", "error", err, "source", msg.From) return } - err := sync.processIncomingBundle(bdl, msg.From) - if err != nil { - sync.logger.Debug("error on parsing a Stream bundle", - "source", msg.From, "bundle", bdl, "error", err) - } + sync.processIncomingBundle(bdl, msg.From) } func (sync *synchronizer) processConnectEvent(ce *network.ConnectEvent) { sync.logger.Debug("processing connect event", "pid", ce.PeerID) sync.peerSet.UpdateAddress(ce.PeerID, ce.RemoteAddress, ce.Direction) - - if sync.firewall.IsBannedAddress(ce.RemoteAddress) { - sync.logger.Debug("Peer is blacklisted", "peer_id", ce.PeerID, "remote_address", ce.RemoteAddress) - sync.peerSet.UpdateStatus(ce.PeerID, status.StatusBanned) - - return - } - sync.peerSet.UpdateStatus(ce.PeerID, status.StatusConnected) } @@ -358,18 +354,16 @@ func (sync *synchronizer) processDisconnectEvent(de *network.DisconnectEvent) { sync.peerSet.UpdateStatus(de.PeerID, status.StatusDisconnected) } -func (sync *synchronizer) processIncomingBundle(bdl *bundle.Bundle, from peer.ID) error { - if bdl == nil { - return nil - } - +func (sync *synchronizer) processIncomingBundle(bdl *bundle.Bundle, from peer.ID) { sync.logger.Debug("received a bundle", "from", from, "bundle", bdl) h := sync.handlers[bdl.Message.Type()] if h == nil { - return fmt.Errorf("invalid message type: %v", bdl.Message.Type()) + sync.logger.Error("invalid message type", "type", bdl.Message.Type()) + + return } - return h.ParseMessage(bdl.Message, from) + h.ParseMessage(bdl.Message, from) } func (sync *synchronizer) String() string { @@ -384,7 +378,7 @@ func (sync *synchronizer) String() string { // Otherwise, the node can request the latest blocks from any nodes. func (sync *synchronizer) updateBlockchain() { // Maybe we have some blocks inside the cache? - _ = sync.tryCommitBlocks() + sync.tryCommitBlocks() // Check if we have any expired sessions sync.peerSet.SetExpiredSessionsAsUncompleted() @@ -520,7 +514,7 @@ func (sync *synchronizer) sendBlockRequestToRandomPeer(from, count uint32, onlyN return false } -func (sync *synchronizer) tryCommitBlocks() error { +func (sync *synchronizer) tryCommitBlocks() { onError := func(height uint32, err error) { sync.logger.Warn("committing block failed, removing block from the cache", "height", height, "error", err) @@ -546,7 +540,7 @@ func (sync *synchronizer) tryCommitBlocks() error { if err != nil { onError(height, err) - return err + return } trx.SetPublicKey(pub) } @@ -555,24 +549,22 @@ func (sync *synchronizer) tryCommitBlocks() error { if err := blk.BasicCheck(); err != nil { onError(height, err) - return err + return } if err := cert.BasicCheck(); err != nil { onError(height, err) - return err + return } sync.logger.Trace("committing block", "height", height, "block", blk) if err := sync.state.CommitBlock(blk, cert); err != nil { onError(height, err) - return err + return } height++ } - - return nil } func (sync *synchronizer) prepareBlocks(from, count uint32) [][]byte { diff --git a/sync/sync_test.go b/sync/sync_test.go index 72c133ccd..92c1e84fd 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -9,6 +9,7 @@ import ( "github.com/pactus-project/pactus/consensus" "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/crypto/bls" + "github.com/pactus-project/pactus/genesis" "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/state" "github.com/pactus-project/pactus/sync/bundle" @@ -63,7 +64,7 @@ func setup(t *testing.T, config *Config) *testData { valKeys := []*bls.ValidatorKey{ts.RandValKey(), ts.RandValKey()} mockState := state.MockingState(ts) - consMgr, consMocks := consensus.MockingManager(ts, []*bls.ValidatorKey{valKeys[0], valKeys[1]}) + consMgr, consMocks := consensus.MockingManager(ts, mockState, []*bls.ValidatorKey{valKeys[0], valKeys[1]}) consMgr.MoveToNewHeight() broadcastCh := make(chan message.Message, 1000) @@ -119,7 +120,7 @@ func shouldPublishMessageWithThisType(t *testing.T, net *network.MockNetwork, ms // ----------- // Check flags require.True(t, util.IsFlagSet(bdl.Flags, bundle.BundleFlagCarrierLibP2P), "invalid flag: %v", bdl) - require.True(t, util.IsFlagSet(bdl.Flags, bundle.BundleFlagNetworkTestnet), "invalid flag: %v", bdl) + require.True(t, util.IsFlagSet(bdl.Flags, bundle.BundleFlagNetworkMainnet), "invalid flag: %v", bdl) if b.Target == nil { require.True(t, util.IsFlagSet(bdl.Flags, bundle.BundleFlagBroadcasted), "invalid flag: %v", bdl) @@ -179,11 +180,11 @@ func (td *testData) shouldNotPublishMessageWithThisType(t *testing.T, msgType me shouldNotPublishMessageWithThisType(t, td.network, msgType) } -func (*testData) receivingNewMessage(sync *synchronizer, msg message.Message, from peer.ID) error { +func (*testData) receivingNewMessage(sync *synchronizer, msg message.Message, from peer.ID) { bdl := bundle.NewBundle(msg) bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagCarrierLibP2P|bundle.BundleFlagNetworkMainnet) - return sync.processIncomingBundle(bdl, from) + sync.processIncomingBundle(bdl, from) } func (td *testData) addPeer(t *testing.T, s status.Status, services service.Services) peer.ID { @@ -259,32 +260,6 @@ func TestConnectEvent(t *testing.T) { p1 := td.sync.peerSet.GetPeer(pid) assert.Equal(t, status.StatusConnected, p1.Status) - - // Receiving connect event for the banned address - pid = td.RandPeerID() - ce = &network.ConnectEvent{ - PeerID: pid, - RemoteAddress: "/ip4/115.193.2.1/tcp/21888", - } - td.network.EventCh <- ce - - assert.Eventually(t, func() bool { - p := td.sync.peerSet.GetPeer(pid) - if p == nil { - return false - } - - isBlocked := td.sync.firewall.IsBannedAddress(p.Address) - - if isBlocked { - p.Status = status.StatusBanned - } - - return isBlocked - }, time.Second, 100*time.Millisecond) - - p2 := td.sync.peerSet.GetPeer(pid) - assert.Equal(t, status.StatusBanned, p2.Status) } func TestDisconnectEvent(t *testing.T) { @@ -315,8 +290,10 @@ func TestProtocolsEvent(t *testing.T) { func TestTestNetFlags(t *testing.T) { td := setup(t, nil) + td.state.TestGenesis = genesis.TestnetGenesis() td.addValidatorToCommittee(t, td.sync.valKeys[0].PublicKey()) bdl := td.sync.prepareBundle(message.NewQueryProposalMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) + require.False(t, util.IsFlagSet(bdl.Flags, bundle.BundleFlagNetworkMainnet), "invalid flag: %v", bdl) require.True(t, util.IsFlagSet(bdl.Flags, bundle.BundleFlagNetworkTestnet), "invalid flag: %v", bdl) } @@ -332,7 +309,7 @@ func TestDownload(t *testing.T) { pid := td.addPeer(t, status.StatusConnected, service.New(service.None)) blk, cert := td.GenerateTestBlock(td.RandHeight()) baMsg := message.NewBlockAnnounceMessage(blk, cert) - assert.NoError(t, td.receivingNewMessage(td.sync, baMsg, pid)) + td.receivingNewMessage(td.sync, baMsg, pid) td.shouldNotPublishMessageWithThisType(t, message.TypeBlocksRequest) td.network.IsClosed(pid) @@ -344,7 +321,7 @@ func TestDownload(t *testing.T) { pid := td.addPeer(t, status.StatusKnown, service.New(service.None)) blk, cert := td.GenerateTestBlock(td.RandHeight()) baMsg := message.NewBlockAnnounceMessage(blk, cert) - assert.NoError(t, td.receivingNewMessage(td.sync, baMsg, pid)) + td.receivingNewMessage(td.sync, baMsg, pid) td.shouldNotPublishMessageWithThisType(t, message.TypeBlocksRequest) td.network.IsClosed(pid) @@ -356,7 +333,7 @@ func TestDownload(t *testing.T) { pid := td.addPeer(t, status.StatusKnown, service.New(service.Network)) blk, cert := td.GenerateTestBlock(td.RandHeight()) baMsg := message.NewBlockAnnounceMessage(blk, cert) - assert.NoError(t, td.receivingNewMessage(td.sync, baMsg, pid)) + td.receivingNewMessage(td.sync, baMsg, pid) td.shouldPublishMessageWithThisType(t, message.TypeBlocksRequest) }) @@ -370,7 +347,7 @@ func TestDownload(t *testing.T) { sid := td.sync.peerSet.OpenSession(pid, from, count) msg := message.NewBlocksResponseMessage(message.ResponseCodeRejected, t.Name(), sid, 1, nil, nil) - assert.NoError(t, td.receivingNewMessage(td.sync, msg, pid)) + td.receivingNewMessage(td.sync, msg, pid) assert.False(t, td.sync.peerSet.HasOpenSession(pid)) }) diff --git a/tests/block_test.go b/tests/block_test.go index 0900b724d..364f33fe7 100644 --- a/tests/block_test.go +++ b/tests/block_test.go @@ -19,8 +19,8 @@ func lastHeight() uint32 { } func waitForNewBlocks(num uint32) { - height := lastHeight() + num for i := uint32(0); i < num; i++ { + height := lastHeight() if lastHeight() > height { break } diff --git a/tests/main_test.go b/tests/main_test.go index 2f121420a..749327dd6 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -72,9 +72,9 @@ func TestMain(m *testing.M) { tConfigs[i].TxPool.MinFeePAC = 0.000001 tConfigs[i].Store.Path = util.TempDirPath() - tConfigs[i].Consensus.ChangeProposerTimeout = 4 * time.Second - tConfigs[i].Consensus.ChangeProposerDelta = 4 * time.Second - tConfigs[i].Consensus.QueryVoteTimeout = 4 * time.Second + tConfigs[i].Consensus.ChangeProposerTimeout = 2 * time.Second + tConfigs[i].Consensus.ChangeProposerDelta = 2 * time.Second + tConfigs[i].Consensus.QueryVoteTimeout = 2 * time.Second tConfigs[i].Logger.Levels["default"] = "info" tConfigs[i].Logger.Levels["_state"] = "info" tConfigs[i].Logger.Levels["_sync"] = "info" diff --git a/types/block/block.go b/types/block/block.go index bbd442065..ccec41de0 100644 --- a/types/block/block.go +++ b/types/block/block.go @@ -134,6 +134,14 @@ func (b *Block) Hash() hash.Hash { return h } +func (b *Block) Height() uint32 { + if b.data.PrevCert == nil { + return 1 + } + + return b.PrevCertificate().Height() + 1 +} + func (b *Block) String() string { return fmt.Sprintf("{⌘ %v πŸ‘€ %v πŸ’» %v πŸ“¨ %d}", b.Hash().ShortString(), diff --git a/types/block/block_test.go b/types/block/block_test.go index 35f590e4e..5d0006303 100644 --- a/types/block/block_test.go +++ b/types/block/block_test.go @@ -341,3 +341,16 @@ func TestMakeBlock(t *testing.T) { assert.Equal(t, blk0.Hash(), blk1.Hash()) } + +func TestBlockHeight(t *testing.T) { + ts := testsuite.NewTestSuite(t) + + blk1, _ := ts.GenerateTestBlock(1, testsuite.BlockWithPrevCert(nil), testsuite.BlockWithPrevHash(hash.UndefHash)) + blk2, _ := ts.GenerateTestBlock(2) + + assert.NoError(t, blk1.BasicCheck()) + assert.NoError(t, blk2.BasicCheck()) + + assert.Equal(t, uint32(1), blk1.Height()) + assert.Equal(t, uint32(2), blk2.Height()) +} diff --git a/wallet/client.go b/wallet/client.go index 6864a35a8..9ec400414 100644 --- a/wallet/client.go +++ b/wallet/client.go @@ -4,6 +4,8 @@ import ( "context" "encoding/hex" "errors" + "net" + "time" "github.com/pactus-project/pactus/crypto/hash" "github.com/pactus-project/pactus/types/amount" @@ -20,23 +22,27 @@ type grpcClient struct { ctx context.Context servers []string conn *grpc.ClientConn + timeout time.Duration blockchainClient pactus.BlockchainClient transactionClient pactus.TransactionClient } -func newGrpcClient() *grpcClient { - ctx := context.WithoutCancel(context.Background()) +func newGrpcClient(timeout time.Duration, servers []string) *grpcClient { + ctx := context.Background() - return &grpcClient{ + cli := &grpcClient{ ctx: ctx, + timeout: timeout, conn: nil, blockchainClient: nil, transactionClient: nil, } -} -func (c *grpcClient) SetServerAddrs(servers []string) { - c.servers = servers + if len(servers) > 0 { + cli.servers = servers + } + + return cli } func (c *grpcClient) connect() error { @@ -46,7 +52,10 @@ func (c *grpcClient) connect() error { for _, server := range c.servers { conn, err := grpc.NewClient(server, - grpc.WithTransportCredentials(insecure.NewCredentials())) + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(_ context.Context, s string) (net.Conn, error) { + return net.DialTimeout("tcp", s, c.timeout) + })) if err != nil { continue } @@ -58,6 +67,8 @@ func (c *grpcClient) connect() error { _, err = blockchainClient.GetBlockchainInfo(c.ctx, &pactus.GetBlockchainInfoRequest{}) if err != nil { + _ = conn.Close() + continue } diff --git a/wallet/manager.go b/wallet/manager.go index 9e7914295..eef29ea88 100644 --- a/wallet/manager.go +++ b/wallet/manager.go @@ -84,11 +84,10 @@ func (wm *Manager) LoadWallet(walletName, serverAddr string) error { } walletPath := util.MakeAbs(filepath.Join(wm.walletDirectory, walletName)) - wlt, err := Open(walletPath, true) + wlt, err := Open(walletPath, true, WithCustomServers([]string{serverAddr})) if err != nil { return err } - wlt.SetServerAddr(serverAddr) wm.wallets[walletName] = wlt diff --git a/wallet/options.go b/wallet/options.go new file mode 100644 index 000000000..87e3614f9 --- /dev/null +++ b/wallet/options.go @@ -0,0 +1,27 @@ +package wallet + +import "time" + +type walletOpt struct { + timeout time.Duration + servers []string +} + +type Option func(*walletOpt) + +var defaultWalletOpt = &walletOpt{ + timeout: 1 * time.Second, + servers: make([]string, 0), +} + +func WithTimeout(timeout time.Duration) Option { + return func(opt *walletOpt) { + opt.timeout = timeout + } +} + +func WithCustomServers(servers []string) Option { + return func(opt *walletOpt) { + opt.servers = servers + } +} diff --git a/wallet/servers.json b/wallet/servers.json index f416b0550..43e1a9e0d 100644 --- a/wallet/servers.json +++ b/wallet/servers.json @@ -5,14 +5,15 @@ "bootstrap2.pactus.org:50051", "bootstrap3.pactus.org:50051", "bootstrap4.pactus.org:50051", - "65.109.234.125:50051", - "157.90.111.140:50051" + "rpc.javad.dev:50051", + "65.109.234.125:50051" ], "testnet": [ "localhost:50052", "testnet1.pactus.org:50052", "testnet2.pactus.org:50052", "testnet3.pactus.org:50052", - "testnet4.pactus.org:50052" + "testnet4.pactus.org:50052", + "testnet-rpc.javad.dev:50052" ] } diff --git a/wallet/wallet.go b/wallet/wallet.go index 26f860acc..ed54e7c71 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -47,7 +47,7 @@ func CheckMnemonic(mnemonic string) error { // A wallet can be opened in offline or online modes. // Offline wallet doesn’t have any connection to any node. // Online wallet has a connection to one of the pre-defined servers. -func Open(walletPath string, offline bool) (*Wallet, error) { +func Open(walletPath string, offline bool, options ...Option) (*Wallet, error) { data, err := util.ReadFile(walletPath) if err != nil { return nil, err @@ -59,12 +59,24 @@ func Open(walletPath string, offline bool) (*Wallet, error) { return nil, err } - return newWallet(walletPath, store, offline) + opts := defaultWalletOpt + + for _, opt := range options { + opt(opts) + } + + return newWallet(walletPath, store, offline, opts) } // Create creates a wallet from mnemonic (seed phrase) and save it at the // given path. -func Create(walletPath, mnemonic, password string, chain genesis.ChainType) (*Wallet, error) { +func Create(walletPath, mnemonic, password string, chain genesis.ChainType, options ...Option) (*Wallet, error) { + opts := defaultWalletOpt + + for _, opt := range options { + opt(opts) + } + walletPath = util.MakeAbs(walletPath) if util.PathExists(walletPath) { return nil, ExitsError{ @@ -89,7 +101,7 @@ func Create(walletPath, mnemonic, password string, chain genesis.ChainType) (*Wa Network: chain, Vault: nil, } - wallet, err := newWallet(walletPath, store, true) + wallet, err := newWallet(walletPath, store, true, opts) if err != nil { return nil, err } @@ -106,7 +118,7 @@ func Create(walletPath, mnemonic, password string, chain genesis.ChainType) (*Wa return wallet, nil } -func newWallet(walletPath string, store *store, offline bool) (*Wallet, error) { +func newWallet(walletPath string, store *store, offline bool, option *walletOpt) (*Wallet, error) { if !store.Network.IsMainnet() { crypto.AddressHRP = "tpc" crypto.PublicKeyHRP = "tpublic" @@ -115,7 +127,7 @@ func newWallet(walletPath string, store *store, offline bool) (*Wallet, error) { crypto.XPrivateKeyHRP = "txsecret" } - client := newGrpcClient() + client := newGrpcClient(option.timeout, option.servers) w := &Wallet{ store: store, @@ -146,16 +158,15 @@ func newWallet(walletPath string, store *store, offline bool) (*Wallet, error) { } util.Shuffle(netServers) - w.grpcClient.SetServerAddrs(netServers) + + if client.servers == nil { + client.servers = netServers + } } return w, nil } -func (w *Wallet) SetServerAddr(addr string) { - w.grpcClient.SetServerAddrs([]string{addr}) -} - func (w *Wallet) Name() string { return path.Base(w.path) } diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go index 75ae5e0f7..fae4ccbc8 100644 --- a/wallet/wallet_test.go +++ b/wallet/wallet_test.go @@ -36,11 +36,6 @@ func setup(t *testing.T) *testData { password := "" walletPath := util.TempFilePath() mnemonic, _ := wallet.GenerateMnemonic(128) - wlt, err := wallet.Create(walletPath, mnemonic, password, genesis.Mainnet) - assert.NoError(t, err) - assert.False(t, wlt.IsEncrypted()) - assert.Equal(t, wlt.Path(), walletPath) - assert.Equal(t, wlt.Name(), path.Base(walletPath)) grpcConf := &grpc.Config{ Enable: true, @@ -59,7 +54,12 @@ func setup(t *testing.T) *testData { assert.NoError(t, gRPCServer.StartServer()) - wlt.SetServerAddr(gRPCServer.Address()) + wlt, err := wallet.Create(walletPath, mnemonic, password, genesis.Mainnet, + wallet.WithCustomServers([]string{gRPCServer.Address()})) + assert.NoError(t, err) + assert.False(t, wlt.IsEncrypted()) + assert.Equal(t, wlt.Path(), walletPath) + assert.Equal(t, wlt.Name(), path.Base(walletPath)) return &testData{ TestSuite: ts, diff --git a/www/grpc/server_test.go b/www/grpc/server_test.go index dd9cc64c8..3f8d9091d 100644 --- a/www/grpc/server_test.go +++ b/www/grpc/server_test.go @@ -60,14 +60,12 @@ func setup(t *testing.T, conf *Config) *testData { const bufSize = 1024 * 1024 - mockConsMgr, consMocks := consensus.MockingManager(ts, []*bls.ValidatorKey{ - ts.RandValKey(), ts.RandValKey(), - }) - listener := bufconn.Listen(bufSize) + valKeys := []*bls.ValidatorKey{ts.RandValKey(), ts.RandValKey()} mockState := state.MockingState(ts) mockNet := network.MockingNetwork(ts, ts.RandPeerID()) mockSync := sync.MockingSync(ts) + mockConsMgr, consMocks := consensus.MockingManager(ts, mockState, valKeys) mockState.CommitTestBlocks(10) diff --git a/www/http/http_test.go b/www/http/http_test.go index f5e2c90ec..6e6fdb4ea 100644 --- a/www/http/http_test.go +++ b/www/http/http_test.go @@ -43,12 +43,11 @@ func setup(t *testing.T) *testData { // http.DefaultServeMux = new(http.ServeMux) + valKeys := []*bls.ValidatorKey{ts.RandValKey(), ts.RandValKey()} mockState := state.MockingState(ts) mockSync := sync.MockingSync(ts) mockNet := network.MockingNetwork(ts, ts.RandPeerID()) - mockConsMgr, _ := consensus.MockingManager(ts, []*bls.ValidatorKey{ - ts.RandValKey(), ts.RandValKey(), - }) + mockConsMgr, _ := consensus.MockingManager(ts, mockState, valKeys) mockConsMgr.MoveToNewHeight()