diff --git a/session/sshforward/sshprovider/agentprovider.go b/session/sshforward/sshprovider/agentprovider.go index f8ed2811da4a..981eb96f5628 100644 --- a/session/sshforward/sshprovider/agentprovider.go +++ b/session/sshforward/sshprovider/agentprovider.go @@ -35,7 +35,11 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) { } if conf.Paths[0] == "" { - return nil, errors.Errorf("invalid empty ssh agent socket, make sure SSH_AUTH_SOCK is set") + p, err := getFallbackAgentPath() + if err != nil { + return nil, errors.Wrap(err, "invalid empty ssh agent socket") + } + conf.Paths[0] = p } src, err := toAgentSource(conf.Paths) @@ -56,7 +60,20 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) { type source struct { agent agent.Agent - socket string + socket *socketDialer +} + +type socketDialer struct { + path string + dialer func(string) (net.Conn, error) +} + +func (s socketDialer) Dial() (net.Conn, error) { + return s.dialer(s.path) +} + +func (s socketDialer) String() string { + return s.path } type socketProvider struct { @@ -94,8 +111,8 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer) var a agent.Agent - if src.socket != "" { - conn, err := net.DialTimeout("unix", src.socket, time.Second) + if src.socket != nil { + conn, err := src.socket.Dial() if err != nil { return errors.Wrapf(err, "failed to connect to %s", src.socket) } @@ -124,21 +141,24 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer) func toAgentSource(paths []string) (source, error) { var keys bool - var socket string + var socket *socketDialer a := agent.NewKeyring() for _, p := range paths { - if socket != "" { + if socket != nil { return source{}, errors.New("only single socket allowed") } + + if parsed := getWindowsPipeDialer(p); parsed != nil { + socket = parsed + continue + } + fi, err := os.Stat(p) if err != nil { return source{}, errors.WithStack(err) } if fi.Mode()&os.ModeSocket > 0 { - if keys { - return source{}, errors.Errorf("invalid combination of keys and sockets") - } - socket = p + socket = &socketDialer{path: p, dialer: unixSocketDialer} continue } @@ -160,7 +180,7 @@ func toAgentSource(paths []string) (source, error) { if keys { return source{}, errors.Errorf("invalid combination of keys and sockets") } - socket = p + socket = &socketDialer{path: p, dialer: unixSocketDialer} continue } @@ -173,13 +193,20 @@ func toAgentSource(paths []string) (source, error) { keys = true } - if socket != "" { + if socket != nil { + if keys { + return source{}, errors.Errorf("invalid combination of keys and sockets") + } return source{socket: socket}, nil } return source{agent: a}, nil } +func unixSocketDialer(path string) (net.Conn, error) { + return net.DialTimeout("unix", path, 2*time.Second) +} + func sockPair() (io.ReadWriteCloser, io.ReadWriteCloser) { pr1, pw1 := io.Pipe() pr2, pw2 := io.Pipe() diff --git a/session/sshforward/sshprovider/agentprovider_unix.go b/session/sshforward/sshprovider/agentprovider_unix.go new file mode 100644 index 000000000000..07b6b7b1e93d --- /dev/null +++ b/session/sshforward/sshprovider/agentprovider_unix.go @@ -0,0 +1,15 @@ +// +build !windows + +package sshprovider + +import ( + "github.com/pkg/errors" +) + +func getFallbackAgentPath() (string, error) { + return "", errors.Errorf("make sure SSH_AUTH_SOCK is set") +} + +func getWindowsPipeDialer(path string) *socketDialer { + return nil +} diff --git a/session/sshforward/sshprovider/agentprovider_windows.go b/session/sshforward/sshprovider/agentprovider_windows.go new file mode 100644 index 000000000000..812e273c2e4c --- /dev/null +++ b/session/sshforward/sshprovider/agentprovider_windows.go @@ -0,0 +1,60 @@ +// +build windows + +package sshprovider + +import ( + "net" + "regexp" + "strings" + + "github.com/Microsoft/go-winio" + "github.com/pkg/errors" + "golang.org/x/sys/windows" +) + +// Returns the Windows OpenSSH agent named pipe path, but +// only if the agent is running. Returns an error otherwise. +func getFallbackAgentPath() (string, error) { + // Windows OpenSSH agent uses a named pipe rather + // than a UNIX socket. These pipes do not play nice + // with os.Stat (which tries to open its target), so + // use a FindFirstFile syscall to check for existence. + var fd windows.Win32finddata + + path := `\\.\pipe\openssh-ssh-agent` + pathPtr, _ := windows.UTF16PtrFromString(path) + handle, err := windows.FindFirstFile(pathPtr, &fd) + + if err != nil { + msg := "Windows OpenSSH agent not available at %s." + + " Enable the SSH agent service or set SSH_AUTH_SOCK." + return "", errors.Errorf(msg, path) + } + + _ = windows.CloseHandle(handle) + + return path, nil +} + +// Returns true if the path references a named pipe. +func isWindowsPipePath(path string) bool { + // If path matches \\*\pipe\* then it references a named pipe + // and requires winio.DialPipe() rather than DialTimeout("unix"). + // Slashes and backslashes may be used interchangeably in the path. + // Path separators may consist of multiple consecutive (back)slashes. + pipePattern := strings.ReplaceAll("^[/]{2}[^/]+[/]+pipe[/]+", "/", `\\/`) + ok, _ := regexp.MatchString(pipePattern, path) + return ok +} + +func getWindowsPipeDialer(path string) *socketDialer { + if isWindowsPipePath(path) { + return &socketDialer{path: path, dialer: windowsPipeDialer} + } + + return nil +} + +func windowsPipeDialer(path string) (net.Conn, error) { + return winio.DialPipe(path, nil) +}