From d484a0e77f29369012910beb69e250c889d27671 Mon Sep 17 00:00:00 2001 From: Dragoon <4461926+1Dragoon@users.noreply.github.com> Date: Thu, 15 Dec 2022 15:15:03 -0800 Subject: [PATCH] Truncated buffer and freeze fixes (#23) Fixes for #22 and #18 in addition to a problem where keys would be truncated and give an invalid response when the agent is being forwarded. --- internal/sshagent/agent.go | 54 ++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/internal/sshagent/agent.go b/internal/sshagent/agent.go index 39b5636..5611a9c 100644 --- a/internal/sshagent/agent.go +++ b/internal/sshagent/agent.go @@ -4,54 +4,69 @@ import ( "bufio" "encoding/binary" "fmt" - "log" + "time" "github.com/Microsoft/go-winio" ) // AgentMaxMessageLength is the maximum length of a message sent to the agent const ( - AgentMaxMessageLength = 1<<14 - 1 // 16383 + AgentMaxMessageLength = 1<<14 - 1 // 16383 + SSH_AGENT_FAIL byte = 0x05 ) +var genericFail = []byte{0x00, 0x00, 0x00, 0x01, SSH_AGENT_FAIL} + // QueryAgent provides a way to query the named windows openssh agent pipe func QueryAgent(pipeName string, buf []byte) (result []byte, err error) { if len(buf) > AgentMaxMessageLength { - return nil, fmt.Errorf("message too long") + fmt.Println("message too long") + return genericFail, nil } conn, err := winio.DialPipe(pipeName, nil) if err != nil { - return nil, fmt.Errorf("cannot connect to pipe %s: %w", pipeName, err) + fmt.Printf("cannot connect to pipe %s: %s", pipeName, err.Error()) + return genericFail, nil } defer conn.Close() + // If the agent needs the user to do something, give them time to do so, but don't wait forever. + conn.SetDeadline(time.Now().Add(time.Second * 20)) - byteCount, err := conn.Write(buf) + _, err = conn.Write(buf) if err != nil { - return nil, fmt.Errorf("cannot write to pipe %s: %w", pipeName, err) + fmt.Printf("cannot write to pipe %s: %s", pipeName, err.Error()) + return genericFail, nil } - reader := bufio.NewReader(conn) + // The buffer needs to be at least as large as the expected message size + reader := bufio.NewReaderSize(conn, AgentMaxMessageLength) // Magic numbers from the ssh-agent protocol specification. // // first 4 bytes are magic numbers related to the named pipe magic := make([]byte, 4) - byteCount, err = reader.Read(magic) + _, err = reader.Read(magic) if err != nil { - return nil, fmt.Errorf("cannot read from pipe %s: %w", pipeName, err) + fmt.Printf("cannot read from pipe %s: %s", pipeName, err.Error()) + return genericFail, nil } - // next byte is the SSH2_AGENT_IDENTITIES_ANSWER - sshHeader := make([]byte, 1) - byteCount, err = reader.Read(sshHeader) + // next byte is the reply code + replyCode := make([]byte, 1) + _, err = reader.Read(replyCode) if err != nil { - return nil, fmt.Errorf("cannot read from pipe %s: %w", pipeName, err) + fmt.Printf("cannot read from pipe %s: %s", pipeName, err.Error()) + return append(magic, []byte{SSH_AGENT_FAIL}...), nil + } + if replyCode[0] == SSH_AGENT_FAIL { + return append(magic, replyCode...), nil } // next 4 bytes (Uint32) is the number of keys keyCountSlice := make([]byte, 4) - byteCount, err = reader.Read(keyCountSlice) + _, err = reader.Read(keyCountSlice) if err != nil { - return nil, fmt.Errorf("cannot read from pipe %s: %w", pipeName, err) + fmt.Printf("cannot read from pipe %s: %s", pipeName, err.Error()) + return append(magic, []byte{SSH_AGENT_FAIL}...), nil } // convert to Uint32 keyCount := binary.BigEndian.Uint32(keyCountSlice) @@ -59,18 +74,17 @@ func QueryAgent(pipeName string, buf []byte) (result []byte, err error) { // set to max agent message length minus the previous 9 bytes res := make([]byte, AgentMaxMessageLength-9) // verify the key count is > 0, otherwise skip + byteCount := 0 if keyCount > 0 { byteCount, err = reader.Read(res) if err != nil { - log.Println(err) - return nil, fmt.Errorf("cannot read from pipe %s: %w", pipeName, err) + fmt.Printf("cannot read from pipe %s: %s", pipeName, err.Error()) + return append(magic, []byte{SSH_AGENT_FAIL}...), nil } - } else { - byteCount = 0 } // Concat all slices together - concatRes := append(magic, sshHeader...) + concatRes := append(magic, replyCode...) concatRes = append(concatRes, keyCountSlice...) concatRes = append(concatRes, res[0:byteCount]...)