Skip to content

Commit

Permalink
fix multiple concurrent tcp connections for dmsgweb
Browse files Browse the repository at this point in the history
  • Loading branch information
0pcom committed Jan 30, 2025
1 parent 4d68c98 commit 0e09fd7
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 96 deletions.
168 changes: 87 additions & 81 deletions cmd/dmsgweb/commands/dmsgweb.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"syscall"

"github.com/chen3feng/safecast"
Expand Down Expand Up @@ -51,7 +52,6 @@ const dwenv = "DMSGWEB"
var dwcfg = os.Getenv(dwenv)

func init() {
dLog = logging.MustGetLogger("dmsgweb")
dmsgDisc = dmsg.DiscAddr(false)
webPort = scriptExecUintSlice("${WEBPORT[@]:-8080}", dwcfg)
proxyPort = scriptExecUint("${PROXYPORT:-4445}", dwcfg)
Expand Down Expand Up @@ -115,6 +115,7 @@ dmsgweb conf file detected: ` + dwcfg
logging.SetLevel(lvl)
}
}
dLog = logging.MustGetLogger("dmsgweb")
if dmsgDisc == "" {
dLog.Fatal("Dmsg Discovery URL not specified")
}
Expand Down Expand Up @@ -161,7 +162,7 @@ dmsgweb conf file detected: ` + dwcfg
signal.Notify(c, os.Interrupt, syscall.SIGTERM) //nolint
go func() {
<-c
os.Exit(1)
os.Exit(0)
}()

ctx, cancel := cmdutil.SignalContext(context.Background(), dLog)
Expand All @@ -173,8 +174,8 @@ dmsgweb conf file detected: ` + dwcfg
}
dLog.Info("dmsg client pk: ", pk.String())
if len(resolveDmsgAddr) > 0 {
dialPK := make([]cipher.PubKey, len(resolveDmsgAddr))
dmsgPorts := make([]uint, len(resolveDmsgAddr))
dialPK = make([]cipher.PubKey, len(resolveDmsgAddr))
dmsgPorts = make([]uint, len(resolveDmsgAddr))

for i, dmsgaddr := range resolveDmsgAddr {
dLog.Info("dmsg address to dial: ", dmsgaddr)
Expand Down Expand Up @@ -207,23 +208,23 @@ dmsgweb conf file detected: ` + dwcfg
}
}

