diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go index bf0aa1fe23..e981cc49a6 100644 --- a/ssh/client_auth_test.go +++ b/ssh/client_auth_test.go @@ -641,17 +641,28 @@ func TestClientAuthMaxAuthTries(t *testing.T) { defer c1.Close() defer c2.Close() - go newServer(c1, serverConfig) - _, _, _, err = NewClientConn(c2, "", clientConfig) - if tries > 2 { - if err == nil { + errCh := make(chan error, 1) + + go func() { + _, err := newServer(c1, serverConfig) + errCh <- err + }() + _, _, _, cliErr := NewClientConn(c2, "", clientConfig) + srvErr := <-errCh + + if tries > serverConfig.MaxAuthTries { + if cliErr == nil { t.Fatalf("client: got no error, want %s", expectedErr) - } else if err.Error() != expectedErr.Error() { + } else if cliErr.Error() != expectedErr.Error() { t.Fatalf("client: got %s, want %s", err, expectedErr) } + var authErr *ServerAuthError + if !errors.As(srvErr, &authErr) { + t.Errorf("expected ServerAuthError, got: %v", srvErr) + } } else { - if err != nil { - t.Fatalf("client: got %s, want no error", err) + if cliErr != nil { + t.Fatalf("client: got %s, want no error", cliErr) } } } diff --git a/ssh/server.go b/ssh/server.go index 3ca9e89e22..c0d1c29e6f 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -510,8 +510,8 @@ userAuthLoop: if err := s.transport.writePacket(Marshal(discMsg)); err != nil { return nil, err } - - return nil, discMsg + authErrs = append(authErrs, discMsg) + return nil, &ServerAuthError{Errors: authErrs} } var userAuthReq userAuthRequestMsg