Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Windows OpenSSH agent forwarding #2127

Merged
merged 4 commits into from
Jun 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions session/sshforward/sshprovider/agentprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) {
}

if conf.Paths[0] == "" {
return nil, errors.Errorf("invalid empty ssh agent socket, make sure SSH_AUTH_SOCK is set")
p, err := getFallbackAgentPath()
if err != nil {
return nil, errors.Wrap(err, "invalid empty ssh agent socket")
}
conf.Paths[0] = p
}

src, err := toAgentSource(conf.Paths)
Expand All @@ -56,7 +60,20 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) {

type source struct {
agent agent.Agent
socket string
socket *socketDialer
}

type socketDialer struct {
path string
dialer func(string) (net.Conn, error)
}

func (s socketDialer) Dial() (net.Conn, error) {
return s.dialer(s.path)
}

func (s socketDialer) String() string {
return s.path
}

type socketProvider struct {
Expand Down Expand Up @@ -94,8 +111,8 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)

var a agent.Agent

if src.socket != "" {
conn, err := net.DialTimeout("unix", src.socket, time.Second)
if src.socket != nil {
conn, err := src.socket.Dial()
if err != nil {
return errors.Wrapf(err, "failed to connect to %s", src.socket)
}
Expand Down Expand Up @@ -124,21 +141,24 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)

func toAgentSource(paths []string) (source, error) {
var keys bool
var socket string
var socket *socketDialer
a := agent.NewKeyring()
for _, p := range paths {
if socket != "" {
if socket != nil {
return source{}, errors.New("only single socket allowed")
}

if parsed := getWindowsPipeDialer(p); parsed != nil {
socket = parsed
continue
}

fi, err := os.Stat(p)
if err != nil {
return source{}, errors.WithStack(err)
}
if fi.Mode()&os.ModeSocket > 0 {
if keys {
return source{}, errors.Errorf("invalid combination of keys and sockets")
}
socket = p
socket = &socketDialer{path: p, dialer: unixSocketDialer}
continue
}

Expand All @@ -160,7 +180,7 @@ func toAgentSource(paths []string) (source, error) {
if keys {
return source{}, errors.Errorf("invalid combination of keys and sockets")
}
socket = p
socket = &socketDialer{path: p, dialer: unixSocketDialer}
continue
}

Expand All @@ -173,13 +193,20 @@ func toAgentSource(paths []string) (source, error) {
keys = true
}

if socket != "" {
if socket != nil {
if keys {
return source{}, errors.Errorf("invalid combination of keys and sockets")
}
return source{socket: socket}, nil
}

return source{agent: a}, nil
}

func unixSocketDialer(path string) (net.Conn, error) {
return net.DialTimeout("unix", path, 2*time.Second)
}

func sockPair() (io.ReadWriteCloser, io.ReadWriteCloser) {
pr1, pw1 := io.Pipe()
pr2, pw2 := io.Pipe()
Expand Down
15 changes: 15 additions & 0 deletions session/sshforward/sshprovider/agentprovider_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// +build !windows

package sshprovider

import (
"github.com/pkg/errors"
)

func getFallbackAgentPath() (string, error) {
return "", errors.Errorf("make sure SSH_AUTH_SOCK is set")
}

func getWindowsPipeDialer(path string) *socketDialer {
return nil
}
60 changes: 60 additions & 0 deletions session/sshforward/sshprovider/agentprovider_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// +build windows

package sshprovider

import (
"net"
"regexp"
"strings"

"github.com/Microsoft/go-winio"
"github.com/pkg/errors"
"golang.org/x/sys/windows"
)

// Returns the Windows OpenSSH agent named pipe path, but
// only if the agent is running. Returns an error otherwise.
func getFallbackAgentPath() (string, error) {
// Windows OpenSSH agent uses a named pipe rather
// than a UNIX socket. These pipes do not play nice
// with os.Stat (which tries to open its target), so
// use a FindFirstFile syscall to check for existence.
var fd windows.Win32finddata

path := `\\.\pipe\openssh-ssh-agent`
sschaap marked this conversation as resolved.
Show resolved Hide resolved
pathPtr, _ := windows.UTF16PtrFromString(path)
handle, err := windows.FindFirstFile(pathPtr, &fd)

if err != nil {
msg := "Windows OpenSSH agent not available at %s." +
" Enable the SSH agent service or set SSH_AUTH_SOCK."
return "", errors.Errorf(msg, path)
}

_ = windows.CloseHandle(handle)

return path, nil
}

// Returns true if the path references a named pipe.
func isWindowsPipePath(path string) bool {
// If path matches \\*\pipe\* then it references a named pipe
// and requires winio.DialPipe() rather than DialTimeout("unix").
// Slashes and backslashes may be used interchangeably in the path.
// Path separators may consist of multiple consecutive (back)slashes.
pipePattern := strings.ReplaceAll("^[/]{2}[^/]+[/]+pipe[/]+", "/", `\\/`)
ok, _ := regexp.MatchString(pipePattern, path)
return ok
}

func getWindowsPipeDialer(path string) *socketDialer {
if isWindowsPipePath(path) {
return &socketDialer{path: path, dialer: windowsPipeDialer}
}

return nil
}

func windowsPipeDialer(path string) (net.Conn, error) {
return winio.DialPipe(path, nil)
}