From 26645b0b8fe5e73f254b6ad8312de11f86a345f5 Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Mon, 14 Nov 2022 20:35:48 +0700 Subject: [PATCH] use source address for ping to check if the connection is alive Signed-off-by: Artem Glazychev --- pkg/kernel/tools/heal/liveness_check.go | 158 +++++++++++++----- pkg/kernel/tools/heal/liveness_check_test.go | 160 +++++++++++++++++++ 2 files changed, 282 insertions(+), 36 deletions(-) create mode 100644 pkg/kernel/tools/heal/liveness_check_test.go diff --git a/pkg/kernel/tools/heal/liveness_check.go b/pkg/kernel/tools/heal/liveness_check.go index df0ee2d6..a9a82b0e 100644 --- a/pkg/kernel/tools/heal/liveness_check.go +++ b/pkg/kernel/tools/heal/liveness_check.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cisco and/or its affiliates. +// Copyright (c) 2022-2023 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -19,9 +19,10 @@ package heal import ( "context" - "net" "time" + "github.com/pkg/errors" + "github.com/go-ping/ping" "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" @@ -33,58 +34,143 @@ const ( packetCount = 4 ) -// KernelLivenessCheck is an implementation of heal.LivenessCheck. It sends ICMP -// ping and checks reply. Returns false if didn't get reply. +type options struct { + pingerFactory PingerFactory +} + +// Option is an option pattern for LivelinessChecker +type Option func(o *options) + +// WithPingerFactory - sets any custom pinger factory +func WithPingerFactory(pf PingerFactory) Option { + return func(o *options) { + o.pingerFactory = pf + } +} + +// KernelLivenessCheck is an implementation of heal.LivenessCheck func KernelLivenessCheck(deadlineCtx context.Context, conn *networkservice.Connection) bool { + return KernelLivenessCheckWithOptions(deadlineCtx, conn) +} + +// KernelLivenessCheckWithOptions is an implementation with options of heal.LivenessCheck. It sends ICMP +// ping and checks reply. Returns false if didn't get reply. +func KernelLivenessCheckWithOptions(deadlineCtx context.Context, conn *networkservice.Connection, opts ...Option) bool { + // Apply options + o := &options{ + pingerFactory: &defaultPingerFactory{}, + } + for _, opt := range opts { + opt(o) + } + var pingerFactory = o.pingerFactory + if mechanism := conn.GetMechanism().GetType(); mechanism != kernel.MECHANISM { log.FromContext(deadlineCtx).Warnf("ping is not supported for mechanism %v", mechanism) return true } + ipContext := conn.GetContext().GetIpContext() + combinationCount := len(ipContext.GetDstIpAddrs()) * len(ipContext.GetSrcIpAddrs()) + if combinationCount == 0 { + log.FromContext(deadlineCtx).Debug("No IP address") + return true + } deadline, ok := deadlineCtx.Deadline() if !ok { deadline = time.Now().Add(defaultTimeout) } + timeout := time.Until(deadline) - addrCount := len(conn.GetContext().GetIpContext().GetDstIpAddrs()) - if addrCount == 0 { - log.FromContext(deadlineCtx).Debug("No dst IP address") - return true + // Start ping for all Src/DstIPs combination + responseCh := make(chan error, combinationCount) + defer close(responseCh) + for _, srcIPNet := range ipContext.GetSrcIPNets() { + for _, dstIPNet := range ipContext.GetDstIPNets() { + // Skip if IPs don't belong to the same family + if (srcIPNet.IP.To4() != nil) != (dstIPNet.IP.To4() != nil) { + responseCh <- nil + continue + } + + go func(srcIP, dstIP string) { + logger := log.FromContext(deadlineCtx).WithField("srcIP", srcIP).WithField("dstIP", dstIP) + pinger := pingerFactory.CreatePinger(srcIP, dstIP, timeout, packetCount) + + err := pinger.Run() + if err != nil { + logger.Errorf("Ping failed: %s", err.Error()) + responseCh <- err + return + } + + if pinger.GetReceivedPackets() == 0 { + err = errors.New("No packets received") + logger.Errorf(err.Error()) + responseCh <- err + return + } + responseCh <- nil + }(srcIPNet.IP.String(), dstIPNet.IP.String()) + } } - timeout := time.Until(deadline) / time.Duration(addrCount) - var pinger *ping.Pinger + // Waiting for all ping results. If at least one fails - return false + return waitForResponses(responseCh) +} - for _, cidr := range conn.GetContext().GetIpContext().GetDstIpAddrs() { - addr, _, err := net.ParseCIDR(cidr) - if err != nil { - log.FromContext(deadlineCtx).Errorf("ParseCIDR failed: %s", err.Error()) +func waitForResponses(responseCh <-chan error) bool { + respCount := cap(responseCh) + success := true + for { + resp, ok := <-responseCh + if !ok { return false } - - ipAddr := &net.IPAddr{IP: addr} - if pinger == nil { - pinger, err = ping.NewPinger(addr.String()) - if err != nil { - log.FromContext(deadlineCtx).Errorf("Failed to create pinger: %s", err.Error()) - return false - } - pinger.SetPrivileged(true) - pinger.Interval = timeout / packetCount - pinger.Timeout = timeout - pinger.Count = packetCount - } else { - pinger.SetIPAddr(ipAddr) + if resp != nil { + success = false } - err = pinger.Run() - if err != nil { - log.FromContext(deadlineCtx).Errorf("Ping failed: %s", err.Error()) - return false + respCount-- + if respCount == 0 { + return success } + } +} - if pinger.Statistics().PacketsRecv == 0 { - return false - } +// PingerFactory - factory interface for creating pingers +type PingerFactory interface { + CreatePinger(srcIP, dstIP string, timeout time.Duration, count int) Pinger +} + +// Pinger - pinger interface +type Pinger interface { + Run() error + GetReceivedPackets() int +} + +type defaultPingerFactory struct{} + +func (p *defaultPingerFactory) CreatePinger(srcIP, dstIP string, timeout time.Duration, count int) Pinger { + pi := ping.New(dstIP) + pi.SetPrivileged(true) + pi.Source = srcIP + pi.Timeout = timeout + pi.Count = count + if count != 0 { + pi.Interval = timeout / time.Duration(count) } - return true + + return &defaultPinger{pinger: pi} +} + +type defaultPinger struct { + pinger *ping.Pinger +} + +func (p *defaultPinger) Run() error { + return p.pinger.Run() +} + +func (p *defaultPinger) GetReceivedPackets() int { + return p.pinger.Statistics().PacketsRecv } diff --git a/pkg/kernel/tools/heal/liveness_check_test.go b/pkg/kernel/tools/heal/liveness_check_test.go new file mode 100644 index 00000000..d08d6507 --- /dev/null +++ b/pkg/kernel/tools/heal/liveness_check_test.go @@ -0,0 +1,160 @@ +// Copyright (c) 2023 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package heal_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "go.uber.org/goleak" + + "github.com/networkservicemesh/api/pkg/api/networkservice" + "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" + "github.com/stretchr/testify/require" + + "github.com/networkservicemesh/sdk-kernel/pkg/kernel/tools/heal" +) + +const unPingableIPv4 = "172.168.1.1" +const unPingableIPv6 = "2005::1" + +func createConnection(srcIPs, dstIPs []string) *networkservice.Connection { + return &networkservice.Connection{ + Mechanism: &networkservice.Mechanism{ + Type: kernel.MECHANISM, + }, + Context: &networkservice.ConnectionContext{IpContext: &networkservice.IPContext{ + SrcIpAddrs: srcIPs, + DstIpAddrs: dstIPs, + }}, + } +} +func Test_LivenessChecker(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + samples := []struct { + Name string + Connection *networkservice.Connection + PingersCount int32 + ExpectedResult bool + }{ + { + Name: "Pingable IPv4 one pair", + Connection: createConnection( + []string{"172.168.0.1/32"}, + []string{"172.168.0.2/32"}, + ), + PingersCount: 1, + ExpectedResult: true, + }, + { + Name: "Pingable IPv4 two pairs", + Connection: createConnection( + []string{"172.168.0.1/32", "172.168.0.3/32"}, + []string{"172.168.0.2/32", "172.168.0.4/32"}, + ), + PingersCount: 4, + ExpectedResult: true, + }, + { + Name: "Unpingable IPv4 two pairs", + Connection: createConnection( + []string{"172.168.0.1/32", "172.168.0.3/32"}, + []string{"172.168.0.2/32", unPingableIPv4 + "/32"}, + ), + PingersCount: 4, + ExpectedResult: false, + }, + { + Name: "Pingable IPv4 and IPv6", + Connection: createConnection( + []string{"172.168.0.1/32", "2004::1/128"}, + []string{"172.168.0.2/32", "2004::2/128"}, + ), + PingersCount: 2, + ExpectedResult: true, + }, + { + Name: "Unpingable IPv4 and IPv6", + Connection: createConnection( + []string{"172.168.0.1/32", "2004::1/128"}, + []string{"172.168.0.2/32", unPingableIPv6 + "/128"}, + ), + PingersCount: 2, + ExpectedResult: false, + }, + { + Name: "SrcIPs is empty", + Connection: createConnection( + []string{}, + []string{"172.168.0.2/32"}, + ), + PingersCount: 0, + ExpectedResult: true, + }, + { + Name: "DstIPs is empty", + Connection: createConnection( + []string{"172.168.0.1/32"}, + []string{}, + ), + PingersCount: 0, + ExpectedResult: true, + }, + } + for _, s := range samples { + sample := s + t.Run(sample.Name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + pingerFactory := &testPingerFactory{} + ok := heal.KernelLivenessCheckWithOptions(ctx, sample.Connection, heal.WithPingerFactory(pingerFactory)) + require.Equal(t, sample.ExpectedResult, ok) + require.Equal(t, pingerFactory.pingersCount, sample.PingersCount) + }) + } +} + +type testPingerFactory struct { + pingersCount int32 +} + +func (p *testPingerFactory) CreatePinger(srcIP, dstIP string, timeout time.Duration, count int) heal.Pinger { + atomic.AddInt32(&p.pingersCount, 1) + return &testPinger{ + dstIP: dstIP, + count: count, + } +} + +type testPinger struct { + dstIP string + count int +} + +func (p *testPinger) Run() error { + return nil +} + +func (p *testPinger) GetReceivedPackets() int { + if p.dstIP == unPingableIPv4 || p.dstIP == unPingableIPv6 { + return 0 + } + return p.count +}