From 420b524b1ddd3c64c1b3bf826b42111f7a497cef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 3 Nov 2023 02:05:09 +0100 Subject: [PATCH] server: Fix compression protocol for larger packets (#47495) close pingcap/tidb#47152 --- pkg/server/internal/BUILD.bazel | 2 +- pkg/server/internal/packetio.go | 169 +++++++++++++++++++-------- pkg/server/internal/packetio_test.go | 144 ++++++++++++++++++++++- 3 files changed, 261 insertions(+), 54 deletions(-) diff --git a/pkg/server/internal/BUILD.bazel b/pkg/server/internal/BUILD.bazel index 9b7576c20c57c..05f0b1cf84edc 100644 --- a/pkg/server/internal/BUILD.bazel +++ b/pkg/server/internal/BUILD.bazel @@ -23,7 +23,7 @@ go_test( srcs = ["packetio_test.go"], embed = [":internal"], flaky = True, - shard_count = 4, + shard_count = 7, deps = [ "//pkg/parser/mysql", "//pkg/server/internal/testutil", diff --git a/pkg/server/internal/packetio.go b/pkg/server/internal/packetio.go index 2fae04517bcd1..cf1b10d55901a 100644 --- a/pkg/server/internal/packetio.go +++ b/pkg/server/internal/packetio.go @@ -46,7 +46,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/server/err" + server_err "github.com/pingcap/tidb/pkg/server/err" "github.com/pingcap/tidb/pkg/server/internal/util" server_metrics "github.com/pingcap/tidb/pkg/server/metrics" "github.com/pingcap/tidb/pkg/sessionctx/variable" @@ -60,6 +60,7 @@ type PacketIO struct { bufReadConn *util.BufferedReadConn bufWriter *bufio.Writer compressedWriter *compressedWriter + compressedReader *compressedReader readTimeout time.Duration // maxAllowedPacket is the maximum size of one packet in ReadPacket. maxAllowedPacket uint64 @@ -119,8 +120,10 @@ func (p *PacketIO) ResetBufWriter(w io.Writer) { // SetCompressionAlgorithm sets the compression algorithm of PacketIO. func (p *PacketIO) SetCompressionAlgorithm(ca int) { p.compressionAlgorithm = ca - p.compressedWriter = newCompressedWriter(p.bufReadConn, ca) + p.compressedWriter = newCompressedWriter(p.bufReadConn, ca, &p.compressedSequence) p.compressedWriter.zstdLevel = p.zstdLevel + p.compressedReader = newCompressedReader(p.bufReadConn, ca, &p.compressedSequence) + p.compressedReader.zstdLevel = p.zstdLevel p.bufWriter.Flush() } @@ -143,56 +146,30 @@ func (p *PacketIO) readOnePacket() ([]byte, error) { return nil, err } } - if p.compressionAlgorithm != mysql.CompressionNone { - var compressedHeader [7]byte - if _, err := io.ReadFull(p.bufReadConn, compressedHeader[:]); err != nil { + if p.compressionAlgorithm == mysql.CompressionNone { + if _, err := io.ReadFull(r, header[:]); err != nil { return nil, errors.Trace(err) } - compressedSequence := compressedHeader[3] - if compressedSequence != p.compressedSequence { - return nil, err.ErrInvalidSequence.GenWithStack( - "invalid compressed sequence %d != %d", compressedSequence, p.compressedSequence) - } - p.compressedSequence++ - p.compressedWriter.compressedSequence = p.compressedSequence - uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16) - - if uncompressedLength > 0 { - switch p.compressionAlgorithm { - case mysql.CompressionZlib: - var err error - r, err = zlib.NewReader(p.bufReadConn) - if err != nil { - return nil, errors.Trace(err) - } - case mysql.CompressionZstd: - zstdReader, err := zstd.NewReader(p.bufReadConn, zstd.WithDecoderConcurrency(1)) - if err != nil { - return nil, errors.Trace(err) - } - r = zstdReader.IOReadCloser() - default: - return nil, errors.New("Unknown compression algorithm") - } + } else { + if _, err := io.ReadFull(p.compressedReader, header[:]); err != nil { + return nil, errors.Trace(err) } } - if _, err := io.ReadFull(r, header[:]); err != nil { - return nil, errors.Trace(err) - } + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) sequence := header[3] + if sequence != p.sequence { - return nil, err.ErrInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence) + return nil, server_err.ErrInvalidSequence.GenWithStack( + "invalid sequence, received %d while expecting %d", sequence, p.sequence) } p.sequence++ - length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - // Accumulated payload length exceeds the limit. if p.accumulatedLength += uint64(length); p.accumulatedLength > p.maxAllowedPacket { - terror.Log(err.ErrNetPacketTooLarge) - return nil, err.ErrNetPacketTooLarge + terror.Log(server_err.ErrNetPacketTooLarge) + return nil, server_err.ErrNetPacketTooLarge } data := make([]byte, length) @@ -201,8 +178,14 @@ func (p *PacketIO) readOnePacket() ([]byte, error) { return nil, err } } - if _, err := io.ReadFull(r, data); err != nil { - return nil, errors.Trace(err) + if p.compressionAlgorithm == mysql.CompressionNone { + if _, err := io.ReadFull(r, data); err != nil { + return nil, errors.Trace(err) + } + } else { + if _, err := io.ReadFull(p.compressedReader, data); err != nil { + return nil, errors.Trace(err) + } } err := r.Close() if err != nil { @@ -256,11 +239,7 @@ func (p *PacketIO) ReadPacket() ([]byte, error) { func (p *PacketIO) WritePacket(data []byte) error { length := len(data) - 4 server_metrics.WritePacketBytes.Add(float64(len(data))) - maxPayloadLen := mysql.MaxPayloadLen - if p.compressionAlgorithm != mysql.CompressionNone { - maxPayloadLen -= 4 - } for length >= maxPayloadLen { data[3] = p.sequence @@ -324,15 +303,19 @@ func (p *PacketIO) Flush() error { if err != nil { return errors.Trace(err) } + + if p.compressionAlgorithm != mysql.CompressionNone { + p.sequence = p.compressedSequence + } return err } -func newCompressedWriter(w io.Writer, ca int) *compressedWriter { +func newCompressedWriter(w io.Writer, ca int, seq *uint8) *compressedWriter { return &compressedWriter{ w, new(bytes.Buffer), + seq, ca, - 0, 3, } } @@ -340,8 +323,8 @@ func newCompressedWriter(w io.Writer, ca int) *compressedWriter { type compressedWriter struct { w io.Writer buf *bytes.Buffer + compressedSequence *uint8 compressionAlgorithm int - compressedSequence uint8 zstdLevel zstd.EncoderLevel } @@ -421,7 +404,7 @@ func (cw *compressedWriter) Flush() error { compressedHeader[0] = byte(compressedLength) compressedHeader[1] = byte(compressedLength >> 8) compressedHeader[2] = byte(compressedLength >> 16) - compressedHeader[3] = cw.compressedSequence + compressedHeader[3] = *cw.compressedSequence compressedHeader[4] = byte(uncompressedLength) compressedHeader[5] = byte(uncompressedLength >> 8) compressedHeader[6] = byte(uncompressedLength >> 16) @@ -429,7 +412,7 @@ func (cw *compressedWriter) Flush() error { if err != nil { return errors.Trace(err) } - cw.compressedSequence++ + *cw.compressedSequence++ if len(data) > minCompressLength { _, err = compressedPacket.Write(payload.Bytes()) @@ -446,3 +429,89 @@ func (cw *compressedWriter) Flush() error { } return nil } + +func newCompressedReader(r io.Reader, ca int, seq *uint8) *compressedReader { + return &compressedReader{ + r, + seq, + nil, + ca, + 3, + 0, + } +} + +type compressedReader struct { + r io.Reader + compressedSequence *uint8 + data []byte + compressionAlgorithm int + zstdLevel zstd.EncoderLevel + pos uint64 +} + +func (cr *compressedReader) Read(data []byte) (n int, err error) { + if cr.data == nil { + var compressedHeader [7]byte + if _, err = io.ReadFull(cr.r, compressedHeader[:]); err != nil { + return + } + + compressedLength := int(uint32(compressedHeader[0]) | uint32(compressedHeader[1])<<8 | uint32(compressedHeader[2])<<16) + compressedSequence := compressedHeader[3] + uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16) + + if compressedSequence != *cr.compressedSequence { + return n, server_err.ErrInvalidSequence.GenWithStack( + "invalid compressed sequence, received %d while expecting %d", compressedSequence, cr.compressedSequence) + } + *cr.compressedSequence++ + + r := io.NopCloser(cr.r) + if uncompressedLength > 0 { + switch cr.compressionAlgorithm { + case mysql.CompressionZlib: + var err error + lr := io.LimitReader(cr.r, int64(compressedLength)) + r, err = zlib.NewReader(lr) + if err != nil { + return n, errors.Trace(err) + } + case mysql.CompressionZstd: + zstdReader, err := zstd.NewReader(cr.r, zstd.WithDecoderConcurrency(1)) + if err != nil { + return n, errors.Trace(err) + } + r = zstdReader.IOReadCloser() + default: + return n, errors.New("Unknown compression algorithm") + } + cr.data = make([]byte, uncompressedLength) + if _, err := io.ReadFull(r, cr.data); err != nil { + return n, errors.Trace(err) + } + n = copy(data, cr.data) + } else { + cr.data = make([]byte, compressedLength) + if _, err := io.ReadFull(r, cr.data); err != nil { + return n, errors.Trace(err) + } + n = copy(data, cr.data) + } + } else { + if cr.pos > uint64(len(cr.data)) { + return n, io.EOF + } + n = copy(data, cr.data[cr.pos:]) + } + cr.pos += uint64(n) + if cr.pos == uint64(len(cr.data)) { + cr.pos = 0 + cr.data = nil + } + return +} + +func (*compressedReader) Close() error { + return nil +} diff --git a/pkg/server/internal/packetio_test.go b/pkg/server/internal/packetio_test.go index 0953f573e0774..7e4230dcdbf05 100644 --- a/pkg/server/internal/packetio_test.go +++ b/pkg/server/internal/packetio_test.go @@ -62,6 +62,32 @@ func TestPacketIOWrite(t *testing.T) { require.Equal(t, byte(0), res[3]) } +func TestPacketIOWriteCompressed(t *testing.T) { + var testdata, outBuffer bytes.Buffer + + seq := uint8(0) + pkt := &PacketIO{ + bufWriter: bufio.NewWriter(&outBuffer), + compressionAlgorithm: mysql.CompressionZlib, + compressedWriter: newCompressedWriter(&testdata, mysql.CompressionZlib, &seq), + } + + payload := bytes.Repeat([]byte{'A'}, 16*1024*1024) + err := pkt.WritePacket(payload) + require.NoError(t, err) + + err = pkt.Flush() + require.NoError(t, err) + + compressedLength := []byte{0x18, 0x4, 0x0} // 1048 bytes + packetNr := []byte{0x0} + uncompressedLength := []byte{0x0, 0x0, 0x10} // 1048576 bytes + + require.Equal(t, compressedLength, testdata.Bytes()[:3]) + require.Equal(t, packetNr, testdata.Bytes()[3:4]) + require.Equal(t, uncompressedLength, testdata.Bytes()[4:7]) +} + func TestPacketIORead(t *testing.T) { t.Run("uncompressed", func(t *testing.T) { var inBuffer bytes.Buffer @@ -215,8 +241,9 @@ func TestPacketIORead(t *testing.T) { func TestCompressedWriterShort(t *testing.T) { var testdata bytes.Buffer payload := []byte("test_short") + seq := uint8(0) - cw := newCompressedWriter(&testdata, mysql.CompressionZlib) + cw := newCompressedWriter(&testdata, mysql.CompressionZlib, &seq) cw.Write(payload) cw.Flush() @@ -236,8 +263,9 @@ func TestCompressedWriterLong(t *testing.T) { t.Run("zlib", func(t *testing.T) { var testdata, decoded bytes.Buffer payload := []byte("test_zlib test_zlib test_zlib test_zlib test_zlib test_zlib test_zlib") + seq := uint8(0) - cw := newCompressedWriter(&testdata, mysql.CompressionZlib) + cw := newCompressedWriter(&testdata, mysql.CompressionZlib, &seq) cw.Write(payload) cw.Flush() @@ -262,8 +290,9 @@ func TestCompressedWriterLong(t *testing.T) { t.Run("zstd", func(t *testing.T) { var testdata bytes.Buffer payload := []byte("test_zstd test_zstd test_zstd test_zstd test_zstd test_zstd test_zstd") + seq := uint8(0) - cw := newCompressedWriter(&testdata, mysql.CompressionZstd) + cw := newCompressedWriter(&testdata, mysql.CompressionZstd, &seq) cw.Write(payload) cw.Flush() @@ -286,3 +315,112 @@ func TestCompressedWriterLong(t *testing.T) { require.Equal(t, payload, decoded) }) } + +// TestCompressedReaderShort test a compressed protocol packet that has an uncompressed +// length of 0, which means the actual payload isn't compressed. +func TestCompressedReaderShort(t *testing.T) { + // payload: 7 bytes compressed header, 37 bytes uncompressed payload + payload := []byte{0x25, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x21, 0x0, 0x0, 0x0, 0x3, + 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, + 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31} + r := bytes.NewReader(payload) + seq := uint8(0) + cr := newCompressedReader(r, mysql.CompressionZlib, &seq) + + // Read 4 byte header from the payload. This is the regular packet header, + // not the compressed header. + header := make([]byte, 4) + _, err := cr.Read(header) + require.NoError(t, err) + require.Equal(t, []byte{0x21, 0x0, 0x0, 0x0}, header) + + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + sequence := header[3] + require.Equal(t, 33, length) + require.Equal(t, uint8(0), sequence) + + data := make([]byte, length) + _, err = cr.Read(data) + require.NoError(t, err) + expected := []byte{0x3, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, + 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, + 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31} + require.Equal(t, expected, data) +} + +func TestCompressedReaderLong(t *testing.T) { + t.Run("zlib", func(t *testing.T) { + payload := []byte{0x19, 0x0, 0x0, 0x0, 0x9c, 0x0, 0x0, 0x78, 0x5e, 0x9b, + 0xc1, 0xc0, 0xc0, 0xc0, 0x1c, 0xec, 0xea, 0xe3, 0xea, 0x1c, 0xa2, + 0xa0, 0xe4, 0x38, 0xa8, 0x80, 0x12, 0x0, 0xbe, 0xe6, 0x26, 0xce} + r := bytes.NewReader(payload) + seq := uint8(0) + cr := newCompressedReader(r, mysql.CompressionZlib, &seq) + header := make([]byte, 4) + _, err := cr.Read(header) + require.NoError(t, err) + require.Equal(t, []byte{0x98, 0x0, 0x0, 0x0}, header) + + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + sequence := header[3] + require.Equal(t, 152, length) + require.Equal(t, uint8(0), sequence) + + data := make([]byte, length) + _, err = cr.Read(data) + require.NoError(t, err) + expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x22} + require.Equal(t, expected, data) + }) + t.Run("zstd", func(t *testing.T) { + payload := []byte{0x1f, 0x0, 0x0, 0x0, 0x9c, 0x0, 0x0, 0x28, 0xb5, 0x2f, 0xfd, + 0x20, 0x9c, 0xb5, 0x0, 0x0, 0x78, 0x98, 0x0, 0x0, 0x0, 0x3, 0x53, + 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, 0x41, 0x22, 0x1, 0x0, 0xa, + 0xa, 0x28, 0x1} + r := bytes.NewReader(payload) + seq := uint8(0) + cr := newCompressedReader(r, mysql.CompressionZstd, &seq) + header := make([]byte, 4) + _, err := cr.Read(header) + require.NoError(t, err) + require.Equal(t, []byte{0x98, 0x0, 0x0, 0x0}, header) + + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + sequence := header[3] + require.Equal(t, 152, length) + require.Equal(t, uint8(0), sequence) + + data := make([]byte, length) + _, err = cr.Read(data) + require.NoError(t, err) + expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x22} + require.Equal(t, expected, data) + }) +}