Skip to content

Commit

Permalink
Feat: Support port range in server config
Browse files Browse the repository at this point in the history
  • Loading branch information
Musixal committed Oct 16, 2024
1 parent b8bb625 commit ad864a7
Show file tree
Hide file tree
Showing 6 changed files with 427 additions and 75 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ To start using the solution, you'll need to configure both server and client com
log_level = "info" # Log level ("panic", "fatal", "error", "warn", "info", "debug", "trace", optional, default: "info").
ports = [
"443-600" # Listen on all ports in the range 443 to 600
"443-600=1.1.1.1:5003" # Listen on all ports in the range 443 to 600 and forward traffic to 1.1.1.1:5003
"443", # Listen on local port 443 and forward to remote port 443 (default forwarding).
"4000=5000", # Listen on local port 4000 (bind to all local IPs) and forward to remote port 5000.
"127.0.0.2:4001=5001", # Bind to specific local IP (127.0.0.2), listen on port 4001, and forward to remote port 5001.
Expand Down
109 changes: 92 additions & 17 deletions internal/server/transport/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,30 +336,105 @@ func (s *TcpTransport) acceptTunnelConn(listener net.Listener) {

func (s *TcpTransport) parsePortMappings() {
for _, portMapping := range s.config.Ports {
var localAddr string
parts := strings.Split(portMapping, "=")
if len(parts) < 2 {
port, err := strconv.Atoi(parts[0])
if err != nil {
s.logger.Fatalf("invalid port mapping format: %s", portMapping)
}
localAddr = fmt.Sprintf(":%d", port)
parts = append(parts, strconv.Itoa(port))
} else {
localAddr = strings.TrimSpace(parts[0])
if _, err := strconv.Atoi(localAddr); err == nil {
localAddr = ":" + localAddr // :3080 format

var localAddr, remoteAddr string

// Check if only a single port or a port range is provided (no "=" present)
if len(parts) == 1 {
localPortOrRange := strings.TrimSpace(parts[0])
remoteAddr = localPortOrRange // If no remote addr is provided, use the local port as the remote port

// Check if it's a port range
if strings.Contains(localPortOrRange, "-") {
rangeParts := strings.Split(localPortOrRange, "-")
if len(rangeParts) != 2 {
s.logger.Fatalf("invalid port range format: %s", localPortOrRange)
}

// Parse and validate start and end ports
startPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[0]))
if err != nil || startPort < 1 || startPort > 65535 {
s.logger.Fatalf("invalid start port in range: %s", rangeParts[0])
}

endPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[1]))
if err != nil || endPort < 1 || endPort > 65535 || endPort < startPort {
s.logger.Fatalf("invalid end port in range: %s", rangeParts[1])
}

// Create listeners for all ports in the range
for port := startPort; port <= endPort; port++ {
localAddr = fmt.Sprintf(":%d", port)
go s.startListeners(localAddr, strconv.Itoa(port)) // Use port as the remoteAddr
time.Sleep(1 * time.Millisecond) // for wide port ranges
}
continue
} else {
// Handle single port case
port, err := strconv.Atoi(localPortOrRange)
if err != nil || port < 1 || port > 65535 {
s.logger.Fatalf("invalid port format: %s", localPortOrRange)
}
localAddr = fmt.Sprintf(":%d", port)
}
}
} else if len(parts) == 2 {
// Handle "local=remote" format
localPortOrRange := strings.TrimSpace(parts[0])
remoteAddr = strings.TrimSpace(parts[1])

// Check if local port is a range
if strings.Contains(localPortOrRange, "-") {
rangeParts := strings.Split(localPortOrRange, "-")
if len(rangeParts) != 2 {
s.logger.Fatalf("invalid port range format: %s", localPortOrRange)
}

remoteAddr := strings.TrimSpace(parts[1])
// Parse and validate start and end ports
startPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[0]))
if err != nil || startPort < 1 || startPort > 65535 {
s.logger.Fatalf("invalid start port in range: %s", rangeParts[0])
}

go s.localListener(localAddr, remoteAddr)
endPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[1]))
if err != nil || endPort < 1 || endPort > 65535 || endPort < startPort {
s.logger.Fatalf("invalid end port in range: %s", rangeParts[1])
}

