diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 8f11a4ec..393d4660 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -262,8 +262,8 @@ func (p *PacketIO) Flush() error { } func (p *PacketIO) GracefulClose() error { - if p.conn != nil { - return p.conn.SetDeadline(time.Now()) + if err := p.conn.SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) { + return err } return nil } @@ -276,11 +276,8 @@ func (p *PacketIO) Close() error { errs = append(errs, err) } */ - if p.conn != nil { - if err := p.conn.Close(); err != nil { - errs = append(errs, err) - } - p.conn = nil + if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + errs = append(errs, err) } return p.wrapErr(errors.Collect(ErrCloseConn, errs...)) } diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 54e78741..281853b8 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -184,3 +184,23 @@ func TestTLS(t *testing.T) { 500, // unable to reproduce stably, loop 500 times ) } + +func TestPacketIOClose(t *testing.T) { + testTCPConn(t, + func(t *testing.T, cli *PacketIO) { + require.NoError(t, cli.Close()) + require.NoError(t, cli.Close()) + require.NoError(t, cli.GracefulClose()) + require.NotEqual(t, cli.LocalAddr(), "") + require.NotEqual(t, cli.RemoteAddr(), "") + }, + func(t *testing.T, srv *PacketIO) { + require.NoError(t, srv.GracefulClose()) + require.NoError(t, srv.Close()) + require.NoError(t, srv.Close()) + require.NotEqual(t, srv.LocalAddr(), "") + require.NotEqual(t, srv.RemoteAddr(), "") + }, + 1, + ) +}