Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: Fix compression protocol for larger packets | tidb-test=pr/2232 (#47495) #48255

Merged
2 changes: 1 addition & 1 deletion pkg/server/internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ go_test(
srcs = ["packetio_test.go"],
embed = [":internal"],
flaky = True,
shard_count = 4,
shard_count = 7,
deps = [
"//pkg/parser/mysql",
"//pkg/server/internal/testutil",
Expand Down
169 changes: 119 additions & 50 deletions pkg/server/internal/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/server/err"
server_err "github.com/pingcap/tidb/pkg/server/err"
"github.com/pingcap/tidb/pkg/server/internal/util"
server_metrics "github.com/pingcap/tidb/pkg/server/metrics"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
Expand All @@ -60,6 +60,7 @@ type PacketIO struct {
bufReadConn *util.BufferedReadConn
bufWriter *bufio.Writer
compressedWriter *compressedWriter
compressedReader *compressedReader
readTimeout time.Duration
// maxAllowedPacket is the maximum size of one packet in ReadPacket.
maxAllowedPacket uint64
Expand Down Expand Up @@ -119,8 +120,10 @@ func (p *PacketIO) ResetBufWriter(w io.Writer) {
// SetCompressionAlgorithm sets the compression algorithm of PacketIO.
func (p *PacketIO) SetCompressionAlgorithm(ca int) {
p.compressionAlgorithm = ca
p.compressedWriter = newCompressedWriter(p.bufReadConn, ca)
p.compressedWriter = newCompressedWriter(p.bufReadConn, ca, &p.compressedSequence)
p.compressedWriter.zstdLevel = p.zstdLevel
p.compressedReader = newCompressedReader(p.bufReadConn, ca, &p.compressedSequence)
p.compressedReader.zstdLevel = p.zstdLevel
p.bufWriter.Flush()
}

Expand All @@ -143,56 +146,30 @@ func (p *PacketIO) readOnePacket() ([]byte, error) {
return nil, err
}
}
if p.compressionAlgorithm != mysql.CompressionNone {
var compressedHeader [7]byte
if _, err := io.ReadFull(p.bufReadConn, compressedHeader[:]); err != nil {
if p.compressionAlgorithm == mysql.CompressionNone {
if _, err := io.ReadFull(r, header[:]); err != nil {
return nil, errors.Trace(err)
}
compressedSequence := compressedHeader[3]
if compressedSequence != p.compressedSequence {
return nil, err.ErrInvalidSequence.GenWithStack(
"invalid compressed sequence %d != %d", compressedSequence, p.compressedSequence)
}
p.compressedSequence++
p.compressedWriter.compressedSequence = p.compressedSequence
uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16)

if uncompressedLength > 0 {
switch p.compressionAlgorithm {
case mysql.CompressionZlib:
var err error
r, err = zlib.NewReader(p.bufReadConn)
if err != nil {
return nil, errors.Trace(err)
}
case mysql.CompressionZstd:
zstdReader, err := zstd.NewReader(p.bufReadConn, zstd.WithDecoderConcurrency(1))
if err != nil {
return nil, errors.Trace(err)
}
r = zstdReader.IOReadCloser()
default:
return nil, errors.New("Unknown compression algorithm")
}
} else {
if _, err := io.ReadFull(p.compressedReader, header[:]); err != nil {
return nil, errors.Trace(err)
}
}
if _, err := io.ReadFull(r, header[:]); err != nil {
return nil, errors.Trace(err)
}

length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
sequence := header[3]

if sequence != p.sequence {
return nil, err.ErrInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence)
return nil, server_err.ErrInvalidSequence.GenWithStack(
"invalid sequence, received %d while expecting %d", sequence, p.sequence)
}

p.sequence++

length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

// Accumulated payload length exceeds the limit.
if p.accumulatedLength += uint64(length); p.accumulatedLength > p.maxAllowedPacket {
terror.Log(err.ErrNetPacketTooLarge)
return nil, err.ErrNetPacketTooLarge
terror.Log(server_err.ErrNetPacketTooLarge)
return nil, server_err.ErrNetPacketTooLarge
}

data := make([]byte, length)
Expand All @@ -201,8 +178,14 @@ func (p *PacketIO) readOnePacket() ([]byte, error) {
return nil, err
}
}
if _, err := io.ReadFull(r, data); err != nil {
return nil, errors.Trace(err)
if p.compressionAlgorithm == mysql.CompressionNone {
if _, err := io.ReadFull(r, data); err != nil {
return nil, errors.Trace(err)
}
} else {
if _, err := io.ReadFull(p.compressedReader, data); err != nil {
return nil, errors.Trace(err)
}
}
err := r.Close()
if err != nil {
Expand Down Expand Up @@ -256,11 +239,7 @@ func (p *PacketIO) ReadPacket() ([]byte, error) {
func (p *PacketIO) WritePacket(data []byte) error {
length := len(data) - 4
server_metrics.WritePacketBytes.Add(float64(len(data)))

maxPayloadLen := mysql.MaxPayloadLen
if p.compressionAlgorithm != mysql.CompressionNone {
maxPayloadLen -= 4
}

for length >= maxPayloadLen {
data[3] = p.sequence
Expand Down Expand Up @@ -324,24 +303,28 @@ func (p *PacketIO) Flush() error {
if err != nil {
return errors.Trace(err)
}

if p.compressionAlgorithm != mysql.CompressionNone {
p.sequence = p.compressedSequence
}
return err
}

func newCompressedWriter(w io.Writer, ca int) *compressedWriter {
func newCompressedWriter(w io.Writer, ca int, seq *uint8) *compressedWriter {
return &compressedWriter{
w,
new(bytes.Buffer),
seq,
ca,
0,
3,
}
}

type compressedWriter struct {
w io.Writer
buf *bytes.Buffer
compressedSequence *uint8
compressionAlgorithm int
compressedSequence uint8
zstdLevel zstd.EncoderLevel
}

Expand Down Expand Up @@ -421,15 +404,15 @@ func (cw *compressedWriter) Flush() error {
compressedHeader[0] = byte(compressedLength)
compressedHeader[1] = byte(compressedLength >> 8)
compressedHeader[2] = byte(compressedLength >> 16)
compressedHeader[3] = cw.compressedSequence
compressedHeader[3] = *cw.compressedSequence
compressedHeader[4] = byte(uncompressedLength)
compressedHeader[5] = byte(uncompressedLength >> 8)
compressedHeader[6] = byte(uncompressedLength >> 16)
_, err = compressedPacket.Write(compressedHeader)
if err != nil {
return errors.Trace(err)
}
cw.compressedSequence++
*cw.compressedSequence++

if len(data) > minCompressLength {
_, err = compressedPacket.Write(payload.Bytes())
Expand All @@ -446,3 +429,89 @@ func (cw *compressedWriter) Flush() error {
}
return nil
}

func newCompressedReader(r io.Reader, ca int, seq *uint8) *compressedReader {
return &compressedReader{
r,
seq,
nil,
ca,
3,
0,
}
}

type compressedReader struct {
r io.Reader
compressedSequence *uint8
data []byte
compressionAlgorithm int
zstdLevel zstd.EncoderLevel
pos uint64
}

func (cr *compressedReader) Read(data []byte) (n int, err error) {
if cr.data == nil {
var compressedHeader [7]byte
if _, err = io.ReadFull(cr.r, compressedHeader[:]); err != nil {
return
}

compressedLength := int(uint32(compressedHeader[0]) | uint32(compressedHeader[1])<<8 | uint32(compressedHeader[2])<<16)
compressedSequence := compressedHeader[3]
uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16)

if compressedSequence != *cr.compressedSequence {
return n, server_err.ErrInvalidSequence.GenWithStack(
"invalid compressed sequence, received %d while expecting %d", compressedSequence, cr.compressedSequence)
}
*cr.compressedSequence++

r := io.NopCloser(cr.r)
if uncompressedLength > 0 {
switch cr.compressionAlgorithm {
case mysql.CompressionZlib:
var err error
lr := io.LimitReader(cr.r, int64(compressedLength))
r, err = zlib.NewReader(lr)
if err != nil {
return n, errors.Trace(err)
}
case mysql.CompressionZstd:
zstdReader, err := zstd.NewReader(cr.r, zstd.WithDecoderConcurrency(1))
if err != nil {
return n, errors.Trace(err)
}
r = zstdReader.IOReadCloser()
default:
return n, errors.New("Unknown compression algorithm")
}
cr.data = make([]byte, uncompressedLength)
if _, err := io.ReadFull(r, cr.data); err != nil {
return n, errors.Trace(err)
}
n = copy(data, cr.data)
} else {
cr.data = make([]byte, compressedLength)
if _, err := io.ReadFull(r, cr.data); err != nil {
return n, errors.Trace(err)
}
n = copy(data, cr.data)
}
} else {
if cr.pos > uint64(len(cr.data)) {
return n, io.EOF
}
n = copy(data, cr.data[cr.pos:])
}
cr.pos += uint64(n)
if cr.pos == uint64(len(cr.data)) {
cr.pos = 0
cr.data = nil
}
return
}

func (*compressedReader) Close() error {
return nil
}
Loading