if s.config.AcceptUDP {
go s.udpListener(localAddr, remoteAddr)
// Create listeners for all ports in the range
for port := startPort; port <= endPort; port++ {
localAddr = fmt.Sprintf(":%d", port)
go s.startListeners(localAddr, remoteAddr)
time.Sleep(1 * time.Millisecond) // for wide port ranges
}
continue
} else {
// Handle single local port case
port, err := strconv.Atoi(localPortOrRange)
if err == nil && port > 1 && port < 65535 { // format port=remoteAddress
localAddr = fmt.Sprintf(":%d", port)
} else {
localAddr = localPortOrRange // format ip:port=remoteAddress
}
}
} else {
s.logger.Fatalf("invalid port mapping format: %s", portMapping)
}
// Start listeners for single port
go s.startListeners(localAddr, remoteAddr)
}
}

func (s *TcpTransport) startListeners(localAddr, remoteAddr string) {
// Start TCP listener
go s.localListener(localAddr, remoteAddr)

// Start UDP listener if configured
if s.config.AcceptUDP {
go s.udpListener(localAddr, remoteAddr)
}

s.logger.Debugf("Started listening on %s, forwarding to %s", localAddr, remoteAddr)
}

func (s *TcpTransport) localListener(localAddr string, remoteAddr string) {
Expand Down
95 changes: 82 additions & 13 deletions internal/server/transport/tcpmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,27 +351,96 @@ func (s *TcpMuxTransport) acceptTunnelConn(listener net.Listener) {

func (s *TcpMuxTransport) parsePortMappings() {
for _, portMapping := range s.config.Ports {
var localAddr string
parts := strings.Split(portMapping, "=")
if len(parts) < 2 {
port, err := strconv.Atoi(parts[0])
if err != nil {
s.logger.Fatalf("invalid port mapping format: %s", portMapping)

var localAddr, remoteAddr string

// Check if only a single port or a port range is provided (no "=" present)
if len(parts) == 1 {
localPortOrRange := strings.TrimSpace(parts[0])
remoteAddr = localPortOrRange // If no remote addr is provided, use the local port as the remote port

// Check if it's a port range
if strings.Contains(localPortOrRange, "-") {
rangeParts := strings.Split(localPortOrRange, "-")
if len(rangeParts) != 2 {
s.logger.Fatalf("invalid port range format: %s", localPortOrRange)
}

// Parse and validate start and end ports
startPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[0]))
if err != nil || startPort < 1 || startPort > 65535 {
s.logger.Fatalf("invalid start port in range: %s", rangeParts[0])
}

endPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[1]))
if err != nil || endPort < 1 || endPort > 65535 || endPort < startPort {
s.logger.Fatalf("invalid end port in range: %s", rangeParts[1])
}

// Create listeners for all ports in the range
for port := startPort; port <= endPort; port++ {
localAddr = fmt.Sprintf(":%d", port)
go s.localListener(localAddr, strconv.Itoa(port)) // Use port as the remoteAddr
time.Sleep(1 * time.Millisecond) // for wide port ranges
}
continue
} else {
// Handle single port case
port, err := strconv.Atoi(localPortOrRange)
if err != nil || port < 1 || port > 65535 {
s.logger.Fatalf("invalid port format: %s", localPortOrRange)
}
localAddr = fmt.Sprintf(":%d", port)
}
localAddr = fmt.Sprintf(":%d", port)
parts = append(parts, strconv.Itoa(port))
} else {
localAddr = strings.TrimSpace(parts[0])
if _, err := strconv.Atoi(localAddr); err == nil {
localAddr = ":" + localAddr // :3080 format
} else if len(parts) == 2 {
// Handle "local=remote" format
localPortOrRange := strings.TrimSpace(parts[0])
remoteAddr = strings.TrimSpace(parts[1])

// Check if local port is a range
if strings.Contains(localPortOrRange, "-") {
rangeParts := strings.Split(localPortOrRange, "-")
if len(rangeParts) != 2 {
s.logger.Fatalf("invalid port range format: %s", localPortOrRange)
}

// Parse and validate start and end ports
startPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[0]))
if err != nil || startPort < 1 || startPort > 65535 {
s.logger.Fatalf("invalid start port in range: %s", rangeParts[0])
}

endPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[1]))
if err != nil || endPort < 1 || endPort > 65535 || endPort < startPort {
s.logger.Fatalf("invalid end port in range: %s", rangeParts[1])
}

// Create listeners for all ports in the range
for port := startPort; port <= endPort; port++ {
localAddr = fmt.Sprintf(":%d", port)
go s.localListener(localAddr, remoteAddr)
time.Sleep(1 * time.Millisecond) // for wide port ranges
}
continue
} else {
// Handle single local port case
port, err := strconv.Atoi(localPortOrRange)
if err == nil && port > 1 && port < 65535 { // format port=remoteAddress
localAddr = fmt.Sprintf(":%d", port)
} else {
localAddr = localPortOrRange // format ip:port=remoteAddress
}
}
} else {
s.logger.Fatalf("invalid port mapping format: %s", portMapping)
}
remoteAddr := strings.TrimSpace(parts[1])

