diff --git a/common/common.go b/common/common.go index c3364fa608..3d5e874c23 100644 --- a/common/common.go +++ b/common/common.go @@ -1,8 +1,13 @@ package common import ( + "fmt" + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-vm-v1_2-go/ipc/marshaling" ) // IsEpochChangeBlockForFlagActivation returns true if the provided header is the first one after the specified flag's activation @@ -24,3 +29,50 @@ func IsFlagEnabledAfterEpochsStartBlock(header data.HeaderHandler, enableEpochsH func ShouldBlockHavePrevProof(header data.HeaderHandler, enableEpochsHandler EnableEpochsHandler, flag core.EnableEpochFlag) bool { return IsFlagEnabledAfterEpochsStartBlock(header, enableEpochsHandler, flag) && header.GetNonce() > 1 } + +// VerifyProofAgainstHeader verifies the fields on the proof match the ones on the header +func VerifyProofAgainstHeader(proof data.HeaderProofHandler, header data.HeaderHandler) error { + if check.IfNilReflect(proof) { + return ErrInvalidHeaderProof + } + + if proof.GetHeaderNonce() != header.GetNonce() { + return fmt.Errorf("%w, nonce mismatch", ErrInvalidHeaderProof) + } + if proof.GetHeaderShardId() != header.GetShardID() { + return fmt.Errorf("%w, shard id mismatch", ErrInvalidHeaderProof) + } + if proof.GetHeaderEpoch() != header.GetEpoch() { + return fmt.Errorf("%w, epoch mismatch", ErrInvalidHeaderProof) + } + if proof.GetHeaderRound() != header.GetRound() { + return fmt.Errorf("%w, round mismatch", ErrInvalidHeaderProof) + } + + return nil +} + +// GetHeader tries to get the header from pool first and if not found, searches for it through storer +func GetHeader( + headerHash []byte, + headersPool HeadersPool, + headersStorer storage.Storer, + marshaller marshaling.Marshalizer, +) (data.HeaderHandler, error) { + header, err := headersPool.GetHeaderByHash(headerHash) + if err == nil { + return header, nil + } + + headerBytes, err := headersStorer.SearchFirst(headerHash) + if err != nil { + return nil, err + } + + err = marshaller.Unmarshal(header, headerBytes) + if err != nil { + return nil, err + } + + return header, nil +} diff --git a/common/errors.go b/common/errors.go index 47b976de9a..eeeaf94c80 100644 --- a/common/errors.go +++ b/common/errors.go @@ -10,3 +10,6 @@ var ErrNilWasmChangeLocker = errors.New("nil wasm change locker") // ErrNilStateSyncNotifierSubscriber signals that a nil state sync notifier subscriber has been provided var ErrNilStateSyncNotifierSubscriber = errors.New("nil state sync notifier subscriber") + +// ErrInvalidHeaderProof signals that an invalid equivalent proof has been provided +var ErrInvalidHeaderProof = errors.New("invalid equivalent proof") diff --git a/common/interface.go b/common/interface.go index 72a2cba262..e3a42e0ee4 100644 --- a/common/interface.go +++ b/common/interface.go @@ -379,3 +379,8 @@ type ChainParametersSubscriptionHandler interface { ChainParametersChanged(chainParameters config.ChainParametersByEpochConfig) IsInterfaceNil() bool } + +// HeadersPool defines what a headers pool structure can perform +type HeadersPool interface { + GetHeaderByHash(hash []byte) (data.HeaderHandler, error) +} diff --git a/consensus/spos/bls/v2/subroundBlock.go b/consensus/spos/bls/v2/subroundBlock.go index 2454ad3643..3ff1459d66 100644 --- a/consensus/spos/bls/v2/subroundBlock.go +++ b/consensus/spos/bls/v2/subroundBlock.go @@ -389,7 +389,7 @@ func isProofEmpty(proof data.HeaderProofHandler) bool { len(proof.GetHeaderHash()) == 0 } -func (sr *subroundBlock) saveProofForPreviousHeaderIfNeeded(header data.HeaderHandler) { +func (sr *subroundBlock) saveProofForPreviousHeaderIfNeeded(header data.HeaderHandler, prevHeader data.HeaderHandler) { hasProof := sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), header.GetPrevHash()) if hasProof { log.Debug("saveProofForPreviousHeaderIfNeeded: no need to set proof since it is already saved") @@ -397,11 +397,19 @@ func (sr *subroundBlock) saveProofForPreviousHeaderIfNeeded(header data.HeaderHa } proof := header.GetPreviousProof() - err := sr.EquivalentProofsPool().AddProof(proof) + err := common.VerifyProofAgainstHeader(proof, prevHeader) + if err != nil { + log.Debug("saveProofForPreviousHeaderIfNeeded: invalid proof, %w", err) + return + } + + err = sr.EquivalentProofsPool().AddProof(proof) if err != nil { log.Debug("saveProofForPreviousHeaderIfNeeded: failed to add proof, %w", err) return } + + return } // receivedBlockBody method is called when a block body is received through the block body channel @@ -445,30 +453,30 @@ func (sr *subroundBlock) receivedBlockBody(ctx context.Context, cnsDta *consensu return blockProcessedWithSuccess } -func (sr *subroundBlock) isHeaderForCurrentConsensus(header data.HeaderHandler) bool { +func (sr *subroundBlock) isHeaderForCurrentConsensus(header data.HeaderHandler) (bool, data.HeaderHandler) { if check.IfNil(header) { - return false + return false, nil } if header.GetShardID() != sr.ShardCoordinator().SelfId() { - return false + return false, nil } if header.GetRound() != uint64(sr.RoundHandler().Index()) { - return false + return false, nil } prevHeader, prevHash := sr.getPrevHeaderAndHash() if check.IfNil(prevHeader) { - return false + return false, nil } if !bytes.Equal(header.GetPrevHash(), prevHash) { - return false + return false, nil } if header.GetNonce() != prevHeader.GetNonce()+1 { - return false + return false, nil } prevRandSeed := prevHeader.GetRandSeed() - return bytes.Equal(header.GetPrevRandSeed(), prevRandSeed) + return bytes.Equal(header.GetPrevRandSeed(), prevRandSeed), prevHeader } func (sr *subroundBlock) getLeaderForHeader(headerHandler data.HeaderHandler) ([]byte, error) { @@ -495,7 +503,8 @@ func (sr *subroundBlock) receivedBlockHeader(headerHandler data.HeaderHandler) { return } - if !sr.isHeaderForCurrentConsensus(headerHandler) { + isHeaderForCurrentConsensus, prevHeader := sr.isHeaderForCurrentConsensus(headerHandler) + if !isHeaderForCurrentConsensus { return } @@ -539,7 +548,7 @@ func (sr *subroundBlock) receivedBlockHeader(headerHandler data.HeaderHandler) { sr.SetData(sr.Hasher().Compute(string(marshalledHeader))) sr.SetHeader(headerHandler) - sr.saveProofForPreviousHeaderIfNeeded(headerHandler) + sr.saveProofForPreviousHeaderIfNeeded(headerHandler, prevHeader) log.Debug("step 1: block header has been received", "nonce", sr.GetHeader().GetNonce(), diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index 4f2a3661ec..9acb2f3475 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -222,7 +222,16 @@ func (bp *baseProcessor) checkBlockValidity( return process.ErrEpochDoesNotMatch } - return nil + return bp.checkPrevProofValidity(currentBlockHeader, headerHandler) +} + +func (bp *baseProcessor) checkPrevProofValidity(prevHeader, headerHandler data.HeaderHandler) error { + if !common.ShouldBlockHavePrevProof(headerHandler, bp.enableEpochsHandler, common.EquivalentMessagesFlag) { + return nil + } + + prevProof := headerHandler.GetPreviousProof() + return common.VerifyProofAgainstHeader(prevProof, prevHeader) } // checkScheduledRootHash checks if the scheduled root hash from the given header is the same with the current user accounts state root hash diff --git a/process/block/interceptedBlocks/interceptedEquivalentProof.go b/process/block/interceptedBlocks/interceptedEquivalentProof.go index 7712aa483b..5d3eeda8ba 100644 --- a/process/block/interceptedBlocks/interceptedEquivalentProof.go +++ b/process/block/interceptedBlocks/interceptedEquivalentProof.go @@ -1,6 +1,7 @@ package interceptedBlocks import ( + "encoding/hex" "fmt" "github.com/multiversx/mx-chain-core-go/core" @@ -8,12 +9,15 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/dataRetriever" proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" + "github.com/multiversx/mx-chain-go/storage" logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-vm-v1_2-go/ipc/marshaling" ) const interceptedEquivalentProofType = "intercepted equivalent proof" @@ -25,6 +29,8 @@ type ArgInterceptedEquivalentProof struct { ShardCoordinator sharding.Coordinator HeaderSigVerifier consensus.HeaderSigVerifier Proofs dataRetriever.ProofsPool + Headers dataRetriever.HeadersPool + Storage dataRetriever.StorageService } type interceptedEquivalentProof struct { @@ -32,6 +38,9 @@ type interceptedEquivalentProof struct { isForCurrentShard bool headerSigVerifier consensus.HeaderSigVerifier proofsPool dataRetriever.ProofsPool + headersPool dataRetriever.HeadersPool + storage dataRetriever.StorageService + marshaller marshaling.Marshalizer } // NewInterceptedEquivalentProof returns a new instance of interceptedEquivalentProof @@ -51,6 +60,9 @@ func NewInterceptedEquivalentProof(args ArgInterceptedEquivalentProof) (*interce isForCurrentShard: extractIsForCurrentShard(args.ShardCoordinator, equivalentProof), headerSigVerifier: args.HeaderSigVerifier, proofsPool: args.Proofs, + headersPool: args.Headers, + marshaller: args.Marshaller, + storage: args.Storage, }, nil } @@ -70,6 +82,12 @@ func checkArgInterceptedEquivalentProof(args ArgInterceptedEquivalentProof) erro if check.IfNil(args.Proofs) { return process.ErrNilProofsPool } + if check.IfNil(args.Headers) { + return process.ErrNilHeadersDataPool + } + if check.IfNil(args.Storage) { + return process.ErrNilStore + } return nil } @@ -115,9 +133,28 @@ func (iep *interceptedEquivalentProof) CheckValidity() error { return proofscache.ErrAlreadyExistingEquivalentProof } + err = iep.checkHeaderParamsFromProof() + if err != nil { + return err + } + return iep.headerSigVerifier.VerifyHeaderProof(iep.proof) } +func (iep *interceptedEquivalentProof) checkHeaderParamsFromProof() error { + headersStorer, err := iep.getHeadersStorer(iep.proof.GetHeaderShardId()) + if err != nil { + return err + } + + header, err := common.GetHeader(iep.proof.GetHeaderHash(), iep.headersPool, headersStorer, iep.marshaller) + if err != nil { + return fmt.Errorf("%w while getting header for proof hash %s", err, hex.EncodeToString(iep.proof.GetHeaderHash())) + } + + return common.VerifyProofAgainstHeader(iep.proof, header) +} + func (iep *interceptedEquivalentProof) integrity() error { isProofValid := len(iep.proof.AggregatedSignature) > 0 && len(iep.proof.PubKeysBitmap) > 0 && @@ -129,6 +166,14 @@ func (iep *interceptedEquivalentProof) integrity() error { return nil } +func (iep *interceptedEquivalentProof) getHeadersStorer(shardID uint32) (storage.Storer, error) { + if shardID == core.MetachainShardId { + return iep.storage.GetStorer(dataRetriever.MetaBlockUnit) + } + + return iep.storage.GetStorer(dataRetriever.BlockHeaderUnit) +} + // GetProof returns the underlying intercepted header proof func (iep *interceptedEquivalentProof) GetProof() data.HeaderProofHandler { return iep.proof diff --git a/process/block/interceptedBlocks/interceptedEquivalentProof_test.go b/process/block/interceptedBlocks/interceptedEquivalentProof_test.go index b0a8cd6c9c..3d1a96afd1 100644 --- a/process/block/interceptedBlocks/interceptedEquivalentProof_test.go +++ b/process/block/interceptedBlocks/interceptedEquivalentProof_test.go @@ -4,16 +4,23 @@ import ( "bytes" "errors" "fmt" + "strings" "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/mock" proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/testscommon/pool" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/require" ) @@ -21,6 +28,10 @@ import ( var ( expectedErr = errors.New("expected error") testMarshaller = &marshallerMock.MarshalizerMock{} + providedEpoch = uint32(123) + providedNonce = uint64(345) + providedShard = uint32(0) + providedRound = uint64(123456) ) func createMockDataBuff() []byte { @@ -28,9 +39,10 @@ func createMockDataBuff() []byte { PubKeysBitmap: []byte("bitmap"), AggregatedSignature: []byte("sig"), HeaderHash: []byte("hash"), - HeaderEpoch: 123, - HeaderNonce: 345, - HeaderShardId: 0, + HeaderEpoch: providedEpoch, + HeaderNonce: providedNonce, + HeaderShardId: providedShard, + HeaderRound: providedRound, } dataBuff, _ := testMarshaller.Marshal(proof) @@ -44,6 +56,24 @@ func createMockArgInterceptedEquivalentProof() ArgInterceptedEquivalentProof { ShardCoordinator: &mock.ShardCoordinatorMock{}, HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, Proofs: &dataRetriever.ProofsPoolMock{}, + Headers: &pool.HeadersPoolStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &testscommon.HeaderHandlerStub{ + EpochField: providedEpoch, + RoundField: providedRound, + GetNonceCalled: func() uint64 { + return providedNonce + }, + GetShardIDCalled: func() uint32 { + return providedShard + }, + }, nil + }, + }, + Storage: &genericMocks.ChainStorerMock{ + BlockHeaders: genericMocks.NewStorerMock(), + Metablocks: genericMocks.NewStorerMock(), + }, } } @@ -105,6 +135,15 @@ func TestNewInterceptedEquivalentProof(t *testing.T) { require.Equal(t, process.ErrNilProofsPool, err) require.Nil(t, iep) }) + t.Run("nil headers pool should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Headers = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilHeadersDataPool, err) + require.Nil(t, iep) + }) t.Run("unmarshal error should error", func(t *testing.T) { t.Parallel() @@ -118,6 +157,15 @@ func TestNewInterceptedEquivalentProof(t *testing.T) { require.Equal(t, expectedErr, err) require.Nil(t, iep) }) + t.Run("nil storage should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Storage = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilStore, err) + require.Nil(t, iep) + }) t.Run("should work", func(t *testing.T) { t.Parallel() @@ -164,6 +212,127 @@ func TestInterceptedEquivalentProof_CheckValidity(t *testing.T) { require.Equal(t, proofscache.ErrAlreadyExistingEquivalentProof, err) }) + t.Run("missing header for proof hash should error", func(t *testing.T) { + t.Parallel() + + providedErr := errors.New("missing header") + args := createMockArgInterceptedEquivalentProof() + args.Headers = &pool.HeadersPoolStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return nil, providedErr + }, + } + args.Storage = &genericMocks.ChainStorerMock{ + BlockHeaders: genericMocks.NewStorerMockWithErrKeyNotFound(0), + } + + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + err = iep.CheckValidity() + require.True(t, errors.Is(err, storage.ErrKeyNotFound)) + }) + + t.Run("nonce mismatch should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Headers = &pool.HeadersPoolStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &testscommon.HeaderHandlerStub{ + GetNonceCalled: func() uint64 { + return providedNonce + 1 + }, + }, nil + }, + } + + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + err = iep.CheckValidity() + require.True(t, errors.Is(err, common.ErrInvalidHeaderProof)) + require.True(t, strings.Contains(err.Error(), "nonce mismatch")) + }) + + t.Run("shard id mismatch should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Headers = &pool.HeadersPoolStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &testscommon.HeaderHandlerStub{ + GetNonceCalled: func() uint64 { + return providedNonce + }, + GetShardIDCalled: func() uint32 { + return providedShard + 1 + }, + }, nil + }, + } + + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + err = iep.CheckValidity() + require.True(t, errors.Is(err, common.ErrInvalidHeaderProof)) + require.True(t, strings.Contains(err.Error(), "shard id mismatch")) + }) + + t.Run("epoch mismatch should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Headers = &pool.HeadersPoolStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &testscommon.HeaderHandlerStub{ + GetNonceCalled: func() uint64 { + return providedNonce + }, + GetShardIDCalled: func() uint32 { + return providedShard + }, + EpochField: providedEpoch + 1, + }, nil + }, + } + + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + err = iep.CheckValidity() + require.True(t, errors.Is(err, common.ErrInvalidHeaderProof)) + require.True(t, strings.Contains(err.Error(), "epoch mismatch")) + }) + + t.Run("round mismatch should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Headers = &pool.HeadersPoolStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &testscommon.HeaderHandlerStub{ + GetNonceCalled: func() uint64 { + return providedNonce + }, + GetShardIDCalled: func() uint32 { + return providedShard + }, + EpochField: providedEpoch, + RoundField: providedRound + 1, + }, nil + }, + } + + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + err = iep.CheckValidity() + require.True(t, errors.Is(err, common.ErrInvalidHeaderProof)) + require.True(t, strings.Contains(err.Error(), "round mismatch")) + }) + t.Run("should work", func(t *testing.T) { t.Parallel() diff --git a/process/block/metablock.go b/process/block/metablock.go index 04220d9936..75ffe1a2f6 100644 --- a/process/block/metablock.go +++ b/process/block/metablock.go @@ -430,6 +430,23 @@ func (mp *metaProcessor) checkProofsForShardData(header *block.MetaBlock) error if !mp.proofsPool.HasProof(shardData.ShardID, shardData.HeaderHash) { return fmt.Errorf("%w for header hash %s", process.ErrMissingHeaderProof, hex.EncodeToString(shardData.HeaderHash)) } + + shardHeadersStorer, err := mp.store.GetStorer(dataRetriever.BlockHeaderUnit) + if err != nil { + return err + } + + prevProof := shardData.GetPreviousProof() + headersPool := mp.dataPool.Headers() + prevHeader, err := common.GetHeader(prevProof.GetHeaderHash(), headersPool, shardHeadersStorer, mp.marshalizer) + if err != nil { + return err + } + + err = common.VerifyProofAgainstHeader(prevProof, prevHeader) + if err != nil { + return err + } } return nil diff --git a/process/errors.go b/process/errors.go index 52d5981ab0..4f786e86e3 100644 --- a/process/errors.go +++ b/process/errors.go @@ -1251,9 +1251,6 @@ var ErrEmptyChainParametersConfiguration = errors.New("empty chain parameters co // ErrNoMatchingConfigForProvidedEpoch signals that there is no matching configuration for the provided epoch var ErrNoMatchingConfigForProvidedEpoch = errors.New("no matching configuration") -// ErrInvalidHeader is raised when header is invalid -var ErrInvalidHeader = errors.New("header is invalid") - // ErrNilHeaderProof signals that a nil header proof has been provided var ErrNilHeaderProof = errors.New("nil header proof") diff --git a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go index bc167e0dab..271f2ac26a 100644 --- a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go +++ b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go @@ -913,7 +913,13 @@ func (bicf *baseInterceptorsContainerFactory) generateValidatorInfoInterceptor() } func (bicf *baseInterceptorsContainerFactory) createOneShardEquivalentProofsInterceptor(topic string) (process.Interceptor, error) { - equivalentProofsFactory := interceptorFactory.NewInterceptedEquivalentProofsFactory(*bicf.argInterceptorFactory, bicf.dataPool.Proofs()) + args := interceptorFactory.ArgInterceptedEquivalentProofsFactory{ + ArgInterceptedDataFactory: *bicf.argInterceptorFactory, + ProofsPool: bicf.dataPool.Proofs(), + HeadersPool: bicf.dataPool.Headers(), + Storage: bicf.store, + } + equivalentProofsFactory := interceptorFactory.NewInterceptedEquivalentProofsFactory(args) marshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() argProcessor := processor.ArgEquivalentProofsInterceptorProcessor{ diff --git a/process/interceptors/factory/interceptedEquivalentProofsFactory.go b/process/interceptors/factory/interceptedEquivalentProofsFactory.go index 4c5694d1e4..f1ba5a150f 100644 --- a/process/interceptors/factory/interceptedEquivalentProofsFactory.go +++ b/process/interceptors/factory/interceptedEquivalentProofsFactory.go @@ -9,20 +9,32 @@ import ( "github.com/multiversx/mx-chain-go/sharding" ) +// ArgInterceptedEquivalentProofsFactory is the DTO used to create a new instance of interceptedEquivalentProofsFactory +type ArgInterceptedEquivalentProofsFactory struct { + ArgInterceptedDataFactory + ProofsPool dataRetriever.ProofsPool + HeadersPool dataRetriever.HeadersPool + Storage dataRetriever.StorageService +} + type interceptedEquivalentProofsFactory struct { marshaller marshal.Marshalizer shardCoordinator sharding.Coordinator headerSigVerifier consensus.HeaderSigVerifier proofsPool dataRetriever.ProofsPool + headersPool dataRetriever.HeadersPool + storage dataRetriever.StorageService } // NewInterceptedEquivalentProofsFactory creates a new instance of interceptedEquivalentProofsFactory -func NewInterceptedEquivalentProofsFactory(args ArgInterceptedDataFactory, proofsPool dataRetriever.ProofsPool) *interceptedEquivalentProofsFactory { +func NewInterceptedEquivalentProofsFactory(args ArgInterceptedEquivalentProofsFactory) *interceptedEquivalentProofsFactory { return &interceptedEquivalentProofsFactory{ marshaller: args.CoreComponents.InternalMarshalizer(), shardCoordinator: args.ShardCoordinator, headerSigVerifier: args.HeaderSigVerifier, - proofsPool: proofsPool, + proofsPool: args.ProofsPool, + headersPool: args.HeadersPool, + storage: args.Storage, } } @@ -34,6 +46,8 @@ func (factory *interceptedEquivalentProofsFactory) Create(buff []byte) (process. ShardCoordinator: factory.shardCoordinator, HeaderSigVerifier: factory.headerSigVerifier, Proofs: factory.proofsPool, + Headers: factory.headersPool, + Storage: factory.storage, } return interceptedBlocks.NewInterceptedEquivalentProof(args) } diff --git a/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go b/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go index c96ade9528..d5c57d0a31 100644 --- a/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go +++ b/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go @@ -9,16 +9,21 @@ import ( processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/pool" "github.com/stretchr/testify/require" ) -func createMockArgInterceptedDataFactory() ArgInterceptedDataFactory { - return ArgInterceptedDataFactory{ - CoreComponents: &processMock.CoreComponentsMock{ - IntMarsh: &mock.MarshalizerMock{}, +func createMockArgInterceptedEquivalentProofsFactory() ArgInterceptedEquivalentProofsFactory { + return ArgInterceptedEquivalentProofsFactory{ + ArgInterceptedDataFactory: ArgInterceptedDataFactory{ + CoreComponents: &processMock.CoreComponentsMock{ + IntMarsh: &mock.MarshalizerMock{}, + }, + ShardCoordinator: &mock.ShardCoordinatorMock{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, }, - ShardCoordinator: &mock.ShardCoordinatorMock{}, - HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + ProofsPool: &dataRetriever.ProofsPoolMock{}, + HeadersPool: &pool.HeadersPoolStub{}, } } @@ -28,22 +33,22 @@ func TestInterceptedEquivalentProofsFactory_IsInterfaceNil(t *testing.T) { var factory *interceptedEquivalentProofsFactory require.True(t, factory.IsInterfaceNil()) - factory = NewInterceptedEquivalentProofsFactory(createMockArgInterceptedDataFactory(), &dataRetriever.ProofsPoolMock{}) + factory = NewInterceptedEquivalentProofsFactory(createMockArgInterceptedEquivalentProofsFactory()) require.False(t, factory.IsInterfaceNil()) } func TestNewInterceptedEquivalentProofsFactory(t *testing.T) { t.Parallel() - factory := NewInterceptedEquivalentProofsFactory(createMockArgInterceptedDataFactory(), &dataRetriever.ProofsPoolMock{}) + factory := NewInterceptedEquivalentProofsFactory(createMockArgInterceptedEquivalentProofsFactory()) require.NotNil(t, factory) } func TestInterceptedEquivalentProofsFactory_Create(t *testing.T) { t.Parallel() - args := createMockArgInterceptedDataFactory() - factory := NewInterceptedEquivalentProofsFactory(args, &dataRetriever.ProofsPoolMock{}) + args := createMockArgInterceptedEquivalentProofsFactory() + factory := NewInterceptedEquivalentProofsFactory(args) require.NotNil(t, factory) providedProof := &block.HeaderProof{ diff --git a/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go b/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go index b11eca03ae..320262fb58 100644 --- a/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go +++ b/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go @@ -11,7 +11,9 @@ import ( "github.com/multiversx/mx-chain-go/process/transaction" "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/testscommon/pool" "github.com/stretchr/testify/require" ) @@ -105,6 +107,8 @@ func TestEquivalentProofsInterceptorProcessor_Save(t *testing.T) { ShardCoordinator: &mock.ShardCoordinatorMock{}, HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, Proofs: &dataRetriever.ProofsPoolMock{}, + Headers: &pool.HeadersPoolStub{}, + Storage: &genericMocks.ChainStorerMock{}, } argInterceptedEquivalentProof.DataBuff, _ = argInterceptedEquivalentProof.Marshaller.Marshal(&block.HeaderProof{ PubKeysBitmap: []byte("bitmap"), diff --git a/testscommon/headerHandlerStub.go b/testscommon/headerHandlerStub.go index 00613c26d4..733c8b5c16 100644 --- a/testscommon/headerHandlerStub.go +++ b/testscommon/headerHandlerStub.go @@ -40,6 +40,7 @@ type HeaderHandlerStub struct { SetLeaderSignatureCalled func(signature []byte) error GetPreviousProofCalled func() data.HeaderProofHandler SetPreviousProofCalled func(proof data.HeaderProofHandler) + GetShardIDCalled func() uint32 } // GetAccumulatedFees - @@ -91,6 +92,9 @@ func (hhs *HeaderHandlerStub) ShallowClone() data.HeaderHandler { // GetShardID - func (hhs *HeaderHandlerStub) GetShardID() uint32 { + if hhs.GetShardIDCalled != nil { + return hhs.GetShardIDCalled() + } return 1 }