diff --git a/go.mod b/go.mod index 34908ae6..2385c355 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/stretchr/testify v1.7.0 github.com/vishvananda/netlink v1.1.1-0.20220118170537-d6b03fdeb845 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 - go.uber.org/atomic v1.7.0 go.uber.org/goleak v1.1.12 golang.org/x/sys v0.0.0-20220307203707-22a9840ba4d7 google.golang.org/genproto v0.0.0-20211129164237-f09f9a12af12 // indirect diff --git a/go.sum b/go.sum index d6cebb35..83fba17b 100644 --- a/go.sum +++ b/go.sum @@ -262,7 +262,6 @@ go.opentelemetry.io/otel/trace v1.3.0/go.mod h1:c/VDhno8888bvQYmbYLqe41/Ldmr/KKu go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.11.0 h1:cLDgIBTf4lLOlztkhzAEdQsJ4Lj+i5Wc9k6Nn0K1VyU= go.opentelemetry.io/proto/otlp v0.11.0/go.mod h1:QpEjXPrNQzrFDZgoTo49dgHR9RYRSrg3NAKnUGl9YpQ= -go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= diff --git a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipaddress/common.go b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipaddress/common.go index dbcd1b8f..03c4b75f 100644 --- a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipaddress/common.go +++ b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/ipaddress/common.go @@ -16,6 +16,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package ipaddress @@ -93,7 +94,7 @@ func create(ctx context.Context, conn *networkservice.Connection, isClient bool) ch := make(chan netlink.AddrUpdate) done := make(chan struct{}) - if err := netlink.AddrSubscribeAt(targetNetNS, ch, done); err != nil { + if err = netlink.AddrSubscribeAt(targetNetNS, ch, done); err != nil { return errors.Wrapf(err, "failed to subscribe for interface address updates") } @@ -106,7 +107,29 @@ func create(ctx context.Context, conn *networkservice.Connection, isClient bool) }() }() - for _, ipNet := range ipNets { + // Get IP addresses to add and to remove + toAdd, toRemove, err := getIPAddrDifferences(netlinkHandle, l, ipNets) + if err != nil { + return err + } + + // Remove no longer existing IPs + for _, ipNet := range toRemove { + now := time.Now() + addr := &netlink.Addr{ + IPNet: ipNet, + } + if err := netlinkHandle.AddrDel(l, addr); err != nil { + return errors.Wrapf(err, "attempting to delete ip address %s to %s", addr.IPNet, l.Attrs().Name) + } + log.FromContext(ctx). + WithField("link.Name", l.Attrs().Name). + WithField("Addr", ipNet.String()). + WithField("duration", time.Since(now)). + WithField("netlink", "AddrDel").Debug("completed") + } + // Add new IP addresses + for _, ipNet := range toAdd { now := time.Now() addr := &netlink.Addr{ IPNet: ipNet, @@ -129,14 +152,42 @@ func create(ctx context.Context, conn *networkservice.Connection, isClient bool) WithField("duration", time.Since(now)). WithField("netlink", "AddrAdd").Debug("completed") } - return waitForIPNets(ctx, ch, l, ipNets) + return waitForIPNets(ctx, ch, l, toAdd) } return nil } +func getIPAddrDifferences(netlinkHandle *netlink.Handle, l netlink.Link, newIPs []*net.IPNet) (toAdd, toRemove []*net.IPNet, err error) { + currentIPs, err := netlinkHandle.AddrList(l, netlink.FAMILY_ALL) + if err != nil { + return nil, nil, errors.Wrapf(err, "failed to list ip addresses") + } + currentIPsMap := make(map[string]*net.IPNet) + for _, addr := range currentIPs { + // ignore link-local addresses (fe80::/10...) + if addr.Scope != unix.RT_SCOPE_UNIVERSE { + continue + } + currentIPsMap[addr.IPNet.String()] = addr.IPNet + } + for _, ipNet := range newIPs { + if _, ok := currentIPsMap[ipNet.String()]; !ok { + toAdd = append(toAdd, ipNet) + } + delete(currentIPsMap, ipNet.String()) + } + for _, ipNet := range currentIPsMap { + toRemove = append(toRemove, ipNet) + } + return toAdd, toRemove, nil +} + func waitForIPNets(ctx context.Context, ch chan netlink.AddrUpdate, l netlink.Link, ipNets []*net.IPNet) error { now := time.Now() for { + if len(ipNets) == 0 { + return nil + } j := -1 select { case <-ctx.Done(): @@ -162,8 +213,5 @@ func waitForIPNets(ctx context.Context, ch chan netlink.AddrUpdate, l netlink.Li if j != -1 { ipNets = append(ipNets[:j], ipNets[j+1:]...) } - if len(ipNets) == 0 { - return nil - } } } diff --git a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/common.go b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/common.go index 4591e54e..2472da12 100644 --- a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/common.go +++ b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/common.go @@ -21,12 +21,14 @@ package iprule import ( "context" + "fmt" "strconv" "time" + "golang.org/x/sys/unix" + "github.com/pkg/errors" "github.com/vishvananda/netlink" - "go.uber.org/atomic" "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" @@ -36,7 +38,7 @@ import ( link "github.com/networkservicemesh/sdk-kernel/pkg/kernel" ) -func create(ctx context.Context, conn *networkservice.Connection, tableIDs *Map, counter *atomic.Int32) error { +func create(ctx context.Context, conn *networkservice.Connection, tableIDs *Map) error { if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil && mechanism.GetVLAN() == 0 { // Construct the netlink handle for the target namespace for this kernel interface netlinkHandle, err := link.GetNetlinkHandle(mechanism.GetNetNSURL()) @@ -50,49 +52,80 @@ func create(ctx context.Context, conn *networkservice.Connection, tableIDs *Map, return errors.WithStack(err) } - if err = netlinkHandle.LinkSetUp(l); err != nil { - return errors.WithStack(err) + ps, ok := tableIDs.Load(conn.GetId()) + if !ok { + if len(conn.Context.IpContext.Policies) == 0 { + return nil + } + ps = make(map[int]*networkservice.PolicyRoute) + tableIDs.Store(conn.GetId(), ps) } + // Get policies to add and to remove + toAdd, toRemove := getPolicyDifferences(ps, conn.Context.IpContext.Policies) - for _, policy := range conn.Context.IpContext.Policies { - // Check if we already created required ip table - key := tableKey{ - connId: conn.GetId(), - from: policy.From, - protocol: policy.Proto, - dstPort: policy.DstPort, - srcPort: policy.SrcPort, + // Remove no longer existing policies + for tableID, policy := range toRemove { + if err := delRule(ctx, netlinkHandle, policy, tableID); err != nil { + return err } - tableID, ok := tableIDs.Load(key) - if !ok { - counter.Inc() - tableID = int(counter.Load()) + delete(ps, tableID) + tableIDs.Store(conn.GetId(), ps) + } + // Add new policies + for _, policy := range toAdd { + tableID, err := getFreeTableID(ctx, netlinkHandle) + if err != nil { + return errors.Wrapf(err, "failed to get free tableId") } - // If policy doesn't contain any route - add default if len(policy.Routes) == 0 { policy.Routes = append(policy.Routes, defaultRoute()) } + for _, route := range policy.Routes { if err := routeAdd(ctx, netlinkHandle, l, route, tableID); err != nil { - return err + return errors.Wrapf(err, "failed to add route") } } - - if !ok { - // Check and delete old rules if they don't fit - _ = delOldRules(ctx, netlinkHandle, policy, tableID) - // Add new rule - if err := ruleAdd(ctx, netlinkHandle, policy, tableID); err != nil { - return err - } - tableIDs.Store(key, tableID) + if err := ruleAdd(ctx, netlinkHandle, policy, tableID); err != nil { + return errors.Wrapf(err, "failed to add rule") } + ps[tableID] = policy + tableIDs.Store(conn.GetId(), ps) } } return nil } +func getPolicyDifferences(current map[int]*networkservice.PolicyRoute, newPolicies []*networkservice.PolicyRoute) (toAdd []*networkservice.PolicyRoute, toRemove map[int]*networkservice.PolicyRoute) { + type table struct { + tableID int + policyRoute *networkservice.PolicyRoute + } + toRemove = make(map[int]*networkservice.PolicyRoute) + currentMap := make(map[string]*table) + for tableID, policy := range current { + currentMap[policyKey(policy)] = &table{ + tableID: tableID, + policyRoute: policy, + } + } + for _, policy := range newPolicies { + if _, ok := currentMap[policyKey(policy)]; !ok { + toAdd = append(toAdd, policy) + } + delete(currentMap, policyKey(policy)) + } + for _, table := range currentMap { + toRemove[table.tableID] = table.policyRoute + } + return toAdd, toRemove +} + +func policyKey(policy *networkservice.PolicyRoute) string { + return fmt.Sprintf("%s;%s;%s;%s", policy.DstPort, policy.SrcPort, policy.From, policy.Proto) +} + func policyToRule(policy *networkservice.PolicyRoute) (*netlink.Rule, error) { rule := netlink.NewRule() if policy.From != "" { @@ -156,32 +189,6 @@ func ruleAdd(ctx context.Context, handle *netlink.Handle, policy *networkservice return nil } -func delOldRules(ctx context.Context, handle *netlink.Handle, policy *networkservice.PolicyRoute, tableID int) error { - rule, err := policyToRule(policy) - if err != nil { - return errors.WithStack(err) - } - flags := netlink.RT_FILTER_PROTOCOL - if rule.Src != nil { - flags |= netlink.RT_FILTER_SRC - } - rules, err := handle.RuleListFiltered(netlink.FAMILY_ALL, rule, flags) - if err != nil { - return errors.WithStack(err) - } - for i := range rules { - if rules[i].Dport == rule.Dport { - if rules[i].Table != tableID { - err = delRule(ctx, handle, policy) - if err != nil { - return errors.WithStack(err) - } - } - } - } - return nil -} - func defaultRoute() *networkservice.Route { return &networkservice.Route{ Prefix: "0.0.0.0/0", @@ -232,28 +239,35 @@ func routeAdd(ctx context.Context, handle *netlink.Handle, l netlink.Link, route return nil } -func del(ctx context.Context, conn *networkservice.Connection) error { +func del(ctx context.Context, conn *networkservice.Connection, tableIDs *Map) error { if mechanism := kernel.ToMechanism(conn.GetMechanism()); mechanism != nil && mechanism.GetVLAN() == 0 { netlinkHandle, err := link.GetNetlinkHandle(mechanism.GetNetNSURL()) if err != nil { return errors.WithStack(err) } defer netlinkHandle.Close() - for _, policy := range conn.Context.IpContext.Policies { - if err := delRule(ctx, netlinkHandle, policy); err != nil { - return errors.WithStack(err) + ps, ok := tableIDs.LoadAndDelete(conn.GetId()) + if ok { + for tableID, policy := range ps { + if err := delRule(ctx, netlinkHandle, policy, tableID); err != nil { + return errors.WithStack(err) + } } } } return nil } -func delRule(ctx context.Context, handle *netlink.Handle, policy *networkservice.PolicyRoute) error { +func delRule(ctx context.Context, handle *netlink.Handle, policy *networkservice.PolicyRoute, tableID int) error { rule, err := policyToRule(policy) if err != nil { return errors.WithStack(err) } + if err := flushTable(ctx, handle, tableID); err != nil { + return err + } + now := time.Now() if err := handle.RuleDel(rule); err != nil { log.FromContext(ctx). @@ -263,7 +277,7 @@ func delRule(ctx context.Context, handle *netlink.Handle, policy *networkservice WithField("SrcPort", policy.SrcPort). WithField("duration", time.Since(now)). WithField("netlink", "RuleDel").Errorf("error %+v", err) - return errors.WithStack(err) + return errors.Wrapf(errors.WithStack(err), "failed to delete rule") } log.FromContext(ctx). WithField("From", policy.From). @@ -274,3 +288,57 @@ func delRule(ctx context.Context, handle *netlink.Handle, policy *networkservice WithField("netlink", "RuleDel").Debug("completed") return nil } + +func flushTable(ctx context.Context, handle *netlink.Handle, tableID int) error { + routes, err := handle.RouteListFiltered(netlink.FAMILY_ALL, + &netlink.Route{ + Table: tableID, + }, + netlink.RT_FILTER_TABLE) + if err != nil { + return errors.Wrapf(errors.WithStack(err), "failed to list routes") + } + for i := 0; i < len(routes); i++ { + err := handle.RouteDel(&routes[i]) + if err != nil { + return errors.Wrapf(errors.WithStack(err), "failed to delete route") + } + } + log.FromContext(ctx). + WithField("tableID", tableID). + WithField("netlink", "flushTable").Debug("completed") + return nil +} + +func getFreeTableID(ctx context.Context, handle *netlink.Handle) (int, error) { + routes, err := handle.RouteListFiltered(netlink.FAMILY_ALL, + &netlink.Route{ + Table: unix.RT_TABLE_UNSPEC, + }, + netlink.RT_FILTER_TABLE) + if err != nil { + return 0, errors.Wrapf(errors.WithStack(err), "getFreeTableID: failed to list routes") + } + + // tableID = 0 is reserved + ids := make(map[int]int) + ids[0] = 0 + for i := 0; i < len(routes); i++ { + ids[routes[i].Table] = routes[i].Table + } + + // Find first missing table id + tableID := len(ids) + for i := 0; i < len(ids); i++ { + if _, ok := ids[i]; !ok { + tableID = i + break + } + } + + log.FromContext(ctx). + WithField("tableID", tableID). + WithField("netlink", "getFreeTableID").Debug("completed") + + return tableID, nil +} diff --git a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/gen.go b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/gen.go index 9459abcd..578cd3f2 100644 --- a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/gen.go +++ b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/gen.go @@ -16,17 +16,15 @@ package iprule -import "sync" +import ( + "sync" -//go:generate go-syncmap -output table_map.gen.go -type Map + "github.com/networkservicemesh/api/pkg/api/networkservice" +) -type tableKey struct { - connId string - from string - protocol string - dstPort string - srcPort string -} +//go:generate go-syncmap -output table_map.gen.go -type Map -// Map - sync.Map with key == tableKey and value == uint32 +type policies map[int]*networkservice.PolicyRoute + +// Map - sync.Map with key == string (connID) and value == policies type Map sync.Map diff --git a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/server.go b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/server.go index eee22181..297af212 100644 --- a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/server.go +++ b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/server.go @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package iprule @@ -26,12 +27,10 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" "github.com/networkservicemesh/sdk/pkg/tools/postpone" "github.com/pkg/errors" - "go.uber.org/atomic" ) type ipruleServer struct { - counter atomic.Int32 - tables Map + tables Map } // NewServer creates a new server chain element setting ip rules @@ -47,7 +46,7 @@ func (i *ipruleServer) Request(ctx context.Context, request *networkservice.Netw return nil, err } - if err := create(ctx, conn, &i.tables, &i.counter); err != nil { + if err := create(ctx, conn, &i.tables); err != nil { closeCtx, cancelClose := postponeCtxFunc() defer cancelClose() @@ -62,6 +61,6 @@ func (i *ipruleServer) Request(ctx context.Context, request *networkservice.Netw } func (i *ipruleServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { - _ = del(ctx, conn) + _ = del(ctx, conn, &i.tables) return next.Server(ctx).Close(ctx, conn) } diff --git a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/table_map.gen.go b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/table_map.gen.go index 713cc8a3..630daf4e 100644 --- a/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/table_map.gen.go +++ b/pkg/kernel/networkservice/connectioncontextkernel/ipcontext/iprule/table_map.gen.go @@ -1,4 +1,4 @@ -// Code generated by "-output table_map.gen.go -type Map -output table_map.gen.go -type Map"; DO NOT EDIT. +// Code generated by "-output table_map.gen.go -type Map -output table_map.gen.go -type Map"; DO NOT EDIT. package iprule import ( @@ -12,47 +12,47 @@ func _() { _ = (sync.Map)(Map{}) } -var _nil_Map_int_value = func() (val int) { return }() +var _nil_Map_policies_value = func() (val policies) { return }() // Load returns the value stored in the map for a key, or nil if no // value is present. // The ok result indicates whether value was found in the map. -func (m *Map) Load(key tableKey) (int, bool) { +func (m *Map) Load(key string) (policies, bool) { value, ok := (*sync.Map)(m).Load(key) if value == nil { - return _nil_Map_int_value, ok + return _nil_Map_policies_value, ok } - return value.(int), ok + return value.(policies), ok } // Store sets the value for a key. -func (m *Map) Store(key tableKey, value int) { +func (m *Map) Store(key string, value policies) { (*sync.Map)(m).Store(key, value) } // LoadOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. // The loaded result is true if the value was loaded, false if stored. -func (m *Map) LoadOrStore(key tableKey, value int) (int, bool) { +func (m *Map) LoadOrStore(key string, value policies) (policies, bool) { actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) if actual == nil { - return _nil_Map_int_value, loaded + return _nil_Map_policies_value, loaded } - return actual.(int), loaded + return actual.(policies), loaded } // LoadAndDelete deletes the value for a key, returning the previous value if any. // The loaded result reports whether the key was present. -func (m *Map) LoadAndDelete(key tableKey) (value int, loaded bool) { +func (m *Map) LoadAndDelete(key string) (value policies, loaded bool) { actual, loaded := (*sync.Map)(m).LoadAndDelete(key) if actual == nil { - return _nil_Map_int_value, loaded + return _nil_Map_policies_value, loaded } - return actual.(int), loaded + return actual.(policies), loaded } // Delete deletes the value for a key. -func (m *Map) Delete(key tableKey) { +func (m *Map) Delete(key string) { (*sync.Map)(m).Delete(key) } @@ -66,8 +66,8 @@ func (m *Map) Delete(key tableKey) { // // Range may be O(N) with the number of elements in the map even if f returns // false after a constant number of calls. -func (m *Map) Range(f func(key tableKey, value int) bool) { +func (m *Map) Range(f func(key string, value policies) bool) { (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(tableKey), value.(int)) + return f(key.(string), value.(policies)) }) }