diff --git a/key_test.go b/key_test.go index 82129dac..2038a2e6 100644 --- a/key_test.go +++ b/key_test.go @@ -190,9 +190,6 @@ func TestGenerateEd25519(t *testing.T) { if _, err = key.MarshalPKIXPublicKeyPEM(); err != nil { t.Fatal(err) } - if _, err = key.MarshalPKCS1PrivateKeyPEM(); err != nil { - t.Fatal(err) - } } func TestSign(t *testing.T) { @@ -431,9 +428,7 @@ func TestMarshalEd25519(t *testing.T) { t.Fatal("invalid cert pem bytes") } - if _, err = key.MarshalPKCS1PrivateKeyPEM(); err != nil { - t.Fatal(err) - } + // NOTE: Ed25519 cannot be marshalled to PEM. if _, err := key.MarshalPKCS1PrivateKeyDER(); err != nil { t.Fatal(err) diff --git a/ssl_test.go b/ssl_test.go index db48aeb2..b1309c59 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -193,7 +193,7 @@ func SimpleConnTest(t testing.TB, constructor func( } buf := bytes.NewBuffer(make([]byte, 0, len(data))) - _, err = io.CopyN(buf, server, int64(len(data))) + _, err = io.Copy(buf, server) if err != nil { t.Fatal(err) } @@ -201,10 +201,8 @@ func SimpleConnTest(t testing.TB, constructor func( t.Fatal("mismatched data") } - err = server.Close() - if err != nil { - t.Fatal(err) - } + // Only one side gets a clean close because closing needs to write a terminator. + _ = server.Close() }() wg.Wait() } @@ -223,10 +221,10 @@ func close_both(closer1, closer2 io.Closer) { wg.Wait() } -func ClosingTest(t testing.TB, constructor func( +func ClosingTest(t *testing.T, constructor func( t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { - run_test := func(close_tcp bool, server_writes bool) { + run_test := func(t *testing.T, close_tcp bool, server_writes bool) { server_conn, client_conn := NetPipe(t) defer server_conn.Close() defer client_conn.Close() @@ -246,12 +244,34 @@ func ClosingTest(t testing.TB, constructor func( } var wg sync.WaitGroup + + // If we're killing the TCP connection, make sure we handshake first + if close_tcp { + wg.Add(2) + go func() { + defer wg.Done() + err := sslconn1.Handshake() + if err != nil { + t.Error(err) + } + }() + go func() { + defer wg.Done() + err := sslconn2.Handshake() + if err != nil { + t.Error(err) + } + }() + wg.Wait() + } + wg.Add(2) go func() { defer wg.Done() _, err := sslconn1.Write([]byte("hello")) if err != nil { - t.Fatal(err) + t.Error(err) + return } if close_tcp { err = conn1.Close() @@ -259,28 +279,37 @@ func ClosingTest(t testing.TB, constructor func( err = sslconn1.Close() } if err != nil { - t.Fatal(err) + t.Error(err) } }() go func() { defer wg.Done() data, err := ioutil.ReadAll(sslconn2) - if err != nil { - t.Fatal(err) - } if !bytes.Equal(data, []byte("hello")) { - t.Fatal("bytes don't match") + t.Error("bytes don't match") + } + if !close_tcp && err != nil { + t.Error(err) + return } }() wg.Wait() } - run_test(true, false) - run_test(false, false) - run_test(true, true) - run_test(false, true) + t.Run("close TCP, server reads", func(t *testing.T) { + run_test(t, true, false) + }) + t.Run("close SSL, server reads", func(t *testing.T) { + run_test(t, false, false) + }) + t.Run("close TCP, server writes", func(t *testing.T) { + run_test(t, true, true) + }) + t.Run("close SSL, server writes", func(t *testing.T) { + run_test(t, false, true) + }) } func ThroughputBenchmark(b *testing.B, constructor func(