Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IP address and policy removal on connection update #437

Merged
merged 6 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ require (
github.com/stretchr/testify v1.7.0
github.com/vishvananda/netlink v1.1.1-0.20220118170537-d6b03fdeb845
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74
go.uber.org/atomic v1.7.0
go.uber.org/goleak v1.1.12
golang.org/x/sys v0.0.0-20220307203707-22a9840ba4d7
google.golang.org/genproto v0.0.0-20211129164237-f09f9a12af12 // indirect
Expand Down
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ go.opentelemetry.io/otel/trace v1.3.0/go.mod h1:c/VDhno8888bvQYmbYLqe41/Ldmr/KKu
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.opentelemetry.io/proto/otlp v0.11.0 h1:cLDgIBTf4lLOlztkhzAEdQsJ4Lj+i5Wc9k6Nn0K1VyU=
go.opentelemetry.io/proto/otlp v0.11.0/go.mod h1:QpEjXPrNQzrFDZgoTo49dgHR9RYRSrg3NAKnUGl9YpQ=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build linux
// +build linux

package ipaddress
Expand Down Expand Up @@ -93,7 +94,7 @@ func create(ctx context.Context, conn *networkservice.Connection, isClient bool)
ch := make(chan netlink.AddrUpdate)
done := make(chan struct{})

if err := netlink.AddrSubscribeAt(targetNetNS, ch, done); err != nil {
if err = netlink.AddrSubscribeAt(targetNetNS, ch, done); err != nil {
return errors.Wrapf(err, "failed to subscribe for interface address updates")
}

Expand All @@ -106,7 +107,29 @@ func create(ctx context.Context, conn *networkservice.Connection, isClient bool)
}()
}()

for _, ipNet := range ipNets {
// Get IP addresses to add and to remove
toAdd, toRemove, err := getIPAddrDifferences(netlinkHandle, l, ipNets)
if err != nil {
return err
}

// Remove no longer existing IPs
for _, ipNet := range toRemove {
now := time.Now()
addr := &netlink.Addr{
IPNet: ipNet,
}
if err := netlinkHandle.AddrDel(l, addr); err != nil {
return errors.Wrapf(err, "attempting to delete ip address %s to %s", addr.IPNet, l.Attrs().Name)
}
log.FromContext(ctx).
WithField("link.Name", l.Attrs().Name).
WithField("Addr", ipNet.String()).
WithField("duration", time.Since(now)).
WithField("netlink", "AddrDel").Debug("completed")
}
// Add new IP addresses
for _, ipNet := range toAdd {
now := time.Now()
addr := &netlink.Addr{
IPNet: ipNet,
Expand All @@ -129,14 +152,42 @@ func create(ctx context.Context, conn *networkservice.Connection, isClient bool)
WithField("duration", time.Since(now)).
WithField("netlink", "AddrAdd").Debug("completed")
}
return waitForIPNets(ctx, ch, l, ipNets)
return waitForIPNets(ctx, ch, l, toAdd)
}
return nil
}

func getIPAddrDifferences(netlinkHandle *netlink.Handle, l netlink.Link, newIPs []*net.IPNet) (toAdd, toRemove []*net.IPNet, err error) {
currentIPs, err := netlinkHandle.AddrList(l, netlink.FAMILY_ALL)
if err != nil {
return nil, nil, errors.Wrapf(err, "failed to list ip addresses")
}
currentIPsMap := make(map[string]*net.IPNet)
for _, addr := range currentIPs {
// ignore link-local addresses (fe80::/10...)
if addr.Scope != unix.RT_SCOPE_UNIVERSE {
continue
}
currentIPsMap[addr.IPNet.String()] = addr.IPNet
}
for _, ipNet := range newIPs {
if _, ok := currentIPsMap[ipNet.String()]; !ok {
toAdd = append(toAdd, ipNet)
}
delete(currentIPsMap, ipNet.String())
}
for _, ipNet := range currentIPsMap {
toRemove = append(toRemove, ipNet)
}
return toAdd, toRemove, nil
}

