From 28567d8f1b0461dcfe53481b2789a19b004f655d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Mon, 9 Oct 2023 20:43:34 +0200 Subject: [PATCH 01/11] server: Fix compression protocol for larger packets --- pkg/server/internal/packetio.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/server/internal/packetio.go b/pkg/server/internal/packetio.go index 2fae04517bcd1..002b1fa88876c 100644 --- a/pkg/server/internal/packetio.go +++ b/pkg/server/internal/packetio.go @@ -256,11 +256,7 @@ func (p *PacketIO) ReadPacket() ([]byte, error) { func (p *PacketIO) WritePacket(data []byte) error { length := len(data) - 4 server_metrics.WritePacketBytes.Add(float64(len(data))) - maxPayloadLen := mysql.MaxPayloadLen - if p.compressionAlgorithm != mysql.CompressionNone { - maxPayloadLen -= 4 - } for length >= maxPayloadLen { data[3] = p.sequence From 051d235654d092089efe544b3164d1a5e79dc2fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 10 Oct 2023 08:27:13 +0200 Subject: [PATCH 02/11] Add test: TestPacketIOWriteCompressed --- pkg/server/internal/packetio_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pkg/server/internal/packetio_test.go b/pkg/server/internal/packetio_test.go index 0953f573e0774..fc8a5e034eb59 100644 --- a/pkg/server/internal/packetio_test.go +++ b/pkg/server/internal/packetio_test.go @@ -62,6 +62,31 @@ func TestPacketIOWrite(t *testing.T) { require.Equal(t, byte(0), res[3]) } +func TestPacketIOWriteCompressed(t *testing.T) { + var testdata, outBuffer bytes.Buffer + + pkt := &PacketIO{ + bufWriter: bufio.NewWriter(&outBuffer), + compressionAlgorithm: mysql.CompressionZlib, + compressedWriter: newCompressedWriter(&testdata, mysql.CompressionZlib), + } + + payload := bytes.Repeat([]byte{'A'}, 16*1024*1024) + err := pkt.WritePacket(payload) + require.NoError(t, err) + + err = pkt.Flush() + require.NoError(t, err) + + compressedLength := []byte{0x18, 0x4, 0x0} // 1048 bytes + packetNr := []byte{0x0} + uncompressedLength := []byte{0x0, 0x0, 0x10} // 1048576 bytes + + require.Equal(t, compressedLength, testdata.Bytes()[:3]) + require.Equal(t, packetNr, testdata.Bytes()[3:4]) + require.Equal(t, uncompressedLength, testdata.Bytes()[4:7]) +} + func TestPacketIORead(t *testing.T) { t.Run("uncompressed", func(t *testing.T) { var inBuffer bytes.Buffer From 3e5a5255bfedcb3bada15ba35e6c1cc70738ad41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Tue, 10 Oct 2023 08:42:46 +0200 Subject: [PATCH 03/11] Update bazel build --- pkg/server/internal/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/server/internal/BUILD.bazel b/pkg/server/internal/BUILD.bazel index 9b7576c20c57c..a8375bcb32cdf 100644 --- a/pkg/server/internal/BUILD.bazel +++ b/pkg/server/internal/BUILD.bazel @@ -23,7 +23,7 @@ go_test( srcs = ["packetio_test.go"], embed = [":internal"], flaky = True, - shard_count = 4, + shard_count = 5, deps = [ "//pkg/parser/mysql", "//pkg/server/internal/testutil", From 7d944887413427d0db1ccd3220ecd0b572426b46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Mon, 16 Oct 2023 10:52:53 +0200 Subject: [PATCH 04/11] server: Avoid overreading compressed packets --- pkg/server/internal/packetio.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/server/internal/packetio.go b/pkg/server/internal/packetio.go index 002b1fa88876c..51909b60fee6f 100644 --- a/pkg/server/internal/packetio.go +++ b/pkg/server/internal/packetio.go @@ -155,13 +155,15 @@ func (p *PacketIO) readOnePacket() ([]byte, error) { } p.compressedSequence++ p.compressedWriter.compressedSequence = p.compressedSequence + compressedLength := int(uint32(compressedHeader[0]) | uint32(compressedHeader[1])<<8 | uint32(compressedHeader[2])<<16) uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16) if uncompressedLength > 0 { switch p.compressionAlgorithm { case mysql.CompressionZlib: var err error - r, err = zlib.NewReader(p.bufReadConn) + lr := io.LimitReader(p.bufReadConn, int64(compressedLength)) + r, err = zlib.NewReader(lr) if err != nil { return nil, errors.Trace(err) } From b747c7f1f9212474854863fa1c18a91bc3dee027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 20 Oct 2023 08:32:15 +0200 Subject: [PATCH 05/11] Improve error message for invalid sequence --- pkg/server/internal/packetio.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/server/internal/packetio.go b/pkg/server/internal/packetio.go index 51909b60fee6f..281af7b010be0 100644 --- a/pkg/server/internal/packetio.go +++ b/pkg/server/internal/packetio.go @@ -151,7 +151,7 @@ func (p *PacketIO) readOnePacket() ([]byte, error) { compressedSequence := compressedHeader[3] if compressedSequence != p.compressedSequence { return nil, err.ErrInvalidSequence.GenWithStack( - "invalid compressed sequence %d != %d", compressedSequence, p.compressedSequence) + "invalid compressed sequence, received %d while expecting %d", compressedSequence, p.compressedSequence) } p.compressedSequence++ p.compressedWriter.compressedSequence = p.compressedSequence @@ -184,7 +184,7 @@ func (p *PacketIO) readOnePacket() ([]byte, error) { sequence := header[3] if sequence != p.sequence { - return nil, err.ErrInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence) + return nil, err.ErrInvalidSequence.GenWithStack("invalid sequence, received %d while expecting %d", sequence, p.sequence) } p.sequence++ From b28928fca54f648893b3e2791f83981a9b265c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 20 Oct 2023 15:25:15 +0200 Subject: [PATCH 06/11] Use io.Reader interface for compressed protocol --- pkg/server/internal/packetio.go | 168 +++++++++++++++++++-------- pkg/server/internal/packetio_test.go | 12 +- 2 files changed, 128 insertions(+), 52 deletions(-) diff --git a/pkg/server/internal/packetio.go b/pkg/server/internal/packetio.go index 281af7b010be0..b674309e9f677 100644 --- a/pkg/server/internal/packetio.go +++ b/pkg/server/internal/packetio.go @@ -46,7 +46,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/server/err" + server_err "github.com/pingcap/tidb/pkg/server/err" "github.com/pingcap/tidb/pkg/server/internal/util" server_metrics "github.com/pingcap/tidb/pkg/server/metrics" "github.com/pingcap/tidb/pkg/sessionctx/variable" @@ -60,6 +60,7 @@ type PacketIO struct { bufReadConn *util.BufferedReadConn bufWriter *bufio.Writer compressedWriter *compressedWriter + compressedReader *compressedReader readTimeout time.Duration // maxAllowedPacket is the maximum size of one packet in ReadPacket. maxAllowedPacket uint64 @@ -119,8 +120,10 @@ func (p *PacketIO) ResetBufWriter(w io.Writer) { // SetCompressionAlgorithm sets the compression algorithm of PacketIO. func (p *PacketIO) SetCompressionAlgorithm(ca int) { p.compressionAlgorithm = ca - p.compressedWriter = newCompressedWriter(p.bufReadConn, ca) + p.compressedWriter = newCompressedWriter(p.bufReadConn, ca, &p.compressedSequence) p.compressedWriter.zstdLevel = p.zstdLevel + p.compressedReader = newCompressedReader(p.bufReadConn, ca, &p.compressedSequence) + p.compressedReader.zstdLevel = p.zstdLevel p.bufWriter.Flush() } @@ -143,58 +146,32 @@ func (p *PacketIO) readOnePacket() ([]byte, error) { return nil, err } } - if p.compressionAlgorithm != mysql.CompressionNone { - var compressedHeader [7]byte - if _, err := io.ReadFull(p.bufReadConn, compressedHeader[:]); err != nil { + if p.compressionAlgorithm == mysql.CompressionNone { + if _, err := io.ReadFull(r, header[:]); err != nil { return nil, errors.Trace(err) } - compressedSequence := compressedHeader[3] - if compressedSequence != p.compressedSequence { - return nil, err.ErrInvalidSequence.GenWithStack( - "invalid compressed sequence, received %d while expecting %d", compressedSequence, p.compressedSequence) - } - p.compressedSequence++ - p.compressedWriter.compressedSequence = p.compressedSequence - compressedLength := int(uint32(compressedHeader[0]) | uint32(compressedHeader[1])<<8 | uint32(compressedHeader[2])<<16) - uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16) - - if uncompressedLength > 0 { - switch p.compressionAlgorithm { - case mysql.CompressionZlib: - var err error - lr := io.LimitReader(p.bufReadConn, int64(compressedLength)) - r, err = zlib.NewReader(lr) - if err != nil { - return nil, errors.Trace(err) - } - case mysql.CompressionZstd: - zstdReader, err := zstd.NewReader(p.bufReadConn, zstd.WithDecoderConcurrency(1)) - if err != nil { - return nil, errors.Trace(err) - } - r = zstdReader.IOReadCloser() - default: - return nil, errors.New("Unknown compression algorithm") - } + } else { + if _, err := io.ReadFull(p.compressedReader, header[:]); err != nil { + return nil, errors.Trace(err) } - } - if _, err := io.ReadFull(r, header[:]); err != nil { - return nil, errors.Trace(err) + //p.compressedSequence = p.compressedReader.compressedSequence + //p.compressedWriter.compressedSequence = p.compressedSequence } + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) sequence := header[3] + if sequence != p.sequence { - return nil, err.ErrInvalidSequence.GenWithStack("invalid sequence, received %d while expecting %d", sequence, p.sequence) + return nil, server_err.ErrInvalidSequence.GenWithStack( + "invalid sequence, received %d while expecting %d", sequence, p.sequence) } p.sequence++ - length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - // Accumulated payload length exceeds the limit. if p.accumulatedLength += uint64(length); p.accumulatedLength > p.maxAllowedPacket { - terror.Log(err.ErrNetPacketTooLarge) - return nil, err.ErrNetPacketTooLarge + terror.Log(server_err.ErrNetPacketTooLarge) + return nil, server_err.ErrNetPacketTooLarge } data := make([]byte, length) @@ -203,8 +180,16 @@ func (p *PacketIO) readOnePacket() ([]byte, error) { return nil, err } } - if _, err := io.ReadFull(r, data); err != nil { - return nil, errors.Trace(err) + if p.compressionAlgorithm == mysql.CompressionNone { + if _, err := io.ReadFull(r, data); err != nil { + return nil, errors.Trace(err) + } + } else { + if _, err := io.ReadFull(p.compressedReader, data); err != nil { + return nil, errors.Trace(err) + } + //p.compressedSequence = p.compressedReader.compressedSequence + //p.compressedWriter.compressedSequence = p.compressedSequence } err := r.Close() if err != nil { @@ -325,12 +310,12 @@ func (p *PacketIO) Flush() error { return err } -func newCompressedWriter(w io.Writer, ca int) *compressedWriter { +func newCompressedWriter(w io.Writer, ca int, seq *uint8) *compressedWriter { return &compressedWriter{ w, new(bytes.Buffer), ca, - 0, + seq, 3, } } @@ -339,7 +324,7 @@ type compressedWriter struct { w io.Writer buf *bytes.Buffer compressionAlgorithm int - compressedSequence uint8 + compressedSequence *uint8 zstdLevel zstd.EncoderLevel } @@ -419,7 +404,7 @@ func (cw *compressedWriter) Flush() error { compressedHeader[0] = byte(compressedLength) compressedHeader[1] = byte(compressedLength >> 8) compressedHeader[2] = byte(compressedLength >> 16) - compressedHeader[3] = cw.compressedSequence + compressedHeader[3] = *cw.compressedSequence compressedHeader[4] = byte(uncompressedLength) compressedHeader[5] = byte(uncompressedLength >> 8) compressedHeader[6] = byte(uncompressedLength >> 16) @@ -427,7 +412,7 @@ func (cw *compressedWriter) Flush() error { if err != nil { return errors.Trace(err) } - cw.compressedSequence++ + *cw.compressedSequence++ if len(data) > minCompressLength { _, err = compressedPacket.Write(payload.Bytes()) @@ -444,3 +429,90 @@ func (cw *compressedWriter) Flush() error { } return nil } + +func newCompressedReader(r io.Reader, ca int, seq *uint8) *compressedReader { + return &compressedReader{ + r, + ca, + seq, + 3, + nil, + 0, + } +} + +type compressedReader struct { + r io.Reader + compressionAlgorithm int + compressedSequence *uint8 + zstdLevel zstd.EncoderLevel + data []byte + pos uint64 +} + +func (cr *compressedReader) Read(data []byte) (n int, err error) { + if cr.data == nil { + + var compressedHeader [7]byte + if _, err = io.ReadFull(cr.r, compressedHeader[:]); err != nil { + return + } + + compressedLength := int(uint32(compressedHeader[0]) | uint32(compressedHeader[1])<<8 | uint32(compressedHeader[2])<<16) + compressedSequence := compressedHeader[3] + uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16) + + if compressedSequence != *cr.compressedSequence { + return n, server_err.ErrInvalidSequence.GenWithStack( + "invalid compressed sequence, received %d while expecting %d", compressedSequence, cr.compressedSequence) + } + *cr.compressedSequence++ + + r := io.NopCloser(cr.r) + if uncompressedLength > 0 { + switch cr.compressionAlgorithm { + case mysql.CompressionZlib: + var err error + lr := io.LimitReader(cr.r, int64(compressedLength)) + r, err = zlib.NewReader(lr) + if err != nil { + return n, errors.Trace(err) + } + case mysql.CompressionZstd: + zstdReader, err := zstd.NewReader(cr.r, zstd.WithDecoderConcurrency(1)) + if err != nil { + return n, errors.Trace(err) + } + r = zstdReader.IOReadCloser() + default: + return n, errors.New("Unknown compression algorithm") + } + cr.data = make([]byte, uncompressedLength) + if _, err := io.ReadFull(r, cr.data); err != nil { + return n, errors.Trace(err) + } + n = copy(data, cr.data) + } else { + cr.data = make([]byte, compressedLength) + if _, err := io.ReadFull(r, cr.data); err != nil { + return n, errors.Trace(err) + } + n = copy(data, cr.data) + } + } else { + if cr.pos > uint64(len(cr.data)) { + return n, io.EOF + } + n = copy(data, cr.data[cr.pos:]) + } + cr.pos += uint64(n) + if cr.pos == uint64(len(cr.data)) { + cr.pos = 0 + cr.data = nil + } + return +} + +func (cr *compressedReader) Close() error { + return nil +} diff --git a/pkg/server/internal/packetio_test.go b/pkg/server/internal/packetio_test.go index fc8a5e034eb59..573ec69b16511 100644 --- a/pkg/server/internal/packetio_test.go +++ b/pkg/server/internal/packetio_test.go @@ -65,10 +65,11 @@ func TestPacketIOWrite(t *testing.T) { func TestPacketIOWriteCompressed(t *testing.T) { var testdata, outBuffer bytes.Buffer + seq := uint8(0) pkt := &PacketIO{ bufWriter: bufio.NewWriter(&outBuffer), compressionAlgorithm: mysql.CompressionZlib, - compressedWriter: newCompressedWriter(&testdata, mysql.CompressionZlib), + compressedWriter: newCompressedWriter(&testdata, mysql.CompressionZlib, &seq), } payload := bytes.Repeat([]byte{'A'}, 16*1024*1024) @@ -240,8 +241,9 @@ func TestPacketIORead(t *testing.T) { func TestCompressedWriterShort(t *testing.T) { var testdata bytes.Buffer payload := []byte("test_short") + seq := uint8(0) - cw := newCompressedWriter(&testdata, mysql.CompressionZlib) + cw := newCompressedWriter(&testdata, mysql.CompressionZlib, &seq) cw.Write(payload) cw.Flush() @@ -261,8 +263,9 @@ func TestCompressedWriterLong(t *testing.T) { t.Run("zlib", func(t *testing.T) { var testdata, decoded bytes.Buffer payload := []byte("test_zlib test_zlib test_zlib test_zlib test_zlib test_zlib test_zlib") + seq := uint8(0) - cw := newCompressedWriter(&testdata, mysql.CompressionZlib) + cw := newCompressedWriter(&testdata, mysql.CompressionZlib, &seq) cw.Write(payload) cw.Flush() @@ -287,8 +290,9 @@ func TestCompressedWriterLong(t *testing.T) { t.Run("zstd", func(t *testing.T) { var testdata bytes.Buffer payload := []byte("test_zstd test_zstd test_zstd test_zstd test_zstd test_zstd test_zstd") + seq := uint8(0) - cw := newCompressedWriter(&testdata, mysql.CompressionZstd) + cw := newCompressedWriter(&testdata, mysql.CompressionZstd, &seq) cw.Write(payload) cw.Flush() From c3f638ea4e780946daa6a404e03ad16bd08bce9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 20 Oct 2023 18:24:56 +0200 Subject: [PATCH 07/11] Fix issues found by CI --- pkg/server/internal/packetio.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pkg/server/internal/packetio.go b/pkg/server/internal/packetio.go index b674309e9f677..9646754e76d01 100644 --- a/pkg/server/internal/packetio.go +++ b/pkg/server/internal/packetio.go @@ -314,8 +314,8 @@ func newCompressedWriter(w io.Writer, ca int, seq *uint8) *compressedWriter { return &compressedWriter{ w, new(bytes.Buffer), - ca, seq, + ca, 3, } } @@ -323,8 +323,8 @@ func newCompressedWriter(w io.Writer, ca int, seq *uint8) *compressedWriter { type compressedWriter struct { w io.Writer buf *bytes.Buffer - compressionAlgorithm int compressedSequence *uint8 + compressionAlgorithm int zstdLevel zstd.EncoderLevel } @@ -433,26 +433,25 @@ func (cw *compressedWriter) Flush() error { func newCompressedReader(r io.Reader, ca int, seq *uint8) *compressedReader { return &compressedReader{ r, - ca, seq, - 3, nil, + ca, + 3, 0, } } type compressedReader struct { r io.Reader - compressionAlgorithm int compressedSequence *uint8 - zstdLevel zstd.EncoderLevel data []byte + compressionAlgorithm int + zstdLevel zstd.EncoderLevel pos uint64 } func (cr *compressedReader) Read(data []byte) (n int, err error) { if cr.data == nil { - var compressedHeader [7]byte if _, err = io.ReadFull(cr.r, compressedHeader[:]); err != nil { return @@ -513,6 +512,6 @@ func (cr *compressedReader) Read(data []byte) (n int, err error) { return } -func (cr *compressedReader) Close() error { +func (*compressedReader) Close() error { return nil } From 3e7ecdbb347a8bc6a98ce31acc17d405f0baa795 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Sun, 22 Oct 2023 12:30:07 +0200 Subject: [PATCH 08/11] Add tests for compressed reader --- pkg/server/internal/BUILD.bazel | 2 +- pkg/server/internal/packetio_test.go | 106 +++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/pkg/server/internal/BUILD.bazel b/pkg/server/internal/BUILD.bazel index a8375bcb32cdf..05f0b1cf84edc 100644 --- a/pkg/server/internal/BUILD.bazel +++ b/pkg/server/internal/BUILD.bazel @@ -23,7 +23,7 @@ go_test( srcs = ["packetio_test.go"], embed = [":internal"], flaky = True, - shard_count = 5, + shard_count = 7, deps = [ "//pkg/parser/mysql", "//pkg/server/internal/testutil", diff --git a/pkg/server/internal/packetio_test.go b/pkg/server/internal/packetio_test.go index 573ec69b16511..ab4d0e8cddaf6 100644 --- a/pkg/server/internal/packetio_test.go +++ b/pkg/server/internal/packetio_test.go @@ -315,3 +315,109 @@ func TestCompressedWriterLong(t *testing.T) { require.Equal(t, payload, decoded) }) } + +// TestCompressedReaderShort test a compressed protocol packet that has an uncompressed +// length of 0, which means the actual payload isn't compressed. +func TestCompressedReaderShort(t *testing.T) { + // payload: 7 bytes compressed header, 37 bytes uncompressed payload + payload := []byte{0x25, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x21, 0x0, 0x0, 0x0, 0x3, + 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, + 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31} + r := bytes.NewReader(payload) + seq := uint8(0) + cr := newCompressedReader(r, mysql.CompressionZlib, &seq) + + // Read 4 byte header from the payload. This is the regular packet header, + // not the compressed header. + header := make([]byte, 4) + _, err := cr.Read(header) + require.NoError(t, err) + require.Equal(t, []byte{0x21, 0x0, 0x0, 0x0}, header) + + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + sequence := header[3] + require.Equal(t, 33, length) + require.Equal(t, uint8(0), sequence) + + data := make([]byte, length) + _, err = cr.Read(data) + expected := []byte{0x3, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, + 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, + 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31} + require.Equal(t, expected, data) +} + +func TestCompressedReaderLong(t *testing.T) { + t.Run("zlib", func(t *testing.T) { + payload := []byte{0x19, 0x0, 0x0, 0x0, 0x9c, 0x0, 0x0, 0x78, 0x5e, 0x9b, + 0xc1, 0xc0, 0xc0, 0xc0, 0x1c, 0xec, 0xea, 0xe3, 0xea, 0x1c, 0xa2, + 0xa0, 0xe4, 0x38, 0xa8, 0x80, 0x12, 0x0, 0xbe, 0xe6, 0x26, 0xce} + r := bytes.NewReader(payload) + seq := uint8(0) + cr := newCompressedReader(r, mysql.CompressionZlib, &seq) + header := make([]byte, 4) + _, err := cr.Read(header) + require.NoError(t, err) + require.Equal(t, []byte{0x98, 0x0, 0x0, 0x0}, header) + + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + sequence := header[3] + require.Equal(t, 152, length) + require.Equal(t, uint8(0), sequence) + + data := make([]byte, length) + _, err = cr.Read(data) + expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x22} + require.Equal(t, expected, data) + }) + t.Run("zstd", func(t *testing.T) { + payload := []byte{0x1f, 0x0, 0x0, 0x0, 0x9c, 0x0, 0x0, 0x28, 0xb5, 0x2f, 0xfd, + 0x20, 0x9c, 0xb5, 0x0, 0x0, 0x78, 0x98, 0x0, 0x0, 0x0, 0x3, 0x53, + 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, 0x41, 0x22, 0x1, 0x0, 0xa, + 0xa, 0x28, 0x1} + r := bytes.NewReader(payload) + seq := uint8(0) + cr := newCompressedReader(r, mysql.CompressionZstd, &seq) + header := make([]byte, 4) + _, err := cr.Read(header) + require.NoError(t, err) + require.Equal(t, []byte{0x98, 0x0, 0x0, 0x0}, header) + + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + sequence := header[3] + require.Equal(t, 152, length) + require.Equal(t, uint8(0), sequence) + + data := make([]byte, length) + _, err = cr.Read(data) + expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x22} + require.Equal(t, expected, data) + }) +} From 88b3bf032cb5d3f4c22d2b25c36207a9080ea3d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Sun, 22 Oct 2023 12:31:49 +0200 Subject: [PATCH 09/11] Cleanup commented out code --- pkg/server/internal/packetio.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/server/internal/packetio.go b/pkg/server/internal/packetio.go index 9646754e76d01..28410235436c9 100644 --- a/pkg/server/internal/packetio.go +++ b/pkg/server/internal/packetio.go @@ -154,8 +154,6 @@ func (p *PacketIO) readOnePacket() ([]byte, error) { if _, err := io.ReadFull(p.compressedReader, header[:]); err != nil { return nil, errors.Trace(err) } - //p.compressedSequence = p.compressedReader.compressedSequence - //p.compressedWriter.compressedSequence = p.compressedSequence } length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) @@ -188,8 +186,6 @@ func (p *PacketIO) readOnePacket() ([]byte, error) { if _, err := io.ReadFull(p.compressedReader, data); err != nil { return nil, errors.Trace(err) } - //p.compressedSequence = p.compressedReader.compressedSequence - //p.compressedWriter.compressedSequence = p.compressedSequence } err := r.Close() if err != nil { From 8269b6109faf493bb29c0afa0cc06d5b73af01e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Sun, 22 Oct 2023 12:44:26 +0200 Subject: [PATCH 10/11] Fix ineffectual assignment to err --- pkg/server/internal/packetio_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/server/internal/packetio_test.go b/pkg/server/internal/packetio_test.go index ab4d0e8cddaf6..7e4230dcdbf05 100644 --- a/pkg/server/internal/packetio_test.go +++ b/pkg/server/internal/packetio_test.go @@ -342,6 +342,7 @@ func TestCompressedReaderShort(t *testing.T) { data := make([]byte, length) _, err = cr.Read(data) + require.NoError(t, err) expected := []byte{0x3, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31} @@ -368,6 +369,7 @@ func TestCompressedReaderLong(t *testing.T) { data := make([]byte, length) _, err = cr.Read(data) + require.NoError(t, err) expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, @@ -404,6 +406,7 @@ func TestCompressedReaderLong(t *testing.T) { data := make([]byte, length) _, err = cr.Read(data) + require.NoError(t, err) expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, From d0fe2dce858c37576a12c688ac35a297808eae31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Thu, 2 Nov 2023 13:24:42 +0100 Subject: [PATCH 11/11] Sync sequence like MySQL does --- pkg/server/internal/packetio.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/server/internal/packetio.go b/pkg/server/internal/packetio.go index 28410235436c9..cf1b10d55901a 100644 --- a/pkg/server/internal/packetio.go +++ b/pkg/server/internal/packetio.go @@ -303,6 +303,10 @@ func (p *PacketIO) Flush() error { if err != nil { return errors.Trace(err) } + + if p.compressionAlgorithm != mysql.CompressionNone { + p.sequence = p.compressedSequence + } return err }