// Start listeners for single port
go s.localListener(localAddr, remoteAddr)
}
}


func (s *TcpMuxTransport) localListener(localAddr string, remoteAddr string) {
listener, err := net.Listen("tcp", localAddr)
if err != nil {
Expand Down
100 changes: 85 additions & 15 deletions internal/server/transport/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,31 +363,101 @@ func (s *UdpTransport) acceptTunnelConn(listener *net.UDPConn) {
}
}



func (s *UdpTransport) parsePortMappings() {
for _, portMapping := range s.config.Ports {
var localAddr string
parts := strings.Split(portMapping, "=")
if len(parts) < 2 {
port, err := strconv.Atoi(parts[0])
if err != nil {
s.logger.Fatalf("invalid port mapping format: %s", portMapping)
}
localAddr = fmt.Sprintf(":%d", port)
parts = append(parts, strconv.Itoa(port))
} else {
localAddr = strings.TrimSpace(parts[0])
if _, err := strconv.Atoi(localAddr); err == nil {
localAddr = ":" + localAddr // :3080 format

var localAddr, remoteAddr string

// Check if only a single port or a port range is provided (no "=" present)
if len(parts) == 1 {
localPortOrRange := strings.TrimSpace(parts[0])
remoteAddr = localPortOrRange // If no remote addr is provided, use the local port as the remote port

// Check if it's a port range
if strings.Contains(localPortOrRange, "-") {
rangeParts := strings.Split(localPortOrRange, "-")
if len(rangeParts) != 2 {
s.logger.Fatalf("invalid port range format: %s", localPortOrRange)
}

// Parse and validate start and end ports
startPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[0]))
if err != nil || startPort < 1 || startPort > 65535 {
s.logger.Fatalf("invalid start port in range: %s", rangeParts[0])
}

endPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[1]))
if err != nil || endPort < 1 || endPort > 65535 || endPort < startPort {
s.logger.Fatalf("invalid end port in range: %s", rangeParts[1])
}

// Create listeners for all ports in the range
for port := startPort; port <= endPort; port++ {
localAddr = fmt.Sprintf(":%d", port)
go s.localListener(localAddr, strconv.Itoa(port)) // Use port as the remoteAddr
time.Sleep(1 * time.Millisecond) // for wide port ranges
}
continue
} else {
// Handle single port case
port, err := strconv.Atoi(localPortOrRange)
if err != nil || port < 1 || port > 65535 {
s.logger.Fatalf("invalid port format: %s", localPortOrRange)
}
localAddr = fmt.Sprintf(":%d", port)
}
}
} else if len(parts) == 2 {
// Handle "local=remote" format
localPortOrRange := strings.TrimSpace(parts[0])
remoteAddr = strings.TrimSpace(parts[1])

// Check if local port is a range
if strings.Contains(localPortOrRange, "-") {
rangeParts := strings.Split(localPortOrRange, "-")
if len(rangeParts) != 2 {
s.logger.Fatalf("invalid port range format: %s", localPortOrRange)
}

remoteAddr := strings.TrimSpace(parts[1])
// Parse and validate start and end ports
startPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[0]))
if err != nil || startPort < 1 || startPort > 65535 {
s.logger.Fatalf("invalid start port in range: %s", rangeParts[0])
}

go s.localListener(localAddr, remoteAddr)
endPort, err := strconv.Atoi(strings.TrimSpace(rangeParts[1]))
if err != nil || endPort < 1 || endPort > 65535 || endPort < startPort {
s.logger.Fatalf("invalid end port in range: %s", rangeParts[1])
}

// Create listeners for all ports in the range
for port := startPort; port <= endPort; port++ {
localAddr = fmt.Sprintf(":%d", port)
go s.localListener(localAddr, remoteAddr)
time.Sleep(1 * time.Millisecond) // for wide port ranges
}
continue
} else {
// Handle single local port case
port, err := strconv.Atoi(localPortOrRange)
if err == nil && port > 1 && port < 65535 { // format port=remoteAddress
localAddr = fmt.Sprintf(":%d", port)
} else {
localAddr = localPortOrRange // format ip:port=remoteAddress
}
}
} else {
s.logger.Fatalf("invalid port mapping format: %s", portMapping)
}
// Start listeners for single port
go s.localListener(localAddr, remoteAddr)
}
}



func (s *UdpTransport) localListener(localAddr, remoteAddr string) {
localUDPAddr, err := net.ResolveUDPAddr("udp", localAddr)
if err != nil {
Expand Down
Loading

0 comments on commit ad864a7

Please sign in to comment.