From 39393452ce28450e2a10e3bc0e3409782cb68f97 Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Tue, 21 Jun 2022 19:52:00 +0700 Subject: [PATCH] wireguard: add policy based routes to allowedIPs Signed-off-by: Artem Glazychev --- .../mechanisms/wireguard/peer/common.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pkg/networkservice/mechanisms/wireguard/peer/common.go b/pkg/networkservice/mechanisms/wireguard/peer/common.go index e26ee0c4..ace9d0f8 100644 --- a/pkg/networkservice/mechanisms/wireguard/peer/common.go +++ b/pkg/networkservice/mechanisms/wireguard/peer/common.go @@ -1,4 +1,6 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. +// Copyright (c) 2021-2022 Doc.ai and/or its affiliates. +// +// Copyright (c) 2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -66,7 +68,7 @@ func createPeer(ctx context.Context, conn *networkservice.Connection, vppConn ap } ipContext := conn.GetContext().GetIpContext() if !isClient { - allowedIPs, err := extractAllowedIPs(ipContext.GetSrcIpAddrs(), ipContext.GetDstRoutes()) + allowedIPs, err := extractAllowedIPs(ipContext.GetSrcIpAddrs(), ipContext.GetDstRoutes(), ipContext.GetPolicies()) if err != nil { return errors.WithStack(e) } @@ -75,7 +77,7 @@ func createPeer(ctx context.Context, conn *networkservice.Connection, vppConn ap peer.Port = mechanism.SrcPort() peer.Endpoint = types.ToVppAddress(mechanism.SrcIP()) } else { - allowedIPs, err := extractAllowedIPs(ipContext.GetDstIpAddrs(), ipContext.GetSrcRoutes()) + allowedIPs, err := extractAllowedIPs(ipContext.GetDstIpAddrs(), ipContext.GetSrcRoutes(), ipContext.GetPolicies()) if err != nil { return errors.WithStack(e) } @@ -101,7 +103,7 @@ func createPeer(ctx context.Context, conn *networkservice.Connection, vppConn ap return nil } -func extractAllowedIPs(ips []string, routes []*networkservice.Route) ([]ip_types.Prefix, error) { +func extractAllowedIPs(ips []string, routes []*networkservice.Route, policies []*networkservice.PolicyRoute) ([]ip_types.Prefix, error) { var allowedIPs []ip_types.Prefix for _, ip := range ips { allowedIP, e := ip_types.ParsePrefix(ip) @@ -117,6 +119,15 @@ func extractAllowedIPs(ips []string, routes []*networkservice.Route) ([]ip_types } allowedIPs = append(allowedIPs, allowedIP) } + for _, p := range policies { + for _, route := range p.Routes { + allowedIP, e := ip_types.ParsePrefix(route.Prefix) + if e != nil { + return nil, errors.WithStack(e) + } + allowedIPs = append(allowedIPs, allowedIP) + } + } return allowedIPs, nil }