diff --git a/go.mod b/go.mod index 48218bae9..44807f078 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( golang.org/x/mod v0.5.1 golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b + gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce gopkg.in/src-d/go-git.v4 v4.13.1 gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b ) diff --git a/go.sum b/go.sum index c05c5d97c..263122aef 100644 --- a/go.sum +++ b/go.sum @@ -1447,6 +1447,8 @@ gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= +gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce h1:+JknDZhAj8YMt7GC73Ei8pv4MzjDUNPHgQWJdtMAaDU= +gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce/go.mod h1:5AcXVHNjg+BDxry382+8OKon8SEWiKktQR07RKPsv1c= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= diff --git a/internal/sshdialer/posix_test.go b/internal/sshdialer/posix_test.go index 5114d7e4a..a296894f9 100644 --- a/internal/sshdialer/posix_test.go +++ b/internal/sshdialer/posix_test.go @@ -3,7 +3,11 @@ package sshdialer_test -import "os" +import ( + "errors" + "net" + "os" +) func fixupPrivateKeyMod(path string) { err := os.Chmod(path, 0400) @@ -11,3 +15,11 @@ func fixupPrivateKeyMod(path string) { panic(err) } } + +func listen(addr string) (net.Listener, error) { + return net.Listen("unix", addr) +} + +func isErrClosed(err error) bool { + return errors.Is(err, net.ErrClosed) +} diff --git a/internal/sshdialer/ssh_agent_unix.go b/internal/sshdialer/ssh_agent_unix.go new file mode 100644 index 000000000..7e11b725f --- /dev/null +++ b/internal/sshdialer/ssh_agent_unix.go @@ -0,0 +1,10 @@ +//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris + +package sshdialer + +import "net" + +func dialSSHAgent(addr string) (net.Conn, error) { + return net.Dial("unix", addr) +} diff --git a/internal/sshdialer/ssh_agent_windows.go b/internal/sshdialer/ssh_agent_windows.go new file mode 100644 index 000000000..b6000ef8a --- /dev/null +++ b/internal/sshdialer/ssh_agent_windows.go @@ -0,0 +1,15 @@ +package sshdialer + +import ( + "net" + "strings" + + "gopkg.in/natefinch/npipe.v2" +) + +func dialSSHAgent(addr string) (net.Conn, error) { + if strings.Contains(addr, "\\pipe\\") { + return npipe.Dial(addr) + } + return net.Dial("unix", addr) +} diff --git a/internal/sshdialer/ssh_dialer.go b/internal/sshdialer/ssh_dialer.go index 83c247c10..5769a29a0 100644 --- a/internal/sshdialer/ssh_dialer.go +++ b/internal/sshdialer/ssh_dialer.go @@ -284,7 +284,7 @@ func getSignersFromAgent() ([]ssh.Signer, error) { var err error var agentSigners []ssh.Signer var agentConn net.Conn - agentConn, err = net.Dial("unix", sock) + agentConn, err = dialSSHAgent(sock) if err != nil { return nil, fmt.Errorf("failed to connect to ssh-agent's socket: %w", err) } diff --git a/internal/sshdialer/ssh_dialer_test.go b/internal/sshdialer/ssh_dialer_test.go index 8708d7077..654f12266 100644 --- a/internal/sshdialer/ssh_dialer_test.go +++ b/internal/sshdialer/ssh_dialer_test.go @@ -770,12 +770,21 @@ func withBadSSHAgent(t *testing.T) func() { } func withSSHAgent(t *testing.T, ag agent.Agent) func() { + var err error t.Helper() - tmpDirForSocket, err := ioutil.TempDir("", "forAuthSock") - th.AssertNil(t, err) - agentSocketPath := filepath.Join(tmpDirForSocket, "agent.sock") - unixListener, err := net.Listen("unix", agentSocketPath) + var tmpDirForSocket string + var agentSocketPath string + if runtime.GOOS == "windows" { + agentSocketPath = `\\.\pipe\openssh-ssh-agent-test` + } else { + tmpDirForSocket, err = ioutil.TempDir("", "forAuthSock") + th.AssertNil(t, err) + + agentSocketPath = filepath.Join(tmpDirForSocket, "agent.sock") + } + + unixListener, err := listen(agentSocketPath) th.AssertNil(t, err) os.Setenv("SSH_AUTH_SOCK", agentSocketPath) @@ -802,7 +811,7 @@ func withSSHAgent(t *testing.T, ag agent.Agent) func() { }() err := agent.ServeAgent(ag, conn) if err != nil { - if !errors.Is(err, net.ErrClosed) { + if !isErrClosed(err) { fmt.Fprintf(os.Stderr, "agent.ServeAgent() failed: %v\n", err) } } @@ -818,12 +827,14 @@ func withSSHAgent(t *testing.T, ag agent.Agent) func() { err = <-errChan - if !errors.Is(err, net.ErrClosed) { + if !isErrClosed(err) { t.Fatal(err) } cancel() wg.Wait() - os.RemoveAll(tmpDirForSocket) + if tmpDirForSocket != "" { + os.RemoveAll(tmpDirForSocket) + } } } diff --git a/internal/sshdialer/windows_test.go b/internal/sshdialer/windows_test.go index 70ff93b28..304549d96 100644 --- a/internal/sshdialer/windows_test.go +++ b/internal/sshdialer/windows_test.go @@ -4,9 +4,13 @@ package sshdialer_test import ( + "errors" + "net" "os/user" + "strings" "github.com/hectane/go-acl" + "gopkg.in/natefinch/npipe.v2" ) func fixupPrivateKeyMod(path string) { @@ -25,3 +29,14 @@ func fixupPrivateKeyMod(path string) { panic(err) } } + +func listen(addr string) (net.Listener, error) { + if strings.Contains(addr, "\\pipe\\") { + return npipe.Listen(addr) + } + return net.Listen("unix", addr) +} + +func isErrClosed(err error) bool { + return errors.Is(err, net.ErrClosed) || errors.Is(err, npipe.ErrClosed) +}