Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssh: add support for extension negotiation (rfc 8308) #197

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions ssh/client_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,31 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
return err
}
packet, err := c.transport.readPacket()
if err != nil {
return err
}

var serviceAccept serviceAcceptMsg
if err := Unmarshal(packet, &serviceAccept); err != nil {
return err
readAcceptLoop:
for {
packet, err := c.transport.readPacket()
if err != nil {
return err
}

switch packet[0] {
case msgExtInfo:
var extInfo extInfoMsg
if err := Unmarshal(packet, &extInfo); err != nil {
return err
}
c.transport.extensions = extInfo.Extensions
continue
case msgServiceAccept:
if err := Unmarshal(packet, &serviceAccept); err != nil {
return err
}
break readAcceptLoop
default:
return fmt.Errorf("ssh: unexpected message received")
}
}

// during the authentication phase the client first attempts the "none" method
Expand Down Expand Up @@ -337,6 +355,14 @@ func handleAuthResponse(c packetConn) (authResult, []string, error) {
}

switch packet[0] {
case msgExtInfo:
var extInfo extInfoMsg
if err := Unmarshal(packet, &extInfo); err != nil {
return authFailure, nil, err
}
if transport, ok := c.(*handshakeTransport); ok {
transport.extensions = extInfo.Extensions
}
case msgUserAuthBanner:
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
Expand Down Expand Up @@ -420,6 +446,15 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe

// like handleAuthResponse, but with less options.
switch packet[0] {
case msgExtInfo:
var extInfo extInfoMsg
if err := Unmarshal(packet, &extInfo); err != nil {
return authFailure, nil, err
}
if transport, ok := c.(*handshakeTransport); ok {
transport.extensions = extInfo.Extensions
}
continue
case msgUserAuthBanner:
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
Expand Down
128 changes: 128 additions & 0 deletions ssh/client_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,75 @@ func TestClientAuthPublicKey(t *testing.T) {
}
}

func TestClientAuthPublicKeyExtensions(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()

certChecker := CertChecker{
IsUserAuthority: func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
},
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
return nil, nil
}

return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
},
IsRevoked: func(c *Certificate) bool {
return c.Serial == 666
},
}
serverConfig := &ServerConfig{
PublicKeyCallback: certChecker.Authenticate,
}
serverConfig.AddHostKey(testSigners["rsa"])

go newServer(c1, serverConfig)
clientConn, _, _, err := NewClientConn(c2, "", &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
})
if err != nil {
t.Fatalf("NewClientConn: %v", err)
}

conn, ok := clientConn.(*connection)
if !ok {
t.Fatalf("conn is not a *connection")
}

rawServerSigAlgs, ok := conn.transport.extensions[ExtServerSigAlgs]
if !ok {
t.Fatalf("did not receive server-sig-algs extension")
}

serverSigAlgs := strings.Split(string(rawServerSigAlgs), ",")
if len(serverSigAlgs) == 0 {
t.Fatalf("did not receive any server-sig-algs")
}

for _, expectedAlg := range supportedSigAlgs() {
hasAlg := false
for _, receivedAlg := range serverSigAlgs {
if receivedAlg == expectedAlg {
hasAlg = true
break
}
}
if !hasAlg {
t.Errorf("server-sig-algs did not have expected alg: %s", expectedAlg)
}
}
}

func TestAuthMethodPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Expand All @@ -131,6 +200,65 @@ func TestAuthMethodPassword(t *testing.T) {
}
}

func TestClientAuthPasswordExtensions(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()

serverConfig := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == clientPassword {
return nil, nil
}
return nil, errors.New("password auth failed")
},
}
serverConfig.AddHostKey(testSigners["rsa"])

go newServer(c1, serverConfig)
clientConn, _, _, err := NewClientConn(c2, "", &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
})
if err != nil {
t.Fatalf("NewClientConn: %v", err)
}

conn, ok := clientConn.(*connection)
if !ok {
t.Fatalf("conn is not a *connection")
}

rawServerSigAlgs, ok := conn.transport.extensions[ExtServerSigAlgs]
if !ok {
t.Fatalf("did not receive server-sig-algs extension")
}

