diff --git a/pkg/ebpf/bpf_client.go b/pkg/ebpf/bpf_client.go index 71cfae8..8a7fa25 100644 --- a/pkg/ebpf/bpf_client.go +++ b/pkg/ebpf/bpf_client.go @@ -120,7 +120,6 @@ func NewBpfClient(policyEndpointeBPFContext *sync.Map, nodeIP string, enablePoli GlobalMaps: new(sync.Map), } ebpfClient.logger = ctrl.Log.WithName("ebpf-client") - ingressBinary, egressBinary, eventsBinary, cliBinary, hostMask := TC_INGRESS_BINARY, TC_EGRESS_BINARY, EVENTS_BINARY, EKS_CLI_BINARY, IPv4_HOST_MASK if enableIPv6 { @@ -739,6 +738,8 @@ func sortFirewallRulesByPrefixLength(rules []EbpfFirewallRules, prefixLenStr str } func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirewallRules) (map[string]uintptr, error) { + + firewallMap := make(map[string][]byte) mapEntries := make(map[string]uintptr) ipCIDRs := make(map[string][]v1alpha1.Port) nonHostCIDRs := make(map[string][]v1alpha1.Port) @@ -749,7 +750,7 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew _, mapKey, _ := net.ParseCIDR(l.nodeIP + l.hostMask) key := utils.ComputeTrieKey(*mapKey, l.enableIPv6) value := utils.ComputeTrieValue([]v1alpha1.Port{}, l.logger, true, false) - mapEntries[string(key)] = uintptr(unsafe.Pointer(&value[0])) + firewallMap[string(key)] = value //Sort the rules sortFirewallRulesByPrefixLength(firewallRules, l.hostMask) @@ -758,10 +759,10 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew catchAllIPPorts, isCatchAllIPEntryPresent, allowAll = l.checkAndDeriveCatchAllIPPorts(firewallRules) if isCatchAllIPEntryPresent { //Add the Catch All IP entry - _, mapKey, _ = net.ParseCIDR("0.0.0.0/0") - key = utils.ComputeTrieKey(*mapKey, l.enableIPv6) - value = utils.ComputeTrieValue(catchAllIPPorts, l.logger, allowAll, false) - mapEntries[string(key)] = uintptr(unsafe.Pointer(&value[0])) + _, mapKey, _ := net.ParseCIDR("0.0.0.0/0") + key := utils.ComputeTrieKey(*mapKey, l.enableIPv6) + value := utils.ComputeTrieValue(catchAllIPPorts, l.logger, allowAll, false) + firewallMap[string(key)] = value } for _, firewallRule := range firewallRules { @@ -812,22 +813,28 @@ func (l *bpfClient) computeMapEntriesFromEndpointRules(firewallRules []EbpfFirew firewallRule.L4Info = append(firewallRule.L4Info, catchAllIPPorts...) l.logger.Info("Updating Map with ", "IP Key:", firewallRule.IPCidr) - _, mapKey, _ = net.ParseCIDR(string(firewallRule.IPCidr)) + _, firewallMapKey, _ := net.ParseCIDR(string(firewallRule.IPCidr)) // Key format: Prefix length (4 bytes) followed by 4/16byte IP address - key = utils.ComputeTrieKey(*mapKey, l.enableIPv6) - value = utils.ComputeTrieValue(firewallRule.L4Info, l.logger, allowAll, false) - mapEntries[string(key)] = uintptr(unsafe.Pointer(&value[0])) + firewallKey := utils.ComputeTrieKey(*firewallMapKey, l.enableIPv6) + firewallValue := utils.ComputeTrieValue(firewallRule.L4Info, l.logger, allowAll, false) + firewallMap[string(firewallKey)] = firewallValue } if firewallRule.Except != nil { for _, exceptCIDR := range firewallRule.Except { - _, mapKey, _ = net.ParseCIDR(string(exceptCIDR)) - key = utils.ComputeTrieKey(*mapKey, l.enableIPv6) + _, mapKey, _ := net.ParseCIDR(string(exceptCIDR)) + key := utils.ComputeTrieKey(*mapKey, l.enableIPv6) l.logger.Info("Parsed Except CIDR", "IP Key: ", mapKey) - value = utils.ComputeTrieValue(firewallRule.L4Info, l.logger, false, true) - mapEntries[string(key)] = uintptr(unsafe.Pointer(&value[0])) + value := utils.ComputeTrieValue(firewallRule.L4Info, l.logger, false, true) + firewallMap[string(key)] = value } } } + + //Add to mapEntries + for key, value := range firewallMap { + byteSlicePtr := unsafe.Pointer(&value[0]) + mapEntries[key] = uintptr(byteSlicePtr) + } return mapEntries, nil }