diff --git a/plugins/meta/vrf/vrf.go b/plugins/meta/vrf/vrf.go
index f265c071f..f32e74ea6 100644
--- a/plugins/meta/vrf/vrf.go
+++ b/plugins/meta/vrf/vrf.go
@@ -17,6 +17,7 @@ package main
import (
"fmt"
"math"
+ "strings"
"github.com/vishvananda/netlink"
)
@@ -104,6 +105,14 @@ func addInterface(vrf *netlink.Vrf, intf string) error {
if err != nil {
return fmt.Errorf("failed getting ipv6 addresses for %s", intf)
}
+
+ // Save routes that are setup for the interface, before setting master,
+ // because otherwise the routes will be deleted after interface is moved.
+ routes, err := netlink.RouteList(i, netlink.FAMILY_ALL)
+ if err != nil {
+ return fmt.Errorf("failed getting all routes for %s", intf)
+ }
+
err = netlink.LinkSetMaster(i, vrf)
if err != nil {
return fmt.Errorf("could not set vrf %s as master of %s: %v", vrf.Name, intf, err)
@@ -130,6 +139,21 @@ CONTINUE:
}
}
+ // Apply all saved routes for the interface that was moved to the VRF
+ for _, route := range routes {
+ r := route
+ // Modify original table to vrf one,
+ // equivalent of 'ip route add
table '.
+ r.Table = int(vrf.Table)
+ err = netlink.RouteAdd(&r)
+ if err != nil {
+ // If route is already present, returned error is "file exists"
+ if !strings.Contains(err.Error(), "file exists") {
+ return fmt.Errorf("could not add route '%s': %v", r, err)
+ }
+ }
+ }
+
return nil
}
diff --git a/plugins/meta/vrf/vrf_test.go b/plugins/meta/vrf/vrf_test.go
index 8eb2fbea3..b8e8ac046 100644
--- a/plugins/meta/vrf/vrf_test.go
+++ b/plugins/meta/vrf/vrf_test.go
@@ -17,10 +17,13 @@ package main
import (
"encoding/json"
"fmt"
+ "net"
+ "strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/vishvananda/netlink"
+ "golang.org/x/sys/unix"
"github.com/containernetworking/cni/pkg/skel"
"github.com/containernetworking/cni/pkg/types"
@@ -107,7 +110,7 @@ var _ = Describe("vrf plugin", func() {
},
})
Expect(err).NotTo(HaveOccurred())
- _, err = netlink.LinkByName(IF0Name)
+ _, err = netlink.LinkByName(IF1Name)
Expect(err).NotTo(HaveOccurred())
return nil
})
@@ -177,6 +180,102 @@ var _ = Describe("vrf plugin", func() {
Expect(err).NotTo(HaveOccurred())
})
+ It("adds the interface and custom routing to new VRF", func() {
+ conf := configWithRouteFor("test", IF0Name, VRF0Name, "10.0.0.2/24", "10.10.10.0/24")
+
+ By("Setting custom routing first", func() {
+ err := targetNS.Do(func(ns.NetNS) error {
+ defer GinkgoRecover()
+
+ ipv4, err := types.ParseCIDR("10.0.0.2/24")
+ Expect(err).NotTo(HaveOccurred())
+ Expect(ipv4).NotTo(BeNil())
+
+ _, routev4, err := net.ParseCIDR("10.10.10.0/24")
+ Expect(err).NotTo(HaveOccurred())
+
+ ipv6, err := types.ParseCIDR("abcd:1234:ffff::cdde/64")
+ Expect(err).NotTo(HaveOccurred())
+ Expect(ipv6).NotTo(BeNil())
+
+ _, routev6, err := net.ParseCIDR("1111:dddd::/80")
+ Expect(err).NotTo(HaveOccurred())
+ Expect(routev6).NotTo(BeNil())
+
+ link, err := netlink.LinkByName(IF0Name)
+ Expect(err).NotTo(HaveOccurred())
+
+ // Add IP addresses for network reachability
+ netlink.AddrAdd(link, &netlink.Addr{IPNet: ipv4})
+ netlink.AddrAdd(link, &netlink.Addr{IPNet: ipv6})
+
+ ipAddrs, err := netlink.AddrList(link, netlink.FAMILY_V4)
+ Expect(err).NotTo(HaveOccurred())
+ // Check if address was assigned properly
+ Expect(ipAddrs[0].IP.String()).To(Equal("10.0.0.2"))
+
+ // Set interface UP, otherwise local route to 10.0.0.0/24 is not present
+ err = netlink.LinkSetUp(link)
+ Expect(err).NotTo(HaveOccurred())
+
+ // Add additional route to 10.10.10.0/24 via 10.0.0.1 gateway
+ r := netlink.Route{
+ LinkIndex: link.Attrs().Index,
+ Src: ipv4.IP,
+ Dst: routev4,
+ Gw: net.ParseIP("10.0.0.1"),
+ }
+ err = netlink.RouteAdd(&r)
+ Expect(err).NotTo(HaveOccurred())
+
+ r6 := netlink.Route{
+ LinkIndex: link.Attrs().Index,
+ Src: ipv6.IP,
+ Dst: routev6,
+ Gw: net.ParseIP("abcd:1234:ffff::1"),
+ }
+ err = netlink.RouteAdd(&r6)
+ Expect(err).NotTo(HaveOccurred())
+
+ return nil
+ })
+ Expect(err).NotTo(HaveOccurred())
+ })
+
+ args := &skel.CmdArgs{
+ ContainerID: "dummy",
+ Netns: targetNS.Path(),
+ IfName: IF0Name,
+ StdinData: conf,
+ }
+
+ err := originalNS.Do(func(ns.NetNS) error {
+ defer GinkgoRecover()
+ r, _, err := testutils.CmdAddWithArgs(args, func() error {
+ return cmdAdd(args)
+ })
+ Expect(err).NotTo(HaveOccurred())
+
+ result, err := current.GetResult(r)
+ Expect(err).NotTo(HaveOccurred())
+
+ Expect(result.Interfaces).To(HaveLen(1))
+ Expect(result.Interfaces[0].Name).To(Equal(IF0Name))
+ Expect(result.Routes).To(HaveLen(1))
+ Expect(result.Routes[0].Dst.IP.String()).To(Equal("10.10.10.0"))
+ return nil
+ })
+ Expect(err).NotTo(HaveOccurred())
+
+ err = targetNS.Do(func(ns.NetNS) error {
+ defer GinkgoRecover()
+ checkInterfaceOnVRF(VRF0Name, IF0Name)
+ checkRoutesOnVRF(VRF0Name, IF0Name, "10.10.10.0/24", "1111:dddd::/80")
+ return nil
+ })
+ Expect(err).NotTo(HaveOccurred())
+ })
+
It("fails if the interface already has a master set", func() {
conf := configFor("test", IF0Name, VRF0Name, "10.0.0.2/24")
@@ -690,6 +789,35 @@ func configWithTableFor(name, intf, vrf, ip string, tableID int) []byte {
return []byte(conf)
}
+func configWithRouteFor(name, intf, vrf, ip, route string) []byte {
+ conf := fmt.Sprintf(`{
+ "name": "%s",
+ "type": "vrf",
+ "cniVersion": "0.3.1",
+ "vrfName": "%s",
+ "prevResult": {
+ "interfaces": [
+ {"name": "%s", "sandbox":"netns"}
+ ],
+ "ips": [
+ {
+ "version": "4",
+ "address": "%s",
+ "gateway": "10.0.0.1",
+ "interface": 0
+ }
+ ],
+ "routes": [
+ {
+ "dst": "%s",
+ "gw": "10.0.0.1"
+ }
+ ]
+ }
+ }`, name, vrf, intf, ip, route)
+ return []byte(conf)
+}
+
func checkInterfaceOnVRF(vrfName, intfName string) {
vrf, err := netlink.LinkByName(vrfName)
Expect(err).NotTo(HaveOccurred())
@@ -702,3 +830,41 @@ func checkInterfaceOnVRF(vrfName, intfName string) {
Expect(err).NotTo(HaveOccurred())
Expect(master.Attrs().Name).To(Equal(vrfName))
}
+
+func checkRoutesOnVRF(vrfName, intfName string, routesToCheck ...string) {
+ vrf, err := netlink.LinkByName(vrfName)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(vrf).To(BeAssignableToTypeOf(&netlink.Vrf{}))
+
+ link, err := netlink.LinkByName(intfName)
+ Expect(err).NotTo(HaveOccurred())
+
+ err = netlink.LinkSetUp(link)
+ Expect(err).NotTo(HaveOccurred())
+
+ ipAddrs, err := netlink.AddrList(link, netlink.FAMILY_V4)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(ipAddrs).To(HaveLen(1))
+ Expect(ipAddrs[0].IP.String()).To(Equal("10.0.0.2"))
+
+ // Need to read all tables, so cannot use RouteList
+ routeFilter := &netlink.Route{
+ LinkIndex: link.Attrs().Index,
+ Table: unix.RT_TABLE_UNSPEC,
+ }
+
+ routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL,
+ routeFilter,
+ netlink.RT_FILTER_OIF|netlink.RT_FILTER_TABLE)
+ Expect(err).NotTo(HaveOccurred())
+
+ routesRead := []string{}
+ for _, route := range routes {
+ routesRead = append(routesRead, route.String())
+ Expect(uint32(route.Table)).To(Equal(vrf.(*netlink.Vrf).Table))
+ }
+ routesStr := strings.Join(routesRead, "\n")
+ for _, route := range routesToCheck {
+ Expect(routesStr).To(ContainSubstring(route))
+ }
+}