Skip to content

Commit

Permalink
server: Fix compression protocol for larger packets (#47495)
Browse files Browse the repository at this point in the history
close #47152
  • Loading branch information
dveeden authored Nov 3, 2023
1 parent c1e28f3 commit 420b524
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pkg/server/internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ go_test(
srcs = ["packetio_test.go"],
embed = [":internal"],
flaky = True,
shard_count = 4,
shard_count = 7,
deps = [
"//pkg/parser/mysql",
"//pkg/server/internal/testutil",
Expand Down
169 changes: 119 additions & 50 deletions pkg/server/internal/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/server/err"
server_err "github.com/pingcap/tidb/pkg/server/err"
"github.com/pingcap/tidb/pkg/server/internal/util"
server_metrics "github.com/pingcap/tidb/pkg/server/metrics"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
Expand All @@ -60,6 +60,7 @@ type PacketIO struct {
bufReadConn *util.BufferedReadConn
bufWriter *bufio.Writer
compressedWriter *compressedWriter
compressedReader *compressedReader
readTimeout time.Duration
// maxAllowedPacket is the maximum size of one packet in ReadPacket.
maxAllowedPacket uint64
Expand Down Expand Up @@ -119,8 +120,10 @@ func (p *PacketIO) ResetBufWriter(w io.Writer) {
// SetCompressionAlgorithm sets the compression algorithm of PacketIO.
func (p *PacketIO) SetCompressionAlgorithm(ca int) {
p.compressionAlgorithm = ca
p.compressedWriter = newCompressedWriter(p.bufReadConn, ca)
p.compressedWriter = newCompressedWriter(p.bufReadConn, ca, &p.compressedSequence)
p.compressedWriter.zstdLevel = p.zstdLevel
p.compressedReader = newCompressedReader(p.bufReadConn, ca, &p.compressedSequence)
p.compressedReader.zstdLevel = p.zstdLevel
p.bufWriter.Flush()
}

Expand All @@ -143,56 +146,30 @@ func (p *PacketIO) readOnePacket() ([]byte, error) {
return nil, err
}
}
if p.compressionAlgorithm != mysql.CompressionNone {
var compressedHeader [7]byte
if _, err := io.ReadFull(p.bufReadConn, compressedHeader[:]); err != nil {
if p.compressionAlgorithm == mysql.CompressionNone {
if _, err := io.ReadFull(r, header[:]); err != nil {
return nil, errors.Trace(err)
}
compressedSequence := compressedHeader[3]
if compressedSequence != p.compressedSequence {
return nil, err.ErrInvalidSequence.GenWithStack(
"invalid compressed sequence %d != %d", compressedSequence, p.compressedSequence)
}
p.compressedSequence++
p.compressedWriter.compressedSequence = p.compressedSequence
uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16)

if uncompressedLength > 0 {
switch p.compressionAlgorithm {
case mysql.CompressionZlib:
var err error
r, err = zlib.NewReader(p.bufReadConn)
if err != nil {
return nil, errors.Trace(err)
}
case mysql.CompressionZstd:
zstdReader, err := zstd.NewReader(p.bufReadConn, zstd.WithDecoderConcurrency(1))
if err != nil {
return nil, errors.Trace(err)
}
r = zstdReader.IOReadCloser()
default:
return nil, errors.New("Unknown compression algorithm")
}
} else {
if _, err := io.ReadFull(p.compressedReader, header[:]); err != nil {
return nil, errors.Trace(err)
}
}
if _, err := io.ReadFull(r, header[:]); err != nil {
return nil, errors.Trace(err)
}

length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
sequence := header[3]

if sequence != p.sequence {
return nil, err.ErrInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence)
return nil, server_err.ErrInvalidSequence.GenWithStack(
"invalid sequence, received %d while expecting %d", sequence, p.sequence)
}

p.sequence++

length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

// Accumulated payload length exceeds the limit.
if p.accumulatedLength += uint64(length); p.accumulatedLength > p.maxAllowedPacket {
terror.Log(err.ErrNetPacketTooLarge)
return nil, err.ErrNetPacketTooLarge
terror.Log(server_err.ErrNetPacketTooLarge)
return nil, server_err.ErrNetPacketTooLarge
}

data := make([]byte, length)
Expand All @@ -201,8 +178,14 @@ func (p *PacketIO) readOnePacket() ([]byte, error) {
return nil, err
}
}
if _, err := io.ReadFull(r, data); err != nil {
return nil, errors.Trace(err)
if p.compressionAlgorithm == mysql.CompressionNone {
if _, err := io.ReadFull(r, data); err != nil {
return nil, errors.Trace(err)
}
} else {
if _, err := io.ReadFull(p.compressedReader, data); err != nil {
return nil, errors.Trace(err)
}
}
err := r.Close()
if err != nil {
Expand Down Expand Up @@ -256,11 +239,7 @@ func (p *PacketIO) ReadPacket() ([]byte, error) {
func (p *PacketIO) WritePacket(data []byte) error {
length := len(data) - 4
server_metrics.WritePacketBytes.Add(float64(len(data)))

maxPayloadLen := mysql.MaxPayloadLen
if p.compressionAlgorithm != mysql.CompressionNone {
maxPayloadLen -= 4
}

for length >= maxPayloadLen {
data[3] = p.sequence
Expand Down Expand Up @@ -324,24 +303,28 @@ func (p *PacketIO) Flush() error {
if err != nil {
return errors.Trace(err)
}

if p.compressionAlgorithm != mysql.CompressionNone {
p.sequence = p.compressedSequence
}
return err
}

func newCompressedWriter(w io.Writer, ca int) *compressedWriter {
func newCompressedWriter(w io.Writer, ca int, seq *uint8) *compressedWriter {
return &compressedWriter{
w,
new(bytes.Buffer),
seq,
ca,
0,
3,
}
}

type compressedWriter struct {
w io.Writer
buf *bytes.Buffer
compressedSequence *uint8
compressionAlgorithm int
compressedSequence uint8
zstdLevel zstd.EncoderLevel
}

Expand Down Expand Up @@ -421,15 +404,15 @@ func (cw *compressedWriter) Flush() error {
compressedHeader[0] = byte(compressedLength)
compressedHeader[1] = byte(compressedLength >> 8)
compressedHeader[2] = byte(compressedLength >> 16)
compressedHeader[3] = cw.compressedSequence
compressedHeader[3] = *cw.compressedSequence
compressedHeader[4] = byte(uncompressedLength)
compressedHeader[5] = byte(uncompressedLength >> 8)
compressedHeader[6] = byte(uncompressedLength >> 16)
_, err = compressedPacket.Write(compressedHeader)
if err != nil {
return errors.Trace(err)
}
cw.compressedSequence++
*cw.compressedSequence++

if len(data) > minCompressLength {
_, err = compressedPacket.Write(payload.Bytes())
Expand All @@ -446,3 +429,89 @@ func (cw *compressedWriter) Flush() error {
}
return nil
}

func newCompressedReader(r io.Reader, ca int, seq *uint8) *compressedReader {
return &compressedReader{
r,
seq,
nil,
ca,
3,
0,
}
}

type compressedReader struct {
r io.Reader
compressedSequence *uint8
data []byte
compressionAlgorithm int
zstdLevel zstd.EncoderLevel
pos uint64
}

func (cr *compressedReader) Read(data []byte) (n int, err error) {
if cr.data == nil {
var compressedHeader [7]byte
if _, err = io.ReadFull(cr.r, compressedHeader[:]); err != nil {
return
}

compressedLength := int(uint32(compressedHeader[0]) | uint32(compressedHeader[1])<<8 | uint32(compressedHeader[2])<<16)
compressedSequence := compressedHeader[3]
uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16)

if compressedSequence != *cr.compressedSequence {
return n, server_err.ErrInvalidSequence.GenWithStack(
"invalid compressed sequence, received %d while expecting %d", compressedSequence, cr.compressedSequence)
}
*cr.compressedSequence++

r := io.NopCloser(cr.r)
if uncompressedLength > 0 {
switch cr.compressionAlgorithm {
case mysql.CompressionZlib:
var err error
lr := io.LimitReader(cr.r, int64(compressedLength))
r, err = zlib.NewReader(lr)
if err != nil {
return n, errors.Trace(err)
}
case mysql.CompressionZstd:
zstdReader, err := zstd.NewReader(cr.r, zstd.WithDecoderConcurrency(1))
if err != nil {
return n, errors.Trace(err)
}
r = zstdReader.IOReadCloser()
default:
return n, errors.New("Unknown compression algorithm")
}
cr.data = make([]byte, uncompressedLength)
if _, err := io.ReadFull(r, cr.data); err != nil {
return n, errors.Trace(err)
}
n = copy(data, cr.data)
} else {
cr.data = make([]byte, compressedLength)
if _, err := io.ReadFull(r, cr.data); err != nil {
return n, errors.Trace(err)
}
n = copy(data, cr.data)
}
} else {
if cr.pos > uint64(len(cr.data)) {
return n, io.EOF
}
n = copy(data, cr.data[cr.pos:])
}
cr.pos += uint64(n)
if cr.pos == uint64(len(cr.data)) {
cr.pos = 0
cr.data = nil
}
return
}

func (*compressedReader) Close() error {
return nil
}
Loading

0 comments on commit 420b524

Please sign in to comment.