serverSigAlgs := strings.Split(string(rawServerSigAlgs), ",")
if len(serverSigAlgs) == 0 {
t.Fatalf("did not receive any server-sig-algs")
}

for _, expectedAlg := range supportedSigAlgs() {
hasAlg := false
for _, receivedAlg := range serverSigAlgs {
if receivedAlg == expectedAlg {
hasAlg = true
break
}
}
if !hasAlg {
t.Errorf("server-sig-algs did not have expected alg: %s", expectedAlg)
}
}
}

func TestAuthMethodFallback(t *testing.T) {
var passwordCalled bool
config := &ClientConfig{
Expand Down
51 changes: 51 additions & 0 deletions ssh/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ const (
serviceSSH = "ssh-connection"
)

// These are string constants related to extensions and extension negotiation
const (
extInfoServer = "ext-info-s"
extInfoClient = "ext-info-c"

ExtServerSigAlgs = "server-sig-algs"
// extDelayCompression = "delay-compression"
// extNoFlowControl = "no-flow-control"
// extElevation = "elevation"
)

// defaultExtensions lists extensions enabled by default.
var defaultExtensions = []string{
ExtServerSigAlgs,
}

// supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
Expand Down Expand Up @@ -102,6 +118,18 @@ var hashFuncs = map[string]crypto.Hash{
CertAlgoECDSA521v01: crypto.SHA512,
}

// supportedSigAlgs returns a slice of algorithms supported for pubkey authentication
// in no particular order.
func supportedSigAlgs() []string {
// TODO(kxd) I'm not sure if hashFuncs is the best place to get this set but it seemed
// like a sensible first step. Should this be a curated list?
var serverSigAlgs []string
for k := range hashFuncs {
serverSigAlgs = append(serverSigAlgs, k)
}
return serverSigAlgs
}

// unexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
func unexpectedMessageError(expected, got uint8) error {
Expand All @@ -124,6 +152,16 @@ func findCommon(what string, client []string, server []string) (common string, e
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
}

// hasString returns true if string "a" is in slice of strings "x", false otherwise.
func hasString(a string, x []string) bool {
for _, s := range x {
if a == s {
return true
}
}
return false
}

// directionAlgorithms records algorithm choices in one direction (either read or write)
type directionAlgorithms struct {
Cipher string
Expand Down Expand Up @@ -159,6 +197,11 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if err != nil {
return
} else if result.kex == extInfoClient || result.kex == extInfoServer {
// According to RFC8308 section 2.2 if either the client or server extension signal
// is chosen as the kex algorithm the parties must disconnect.
// chosen
return result, fmt.Errorf("ssh: invalid kex algorithm chosen")
}

result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
Expand Down Expand Up @@ -232,6 +275,10 @@ type Config struct {
// The allowed MAC algorithms. If unspecified then a sensible default
// is used.
MACs []string

// A list of enabled extensions. If unspecified then a sensible
// default is used
Extensions []string
}

// SetDefaults sets sensible values for unset fields in config. This is
Expand Down Expand Up @@ -261,6 +308,10 @@ func (c *Config) SetDefaults() {
c.MACs = supportedMACs
}

if c.Extensions == nil {
c.Extensions = defaultExtensions
}

if c.RekeyThreshold == 0 {
// cipher specific default
} else if c.RekeyThreshold < minRekeyThreshold {
Expand Down
22 changes: 22 additions & 0 deletions ssh/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,28 @@ func TestFindAgreedAlgorithms(t *testing.T) {
wantErr: true,
},

testcase{
name: "server ext info kex chosen",
serverIn: kexInitMsg{
KexAlgos: []string{extInfoServer},
},
clientIn: kexInitMsg{
KexAlgos: []string{extInfoServer},
},
wantErr: true,
},

testcase{
name: "client ext info kex chosen",
serverIn: kexInitMsg{
KexAlgos: []string{extInfoClient},
},
clientIn: kexInitMsg{
KexAlgos: []string{extInfoClient},
},
wantErr: true,
},

testcase{
name: "client decides cipher",
serverIn: kexInitMsg{
Expand Down
Loading