Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

truncated buffer and freeze fixes #23

Merged
merged 1 commit into from
Dec 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions internal/sshagent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,87 @@ import (
"bufio"
"encoding/binary"
"fmt"
"log"
"time"

"github.com/Microsoft/go-winio"
)

// AgentMaxMessageLength is the maximum length of a message sent to the agent
const (
AgentMaxMessageLength = 1<<14 - 1 // 16383
AgentMaxMessageLength = 1<<14 - 1 // 16383
SSH_AGENT_FAIL byte = 0x05
)

var genericFail = []byte{0x00, 0x00, 0x00, 0x01, SSH_AGENT_FAIL}

// QueryAgent provides a way to query the named windows openssh agent pipe
func QueryAgent(pipeName string, buf []byte) (result []byte, err error) {
if len(buf) > AgentMaxMessageLength {
return nil, fmt.Errorf("message too long")
fmt.Println("message too long")
return genericFail, nil
}

conn, err := winio.DialPipe(pipeName, nil)
if err != nil {
return nil, fmt.Errorf("cannot connect to pipe %s: %w", pipeName, err)
fmt.Printf("cannot connect to pipe %s: %s", pipeName, err.Error())
return genericFail, nil
}
defer conn.Close()
// If the agent needs the user to do something, give them time to do so, but don't wait forever.
conn.SetDeadline(time.Now().Add(time.Second * 20))

byteCount, err := conn.Write(buf)
_, err = conn.Write(buf)
if err != nil {
return nil, fmt.Errorf("cannot write to pipe %s: %w", pipeName, err)
fmt.Printf("cannot write to pipe %s: %s", pipeName, err.Error())
return genericFail, nil
}

reader := bufio.NewReader(conn)
// The buffer needs to be at least as large as the expected message size
reader := bufio.NewReaderSize(conn, AgentMaxMessageLength)

// Magic numbers from the ssh-agent protocol specification.
// <https://github.com/openssh/openssh-portable/blob/4e636cf/PROTOCOL.agent>
// first 4 bytes are magic numbers related to the named pipe
magic := make([]byte, 4)
byteCount, err = reader.Read(magic)
_, err = reader.Read(magic)
if err != nil {
return nil, fmt.Errorf("cannot read from pipe %s: %w", pipeName, err)
fmt.Printf("cannot read from pipe %s: %s", pipeName, err.Error())
return genericFail, nil
}
// next byte is the SSH2_AGENT_IDENTITIES_ANSWER
sshHeader := make([]byte, 1)
byteCount, err = reader.Read(sshHeader)
// next byte is the reply code
replyCode := make([]byte, 1)
_, err = reader.Read(replyCode)
if err != nil {
return nil, fmt.Errorf("cannot read from pipe %s: %w", pipeName, err)
fmt.Printf("cannot read from pipe %s: %s", pipeName, err.Error())
return append(magic, []byte{SSH_AGENT_FAIL}...), nil
}
if replyCode[0] == SSH_AGENT_FAIL {
return append(magic, replyCode...), nil
}
// next 4 bytes (Uint32) is the number of keys
keyCountSlice := make([]byte, 4)
byteCount, err = reader.Read(keyCountSlice)
_, err = reader.Read(keyCountSlice)
if err != nil {
return nil, fmt.Errorf("cannot read from pipe %s: %w", pipeName, err)
fmt.Printf("cannot read from pipe %s: %s", pipeName, err.Error())
return append(magic, []byte{SSH_AGENT_FAIL}...), nil
}
// convert to Uint32
keyCount := binary.BigEndian.Uint32(keyCountSlice)

// set to max agent message length minus the previous 9 bytes
res := make([]byte, AgentMaxMessageLength-9)
// verify the key count is > 0, otherwise skip
byteCount := 0
if keyCount > 0 {
byteCount, err = reader.Read(res)
if err != nil {
log.Println(err)
return nil, fmt.Errorf("cannot read from pipe %s: %w", pipeName, err)
fmt.Printf("cannot read from pipe %s: %s", pipeName, err.Error())
return append(magic, []byte{SSH_AGENT_FAIL}...), nil
}
} else {
byteCount = 0
}

// Concat all slices together
concatRes := append(magic, sshHeader...)
concatRes := append(magic, replyCode...)
concatRes = append(concatRes, keyCountSlice...)
concatRes = append(concatRes, res[0:byteCount]...)

Expand Down