func waitForIPNets(ctx context.Context, ch chan netlink.AddrUpdate, l netlink.Link, ipNets []*net.IPNet) error {
now := time.Now()
for {
if len(ipNets) == 0 {
return nil
}
j := -1
select {
case <-ctx.Done():
Expand All @@ -162,8 +213,5 @@ func waitForIPNets(ctx context.Context, ch chan netlink.AddrUpdate, l netlink.Li
if j != -1 {
ipNets = append(ipNets[:j], ipNets[j+1:]...)
}
if len(ipNets) == 0 {
return nil
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ package iprule

import (
"context"
"fmt"
"strconv"
"time"

"golang.org/x/sys/unix"

"github.com/pkg/errors"
"github.com/vishvananda/netlink"
"go.uber.org/atomic"

"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel"
Expand All @@ -36,7 +38,7 @@ import (
link "github.com/networkservicemesh/sdk-kernel/pkg/kernel"
)

func create(ctx context.Context, conn *networkservice.Connection, tableIDs *Map, counter *atomic.Int32) error {
func create(ctx context.Context, conn *networkservice.Connection, tableIDs *Map) error {
if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil && mechanism.GetVLAN() == 0 {
// Construct the netlink handle for the target namespace for this kernel interface
netlinkHandle, err := link.GetNetlinkHandle(mechanism.GetNetNSURL())
Expand All @@ -50,49 +52,80 @@ func create(ctx context.Context, conn *networkservice.Connection, tableIDs *Map,
return errors.WithStack(err)
}

if err = netlinkHandle.LinkSetUp(l); err != nil {
return errors.WithStack(err)
ps, ok := tableIDs.Load(conn.GetId())
if !ok {
if len(conn.Context.IpContext.Policies) == 0 {
return nil
}
ps = make(map[int]*networkservice.PolicyRoute)
tableIDs.Store(conn.GetId(), ps)
}
// Get policies to add and to remove
toAdd, toRemove := getPolicyDifferences(ps, conn.Context.IpContext.Policies)

for _, policy := range conn.Context.IpContext.Policies {
// Check if we already created required ip table
key := tableKey{
connId: conn.GetId(),
from: policy.From,
protocol: policy.Proto,
dstPort: policy.DstPort,
srcPort: policy.SrcPort,
// Remove no longer existing policies
for tableID, policy := range toRemove {
if err := delRule(ctx, netlinkHandle, policy, tableID); err != nil {
return err
}
tableID, ok := tableIDs.Load(key)
if !ok {
counter.Inc()
tableID = int(counter.Load())
delete(ps, tableID)
tableIDs.Store(conn.GetId(), ps)
}
// Add new policies
for _, policy := range toAdd {
tableID, err := getFreeTableID(ctx, netlinkHandle)
if err != nil {
return errors.Wrapf(err, "failed to get free tableId")
}

// If policy doesn't contain any route - add default
if len(policy.Routes) == 0 {
policy.Routes = append(policy.Routes, defaultRoute())
}

for _, route := range policy.Routes {
if err := routeAdd(ctx, netlinkHandle, l, route, tableID); err != nil {
return err
return errors.Wrapf(err, "failed to add route")
}
}

if !ok {
// Check and delete old rules if they don't fit
_ = delOldRules(ctx, netlinkHandle, policy, tableID)
// Add new rule
if err := ruleAdd(ctx, netlinkHandle, policy, tableID); err != nil {
return err
}
tableIDs.Store(key, tableID)
if err := ruleAdd(ctx, netlinkHandle, policy, tableID); err != nil {
return errors.Wrapf(err, "failed to add rule")
}
ps[tableID] = policy
tableIDs.Store(conn.GetId(), ps)
}
}
return nil
}

func getPolicyDifferences(current map[int]*networkservice.PolicyRoute, newPolicies []*networkservice.PolicyRoute) (toAdd []*networkservice.PolicyRoute, toRemove map[int]*networkservice.PolicyRoute) {
type table struct {
tableID int
policyRoute *networkservice.PolicyRoute
}
toRemove = make(map[int]*networkservice.PolicyRoute)
currentMap := make(map[string]*table)
for tableID, policy := range current {
currentMap[policyKey(policy)] = &table{
tableID: tableID,
policyRoute: policy,
}
}
for _, policy := range newPolicies {
if _, ok := currentMap[policyKey(policy)]; !ok {
toAdd = append(toAdd, policy)
}
delete(currentMap, policyKey(policy))
}
for _, table := range currentMap {
toRemove[table.tableID] = table.policyRoute
}
return toAdd, toRemove
}

func policyKey(policy *networkservice.PolicyRoute) string {
return fmt.Sprintf("%s;%s;%s;%s", policy.DstPort, policy.SrcPort, policy.From, policy.Proto)
}

func policyToRule(policy *networkservice.PolicyRoute) (*netlink.Rule, error) {
rule := netlink.NewRule()
if policy.From != "" {
Expand Down Expand Up @@ -156,32 +189,6 @@ func ruleAdd(ctx context.Context, handle *netlink.Handle, policy *networkservice
return nil
}

func delOldRules(ctx context.Context, handle *netlink.Handle, policy *networkservice.PolicyRoute, tableID int) error {
rule, err := policyToRule(policy)
if err != nil {
return errors.WithStack(err)
}
flags := netlink.RT_FILTER_PROTOCOL
if rule.Src != nil {
flags |= netlink.RT_FILTER_SRC
}
rules, err := handle.RuleListFiltered(netlink.FAMILY_ALL, rule, flags)
if err != nil {
return errors.WithStack(err)
}
for i := range rules {
if rules[i].Dport == rule.Dport {
if rules[i].Table != tableID {
err = delRule(ctx, handle, policy)
if err != nil {
return errors.WithStack(err)
}
}
}
}
return nil
}

func defaultRoute() *networkservice.Route {
return &networkservice.Route{
Prefix: "0.0.0.0/0",
Expand Down Expand Up @@ -232,28 +239,35 @@ func routeAdd(ctx context.Context, handle *netlink.Handle, l netlink.Link, route
return nil
}

func del(ctx context.Context, conn *networkservice.Connection) error {
func del(ctx context.Context, conn *networkservice.Connection, tableIDs *Map) error {
if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil && mechanism.GetVLAN() == 0 {
netlinkHandle, err := link.GetNetlinkHandle(mechanism.GetNetNSURL())
if err != nil {
return errors.WithStack(err)
}
defer netlinkHandle.Close()
for _, policy := range conn.Context.IpContext.Policies {
if err := delRule(ctx, netlinkHandle, policy); err != nil {
return errors.WithStack(err)
ps, ok := tableIDs.LoadAndDelete(conn.GetId())
if ok {
for tableID, policy := range ps {
if err := delRule(ctx, netlinkHandle, policy, tableID); err != nil {
return errors.WithStack(err)
}
}
}
}
return nil
}

func delRule(ctx context.Context, handle *netlink.Handle, policy *networkservice.PolicyRoute) error {
func delRule(ctx context.Context, handle *netlink.Handle, policy *networkservice.PolicyRoute, tableID int) error {
rule, err := policyToRule(policy)
if err != nil {
return errors.WithStack(err)
}

if err := flushTable(ctx, handle, tableID); err != nil {
return err
}

now := time.Now()
if err := handle.RuleDel(rule); err != nil {
log.FromContext(ctx).
Expand All @@ -263,7 +277,7 @@ func delRule(ctx context.Context, handle *netlink.Handle, policy *networkservice
WithField("SrcPort", policy.SrcPort).
WithField("duration", time.Since(now)).
WithField("netlink", "RuleDel").Errorf("error %+v", err)
return errors.WithStack(err)
return errors.Wrapf(errors.WithStack(err), "failed to delete rule")
}
log.FromContext(ctx).
WithField("From", policy.From).
Expand All @@ -274,3 +288,57 @@ func delRule(ctx context.Context, handle *netlink.Handle, policy *networkservice
WithField("netlink", "RuleDel").Debug("completed")
return nil
}

func flushTable(ctx context.Context, handle *netlink.Handle, tableID int) error {
routes, err := handle.RouteListFiltered(netlink.FAMILY_ALL,
&netlink.Route{
Table: tableID,
},
netlink.RT_FILTER_TABLE)
if err != nil {
return errors.Wrapf(errors.WithStack(err), "failed to list routes")
}
for i := 0; i < len(routes); i++ {
err := handle.RouteDel(&routes[i])
if err != nil {
return errors.Wrapf(errors.WithStack(err), "failed to delete route")
}
}
log.FromContext(ctx).
WithField("tableID", tableID).
WithField("netlink", "flushTable").Debug("completed")
return nil
}

func getFreeTableID(ctx context.Context, handle *netlink.Handle) (int, error) {
routes, err := handle.RouteListFiltered(netlink.FAMILY_ALL,
&netlink.Route{
Table: unix.RT_TABLE_UNSPEC,
},
netlink.RT_FILTER_TABLE)
if err != nil {
return 0, errors.Wrapf(errors.WithStack(err), "getFreeTableID: failed to list routes")
}

// tableID = 0 is reserved
ids := make(map[int]int)
ids[0] = 0
for i := 0; i < len(routes); i++ {
ids[routes[i].Table] = routes[i].Table
}

// Find first missing table id
tableID := len(ids)
for i := 0; i < len(ids); i++ {
if _, ok := ids[i]; !ok {
tableID = i
break
}
}

log.FromContext(ctx).
WithField("tableID", tableID).
WithField("netlink", "getFreeTableID").Debug("completed")

return tableID, nil
}
Loading