diff --git a/.github/workflows/basic-sanity.yml b/.github/workflows/basic-sanity.yml index 2750ee5..73396ec 100644 --- a/.github/workflows/basic-sanity.yml +++ b/.github/workflows/basic-sanity.yml @@ -20,5 +20,5 @@ jobs: - uses: actions/setup-go@v3 with: go-version: '>=1.17.0' - - run: go test . - - run: go test -bench=. + - run: sudo go test . + - run: sudo go test -bench=. diff --git a/lib_test.go b/lib_test.go index a927be7..7f98607 100644 --- a/lib_test.go +++ b/lib_test.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "testing" - "github.com/loxilb-io/sctp" ) @@ -1036,4 +1035,10 @@ func TestProber(t *testing.T) { sOk = L4ServiceProber("udp", "127.0.0.1:12234", "", "", "") t.Logf("udp prober test2 %v\n", sOk) + + sOk = L4ServiceProber("udp", "127.0.0.1:8080", "", "", "") + t.Logf("udp prober test3 %v\n", sOk) + + sOk = L4ServiceProber("udp", "192.168.20.55:2234", "", "", "") + t.Logf("udp prober test4 %v\n\n\n", sOk) } diff --git a/serviceprobe.go b/serviceprobe.go index cb48a0f..5fc9ac3 100644 --- a/serviceprobe.go +++ b/serviceprobe.go @@ -12,11 +12,79 @@ import ( "golang.org/x/net/ipv4" "net" "net/http" + "os" "strconv" "strings" + "sync" "time" ) +// SvcWait - Channel to wait for service reply +type SvcWait struct { + wait chan bool +} + +// SvcKey - Service Key +type SvcKey struct { + Dst string + Port int +} + +var ( + icmpRunner chan bool + svcLock sync.RWMutex + svcs map[SvcKey]*SvcWait +) + +func waitForBoolChannelOrTimeout(ch <-chan bool, timeout time.Duration) (bool, bool) { + select { + case val := <-ch: + return val, true + case <-time.After(timeout): + return false, false + } +} + +func listenForICMPUNreachable() { + + // Open a raw socket to listen for ICMP messages + rc, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0") + if err != nil { + os.Exit(1) + } + defer rc.Close() + pktData := make([]byte, 1500) + //rc.SetDeadline(time.Now().Add(5 * time.Second)) + icmpRunner <- true + for { + plen, _, err := rc.ReadFrom(pktData) + if err != nil { + continue + } + + icmpNr, err := icmp.ParseMessage(1, pktData) + if err != nil { + continue + } + if icmpNr.Code == 3 && plen >= 8+20+8 { + iph, err := ipv4.ParseHeader(pktData[8:]) + if err != nil { + continue + } + + if iph.Protocol == 17 { + dport := int(binary.BigEndian.Uint16(pktData[30:32])) + svcLock.Lock() + key := SvcKey{Dst: iph.Dst.String(), Port: dport} + if svcWait := svcs[key]; svcWait != nil { + svcWait.wait <- true + } + svcLock.Unlock() + } + } + } +} + // HTTPProber - Do a http probe for given url // returns true/false depending on whether probing was successful func HTTPProber(urls string) bool { @@ -35,6 +103,16 @@ func HTTPProber(urls string) bool { // resp is the response expected from server (empty for none) // returns true/false depending on whether probing was successful func L4ServiceProber(sType string, sName string, sHint, req, resp string) bool { + + svcLock.Lock() + if svcs == nil { + icmpRunner = make(chan bool) + svcs = map[SvcKey]*SvcWait{} + go listenForICMPUNreachable() + <-icmpRunner + } + svcLock.Unlock() + sOk := false timeout := 1 * time.Second @@ -145,14 +223,16 @@ func L4ServiceProber(sType string, sName string, sHint, req, resp string) bool { return false } } else if sType == "udp" { - var lc net.ListenConfig - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(3*time.Second)) - defer cancel() - rc, err := lc.ListenPacket(ctx, "ip4:1", "0.0.0.0") - if err != nil { - return sOk + + svcLock.Lock() + key := SvcKey{Dst: svcPair[0], Port: svcPort} + svcWait := svcs[key] + if svcWait == nil { + svcWait = &SvcWait{wait: make(chan bool)} + svcs[key] = svcWait } - defer rc.Close() + svcLock.Unlock() + c.SetDeadline(time.Now().Add(1 * time.Second)) sOk = true _, err = c.Write([]byte("probe")) @@ -160,32 +240,19 @@ func L4ServiceProber(sType string, sName string, sHint, req, resp string) bool { return false } pktData := make([]byte, 1500) - rc.SetDeadline(time.Now().Add(1 * time.Second)) _, err = c.Read(pktData) if err == nil { return sOk } - plen, _, err := rc.ReadFrom(pktData) - if err != nil { - return sOk - } - icmpNr, err := icmp.ParseMessage(1, pktData) - if err != nil { - return sOk - } - if icmpNr.Code == 3 && plen >= 8+20+8 { - iph, err := ipv4.ParseHeader(pktData[8:]) - if err != nil { - return sOk - } - if iph.Dst.String() == svcPair[0] && iph.Protocol == 17 { - dport := int(binary.BigEndian.Uint16(pktData[30:32])) - if dport == svcPort { - sOk = false - } - } + _, unRch := waitForBoolChannelOrTimeout(svcWait.wait, 1*time.Second) + if unRch { + sOk = false } + + svcLock.Lock() + delete(svcs, key) + svcLock.Unlock() } return sOk