Skip to content

Commit

Permalink
prevent port number reuse with TTL-based caching
Browse files Browse the repository at this point in the history
  • Loading branch information
wdbaruni committed Nov 5, 2024
1 parent f68f794 commit 0539d2d
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 48 deletions.
102 changes: 71 additions & 31 deletions pkg/lib/network/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,78 @@ package network
import (
"net"
"strconv"
"sync"
"time"
)

// IsPortOpen checks if a port is open by attempting to listen on it. If the
// port is open, it returns true, otherwise it returns false. The port listen
// socket will be closed if the function was able to create it.
func IsPortOpen(port int) bool {
addr := net.JoinHostPort("", strconv.Itoa(port))
ln, err := net.Listen("tcp", addr)
const defaultPortAllocatorTTL = 5 * time.Second

// PortAllocator manages thread-safe allocation of network ports with a time-based reservation system.
// Once a port is allocated, it won't be reallocated until after the TTL expires, helping prevent
// race conditions in concurrent port allocation scenarios.
type PortAllocator struct {
mu sync.Mutex
reservedPorts map[int]time.Time
ttl time.Duration
}

var (
// globalAllocator is the package-level port allocator instance used by GetFreePort
globalAllocator = NewPortAllocator(defaultPortAllocatorTTL)
)

// NewPortAllocator creates a new PortAllocator instance with the specified TTL.
// The TTL determines how long a port remains reserved after allocation.
func NewPortAllocator(ttl time.Duration) *PortAllocator {
return &PortAllocator{
reservedPorts: make(map[int]time.Time),
ttl: ttl,
}
}

// GetFreePort returns an available port using the global port allocator.
// The returned port is guaranteed to not be reallocated by this package
// for the duration of the TTL (default 5 seconds).
func GetFreePort() (int, error) {
return globalAllocator.GetFreePort()
}

// GetFreePort returns an available port and reserves it for the duration of the TTL.
// If a port is already reserved but its TTL has expired, it may be returned if it's
// still available on the system.
func (pa *PortAllocator) GetFreePort() (int, error) {
pa.mu.Lock()
defer pa.mu.Unlock()

port, err := getFreePortFromSystem()
if err != nil {
// There was a problem listening, the port is probably in use
return false
return 0, err
}

// We were able to use the port, so it is free, but we should close it
// first
_ = ln.Close()
return true
// Keep trying until we find a port that isn't reserved or has expired reservation
now := time.Now()
for {
if expiration, reserved := pa.reservedPorts[port]; !reserved || now.After(expiration) {
break
}
port, err = getFreePortFromSystem()
if err != nil {
return 0, err
}
}

pa.reservedPorts[port] = now.Add(pa.ttl)
return port, nil
}

// GetFreePort returns a single available port by asking the operating
// system to pick one for us. Luckily ports are not re-used so after asking
// for a port number, we attempt to create a tcp listener.
// getFreePortFromSystem asks the operating system for an available port by creating
// a TCP listener with port 0, which triggers the OS to assign a random available port.
//
// Essentially the same code as https://github.com/phayes/freeport but we bind
// to 0.0.0.0 to ensure the port is free on all interfaces, and not just localhost.GetFreePort
// Ports must be unique for an address, not an entire system and so checking just localhost
// is not enough.
func GetFreePort() (int, error) {
func getFreePortFromSystem() (int, error) {
addr, err := net.ResolveTCPAddr("tcp", ":0")
if err != nil {
return 0, err
Expand All @@ -44,20 +88,16 @@ func GetFreePort() (int, error) {
return l.Addr().(*net.TCPAddr).Port, nil
}

// GetFreePorts returns an array available ports by asking the operating
// system to pick one for us.
//
// Essentially the same code as https://github.com/phayes/freeport apart from
// the caveats described in GetFreePort.
func GetFreePorts(count int) ([]int, error) {
ports := []int{}

for i := 0; i < count; i++ {
port, err := GetFreePort()
if err != nil {
return nil, err
}
ports = append(ports, port)
// IsPortOpen checks if a specific port is available for use by attempting to create
// a TCP listener on that port. It returns true if the port is available, false otherwise.
// The caller should note that the port's availability may change immediately after
// this check returns.
func IsPortOpen(port int) bool {
addr := net.JoinHostPort("", strconv.Itoa(port))
ln, err := net.Listen("tcp", addr)
if err != nil {
return false
}
return ports, nil
ln.Close()
return true
}
95 changes: 78 additions & 17 deletions pkg/lib/network/ports_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,104 @@ package network_test
import (
"net"
"strconv"
"sync"
"testing"
"time"

"github.com/bacalhau-project/bacalhau/pkg/lib/network"
"github.com/stretchr/testify/suite"

"github.com/bacalhau-project/bacalhau/pkg/lib/network"
)

type FreePortTestSuite struct {
type PortAllocatorTestSuite struct {
suite.Suite
}

func TestFreePortTestSuite(t *testing.T) {
suite.Run(t, new(FreePortTestSuite))
func TestPortAllocatorTestSuite(t *testing.T) {
suite.Run(t, new(PortAllocatorTestSuite))
}

func (s *FreePortTestSuite) TestGetFreePort() {
// TestGetFreePort verifies that GetFreePort returns a usable port
func (s *PortAllocatorTestSuite) TestGetFreePort() {
port, err := network.GetFreePort()
s.Require().NoError(err)
s.NotEqual(0, port, "expected a non-zero port")

// Try to listen on the port
l, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port))
// Verify we can listen on the port
l, err := net.Listen("tcp", ":"+strconv.Itoa(port))
s.Require().NoError(err)
defer l.Close()
}

func (s *FreePortTestSuite) TestGetFreePorts() {
count := 3
ports, err := network.GetFreePorts(count)
// TestPortReservation verifies that ports aren't reused within TTL
func (s *PortAllocatorTestSuite) TestPortReservation() {
// Create allocator with 1 second TTL for testing
allocator := network.NewPortAllocator(time.Second)

// Get first port
port1, err := allocator.GetFreePort()
s.Require().NoError(err)

// Get second port - should be different
port2, err := allocator.GetFreePort()
s.Require().NoError(err)
s.NotEqual(port1, port2, "got same port within TTL period")
}

// TestConcurrentPortAllocation verifies thread-safety of port allocation
func (s *PortAllocatorTestSuite) TestConcurrentPortAllocation() {
var wg sync.WaitGroup
allocator := network.NewPortAllocator(time.Second)
ports := make(map[int]bool)
var mu sync.Mutex

// Spawn 20 goroutines to allocate ports concurrently
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
port, err := allocator.GetFreePort()
s.Require().NoError(err)

mu.Lock()
s.False(ports[port], "port %d was allocated multiple times", port)
ports[port] = true
mu.Unlock()

// Verify we can listen on the port
l, err := net.Listen("tcp", ":"+strconv.Itoa(port))
s.Require().NoError(err)
l.Close()
}()
}
wg.Wait()
}

// TestIsPortOpen verifies the port availability check
func (s *PortAllocatorTestSuite) TestIsPortOpen() {
// Get a port we know should be available
port, err := network.GetFreePort()
s.Require().NoError(err)
s.Equal(count, len(ports), "expected %d ports", count)
s.True(network.IsPortOpen(port), "newly allocated port should be open")

for _, port := range ports {
s.NotEqual(0, port, "expected a non-zero port")
// Listen on the port
l, err := net.Listen("tcp", ":"+strconv.Itoa(port))
s.Require().NoError(err)
defer l.Close()

// Port should now be in use
s.False(network.IsPortOpen(port), "port should be in use")
}

// Try to listen on the port
l, err := net.Listen("tcp", ":"+strconv.Itoa(port))
s.Require().NoError(err, "failed to listen on newly given port")
defer l.Close()
// TestGlobalAllocator verifies that the global GetFreePort function
// prevents immediate port reuse
func (s *PortAllocatorTestSuite) TestGlobalAllocator() {
// Get a batch of ports using the global allocator
usedPorts := make(map[int]bool)
for i := 0; i < 10; i++ {
port, err := network.GetFreePort()
s.Require().NoError(err)
s.False(usedPorts[port], "global allocator reused port %d", port)
usedPorts[port] = true
}
}

0 comments on commit 0539d2d

Please sign in to comment.