diff --git a/pkg/evpn/svi.go b/pkg/evpn/svi.go index 064a0032..65dc8ad0 100644 --- a/pkg/evpn/svi.go +++ b/pkg/evpn/svi.go @@ -63,13 +63,6 @@ func (s *Server) CreateSvi(_ context.Context, in *pb.CreateSviRequest) (*pb.Svi, log.Printf("error: %v", err) return nil, err } - // not found, so create a new one - bridge, err := s.nLink.LinkByName(tenantbridgeName) - if err != nil { - err := status.Errorf(codes.NotFound, "unable to find key %s", tenantbridgeName) - log.Printf("error: %v", err) - return nil, err - } // now get LogicalBridge object to fetch VID field bridgeObject, ok := s.Bridges[in.Svi.Spec.LogicalBridge] if !ok { @@ -77,6 +70,20 @@ func (s *Server) CreateSvi(_ context.Context, in *pb.CreateSviRequest) (*pb.Svi, log.Printf("error: %v", err) return nil, err } + // now get Vrf to plug this vlandev into + vrf, ok := s.Vrfs[in.Svi.Spec.Vrf] + if !ok { + err := status.Errorf(codes.NotFound, "unable to find key %s", in.Svi.Spec.Vrf) + log.Printf("error: %v", err) + return nil, err + } + // not found, so create a new one + bridge, err := s.nLink.LinkByName(tenantbridgeName) + if err != nil { + err := status.Errorf(codes.NotFound, "unable to find key %s", tenantbridgeName) + log.Printf("error: %v", err) + return nil, err + } vid := uint16(bridgeObject.Spec.VlanId) // Example: bridge vlan add dev br-tenant vid self if err := s.nLink.BridgeVlanAdd(bridge, vid, false, false, true, false); err != nil { @@ -85,13 +92,7 @@ func (s *Server) CreateSvi(_ context.Context, in *pb.CreateSviRequest) (*pb.Svi, } // Example: ip link add link br-tenant name type vlan id vlanName := fmt.Sprintf("vlan%d", vid) - vlandev := &netlink.Vlan{ - LinkAttrs: netlink.LinkAttrs{ - Name: vlanName, - ParentIndex: bridge.Attrs().Index, - }, - VlanId: int(bridgeObject.Spec.VlanId), - } + vlandev := &netlink.Vlan{LinkAttrs: netlink.LinkAttrs{Name: vlanName, ParentIndex: bridge.Attrs().Index}, VlanId: int(vid)} log.Printf("Creating VLAN %v", vlandev) if err := s.nLink.LinkAdd(vlandev); err != nil { fmt.Printf("Failed to create vlan link: %v", err) @@ -115,13 +116,6 @@ func (s *Server) CreateSvi(_ context.Context, in *pb.CreateSviRequest) (*pb.Svi, return nil, err } } - // now get Vrf to plug this vlandev into - vrf, ok := s.Vrfs[in.Svi.Spec.Vrf] - if !ok { - err := status.Errorf(codes.NotFound, "unable to find key %s", in.Svi.Spec.Vrf) - log.Printf("error: %v", err) - return nil, err - } // get net device by name vrfdev, err := s.nLink.LinkByName(path.Base(vrf.Name)) if err != nil { diff --git a/pkg/evpn/svi_test.go b/pkg/evpn/svi_test.go index 4a2f8899..5b82839f 100644 --- a/pkg/evpn/svi_test.go +++ b/pkg/evpn/svi_test.go @@ -156,6 +156,36 @@ func Test_CreateSvi(t *testing.T) { fmt.Sprintf("segment '%s': not a valid DNS name", "-ABC-DEF"), false, }, + "missing LogicalBridge name": { + testSviID, + &pb.Svi{ + Spec: &pb.SviSpec{ + Vrf: testVrfName, + LogicalBridge: "unknown-bridge-id", + MacAddress: []byte{0xCB, 0xB8, 0x33, 0x4C, 0x88, 0x4F}, + GwIpPrefix: []*pc.IPPrefix{{Len: 24}}, + }, + }, + nil, + codes.NotFound, + fmt.Sprintf("unable to find key %v", "unknown-bridge-id"), + false, + }, + "missing Vrf name": { + testSviID, + &pb.Svi{ + Spec: &pb.SviSpec{ + Vrf: "unknown-vrf-id", + LogicalBridge: testLogicalBridgeName, + MacAddress: []byte{0xCB, 0xB8, 0x33, 0x4C, 0x88, 0x4F}, + GwIpPrefix: []*pc.IPPrefix{{Len: 24}}, + }, + }, + nil, + codes.NotFound, + fmt.Sprintf("unable to find key %v", "unknown-vrf-id"), + false, + }, "failed LinkByName call": { testSviID, &testSvi, @@ -218,6 +248,7 @@ func Test_CreateSvi(t *testing.T) { if tt.out != nil { tt.out.Name = testSviName } + opi.Vrfs[testVrfName] = &testVrf opi.Bridges[testLogicalBridgeName] = &testLogicalBridge // TODO: refactor this mocking