/*
if proxyAddr != "" {
// Use SOCKS5 proxy dialer if specified
dialer, err = proxy.SOCKS5("tcp", proxyAddr, nil, proxy.Direct)
if err != nil {
dLog.WithError(err).Fatal("Error creating SOCKS5 dialer")
}
transport := &http.Transport{
Dial: dialer.Dial,
}
httpClient = &http.Client{
Transport: transport,
/*
if proxyAddr != "" {
// Use SOCKS5 proxy dialer if specified
dialer, err = proxy.SOCKS5("tcp", proxyAddr, nil, proxy.Direct)
if err != nil {
dLog.WithError(err).Fatal("Error creating SOCKS5 dialer")
}
transport := &http.Transport{
Dial: dialer.Dial,
}
httpClient = &http.Client{
Transport: transport,
}
ctx = context.WithValue(context.Background(), "socks5_proxy", proxyAddr) //nolint
}
ctx = context.WithValue(context.Background(), "socks5_proxy", proxyAddr) //nolint
}
*/
// dmsgC, closeDmsg, err := cli.StartDmsg(ctx, dLog, pk, sk, &httpC, dmsgDisc, dmsgSessions)
*/
// dmsgC, closeDmsg, err := cli.StartDmsg(ctx, dLog, pk, sk, &httpC, dmsgDisc, dmsgSessions)
dmsgC, closeDmsg, err = cli.StartDmsg(ctx, dLog, pk, sk, &http.Client{}, dmsgDisc, dmsgSessions)
if err != nil {
dLog.WithError(err).Fatal("failed to start dmsg")
Expand Down Expand Up @@ -299,19 +300,83 @@ dmsgweb conf file detected: ` + dwcfg
}
} else {
for i := range resolveDmsgAddr {
wg.Add(1)
if rawTCP[i] {
dLog.Debug("proxyTCPConn(" + fmt.Sprintf("%v", i) + ")")
proxyTCPConn(i)
go proxyTCPConn(i)
} else {
dLog.Debug("proxyHTTPConn(" + fmt.Sprintf("%v", i) + ")")
proxyHTTPConn(i)
go proxyHTTPConn(i)
}
}
}
wg.Wait()
},
}

func proxyTCPConn(n int) {
var thiswebport uint
if n == -1 {
thiswebport = webPort[0]
} else {
thiswebport = webPort[n]
}
listener, err := net.Listen("tcp", fmt.Sprintf(":%v", thiswebport))
if err != nil {
dLog.WithError(err).Fatal(fmt.Sprintf("Failed to start TCP listener on port: %v", thiswebport))
}
defer listener.Close() //nolint
dLog.Debug("Serving TCP on 127.0.0.1:", thiswebport)
if dmsgC == nil {
dLog.Fatal("dmsgC is nil")
}

for {
conn, err := listener.Accept()
if err != nil {
dLog.WithError(err).Warn("Failed to accept connection")
continue
}

go func(conn net.Conn, n int, dmsgC *dmsg.Client) {
defer conn.Close()

Check failure on line 342 in cmd/dmsgweb/commands/dmsgweb.go

View workflow job for this annotation

GitHub Actions / linux

Error return value of `conn.Close` is not checked (errcheck)
dp, ok := safecast.To[uint16](dmsgPorts[n])
if !ok {
dLog.Fatal("uint16 overflow when converting dmsg port")
}
dLog.Debug(fmt.Sprintf("Dialing %v:%v", dialPK[n].String(), dp))
dmsgConn, err := dmsgC.DialStream(context.Background(), dmsg.Addr{PK: dialPK[n], Port: dp}) //nolint
if err != nil {
dLog.WithError(err).Warn(fmt.Sprintf("Failed to dial dmsg address %v port %v", dialPK[n].String(), dmsgPorts[n]))
return
}

defer dmsgConn.Close()

Check failure on line 354 in cmd/dmsgweb/commands/dmsgweb.go

View workflow job for this annotation

GitHub Actions / linux

Error return value of `dmsgConn.Close` is not checked (errcheck)

var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
_, err := io.Copy(dmsgConn, conn)
if err != nil {
dLog.WithError(err).Warn("Error on io.Copy(dmsgConn, conn)")
}
}()

go func() {
defer wg.Done()
_, err := io.Copy(conn, dmsgConn)
if err != nil {
dLog.WithError(err).Warn("Error on io.Copy(conn, dmsgConn)")
}
}()

wg.Wait()
}(conn, n, dmsgC)
}
}

func proxyHTTPConn(n int) {
r := gin.New()

Expand Down Expand Up @@ -388,65 +453,6 @@ func proxyHTTPConn(n int) {
wg.Done()
}()
}
func proxyTCPConn(n int) {
var thiswebport uint
if n == -1 {
thiswebport = webPort[0]
} else {
thiswebport = webPort[n]
}
listener, err := net.Listen("tcp", fmt.Sprintf(":%v", thiswebport))
if err != nil {
dLog.WithError(err).Fatal(fmt.Sprintf("Failed to start TCP listener on port: %v", thiswebport))
}
defer listener.Close() //nolint
dLog.Debug("Serving TCP on 127.0.0.1:", thiswebport)
if dmsgC == nil {
dLog.Fatal("dmsgC is nil")
}

for {
conn, err := listener.Accept()
if err != nil {
dLog.WithError(err).Warn("Failed to accept connection")
continue
}

wg.Add(1)
go func(conn net.Conn, n int, dmsgC *dmsg.Client) {
defer wg.Done()

dp, ok := safecast.To[uint16](dmsgPorts[n])
if !ok {
dLog.Fatal("uint16 overflow when converting dmsg port")
}
dLog.Debug(fmt.Sprintf("Dialing dmsg address: %v ; port: %v", dialPK[n].String(), dp))
dmsgConn, err := dmsgC.DialStream(context.Background(), dmsg.Addr{PK: dialPK[n], Port: dp}) //nolint
if err != nil {
dLog.WithError(err).Warn(fmt.Sprintf("Failed to dial dmsg address %v port %v", dialPK[n].String(), dmsgPorts[n]))
return
}
defer dmsgConn.Close() //nolint

go func() {
defer dmsgConn.Close()
_, err := io.Copy(dmsgConn, conn)
if err != nil {
dLog.WithError(err).Warn("Error on io.Copy(dmsgConn, conn)")
}
}()

go func() {
defer conn.Close() //nolint
_, err := io.Copy(conn, dmsgConn)
if err != nil {
dLog.WithError(err).Warn("Error on io.Copy(conn, dmsgConn)")
}
}()
}(conn, n, dmsgC)
wg.Wait()
}
}

const envfileLinux = `
#########################################################################
Expand Down
45 changes: 38 additions & 7 deletions cmd/dmsgweb/commands/dmsgwebsrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ const dwsenv = "DMSGWEBSRV"
var dwscfg = os.Getenv(dwsenv)

func init() {
dLog = logging.MustGetLogger("dmsgwebsrv")
dmsgPort = scriptExecUintSlice("${DMSGPORT[@]:-80}", dwscfg)
dmsgSess = scriptExecInt("${DMSGSESSIONS:-1}", dwscfg)
wl = scriptExecStringSlice("${WHITELISTPKS[@]}", dwscfg)
Expand All @@ -43,13 +42,14 @@ func init() {
pk, _ = sk.PubKey()

Check failure on line 42 in cmd/dmsgweb/commands/dmsgwebsrv.go

View workflow job for this annotation

GitHub Actions / linux

Error return value of `sk.PubKey` is not checked (errcheck)

RootCmd.AddCommand(srvCmd)
srvCmd.Flags().UintSliceVarP(&localPort, "lport", "l", localPort, "local application HTTP interface port(s)")
srvCmd.Flags().UintSliceVarP(&localPort, "lport", "p", localPort, "local application interface port(s)")
srvCmd.Flags().UintSliceVarP(&dmsgPort, "dport", "d", dmsgPort, "DMSG port(s) to serve")
srvCmd.Flags().StringSliceVarP(&wl, "wl", "w", wl, "whitelisted keys for DMSG authenticated routes")
srvCmd.Flags().StringVarP(&dmsgDisc, "dmsg-disc", "D", dmsgDisc, "DMSG discovery URL(s)")
srvCmd.Flags().StringVarP(&dmsgDisc, "dmsg-disc", "D", dmsgDisc, "DMSG discovery URL")
srvCmd.Flags().StringVarP(&proxyAddr, "proxy", "x", proxyAddr, "connect to DMSG via proxy (e.g., '127.0.0.1:1080')")
srvCmd.Flags().IntVarP(&dmsgSess, "dsess", "e", dmsgSess, "DMSG sessions")
srvCmd.Flags().BoolSliceVarP(&rawTCP, "rt", "c", rawTCP, "proxy local port as raw TCP")
srvCmd.Flags().BoolSliceVarP(&rawTCP, "rt", "c", rawTCP, "proxy local port as raw TCP, comma separated")
srvCmd.Flags().StringVarP(&logLvl, "loglvl", "l", "", "[ debug | warn | error | fatal | panic | trace | info ]\033[0m")
srvCmd.Flags().BoolVarP(&isEnvs, "envs", "z", false, "show example .conf file")
srvCmd.Flags().VarP(&sk, "sk", "s", "a random key is generated if unspecified")
srvCmd.CompletionOptions.DisableDefaultCmd = true
Expand All @@ -68,14 +68,20 @@ var srvCmd = &cobra.Command{
if isEnvs {
printEnvs(srvenvfileLinux)
}
if logLvl != "" {
if lvl, err := logging.LevelFromString(logLvl); err == nil {
logging.SetLevel(lvl)
}
}
dLog = logging.MustGetLogger("dmsgwebsrv")
if len(localPort) != len(dmsgPort) || len(localPort) != len(rawTCP) {
dLog.Fatal("The number of local ports, DMSG ports, and raw TCP flags must be the same")
}
pk, err = sk.PubKey()
if err != nil {
pk, sk = cipher.GenerateKeyPair()
}
dLog.Infof("DMSG client public key: %v", pk.String())
dLog.Debugf("DMSG client public key: %v", pk.String())

if len(wl) > 0 {
for _, key := range wl {
Expand Down Expand Up @@ -179,6 +185,8 @@ func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Liste
// To track active connections for cleanup
var connWg sync.WaitGroup
connChan := make(chan net.Conn)
activeConns := make(map[net.Conn]struct{})
connMutex := &sync.Mutex{} // Protect access to activeConns

// Goroutine to accept new connections
go func() {
Expand Down Expand Up @@ -206,15 +214,28 @@ func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Liste
// Context canceled: stop accepting new connections and clean up
dLog.Info("Shutting down TCP proxy connections...")
listener.Close() // Close the listener to stop new connections

Check failure on line 216 in cmd/dmsgweb/commands/dmsgwebsrv.go

View workflow job for this annotation

GitHub Actions / linux

Error return value of `listener.Close` is not checked (errcheck)
connWg.Wait() // Wait for all active connections to finish

// Close all active connections
connMutex.Lock()
for conn := range activeConns {
conn.Close() // Forcefully close connections to unblock io.Copy()

Check failure on line 221 in cmd/dmsgweb/commands/dmsgwebsrv.go

View workflow job for this annotation

GitHub Actions / linux

Error return value of `conn.Close` is not checked (errcheck)
}
connMutex.Unlock()

connWg.Wait() // Now it should not hang because io.Copy() is unblocked
return

case conn, ok := <-connChan:
if !ok {
// connChan closed, exit the loop
return
}

// Handle each connection
// Track the connection
connMutex.Lock()
activeConns[conn] = struct{}{}
connMutex.Unlock()

connWg.Add(1)
go func(dmsgConn net.Conn) {
defer connWg.Done()
Expand All @@ -223,13 +244,23 @@ func proxyTCPConnections(ctx context.Context, localPort uint, listener net.Liste
localConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
if err != nil {
dLog.Errorf("Error connecting to local port %d: %v", localPort, err)

connMutex.Lock()
delete(activeConns, dmsgConn) // Remove from tracking
connMutex.Unlock()

return
}
defer localConn.Close()

// Start bidirectional copy
go io.Copy(dmsgConn, localConn)
io.Copy(localConn, dmsgConn)

// Remove connection from tracking on completion
connMutex.Lock()
delete(activeConns, dmsgConn)
connMutex.Unlock()
}(conn)
}
}
Expand Down
5 changes: 3 additions & 2 deletions examples/dmsghttp-client/dmsghttp-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (
"net/url"
"os"

"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging"

"github.com/skycoin/dmsg/internal/cli"
"github.com/skycoin/dmsg/pkg/dmsg"
"github.com/skycoin/dmsg/pkg/dmsghttp"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging"
)

func main() {
Expand Down
6 changes: 3 additions & 3 deletions examples/dmsghttp/dmsghttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
"strings"
"time"

cc "github.com/ivanpirog/coloredcobra"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cmdutil"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging"
cc "github.com/ivanpirog/coloredcobra"
"github.com/spf13/cobra"

"github.com/skycoin/dmsg/pkg/disc"
Expand All @@ -39,8 +39,8 @@ var RootCmd = &cobra.Command{
Use: func() string {
return strings.Split(os.Args[0], " ")[0]
}(),
Short: "DMSG HTTP Hello World server",
Long: "DMSG HTTP Hello World server",
Short: "DMSG HTTP Hello World server",
Long: "DMSG HTTP Hello World server",
SilenceErrors: true,
SilenceUsage: true,
DisableSuggestions: true,
Expand Down
Loading

0 comments on commit 0e09fd7

Please sign in to comment.