diff --git a/tools/remotecommand/errorstream.go b/tools/remotecommand/errorstream.go index e60dd7cdc..90bb39b4a 100644 --- a/tools/remotecommand/errorstream.go +++ b/tools/remotecommand/errorstream.go @@ -41,7 +41,7 @@ func watchErrorStream(errorStream io.Reader, d errorStreamDecoder) chan error { message, err := io.ReadAll(errorStream) switch { case err != nil && err != io.EOF: - errorChan <- fmt.Errorf("error reading from error stream: %s", err) + errorChan <- fmt.Errorf("error reading from error stream: %w", err) case len(message) > 0: errorChan <- d.decode(message) default: diff --git a/tools/remotecommand/v2_test.go b/tools/remotecommand/v2_test.go index e303f57a9..412cee8d2 100644 --- a/tools/remotecommand/v2_test.go +++ b/tools/remotecommand/v2_test.go @@ -19,12 +19,13 @@ package remotecommand import ( "errors" "io" + "net" "net/http" "strings" "testing" "time" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/wait" ) @@ -181,17 +182,34 @@ func TestV2ErrorStreamReading(t *testing.T) { tests := []struct { name string stream io.Reader - expectedError error + expectedError func(*testing.T, error) }{ { - name: "error reading from stream", - stream: &fakeReader{errors.New("foo")}, - expectedError: errors.New("error reading from error stream: foo"), + name: "error reading from stream", + stream: &fakeReader{errors.New("foo")}, + expectedError: func(t *testing.T, err error) { + if e, a := "error reading from error stream: foo", err.Error(); e != a { + t.Errorf("expected '%s', got '%s'", e, a) + } + }, }, { - name: "stream returns an error", - stream: strings.NewReader("some error"), - expectedError: errors.New("error executing remote command: some error"), + name: "stream returns an error", + stream: strings.NewReader("some error"), + expectedError: func(t *testing.T, err error) { + if e, a := "error executing remote command: some error", err.Error(); e != a { + t.Errorf("expected '%s', got '%s'", e, a) + } + }, + }, + { + name: "typed error", + stream: &fakeReader{net.ErrClosed}, + expectedError: func(t *testing.T, err error) { + if !errors.Is(err, net.ErrClosed) { + t.Errorf("expected errors.Is(err, net.ErrClosed), failed on %#v", err) + } + }, }, } @@ -214,8 +232,8 @@ func TestV2ErrorStreamReading(t *testing.T) { if test.expectedError != nil { if err == nil { t.Errorf("%s: expected an error", test.name) - } else if e, a := test.expectedError, err; e.Error() != a.Error() { - t.Errorf("%s: expected %q, got %q", test.name, e, a) + } else { + test.expectedError(t, err) } continue }