Skip to content

Commit

Permalink
Address additional code review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
djjuhasz committed Nov 7, 2023
1 parent 390c986 commit 53e3241
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 26 deletions.
6 changes: 3 additions & 3 deletions internal/sftp/goclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ func (c *GoClient) Upload(ctx context.Context, src io.Reader, dest string) (int6
// Note: Some SFTP servers don't support O_RDWR mode.
w, err := c.sftp.OpenFile(dest, (os.O_WRONLY | os.O_CREATE | os.O_TRUNC))
if err != nil {
return 0, fmt.Errorf("SFTP: couldn't create remote file %q: %v", dest, err)
return 0, fmt.Errorf("SFTP: open remote file %q: %v", dest, err)
}
defer w.Close()

// Use contextio to stop the upload if a context cancellation signal is
// received.
bytes, err := io.Copy(contextio.NewWriter(ctx, w), contextio.NewReader(ctx, src))
if err != nil {
return 0, fmt.Errorf("SFTP: failed to write to %q: %v", dest, err)
return 0, fmt.Errorf("SFTP: upload to %q: %v", dest, err)
}

return bytes, nil
Expand All @@ -70,7 +70,7 @@ func (c *GoClient) dial() error {

c.sftp, err = sftp.NewClient(c.ssh)
if err != nil {
return fmt.Errorf("Unable to start SFTP subsystem: %v", err)
return fmt.Errorf("start SFTP subsystem: %v", err)
}

return nil
Expand Down
35 changes: 17 additions & 18 deletions internal/sftp/goclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,22 @@ const serverAddress = "127.0.0.1:2222"
func pubkeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
file, err := os.Open("./testdata/authorized_keys")
if err != nil {
log.Fatalln("SSH: couldn't open authorized_keys file.")
log.Fatalf("SFTP server: couldn't open authorized_keys file: %s", err)
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
allowed, _, _, _, err := ssh.ParseAuthorizedKey([]byte(scanner.Text()))
if err != nil {
log.Fatalln("SSH: couldn't parse authorized key.")
log.Fatalf("SFTP server: couldn't parse authorized keys: %s", err)
}
if ssh.KeysEqual(key, allowed) {
return true
}
}

log.Println("SSH: unknown key provided.")
log.Println("SFTP server: unknown key provided.")
return false
}

Expand All @@ -55,12 +55,12 @@ func hostKeySigner() (gossh.Signer, error) {

key, err := os.ReadFile(keyfile)
if err != nil {
return nil, fmt.Errorf("couldn't read keyfile %q, %v\n", keyfile, err)
return nil, fmt.Errorf("read keyfile %q, %v\n", keyfile, err)
}

signer, err := gossh.ParsePrivateKey(key)
if err != nil {
return nil, fmt.Errorf("couldn't parse private key: %v\n", err)
return nil, fmt.Errorf("parse private key: %v\n", err)
}

return signer, nil
Expand All @@ -77,14 +77,13 @@ func sftpHandler(sess ssh.Session) {
serverOptions...,
)
if err != nil {
log.Printf("sftp server init error: %s\n", err)
return
log.Fatalf("SFTP server init error: %s", err)
}
if err := server.Serve(); err == io.EOF {
server.Close()
fmt.Println("sftp client exited session.")
fmt.Println("SFTP client exited session.")
} else if err != nil {
fmt.Println("sftp server completed with error:", err)
fmt.Println("SFTP server completed with error:", err)
}
}

Expand All @@ -110,7 +109,7 @@ func startSFTPServer(t *testing.T, addr string) *ssh.Server {

signer, err := hostKeySigner()
if err != nil {
t.Fatalf("SFTP server: %v", err)
t.Fatalf("SFTP server: couldn't create host key signer: %v", err)
}
srv.AddHostKey(signer)

Expand All @@ -124,7 +123,7 @@ func startSFTPServer(t *testing.T, addr string) *ssh.Server {
for {
select {
case err := <-errCh:
t.Fatalf("Couldn't start SFTP server: %v", err)
t.Fatalf("SFTP server: failed to start: %v", err)
default:
conn, err := net.DialTimeout("tcp", addr, 1*time.Second)
if err == nil {
Expand All @@ -143,7 +142,7 @@ func startSFTPServer(t *testing.T, addr string) *ssh.Server {
func TestGoClient(t *testing.T) {
host, port, err := net.SplitHostPort(serverAddress)
if err != nil {
t.Fatalf("Bad serverAddress: %s", serverAddress)
t.Fatalf("Bad server address: %s", serverAddress)
}

_ = startSFTPServer(t, serverAddress)
Expand All @@ -152,7 +151,7 @@ func TestGoClient(t *testing.T) {
// server address.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("couldn't start listener: %v", err)
t.Fatalf("Couldn't start listener: %v", err)
}
defer listener.Close()
badHost, badPort, _ := net.SplitHostPort(listener.Addr().String())
Expand Down Expand Up @@ -211,7 +210,7 @@ func TestGoClient(t *testing.T) {
Passphrase: "wrong",
},
},
wantErr: "SSH: couldn't parse private key with passphrase: x509: decryption password incorrect",
wantErr: "SSH: parse private key with passphrase: x509: decryption password incorrect",
},
{
name: "Errors when the SFTP server isn't there",
Expand All @@ -224,7 +223,7 @@ func TestGoClient(t *testing.T) {
},
},
wantErr: fmt.Sprintf(
"SSH: failed to connect: dial tcp %s:%s: connect: connection refused",
"SSH: connect: dial tcp %s:%s: connect: connection refused",
badHost, badPort,
),
},
Expand All @@ -238,7 +237,7 @@ func TestGoClient(t *testing.T) {
Path: "./testdata/clientkeys/test_unk_ed25519",
},
},
wantErr: "SSH: failed to connect: ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain",
wantErr: "SSH: connect: ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain",
},
{
name: "Errors when the host key is not in known_hosts",
Expand All @@ -250,7 +249,7 @@ func TestGoClient(t *testing.T) {
Path: "./testdata/clientkeys/test_ed25519",
},
},
wantErr: "SSH: failed to connect: ssh: handshake failed: knownhosts: key is unknown",
wantErr: "SSH: connect: ssh: handshake failed: knownhosts: key is unknown",
},
{
name: "Errors when the known_hosts file doesn't exist",
Expand All @@ -262,7 +261,7 @@ func TestGoClient(t *testing.T) {
Path: "./testdata/clientkeys/test_ed25519",
},
},
wantErr: "SSH: couldn't parse known_hosts file: open testdata/missing: no such file or directory",
wantErr: "SSH: parse known_hosts: open testdata/missing: no such file or directory",
},
} {
tc := tc
Expand Down
10 changes: 5 additions & 5 deletions internal/sftp/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,27 @@ func sshConnect(cfg Config) (*ssh.Client, error) {
// Load private key for authentication.
keyBytes, err := os.ReadFile(filepath.Clean(cfg.PrivateKey.Path)) // #nosec G304 -- File data is validated below
if err != nil {
return nil, fmt.Errorf("couldn't read private key: %v", err)
return nil, fmt.Errorf("read private key: %v", err)
}

// Create a signer from the private key, with or without a passphrase.
var signer ssh.Signer
if cfg.PrivateKey.Passphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(keyBytes, []byte(cfg.PrivateKey.Passphrase))
if err != nil {
return nil, fmt.Errorf("couldn't parse private key with passphrase: %v", err)
return nil, fmt.Errorf("parse private key with passphrase: %v", err)
}
} else {
signer, err = ssh.ParsePrivateKey(keyBytes)
if err != nil {
return nil, fmt.Errorf("couldn't parse private key: %v", err)
return nil, fmt.Errorf("parse private key: %v", err)
}
}

// Check that the host key is in the client's known_hosts file.
hostcallback, err := knownhosts.New(filepath.Clean(cfg.KnownHostsFile))
if err != nil {
return nil, fmt.Errorf("couldn't parse known_hosts file: %v", err)
return nil, fmt.Errorf("parse known_hosts: %v", err)
}

// Configure the SSH client.
Expand All @@ -56,7 +56,7 @@ func sshConnect(cfg Config) (*ssh.Client, error) {
address := net.JoinHostPort(cfg.Host, cfg.Port)
conn, err := ssh.Dial("tcp", address, sshConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect: %v", err)
return nil, fmt.Errorf("connect: %v", err)
}

return conn, nil
Expand Down

0 comments on commit 53e3241

Please sign in to comment.