diff --git a/pkg/ipamd/ipamd.go b/pkg/ipamd/ipamd.go index 58f5cd3b0f..0f36713d3e 100644 --- a/pkg/ipamd/ipamd.go +++ b/pkg/ipamd/ipamd.go @@ -426,21 +426,7 @@ func New(rawK8SClient client.Client, cachedK8SClient client.Client) (*IPAMContex checkpointer := datastore.NewJSONFile(dsBackingStorePath()) c.dataStore = datastore.NewDataStore(log, checkpointer, c.enablePrefixDelegation) - // Retrieve security groups - mac := c.awsClient.GetPrimaryENImac() - if c.enableIPv4 && !c.disableENIProvisioning { - err = c.awsClient.RefreshSGIDs(mac) - if err != nil { - return nil, err - } - - // Refresh security groups and VPC CIDR blocks in the background - // Ignoring errors since we will retry in 30s - go wait.Forever(func() { _ = c.awsClient.RefreshSGIDs(mac) }, 30*time.Second) - } - - err = c.nodeInit() - if err != nil { + if err := c.nodeInit(); err != nil { return nil, err } return c, nil @@ -454,7 +440,6 @@ func (c *IPAMContext) nodeInit() error { ctx := context.TODO() log.Debugf("Start node init") - primaryV4IP := c.awsClient.GetLocalIPv4() if err = c.initENIAndIPLimits(); err != nil { return err @@ -469,7 +454,8 @@ func (c *IPAMContext) nodeInit() error { } } - err = c.networkClient.SetupHostNetwork(vpcV4CIDRs, c.awsClient.GetPrimaryENImac(), &primaryV4IP, c.enablePodENI, c.enableIPv4, c.enableIPv6) + primaryENIMac := c.awsClient.GetPrimaryENImac() + err = c.networkClient.SetupHostNetwork(vpcV4CIDRs, primaryENIMac, &primaryV4IP, c.enablePodENI, c.enableIPv4, c.enableIPv6) if err != nil { return errors.Wrap(err, "ipamd init: failed to set up host network") } @@ -552,6 +538,21 @@ func (c *IPAMContext) nodeInit() error { vpcV4CIDRs = c.updateCIDRsRulesOnChange(vpcV4CIDRs) }, 30*time.Second) + // RefreshSGIDs populates the ENI cache with ENI -> security group ID mappings, and so it must be called: + // 1. after managed/unmanaged ENIs have been determined + // 2. before any new ENIs are attached + if c.enableIPv4 && !c.disableENIProvisioning { + if err := c.awsClient.RefreshSGIDs(primaryENIMac); err != nil { + return err + } + + // Refresh security groups and VPC CIDR blocks in the background + // Ignoring errors since we will retry in 30s + go wait.Forever(func() { + c.awsClient.RefreshSGIDs(primaryENIMac) + }, 30*time.Second) + } + node, err := k8sapi.GetNode(ctx, c.cachedK8SClient) if err != nil { log.Errorf("Failed to get node", err) diff --git a/pkg/ipamd/ipamd_test.go b/pkg/ipamd/ipamd_test.go index 906f7a391b..b3e99710ee 100644 --- a/pkg/ipamd/ipamd_test.go +++ b/pkg/ipamd/ipamd_test.go @@ -150,8 +150,8 @@ func TestNodeInit(t *testing.T) { m.awsutils.EXPECT().GetVPCIPv4CIDRs().AnyTimes().Return(cidrs, nil) m.awsutils.EXPECT().GetPrimaryENImac().Return("") m.network.EXPECT().SetupHostNetwork(cidrs, "", &primaryIP, false, true, false).Return(nil) - m.awsutils.EXPECT().GetPrimaryENI().AnyTimes().Return(primaryENIid) + m.awsutils.EXPECT().RefreshSGIDs(gomock.Any()).AnyTimes().Return(nil) eniMetadataSlice := []awsutils.ENIMetadata{eni1, eni2} resp := awsutils.DescribeAllENIsResult{ @@ -179,7 +179,7 @@ func TestNodeInit(t *testing.T) { Spec: v1.NodeSpec{}, Status: v1.NodeStatus{}, } - _ = m.cachedK8SClient.Create(ctx, &fakeNode) + m.cachedK8SClient.Create(ctx, &fakeNode) // Add IPs m.awsutils.EXPECT().AllocIPAddresses(gomock.Any(), gomock.Any()) @@ -236,8 +236,8 @@ func TestNodeInitwithPDenabledIPv4Mode(t *testing.T) { m.awsutils.EXPECT().GetVPCIPv4CIDRs().AnyTimes().Return(cidrs, nil) m.awsutils.EXPECT().GetPrimaryENImac().Return("") m.network.EXPECT().SetupHostNetwork(cidrs, "", &primaryIP, false, true, false).Return(nil) - m.awsutils.EXPECT().GetPrimaryENI().AnyTimes().Return(primaryENIid) + m.awsutils.EXPECT().RefreshSGIDs(gomock.Any()).AnyTimes().Return(nil) eniMetadataSlice := []awsutils.ENIMetadata{eni1, eni2} resp := awsutils.DescribeAllENIsResult{ @@ -264,8 +264,9 @@ func TestNodeInitwithPDenabledIPv4Mode(t *testing.T) { Spec: v1.NodeSpec{}, Status: v1.NodeStatus{}, } - _ = m.cachedK8SClient.Create(ctx, &fakeNode) + m.cachedK8SClient.Create(ctx, &fakeNode) + os.Setenv("MY_NODE_NAME", myNodeName) err := mockContext.nodeInit() assert.NoError(t, err) }