From a5de20340187308d1b2fcfeda228b21888a817e1 Mon Sep 17 00:00:00 2001 From: VHSgunzo Date: Sun, 19 Jan 2025 08:42:43 +0300 Subject: [PATCH] Fix hang on system with legacy forking --- go.mod | 6 ++--- go.sum | 4 ++++ ssrv.go | 69 +++++++++++++++++++++++++++++------------------------ tls/go.mod | 8 +++---- tls/go.sum | 6 +++++ tls/ssrv.go | 68 +++++++++++++++++++++++++++++----------------------- 6 files changed, 93 insertions(+), 68 deletions(-) diff --git a/go.mod b/go.mod index 201be0e..f2494a7 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,16 @@ module github.com/VHSgunzo/ssrv -go 1.23.4 +go 1.23.5 require ( github.com/creack/pty v1.1.24 github.com/hashicorp/yamux v0.1.2 - golang.org/x/term v0.27.0 + golang.org/x/term v0.28.0 ) require ( github.com/shirou/gopsutil/v3 v3.24.5 - golang.org/x/sys v0.28.0 + golang.org/x/sys v0.29.0 ) require ( diff --git a/go.sum b/go.sum index 5626742..7d356ee 100644 --- a/go.sum +++ b/go.sum @@ -34,7 +34,11 @@ golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ssrv.go b/ssrv.go index edb3b64..853d571 100644 --- a/ssrv.go +++ b/ssrv.go @@ -29,7 +29,7 @@ import ( "golang.org/x/term" ) -var VERSION string = "v0.3.2" +var VERSION string = "v0.3.3" const ENV_VARS = "TERM" const BINARY_NAME = "ssrv" @@ -359,7 +359,6 @@ func read_environ(pid string) (map[string]string, error) { } func srv_handle(conn net.Conn, self_cpids_dir string) { - var wg sync.WaitGroup disconnect := func(session *yamux.Session, remote string) { session.Close() log_out.Printf("[%s] [ DISCONNECT ]", remote) @@ -502,6 +501,9 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { } } + done_stdout := make(chan struct{}) + done_stderr := make(chan struct{}) + var cmd_ptmx *os.File var cmd_stdout, cmd_stderr io.ReadCloser if is_alloc_pty { @@ -541,7 +543,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { defer os.Remove(cpid) cp := func(dst io.Writer, src io.Reader) { - defer wg.Done() io.Copy(dst, src) } @@ -551,6 +552,8 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { return } if is_alloc_pty { + close(done_stdout) + close(done_stderr) go func() { decoder := gob.NewDecoder(control_channel) for { @@ -570,7 +573,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { } } }() - wg.Add(2) go cp(data_channel, cmd_ptmx) go cp(cmd_ptmx, data_channel) } else { @@ -620,9 +622,14 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { exec_cmd_kill(syscall.SIGUSR2) } }() - wg.Add(2) - go cp(data_channel, cmd_stdout) - go cp(stderr_channel, cmd_stderr) + go func() { + cp(data_channel, cmd_stdout) + close(done_stdout) + }() + go func() { + cp(stderr_channel, cmd_stderr) + close(done_stderr) + }() } state, err := exec_cmd.Process.Wait() @@ -639,11 +646,15 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { return } + <-done_stdout + <-done_stderr + if is_alloc_pty { session.Close() + } else { + data_channel.Close() + stderr_channel.Close() } - - wg.Wait() } func server(proto, socket string) { @@ -781,8 +792,6 @@ func server(proto, socket string) { } func client(proto, socket string, exec_args []string) int { - var wg sync.WaitGroup - is_alloc_pty := true if len(exec_args) != 0 { is_alloc_pty = !pty_blocklist[exec_args[0]] @@ -935,12 +944,10 @@ func client(proto, socket string, exec_args []string) int { } pipe_stdin := func(dst io.Writer, src io.Reader) { - defer wg.Done() io.Copy(dst, src) stdin_channel.Close() } cp := func(dst io.Writer, src io.Reader) { - defer wg.Done() io.Copy(dst, src) } @@ -994,21 +1001,26 @@ func client(proto, socket string, exec_args []string) int { }() } + done_stdout := make(chan struct{}) + done_stderr := make(chan struct{}) + if is_foreground { - if !is_stdin_term { - wg.Add(1) - go pipe_stdin(stdin_channel, os.Stdin) - } else { - wg.Add(1) + if is_stdin_term { go cp(data_channel, os.Stdin) + } else { + go pipe_stdin(stdin_channel, os.Stdin) } } if !is_alloc_pty { - wg.Add(1) - go cp(os.Stderr, stderr_channel) + go func() { + cp(os.Stderr, stderr_channel) + close(done_stderr) + }() } - wg.Add(1) - go cp(os.Stdout, data_channel) + go func() { + cp(os.Stdout, data_channel) + close(done_stdout) + }() var exit_code = 1 exit_reader := bufio.NewReader(command_channel) @@ -1026,18 +1038,13 @@ func client(proto, socket string, exec_args []string) int { if term_old_state != nil { term.Restore(stdin, term_old_state) - if is_foreground { - wg.Done() - } } - if is_foreground && is_stdin_term && ((!*is_pty && !*is_no_pty) || - (*is_no_pty && (!is_stdout_term || !is_stderr_term)) || *is_no_pty) { - if !is_stderr_term || !is_alloc_pty { - wg.Done() - } + + <-done_stdout + if !is_alloc_pty { + <-done_stderr } - wg.Wait() return exit_code } diff --git a/tls/go.mod b/tls/go.mod index a28a229..5ebba3c 100644 --- a/tls/go.mod +++ b/tls/go.mod @@ -1,12 +1,12 @@ module github.com/VHSgunzo/ssrv/tls -go 1.23.4 +go 1.23.5 require ( github.com/creack/pty v1.1.24 github.com/hashicorp/yamux v0.1.2 github.com/shirou/gopsutil/v3 v3.24.5 - golang.org/x/term v0.27.0 + golang.org/x/term v0.28.0 ) require ( @@ -20,6 +20,6 @@ require ( ) require ( - golang.org/x/crypto v0.31.0 - golang.org/x/sys v0.28.0 + golang.org/x/crypto v0.32.0 + golang.org/x/sys v0.29.0 ) diff --git a/tls/go.sum b/tls/go.sum index d68ee90..a166146 100644 --- a/tls/go.sum +++ b/tls/go.sum @@ -31,12 +31,18 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tls/ssrv.go b/tls/ssrv.go index 63ee243..b7e53dd 100644 --- a/tls/ssrv.go +++ b/tls/ssrv.go @@ -32,7 +32,7 @@ import ( "golang.org/x/term" ) -var VERSION string = "v0.3.2" +var VERSION string = "v0.3.3" const ENV_VARS = "TERM" const TLS_KEY = "key.pem" @@ -418,7 +418,6 @@ func verify_cert_hash(provided_cert_hash, cert string) (bool, error) { } func srv_handle(conn net.Conn, self_cpids_dir string) { - var wg sync.WaitGroup disconnect := func(session *yamux.Session, remote string) { session.Close() log_out.Printf("[%s] [ DISCONNECT ]", remote) @@ -578,6 +577,9 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { } } + done_stdout := make(chan struct{}) + done_stderr := make(chan struct{}) + var cmd_ptmx *os.File var cmd_stdout, cmd_stderr io.ReadCloser if is_alloc_pty { @@ -617,7 +619,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { defer os.Remove(cpid) cp := func(dst io.Writer, src io.Reader) { - defer wg.Done() io.Copy(dst, src) } @@ -627,6 +628,8 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { return } if is_alloc_pty { + close(done_stdout) + close(done_stderr) go func() { decoder := gob.NewDecoder(control_channel) for { @@ -646,7 +649,6 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { } } }() - wg.Add(2) go cp(data_channel, cmd_ptmx) go cp(cmd_ptmx, data_channel) } else { @@ -696,9 +698,14 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { exec_cmd_kill(syscall.SIGUSR2) } }() - wg.Add(2) - go cp(data_channel, cmd_stdout) - go cp(stderr_channel, cmd_stderr) + go func() { + cp(data_channel, cmd_stdout) + close(done_stdout) + }() + go func() { + cp(stderr_channel, cmd_stderr) + close(done_stderr) + }() } state, err := exec_cmd.Process.Wait() @@ -715,11 +722,15 @@ func srv_handle(conn net.Conn, self_cpids_dir string) { return } + <-done_stdout + <-done_stderr + if is_alloc_pty { session.Close() + } else { + data_channel.Close() + stderr_channel.Close() } - - wg.Wait() } func server(proto, socket string) { @@ -879,7 +890,6 @@ func server(proto, socket string) { func client(proto, socket string, exec_args []string) int { var err error - var wg sync.WaitGroup is_alloc_pty := true if len(exec_args) != 0 { @@ -1061,12 +1071,10 @@ func client(proto, socket string, exec_args []string) int { } pipe_stdin := func(dst io.Writer, src io.Reader) { - defer wg.Done() io.Copy(dst, src) stdin_channel.Close() } cp := func(dst io.Writer, src io.Reader) { - defer wg.Done() io.Copy(dst, src) } @@ -1120,21 +1128,26 @@ func client(proto, socket string, exec_args []string) int { }() } + done_stdout := make(chan struct{}) + done_stderr := make(chan struct{}) + if is_foreground { - if !is_stdin_term { - wg.Add(1) - go pipe_stdin(stdin_channel, os.Stdin) - } else { - wg.Add(1) + if is_stdin_term { go cp(data_channel, os.Stdin) + } else { + go pipe_stdin(stdin_channel, os.Stdin) } } if !is_alloc_pty { - wg.Add(1) - go cp(os.Stderr, stderr_channel) + go func() { + cp(os.Stderr, stderr_channel) + close(done_stderr) + }() } - wg.Add(1) - go cp(os.Stdout, data_channel) + go func() { + cp(os.Stdout, data_channel) + close(done_stdout) + }() var exit_code = 1 exit_reader := bufio.NewReader(command_channel) @@ -1152,18 +1165,13 @@ func client(proto, socket string, exec_args []string) int { if term_old_state != nil { term.Restore(stdin, term_old_state) - if is_foreground { - wg.Done() - } } - if is_foreground && is_stdin_term && ((!*is_pty && !*is_no_pty) || - (*is_no_pty && (!is_stdout_term || !is_stderr_term)) || *is_no_pty) { - if !is_stderr_term || !is_alloc_pty { - wg.Done() - } + + <-done_stdout + if !is_alloc_pty { + <-done_stderr } - wg.Wait() return exit_code }