diff --git a/vms/platformvm/validators/manager.go b/vms/platformvm/validators/manager.go index bf78a05cdc7e..6d1911cdb23c 100644 --- a/vms/platformvm/validators/manager.go +++ b/vms/platformvm/validators/manager.go @@ -164,15 +164,7 @@ func (m *manager) GetValidatorSet(ctx context.Context, height uint64, subnetID i return nil, err } } - currentSubnetValidatorList := currentSubnetValidators.List() - subnetSet := make(map[ids.NodeID]*validators.GetValidatorOutput, len(currentSubnetValidatorList)) - for _, vdr := range currentSubnetValidatorList { - subnetSet[vdr.NodeID] = &validators.GetValidatorOutput{ - NodeID: vdr.NodeID, - // PublicKey will be picked from primary validators - Weight: vdr.Weight, - } - } + subnetSet := make(map[ids.NodeID]*validators.GetValidatorOutput, currentSubnetValidators.Len()) currentPrimaryNetworkValidators, ok := m.cfg.Validators.Get(constants.PrimaryNetworkID) if !ok { @@ -182,16 +174,19 @@ func (m *manager) GetValidatorSet(ctx context.Context, height uint64, subnetID i currentPrimaryValidatorList := currentPrimaryNetworkValidators.List() primarySet := make(map[ids.NodeID]*validators.GetValidatorOutput, len(currentPrimaryValidatorList)) for _, vdr := range currentPrimaryValidatorList { + if currentSubnetValidators.Contains(vdr.NodeID) { + subnetSet[vdr.NodeID] = &validators.GetValidatorOutput{ + NodeID: vdr.NodeID, + PublicKey: vdr.PublicKey, + Weight: vdr.Weight, + } + } + primarySet[vdr.NodeID] = &validators.GetValidatorOutput{ NodeID: vdr.NodeID, PublicKey: vdr.PublicKey, Weight: vdr.Weight, } - - // fill PK from primary network - if _, found := subnetSet[vdr.NodeID]; found { - subnetSet[vdr.NodeID].PublicKey = vdr.PublicKey - } } for diffHeight := lastAcceptedHeight; diffHeight > height; diffHeight-- { diff --git a/vms/platformvm/validators/manager_test.go b/vms/platformvm/validators/manager_test.go index 43539771775a..acb173e099aa 100644 --- a/vms/platformvm/validators/manager_test.go +++ b/vms/platformvm/validators/manager_test.go @@ -397,6 +397,16 @@ func TestVM_GetValidatorSet(t *testing.T) { mockSubnetVdrSet = validators.NewMockSet(ctrl) mockSubnetVdrSet.EXPECT().List().Return(tt.currentSubnetValidators).AnyTimes() } + mockSubnetVdrSet.EXPECT().Len().Return(len(tt.currentSubnetValidators)).AnyTimes() + mockSubnetVdrSet.EXPECT().Contains(gomock.Any()).DoAndReturn(func(nodeID ids.NodeID) bool { + for _, vdr := range tt.currentSubnetValidators { + if vdr.NodeID == nodeID { + return true + } + } + + return false + }).AnyTimes() vdrs.EXPECT().Get(tt.subnetID).Return(mockSubnetVdrSet, true).AnyTimes() for _, vdr := range testValidators {