diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 5363a871a50f0..6764719242b09 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -126,6 +126,7 @@ go_library( "//executor/internal/pdhelper", "//executor/internal/querywatch", "//executor/internal/util", + "//executor/internal/vecgroupchecker", "//executor/metrics", "//executor/mppcoordmanager", "//expression", @@ -294,7 +295,6 @@ go_test( "charset_test.go", "chunk_size_control_test.go", "cluster_table_test.go", - "collation_test.go", "compact_table_test.go", "concurrent_map_test.go", "copr_cache_test.go", diff --git a/executor/aggregate.go b/executor/aggregate.go index 8c272af9062ac..506d5d915de9a 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/tidb/executor/aggfuncs" "github.com/pingcap/tidb/executor/internal/exec" + "github.com/pingcap/tidb/executor/internal/vecgroupchecker" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -1250,7 +1251,7 @@ type StreamAggExec struct { // isChildReturnEmpty indicates whether the child executor only returns an empty input. isChildReturnEmpty bool defaultVal *chunk.Chunk - groupChecker *vecGroupChecker + groupChecker *vecgroupchecker.VecGroupChecker inputIter *chunk.Iterator4Chunk inputRow chunk.Row aggFuncs []aggfuncs.AggFunc @@ -1312,7 +1313,7 @@ func (e *StreamAggExec) Close() error { e.memTracker.Consume(-e.childResult.MemoryUsage() - e.memUsageOfInitialPartialResult) e.childResult = nil } - e.groupChecker.reset() + e.groupChecker.Reset() return e.BaseExecutor.Close() } @@ -1330,19 +1331,19 @@ func (e *StreamAggExec) Next(ctx context.Context, req *chunk.Chunk) (err error) } func (e *StreamAggExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) (err error) { - if e.groupChecker.isExhausted() { + if e.groupChecker.IsExhausted() { if err = e.consumeCurGroupRowsAndFetchChild(ctx, chk); err != nil { return err } if e.executed { return nil } - _, err := e.groupChecker.splitIntoGroups(e.childResult) + _, err := e.groupChecker.SplitIntoGroups(e.childResult) if err != nil { return err } } - begin, end := e.groupChecker.getNextGroup() + begin, end := e.groupChecker.GetNextGroup() for i := begin; i < end; i++ { e.groupRows = append(e.groupRows, e.childResult.GetRow(i)) } @@ -1353,13 +1354,13 @@ func (e *StreamAggExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) ( return err } - isFirstGroupSameAsPrev, err := e.groupChecker.splitIntoGroups(e.childResult) + isFirstGroupSameAsPrev, err := e.groupChecker.SplitIntoGroups(e.childResult) if err != nil { return err } if isFirstGroupSameAsPrev { - begin, end = e.groupChecker.getNextGroup() + begin, end = e.groupChecker.GetNextGroup() for i := begin; i < end; i++ { e.groupRows = append(e.groupRows, e.childResult.GetRow(i)) } @@ -1444,480 +1445,6 @@ func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error { return nil } -// vecGroupChecker is used to split a given chunk according to the `group by` expression in a vectorized manner -// It is usually used for streamAgg -type vecGroupChecker struct { - ctx sessionctx.Context - GroupByItems []expression.Expression - - // groupOffset holds the offset of the last row in each group of the current chunk - groupOffset []int - // groupCount is the count of groups in the current chunk - groupCount int - // nextGroupID records the group id of the next group to be consumed - nextGroupID int - - // lastGroupKeyOfPrevChk is the groupKey of the last group of the previous chunk - lastGroupKeyOfPrevChk []byte - // firstGroupKey and lastGroupKey are used to store the groupKey of the first and last group of the current chunk - firstGroupKey []byte - lastGroupKey []byte - - // firstRowDatums and lastRowDatums store the results of the expression evaluation for the first and last rows of the current chunk in datum - // They are used to encode to get firstGroupKey and lastGroupKey - firstRowDatums []types.Datum - lastRowDatums []types.Datum - - // sameGroup is used to check whether the current row belongs to the same group as the previous row - sameGroup []bool - - // set these functions for testing - allocateBuffer func(evalType types.EvalType, capacity int) (*chunk.Column, error) - releaseBuffer func(buf *chunk.Column) -} - -func newVecGroupChecker(ctx sessionctx.Context, items []expression.Expression) *vecGroupChecker { - return &vecGroupChecker{ - ctx: ctx, - GroupByItems: items, - groupCount: 0, - nextGroupID: 0, - sameGroup: make([]bool, 1024), - } -} - -// splitIntoGroups splits a chunk into multiple groups which the row in the same group have the same groupKey -// `isFirstGroupSameAsPrev` indicates whether the groupKey of the first group of the newly passed chunk is equal to the groupKey of the last group left before -// TODO: Since all the group by items are only a column reference, guaranteed by building projection below aggregation, we can directly compare data in a chunk. -func (e *vecGroupChecker) splitIntoGroups(chk *chunk.Chunk) (isFirstGroupSameAsPrev bool, err error) { - // The numRows can not be zero. `fetchChild` is called before `splitIntoGroups` is called. - // if numRows == 0, it will be returned in `fetchChild`. See `fetchChild` for more details. - numRows := chk.NumRows() - - e.reset() - e.nextGroupID = 0 - if len(e.GroupByItems) == 0 { - e.groupOffset = append(e.groupOffset, numRows) - e.groupCount = 1 - return true, nil - } - - for _, item := range e.GroupByItems { - err = e.getFirstAndLastRowDatum(item, chk, numRows) - if err != nil { - return false, err - } - } - e.firstGroupKey, err = codec.EncodeValue(e.ctx.GetSessionVars().StmtCtx, e.firstGroupKey, e.firstRowDatums...) - if err != nil { - return false, err - } - - e.lastGroupKey, err = codec.EncodeValue(e.ctx.GetSessionVars().StmtCtx, e.lastGroupKey, e.lastRowDatums...) - if err != nil { - return false, err - } - - if len(e.lastGroupKeyOfPrevChk) == 0 { - isFirstGroupSameAsPrev = false - } else { - if bytes.Equal(e.lastGroupKeyOfPrevChk, e.firstGroupKey) { - isFirstGroupSameAsPrev = true - } else { - isFirstGroupSameAsPrev = false - } - } - - if length := len(e.lastGroupKey); len(e.lastGroupKeyOfPrevChk) >= length { - e.lastGroupKeyOfPrevChk = e.lastGroupKeyOfPrevChk[:length] - } else { - e.lastGroupKeyOfPrevChk = make([]byte, length) - } - copy(e.lastGroupKeyOfPrevChk, e.lastGroupKey) - - if bytes.Equal(e.firstGroupKey, e.lastGroupKey) { - e.groupOffset = append(e.groupOffset, numRows) - e.groupCount = 1 - return isFirstGroupSameAsPrev, nil - } - - if cap(e.sameGroup) < numRows { - e.sameGroup = make([]bool, 0, numRows) - } - e.sameGroup = append(e.sameGroup, false) - for i := 1; i < numRows; i++ { - e.sameGroup = append(e.sameGroup, true) - } - - for _, item := range e.GroupByItems { - err = e.evalGroupItemsAndResolveGroups(item, chk, numRows) - if err != nil { - return false, err - } - } - - for i := 1; i < numRows; i++ { - if !e.sameGroup[i] { - e.groupOffset = append(e.groupOffset, i) - } - } - e.groupOffset = append(e.groupOffset, numRows) - e.groupCount = len(e.groupOffset) - return isFirstGroupSameAsPrev, nil -} - -func (e *vecGroupChecker) getFirstAndLastRowDatum(item expression.Expression, chk *chunk.Chunk, numRows int) (err error) { - var firstRowDatum, lastRowDatum types.Datum - tp := item.GetType() - eType := tp.EvalType() - switch eType { - case types.ETInt: - firstRowVal, firstRowIsNull, err := item.EvalInt(e.ctx, chk.GetRow(0)) - if err != nil { - return err - } - lastRowVal, lastRowIsNull, err := item.EvalInt(e.ctx, chk.GetRow(numRows-1)) - if err != nil { - return err - } - if !firstRowIsNull { - firstRowDatum.SetInt64(firstRowVal) - } else { - firstRowDatum.SetNull() - } - if !lastRowIsNull { - lastRowDatum.SetInt64(lastRowVal) - } else { - lastRowDatum.SetNull() - } - case types.ETReal: - firstRowVal, firstRowIsNull, err := item.EvalReal(e.ctx, chk.GetRow(0)) - if err != nil { - return err - } - lastRowVal, lastRowIsNull, err := item.EvalReal(e.ctx, chk.GetRow(numRows-1)) - if err != nil { - return err - } - if !firstRowIsNull { - firstRowDatum.SetFloat64(firstRowVal) - } else { - firstRowDatum.SetNull() - } - if !lastRowIsNull { - lastRowDatum.SetFloat64(lastRowVal) - } else { - lastRowDatum.SetNull() - } - case types.ETDecimal: - firstRowVal, firstRowIsNull, err := item.EvalDecimal(e.ctx, chk.GetRow(0)) - if err != nil { - return err - } - lastRowVal, lastRowIsNull, err := item.EvalDecimal(e.ctx, chk.GetRow(numRows-1)) - if err != nil { - return err - } - if !firstRowIsNull { - // make a copy to avoid DATA RACE - firstDatum := types.MyDecimal{} - err := firstDatum.FromString(firstRowVal.ToString()) - if err != nil { - return err - } - firstRowDatum.SetMysqlDecimal(&firstDatum) - } else { - firstRowDatum.SetNull() - } - if !lastRowIsNull { - // make a copy to avoid DATA RACE - lastDatum := types.MyDecimal{} - err := lastDatum.FromString(lastRowVal.ToString()) - if err != nil { - return err - } - lastRowDatum.SetMysqlDecimal(&lastDatum) - } else { - lastRowDatum.SetNull() - } - case types.ETDatetime, types.ETTimestamp: - firstRowVal, firstRowIsNull, err := item.EvalTime(e.ctx, chk.GetRow(0)) - if err != nil { - return err - } - lastRowVal, lastRowIsNull, err := item.EvalTime(e.ctx, chk.GetRow(numRows-1)) - if err != nil { - return err - } - if !firstRowIsNull { - firstRowDatum.SetMysqlTime(firstRowVal) - } else { - firstRowDatum.SetNull() - } - if !lastRowIsNull { - lastRowDatum.SetMysqlTime(lastRowVal) - } else { - lastRowDatum.SetNull() - } - case types.ETDuration: - firstRowVal, firstRowIsNull, err := item.EvalDuration(e.ctx, chk.GetRow(0)) - if err != nil { - return err - } - lastRowVal, lastRowIsNull, err := item.EvalDuration(e.ctx, chk.GetRow(numRows-1)) - if err != nil { - return err - } - if !firstRowIsNull { - firstRowDatum.SetMysqlDuration(firstRowVal) - } else { - firstRowDatum.SetNull() - } - if !lastRowIsNull { - lastRowDatum.SetMysqlDuration(lastRowVal) - } else { - lastRowDatum.SetNull() - } - case types.ETJson: - firstRowVal, firstRowIsNull, err := item.EvalJSON(e.ctx, chk.GetRow(0)) - if err != nil { - return err - } - lastRowVal, lastRowIsNull, err := item.EvalJSON(e.ctx, chk.GetRow(numRows-1)) - if err != nil { - return err - } - if !firstRowIsNull { - // make a copy to avoid DATA RACE - firstRowDatum.SetMysqlJSON(firstRowVal.Copy()) - } else { - firstRowDatum.SetNull() - } - if !lastRowIsNull { - // make a copy to avoid DATA RACE - lastRowDatum.SetMysqlJSON(lastRowVal.Copy()) - } else { - lastRowDatum.SetNull() - } - case types.ETString: - firstRowVal, firstRowIsNull, err := item.EvalString(e.ctx, chk.GetRow(0)) - if err != nil { - return err - } - lastRowVal, lastRowIsNull, err := item.EvalString(e.ctx, chk.GetRow(numRows-1)) - if err != nil { - return err - } - if !firstRowIsNull { - // make a copy to avoid DATA RACE - firstDatum := string([]byte(firstRowVal)) - firstRowDatum.SetString(firstDatum, tp.GetCollate()) - } else { - firstRowDatum.SetNull() - } - if !lastRowIsNull { - // make a copy to avoid DATA RACE - lastDatum := string([]byte(lastRowVal)) - lastRowDatum.SetString(lastDatum, tp.GetCollate()) - } else { - lastRowDatum.SetNull() - } - default: - err = fmt.Errorf("invalid eval type %v", eType) - return err - } - - e.firstRowDatums = append(e.firstRowDatums, firstRowDatum) - e.lastRowDatums = append(e.lastRowDatums, lastRowDatum) - return err -} - -// evalGroupItemsAndResolveGroups evaluates the chunk according to the expression item. -// And resolve the rows into groups according to the evaluation results -func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Expression, chk *chunk.Chunk, numRows int) (err error) { - tp := item.GetType() - eType := tp.EvalType() - if e.allocateBuffer == nil { - e.allocateBuffer = expression.GetColumn - } - if e.releaseBuffer == nil { - e.releaseBuffer = expression.PutColumn - } - col, err := e.allocateBuffer(eType, numRows) - if err != nil { - return err - } - defer e.releaseBuffer(col) - err = expression.EvalExpr(e.ctx, item, eType, chk, col) - if err != nil { - return err - } - - previousIsNull := col.IsNull(0) - switch eType { - case types.ETInt: - vals := col.Int64s() - for i := 1; i < numRows; i++ { - isNull := col.IsNull(i) - if e.sameGroup[i] { - switch { - case !previousIsNull && !isNull: - if vals[i] != vals[i-1] { - e.sameGroup[i] = false - } - case isNull != previousIsNull: - e.sameGroup[i] = false - } - } - previousIsNull = isNull - } - case types.ETReal: - vals := col.Float64s() - for i := 1; i < numRows; i++ { - isNull := col.IsNull(i) - if e.sameGroup[i] { - switch { - case !previousIsNull && !isNull: - if vals[i] != vals[i-1] { - e.sameGroup[i] = false - } - case isNull != previousIsNull: - e.sameGroup[i] = false - } - } - previousIsNull = isNull - } - case types.ETDecimal: - vals := col.Decimals() - for i := 1; i < numRows; i++ { - isNull := col.IsNull(i) - if e.sameGroup[i] { - switch { - case !previousIsNull && !isNull: - if vals[i].Compare(&vals[i-1]) != 0 { - e.sameGroup[i] = false - } - case isNull != previousIsNull: - e.sameGroup[i] = false - } - } - previousIsNull = isNull - } - case types.ETDatetime, types.ETTimestamp: - vals := col.Times() - for i := 1; i < numRows; i++ { - isNull := col.IsNull(i) - if e.sameGroup[i] { - switch { - case !previousIsNull && !isNull: - if vals[i].Compare(vals[i-1]) != 0 { - e.sameGroup[i] = false - } - case isNull != previousIsNull: - e.sameGroup[i] = false - } - } - previousIsNull = isNull - } - case types.ETDuration: - vals := col.GoDurations() - for i := 1; i < numRows; i++ { - isNull := col.IsNull(i) - if e.sameGroup[i] { - switch { - case !previousIsNull && !isNull: - if vals[i] != vals[i-1] { - e.sameGroup[i] = false - } - case isNull != previousIsNull: - e.sameGroup[i] = false - } - } - previousIsNull = isNull - } - case types.ETJson: - var previousKey, key types.BinaryJSON - if !previousIsNull { - previousKey = col.GetJSON(0) - } - for i := 1; i < numRows; i++ { - isNull := col.IsNull(i) - if !isNull { - key = col.GetJSON(i) - } - if e.sameGroup[i] { - if isNull == previousIsNull { - if !isNull && types.CompareBinaryJSON(previousKey, key) != 0 { - e.sameGroup[i] = false - } - } else { - e.sameGroup[i] = false - } - } - if !isNull { - previousKey = key - } - previousIsNull = isNull - } - case types.ETString: - previousKey := codec.ConvertByCollationStr(col.GetString(0), tp) - for i := 1; i < numRows; i++ { - key := codec.ConvertByCollationStr(col.GetString(i), tp) - isNull := col.IsNull(i) - if e.sameGroup[i] { - if isNull != previousIsNull || previousKey != key { - e.sameGroup[i] = false - } - } - previousKey = key - previousIsNull = isNull - } - default: - err = fmt.Errorf("invalid eval type %v", eType) - } - if err != nil { - return err - } - - return err -} - -func (e *vecGroupChecker) getNextGroup() (begin, end int) { - if e.nextGroupID == 0 { - begin = 0 - } else { - begin = e.groupOffset[e.nextGroupID-1] - } - end = e.groupOffset[e.nextGroupID] - e.nextGroupID++ - return begin, end -} - -func (e *vecGroupChecker) isExhausted() bool { - return e.nextGroupID >= e.groupCount -} - -func (e *vecGroupChecker) reset() { - if e.groupOffset != nil { - e.groupOffset = e.groupOffset[:0] - } - if e.sameGroup != nil { - e.sameGroup = e.sameGroup[:0] - } - if e.firstGroupKey != nil { - e.firstGroupKey = e.firstGroupKey[:0] - } - if e.lastGroupKey != nil { - e.lastGroupKey = e.lastGroupKey[:0] - } - if e.firstRowDatums != nil { - e.firstRowDatums = e.firstRowDatums[:0] - } - if e.lastRowDatums != nil { - e.lastRowDatums = e.lastRowDatums[:0] - } -} - // ActionSpill returns a AggSpillDiskAction for spilling intermediate data for hashAgg. func (e *HashAggExec) ActionSpill() *AggSpillDiskAction { if e.spillAction == nil { diff --git a/executor/builder.go b/executor/builder.go index 9bcdfadd371bb..075cdc6f32ba6 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -40,6 +40,7 @@ import ( "github.com/pingcap/tidb/executor/internal/exec" "github.com/pingcap/tidb/executor/internal/pdhelper" "github.com/pingcap/tidb/executor/internal/querywatch" + "github.com/pingcap/tidb/executor/internal/vecgroupchecker" executor_metrics "github.com/pingcap/tidb/executor/metrics" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" @@ -1807,7 +1808,7 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) exec. } e := &StreamAggExec{ BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), src), - groupChecker: newVecGroupChecker(b.ctx, v.GroupByItems), + groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx, v.GroupByItems), aggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)), } @@ -4955,7 +4956,7 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Execut if b.ctx.GetSessionVars().EnablePipelinedWindowExec { exec := &PipelinedWindowExec{ BaseExecutor: base, - groupChecker: newVecGroupChecker(b.ctx, groupByItems), + groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx, groupByItems), numWindowFuncs: len(v.WindowFuncDescs), } @@ -5014,7 +5015,7 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Execut } return &WindowExec{BaseExecutor: base, processor: processor, - groupChecker: newVecGroupChecker(b.ctx, groupByItems), + groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx, groupByItems), numWindowFuncs: len(v.WindowFuncDescs), } } diff --git a/executor/collation_test.go b/executor/collation_test.go deleted file mode 100644 index 637b0cf162c2e..0000000000000 --- a/executor/collation_test.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "testing" - - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/mock" - "github.com/stretchr/testify/require" -) - -func TestVecGroupChecker(t *testing.T) { - tp := types.NewFieldTypeBuilder().SetType(mysql.TypeVarchar).BuildP() - col0 := &expression.Column{ - RetType: tp, - Index: 0, - } - ctx := mock.NewContext() - groupChecker := newVecGroupChecker(ctx, []expression.Expression{col0}) - - chk := chunk.New([]*types.FieldType{tp}, 6, 6) - chk.Reset() - chk.Column(0).AppendString("aaa") - chk.Column(0).AppendString("AAA") - chk.Column(0).AppendString("😜") - chk.Column(0).AppendString("😃") - chk.Column(0).AppendString("À") - chk.Column(0).AppendString("A") - - tp.SetCollate("bin") - groupChecker.reset() - _, err := groupChecker.splitIntoGroups(chk) - require.NoError(t, err) - for i := 0; i < 6; i++ { - b, e := groupChecker.getNextGroup() - require.Equal(t, b, i) - require.Equal(t, e, i+1) - } - require.True(t, groupChecker.isExhausted()) - - tp.SetCollate("utf8_general_ci") - groupChecker.reset() - _, err = groupChecker.splitIntoGroups(chk) - require.NoError(t, err) - for i := 0; i < 3; i++ { - b, e := groupChecker.getNextGroup() - require.Equal(t, b, i*2) - require.Equal(t, e, i*2+2) - } - require.True(t, groupChecker.isExhausted()) - - tp.SetCollate("utf8_unicode_ci") - groupChecker.reset() - _, err = groupChecker.splitIntoGroups(chk) - require.NoError(t, err) - for i := 0; i < 3; i++ { - b, e := groupChecker.getNextGroup() - require.Equal(t, b, i*2) - require.Equal(t, e, i*2+2) - } - require.True(t, groupChecker.isExhausted()) - - // test padding - tp.SetCollate("utf8_bin") - tp.SetFlen(6) - chk.Reset() - chk.Column(0).AppendString("a") - chk.Column(0).AppendString("a ") - chk.Column(0).AppendString("a ") - groupChecker.reset() - _, err = groupChecker.splitIntoGroups(chk) - require.NoError(t, err) - b, e := groupChecker.getNextGroup() - require.Equal(t, b, 0) - require.Equal(t, e, 3) - require.True(t, groupChecker.isExhausted()) -} diff --git a/executor/executor_required_rows_test.go b/executor/executor_required_rows_test.go index 983c1e2dbe105..fb57ed290ae75 100644 --- a/executor/executor_required_rows_test.go +++ b/executor/executor_required_rows_test.go @@ -716,118 +716,6 @@ func TestMergeJoinRequiredRows(t *testing.T) { } } -func genTestChunk4VecGroupChecker(chkRows []int, sameNum int) (expr []expression.Expression, inputs []*chunk.Chunk) { - chkNum := len(chkRows) - numRows := 0 - inputs = make([]*chunk.Chunk, chkNum) - fts := make([]*types.FieldType, 1) - fts[0] = types.NewFieldType(mysql.TypeLonglong) - for i := 0; i < chkNum; i++ { - inputs[i] = chunk.New(fts, chkRows[i], chkRows[i]) - numRows += chkRows[i] - } - var numGroups int - if numRows%sameNum == 0 { - numGroups = numRows / sameNum - } else { - numGroups = numRows/sameNum + 1 - } - - rand.Seed(time.Now().Unix()) - nullPos := rand.Intn(numGroups) - cnt := 0 - val := rand.Int63() - for i := 0; i < chkNum; i++ { - col := inputs[i].Column(0) - col.ResizeInt64(chkRows[i], false) - i64s := col.Int64s() - for j := 0; j < chkRows[i]; j++ { - if cnt == sameNum { - val = rand.Int63() - cnt = 0 - nullPos-- - } - if nullPos == 0 { - col.SetNull(j, true) - } else { - i64s[j] = val - } - cnt++ - } - } - - expr = make([]expression.Expression, 1) - expr[0] = &expression.Column{ - RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetFlen(mysql.MaxIntWidth).BuildP(), - Index: 0, - } - return -} - -func TestVecGroupChecker4GroupCount(t *testing.T) { - testCases := []struct { - chunkRows []int - expectedGroups int - expectedFlag []bool - sameNum int - }{ - { - chunkRows: []int{1024, 1}, - expectedGroups: 1025, - expectedFlag: []bool{false, false}, - sameNum: 1, - }, - { - chunkRows: []int{1024, 1}, - expectedGroups: 1, - expectedFlag: []bool{false, true}, - sameNum: 1025, - }, - { - chunkRows: []int{1, 1}, - expectedGroups: 1, - expectedFlag: []bool{false, true}, - sameNum: 2, - }, - { - chunkRows: []int{1, 1}, - expectedGroups: 2, - expectedFlag: []bool{false, false}, - sameNum: 1, - }, - { - chunkRows: []int{2, 2}, - expectedGroups: 2, - expectedFlag: []bool{false, false}, - sameNum: 2, - }, - { - chunkRows: []int{2, 2}, - expectedGroups: 1, - expectedFlag: []bool{false, true}, - sameNum: 4, - }, - } - - ctx := mock.NewContext() - for _, testCase := range testCases { - expr, inputChks := genTestChunk4VecGroupChecker(testCase.chunkRows, testCase.sameNum) - groupChecker := newVecGroupChecker(ctx, expr) - groupNum := 0 - for i, inputChk := range inputChks { - flag, err := groupChecker.splitIntoGroups(inputChk) - require.NoError(t, err) - require.Equal(t, testCase.expectedFlag[i], flag) - if flag { - groupNum += groupChecker.groupCount - 1 - } else { - groupNum += groupChecker.groupCount - } - } - require.Equal(t, testCase.expectedGroups, groupNum) - } -} - func buildMergeJoinExec(ctx sessionctx.Context, joinType plannercore.JoinType, innerSrc, outerSrc exec.Executor) exec.Executor { if joinType == plannercore.RightOuterJoin { innerSrc, outerSrc = outerSrc, innerSrc @@ -868,67 +756,3 @@ func (mp *mockPlan) Schema() *expression.Schema { func (mp *mockPlan) MemoryUsage() (sum int64) { return } - -func TestVecGroupCheckerDATARACE(t *testing.T) { - ctx := mock.NewContext() - - mTypes := []byte{mysql.TypeVarString, mysql.TypeNewDecimal, mysql.TypeJSON} - for _, mType := range mTypes { - exprs := make([]expression.Expression, 1) - exprs[0] = &expression.Column{ - RetType: types.NewFieldTypeBuilder().SetType(mType).BuildP(), - Index: 0, - } - vgc := newVecGroupChecker(ctx, exprs) - - fts := []*types.FieldType{types.NewFieldType(mType)} - chk := chunk.New(fts, 1, 1) - vgc.allocateBuffer = func(evalType types.EvalType, capacity int) (*chunk.Column, error) { - return chk.Column(0), nil - } - vgc.releaseBuffer = func(column *chunk.Column) {} - - switch mType { - case mysql.TypeVarString: - chk.Column(0).ReserveString(1) - chk.Column(0).AppendString("abc") - case mysql.TypeNewDecimal: - chk.Column(0).ResizeDecimal(1, false) - chk.Column(0).Decimals()[0] = *types.NewDecFromInt(123) - case mysql.TypeJSON: - chk.Column(0).ReserveJSON(1) - j := new(types.BinaryJSON) - require.NoError(t, j.UnmarshalJSON([]byte(fmt.Sprintf(`{"%v":%v}`, 123, 123)))) - chk.Column(0).AppendJSON(*j) - } - - _, err := vgc.splitIntoGroups(chk) - require.NoError(t, err) - - switch mType { - case mysql.TypeVarString: - require.Equal(t, "abc", vgc.firstRowDatums[0].GetString()) - require.Equal(t, "abc", vgc.lastRowDatums[0].GetString()) - chk.Column(0).ReserveString(1) - chk.Column(0).AppendString("edf") - require.Equal(t, "abc", vgc.firstRowDatums[0].GetString()) - require.Equal(t, "abc", vgc.lastRowDatums[0].GetString()) - case mysql.TypeNewDecimal: - require.Equal(t, "123", vgc.firstRowDatums[0].GetMysqlDecimal().String()) - require.Equal(t, "123", vgc.lastRowDatums[0].GetMysqlDecimal().String()) - chk.Column(0).ResizeDecimal(1, false) - chk.Column(0).Decimals()[0] = *types.NewDecFromInt(456) - require.Equal(t, "123", vgc.firstRowDatums[0].GetMysqlDecimal().String()) - require.Equal(t, "123", vgc.lastRowDatums[0].GetMysqlDecimal().String()) - case mysql.TypeJSON: - require.Equal(t, `{"123": 123}`, vgc.firstRowDatums[0].GetMysqlJSON().String()) - require.Equal(t, `{"123": 123}`, vgc.lastRowDatums[0].GetMysqlJSON().String()) - chk.Column(0).ReserveJSON(1) - j := new(types.BinaryJSON) - require.NoError(t, j.UnmarshalJSON([]byte(fmt.Sprintf(`{"%v":%v}`, 456, 456)))) - chk.Column(0).AppendJSON(*j) - require.Equal(t, `{"123": 123}`, vgc.firstRowDatums[0].GetMysqlJSON().String()) - require.Equal(t, `{"123": 123}`, vgc.lastRowDatums[0].GetMysqlJSON().String()) - } - } -} diff --git a/executor/internal/vecgroupchecker/BUILD.bazel b/executor/internal/vecgroupchecker/BUILD.bazel new file mode 100644 index 0000000000000..b67b08c8e8705 --- /dev/null +++ b/executor/internal/vecgroupchecker/BUILD.bazel @@ -0,0 +1,32 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "vecgroupchecker", + srcs = ["vec_group_checker.go"], + importpath = "github.com/pingcap/tidb/executor/internal/vecgroupchecker", + visibility = ["//executor:__subpackages__"], + deps = [ + "//expression", + "//sessionctx", + "//types", + "//util/chunk", + "//util/codec", + ], +) + +go_test( + name = "vecgroupchecker_test", + timeout = "short", + srcs = ["vec_group_checker_test.go"], + embed = [":vecgroupchecker"], + flaky = True, + shard_count = 3, + deps = [ + "//expression", + "//parser/mysql", + "//types", + "//util/chunk", + "//util/mock", + "@com_github_stretchr_testify//require", + ], +) diff --git a/executor/internal/vecgroupchecker/vec_group_checker.go b/executor/internal/vecgroupchecker/vec_group_checker.go new file mode 100644 index 0000000000000..557c3e8f914ba --- /dev/null +++ b/executor/internal/vecgroupchecker/vec_group_checker.go @@ -0,0 +1,514 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vecgroupchecker + +import ( + "bytes" + "fmt" + + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/codec" +) + +// VecGroupChecker is used to split a given chunk according to the `group by` expression in a vectorized manner +// It is usually used for streamAgg +type VecGroupChecker struct { + ctx sessionctx.Context + releaseBuffer func(buf *chunk.Column) + + // set these functions for testing + allocateBuffer func(evalType types.EvalType, capacity int) (*chunk.Column, error) + lastRowDatums []types.Datum + + // lastGroupKeyOfPrevChk is the groupKey of the last group of the previous chunk + lastGroupKeyOfPrevChk []byte + // firstGroupKey and lastGroupKey are used to store the groupKey of the first and last group of the current chunk + firstGroupKey []byte + lastGroupKey []byte + + // firstRowDatums and lastRowDatums store the results of the expression evaluation + // for the first and last rows of the current chunk in datum + // They are used to encode to get firstGroupKey and lastGroupKey + firstRowDatums []types.Datum + + // sameGroup is used to check whether the current row belongs to the same group as the previous row + sameGroup []bool + + // groupOffset holds the offset of the last row in each group of the current chunk + groupOffset []int + GroupByItems []expression.Expression + + // nextGroupID records the group id of the next group to be consumed + nextGroupID int + + // groupCount is the count of groups in the current chunk + groupCount int +} + +// NewVecGroupChecker creates a new VecGroupChecker +func NewVecGroupChecker(ctx sessionctx.Context, items []expression.Expression) *VecGroupChecker { + return &VecGroupChecker{ + ctx: ctx, + GroupByItems: items, + groupCount: 0, + nextGroupID: 0, + sameGroup: make([]bool, 1024), + } +} + +// SplitIntoGroups splits a chunk into multiple groups which the row in the same group have the same groupKey +// `isFirstGroupSameAsPrev` indicates whether the groupKey of the first group of the newly passed chunk is equal to the groupKey of the last group left before +// TODO: Since all the group by items are only a column reference, guaranteed by building projection below aggregation, we can directly compare data in a chunk. +func (e *VecGroupChecker) SplitIntoGroups(chk *chunk.Chunk) (isFirstGroupSameAsPrev bool, err error) { + // The numRows can not be zero. `fetchChild` is called before `splitIntoGroups` is called. + // if numRows == 0, it will be returned in `fetchChild`. See `fetchChild` for more details. + numRows := chk.NumRows() + + e.Reset() + e.nextGroupID = 0 + if len(e.GroupByItems) == 0 { + e.groupOffset = append(e.groupOffset, numRows) + e.groupCount = 1 + return true, nil + } + + for _, item := range e.GroupByItems { + err = e.getFirstAndLastRowDatum(item, chk, numRows) + if err != nil { + return false, err + } + } + e.firstGroupKey, err = codec.EncodeValue(e.ctx.GetSessionVars().StmtCtx, e.firstGroupKey, e.firstRowDatums...) + if err != nil { + return false, err + } + + e.lastGroupKey, err = codec.EncodeValue(e.ctx.GetSessionVars().StmtCtx, e.lastGroupKey, e.lastRowDatums...) + if err != nil { + return false, err + } + + if len(e.lastGroupKeyOfPrevChk) == 0 { + isFirstGroupSameAsPrev = false + } else { + if bytes.Equal(e.lastGroupKeyOfPrevChk, e.firstGroupKey) { + isFirstGroupSameAsPrev = true + } else { + isFirstGroupSameAsPrev = false + } + } + + if length := len(e.lastGroupKey); len(e.lastGroupKeyOfPrevChk) >= length { + e.lastGroupKeyOfPrevChk = e.lastGroupKeyOfPrevChk[:length] + } else { + e.lastGroupKeyOfPrevChk = make([]byte, length) + } + copy(e.lastGroupKeyOfPrevChk, e.lastGroupKey) + + if bytes.Equal(e.firstGroupKey, e.lastGroupKey) { + e.groupOffset = append(e.groupOffset, numRows) + e.groupCount = 1 + return isFirstGroupSameAsPrev, nil + } + + if cap(e.sameGroup) < numRows { + e.sameGroup = make([]bool, 0, numRows) + } + e.sameGroup = append(e.sameGroup, false) + for i := 1; i < numRows; i++ { + e.sameGroup = append(e.sameGroup, true) + } + + for _, item := range e.GroupByItems { + err = e.evalGroupItemsAndResolveGroups(item, chk, numRows) + if err != nil { + return false, err + } + } + + for i := 1; i < numRows; i++ { + if !e.sameGroup[i] { + e.groupOffset = append(e.groupOffset, i) + } + } + e.groupOffset = append(e.groupOffset, numRows) + e.groupCount = len(e.groupOffset) + return isFirstGroupSameAsPrev, nil +} + +func (e *VecGroupChecker) getFirstAndLastRowDatum( + item expression.Expression, chk *chunk.Chunk, numRows int) (err error) { + var firstRowDatum, lastRowDatum types.Datum + tp := item.GetType() + eType := tp.EvalType() + switch eType { + case types.ETInt: + firstRowVal, firstRowIsNull, err := item.EvalInt(e.ctx, chk.GetRow(0)) + if err != nil { + return err + } + lastRowVal, lastRowIsNull, err := item.EvalInt(e.ctx, chk.GetRow(numRows-1)) + if err != nil { + return err + } + if !firstRowIsNull { + firstRowDatum.SetInt64(firstRowVal) + } else { + firstRowDatum.SetNull() + } + if !lastRowIsNull { + lastRowDatum.SetInt64(lastRowVal) + } else { + lastRowDatum.SetNull() + } + case types.ETReal: + firstRowVal, firstRowIsNull, err := item.EvalReal(e.ctx, chk.GetRow(0)) + if err != nil { + return err + } + lastRowVal, lastRowIsNull, err := item.EvalReal(e.ctx, chk.GetRow(numRows-1)) + if err != nil { + return err + } + if !firstRowIsNull { + firstRowDatum.SetFloat64(firstRowVal) + } else { + firstRowDatum.SetNull() + } + if !lastRowIsNull { + lastRowDatum.SetFloat64(lastRowVal) + } else { + lastRowDatum.SetNull() + } + case types.ETDecimal: + firstRowVal, firstRowIsNull, err := item.EvalDecimal(e.ctx, chk.GetRow(0)) + if err != nil { + return err + } + lastRowVal, lastRowIsNull, err := item.EvalDecimal(e.ctx, chk.GetRow(numRows-1)) + if err != nil { + return err + } + if !firstRowIsNull { + // make a copy to avoid DATA RACE + firstDatum := types.MyDecimal{} + err := firstDatum.FromString(firstRowVal.ToString()) + if err != nil { + return err + } + firstRowDatum.SetMysqlDecimal(&firstDatum) + } else { + firstRowDatum.SetNull() + } + if !lastRowIsNull { + // make a copy to avoid DATA RACE + lastDatum := types.MyDecimal{} + err := lastDatum.FromString(lastRowVal.ToString()) + if err != nil { + return err + } + lastRowDatum.SetMysqlDecimal(&lastDatum) + } else { + lastRowDatum.SetNull() + } + case types.ETDatetime, types.ETTimestamp: + firstRowVal, firstRowIsNull, err := item.EvalTime(e.ctx, chk.GetRow(0)) + if err != nil { + return err + } + lastRowVal, lastRowIsNull, err := item.EvalTime(e.ctx, chk.GetRow(numRows-1)) + if err != nil { + return err + } + if !firstRowIsNull { + firstRowDatum.SetMysqlTime(firstRowVal) + } else { + firstRowDatum.SetNull() + } + if !lastRowIsNull { + lastRowDatum.SetMysqlTime(lastRowVal) + } else { + lastRowDatum.SetNull() + } + case types.ETDuration: + firstRowVal, firstRowIsNull, err := item.EvalDuration(e.ctx, chk.GetRow(0)) + if err != nil { + return err + } + lastRowVal, lastRowIsNull, err := item.EvalDuration(e.ctx, chk.GetRow(numRows-1)) + if err != nil { + return err + } + if !firstRowIsNull { + firstRowDatum.SetMysqlDuration(firstRowVal) + } else { + firstRowDatum.SetNull() + } + if !lastRowIsNull { + lastRowDatum.SetMysqlDuration(lastRowVal) + } else { + lastRowDatum.SetNull() + } + case types.ETJson: + firstRowVal, firstRowIsNull, err := item.EvalJSON(e.ctx, chk.GetRow(0)) + if err != nil { + return err + } + lastRowVal, lastRowIsNull, err := item.EvalJSON(e.ctx, chk.GetRow(numRows-1)) + if err != nil { + return err + } + if !firstRowIsNull { + // make a copy to avoid DATA RACE + firstRowDatum.SetMysqlJSON(firstRowVal.Copy()) + } else { + firstRowDatum.SetNull() + } + if !lastRowIsNull { + // make a copy to avoid DATA RACE + lastRowDatum.SetMysqlJSON(lastRowVal.Copy()) + } else { + lastRowDatum.SetNull() + } + case types.ETString: + firstRowVal, firstRowIsNull, err := item.EvalString(e.ctx, chk.GetRow(0)) + if err != nil { + return err + } + lastRowVal, lastRowIsNull, err := item.EvalString(e.ctx, chk.GetRow(numRows-1)) + if err != nil { + return err + } + if !firstRowIsNull { + // make a copy to avoid DATA RACE + firstDatum := string([]byte(firstRowVal)) + firstRowDatum.SetString(firstDatum, tp.GetCollate()) + } else { + firstRowDatum.SetNull() + } + if !lastRowIsNull { + // make a copy to avoid DATA RACE + lastDatum := string([]byte(lastRowVal)) + lastRowDatum.SetString(lastDatum, tp.GetCollate()) + } else { + lastRowDatum.SetNull() + } + default: + err = fmt.Errorf("invalid eval type %v", eType) + return err + } + + e.firstRowDatums = append(e.firstRowDatums, firstRowDatum) + e.lastRowDatums = append(e.lastRowDatums, lastRowDatum) + return err +} + +// evalGroupItemsAndResolveGroups evaluates the chunk according to the expression item. +// And resolve the rows into groups according to the evaluation results +func (e *VecGroupChecker) evalGroupItemsAndResolveGroups( + item expression.Expression, chk *chunk.Chunk, numRows int) (err error) { + tp := item.GetType() + eType := tp.EvalType() + if e.allocateBuffer == nil { + e.allocateBuffer = expression.GetColumn + } + if e.releaseBuffer == nil { + e.releaseBuffer = expression.PutColumn + } + col, err := e.allocateBuffer(eType, numRows) + if err != nil { + return err + } + defer e.releaseBuffer(col) + err = expression.EvalExpr(e.ctx, item, eType, chk, col) + if err != nil { + return err + } + + previousIsNull := col.IsNull(0) + switch eType { + case types.ETInt: + vals := col.Int64s() + for i := 1; i < numRows; i++ { + isNull := col.IsNull(i) + if e.sameGroup[i] { + switch { + case !previousIsNull && !isNull: + if vals[i] != vals[i-1] { + e.sameGroup[i] = false + } + case isNull != previousIsNull: + e.sameGroup[i] = false + } + } + previousIsNull = isNull + } + case types.ETReal: + vals := col.Float64s() + for i := 1; i < numRows; i++ { + isNull := col.IsNull(i) + if e.sameGroup[i] { + switch { + case !previousIsNull && !isNull: + if vals[i] != vals[i-1] { + e.sameGroup[i] = false + } + case isNull != previousIsNull: + e.sameGroup[i] = false + } + } + previousIsNull = isNull + } + case types.ETDecimal: + vals := col.Decimals() + for i := 1; i < numRows; i++ { + isNull := col.IsNull(i) + if e.sameGroup[i] { + switch { + case !previousIsNull && !isNull: + if vals[i].Compare(&vals[i-1]) != 0 { + e.sameGroup[i] = false + } + case isNull != previousIsNull: + e.sameGroup[i] = false + } + } + previousIsNull = isNull + } + case types.ETDatetime, types.ETTimestamp: + vals := col.Times() + for i := 1; i < numRows; i++ { + isNull := col.IsNull(i) + if e.sameGroup[i] { + switch { + case !previousIsNull && !isNull: + if vals[i].Compare(vals[i-1]) != 0 { + e.sameGroup[i] = false + } + case isNull != previousIsNull: + e.sameGroup[i] = false + } + } + previousIsNull = isNull + } + case types.ETDuration: + vals := col.GoDurations() + for i := 1; i < numRows; i++ { + isNull := col.IsNull(i) + if e.sameGroup[i] { + switch { + case !previousIsNull && !isNull: + if vals[i] != vals[i-1] { + e.sameGroup[i] = false + } + case isNull != previousIsNull: + e.sameGroup[i] = false + } + } + previousIsNull = isNull + } + case types.ETJson: + var previousKey, key types.BinaryJSON + if !previousIsNull { + previousKey = col.GetJSON(0) + } + for i := 1; i < numRows; i++ { + isNull := col.IsNull(i) + if !isNull { + key = col.GetJSON(i) + } + if e.sameGroup[i] { + if isNull == previousIsNull { + if !isNull && types.CompareBinaryJSON(previousKey, key) != 0 { + e.sameGroup[i] = false + } + } else { + e.sameGroup[i] = false + } + } + if !isNull { + previousKey = key + } + previousIsNull = isNull + } + case types.ETString: + previousKey := codec.ConvertByCollationStr(col.GetString(0), tp) + for i := 1; i < numRows; i++ { + key := codec.ConvertByCollationStr(col.GetString(i), tp) + isNull := col.IsNull(i) + if e.sameGroup[i] { + if isNull != previousIsNull || previousKey != key { + e.sameGroup[i] = false + } + } + previousKey = key + previousIsNull = isNull + } + default: + err = fmt.Errorf("invalid eval type %v", eType) + } + if err != nil { + return err + } + + return err +} + +// GetNextGroup returns the begin and end position of the next group. +func (e *VecGroupChecker) GetNextGroup() (begin, end int) { + if e.nextGroupID == 0 { + begin = 0 + } else { + begin = e.groupOffset[e.nextGroupID-1] + } + end = e.groupOffset[e.nextGroupID] + e.nextGroupID++ + return begin, end +} + +// IsExhausted returns true if there is no more group to check. +func (e *VecGroupChecker) IsExhausted() bool { + return e.nextGroupID >= e.groupCount +} + +// Reset resets the group checker. +func (e *VecGroupChecker) Reset() { + if e.groupOffset != nil { + e.groupOffset = e.groupOffset[:0] + } + if e.sameGroup != nil { + e.sameGroup = e.sameGroup[:0] + } + if e.firstGroupKey != nil { + e.firstGroupKey = e.firstGroupKey[:0] + } + if e.lastGroupKey != nil { + e.lastGroupKey = e.lastGroupKey[:0] + } + if e.firstRowDatums != nil { + e.firstRowDatums = e.firstRowDatums[:0] + } + if e.lastRowDatums != nil { + e.lastRowDatums = e.lastRowDatums[:0] + } +} + +// GroupCount returns the number of groups. +func (e *VecGroupChecker) GroupCount() int { + return e.groupCount +} diff --git a/executor/internal/vecgroupchecker/vec_group_checker_test.go b/executor/internal/vecgroupchecker/vec_group_checker_test.go new file mode 100644 index 0000000000000..602132bbc6749 --- /dev/null +++ b/executor/internal/vecgroupchecker/vec_group_checker_test.go @@ -0,0 +1,272 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vecgroupchecker + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/mock" + "github.com/stretchr/testify/require" +) + +func TestVecGroupCheckerDATARACE(t *testing.T) { + ctx := mock.NewContext() + + mTypes := []byte{mysql.TypeVarString, mysql.TypeNewDecimal, mysql.TypeJSON} + for _, mType := range mTypes { + exprs := make([]expression.Expression, 1) + exprs[0] = &expression.Column{ + RetType: types.NewFieldTypeBuilder().SetType(mType).BuildP(), + Index: 0, + } + vgc := NewVecGroupChecker(ctx, exprs) + + fts := []*types.FieldType{types.NewFieldType(mType)} + chk := chunk.New(fts, 1, 1) + vgc.allocateBuffer = func(evalType types.EvalType, capacity int) (*chunk.Column, error) { + return chk.Column(0), nil + } + vgc.releaseBuffer = func(column *chunk.Column) {} + + switch mType { + case mysql.TypeVarString: + chk.Column(0).ReserveString(1) + chk.Column(0).AppendString("abc") + case mysql.TypeNewDecimal: + chk.Column(0).ResizeDecimal(1, false) + chk.Column(0).Decimals()[0] = *types.NewDecFromInt(123) + case mysql.TypeJSON: + chk.Column(0).ReserveJSON(1) + j := new(types.BinaryJSON) + require.NoError(t, j.UnmarshalJSON([]byte(fmt.Sprintf(`{"%v":%v}`, 123, 123)))) + chk.Column(0).AppendJSON(*j) + } + + _, err := vgc.SplitIntoGroups(chk) + require.NoError(t, err) + + switch mType { + case mysql.TypeVarString: + require.Equal(t, "abc", vgc.firstRowDatums[0].GetString()) + require.Equal(t, "abc", vgc.lastRowDatums[0].GetString()) + chk.Column(0).ReserveString(1) + chk.Column(0).AppendString("edf") + require.Equal(t, "abc", vgc.firstRowDatums[0].GetString()) + require.Equal(t, "abc", vgc.lastRowDatums[0].GetString()) + case mysql.TypeNewDecimal: + require.Equal(t, "123", vgc.firstRowDatums[0].GetMysqlDecimal().String()) + require.Equal(t, "123", vgc.lastRowDatums[0].GetMysqlDecimal().String()) + chk.Column(0).ResizeDecimal(1, false) + chk.Column(0).Decimals()[0] = *types.NewDecFromInt(456) + require.Equal(t, "123", vgc.firstRowDatums[0].GetMysqlDecimal().String()) + require.Equal(t, "123", vgc.lastRowDatums[0].GetMysqlDecimal().String()) + case mysql.TypeJSON: + require.Equal(t, `{"123": 123}`, vgc.firstRowDatums[0].GetMysqlJSON().String()) + require.Equal(t, `{"123": 123}`, vgc.lastRowDatums[0].GetMysqlJSON().String()) + chk.Column(0).ReserveJSON(1) + j := new(types.BinaryJSON) + require.NoError(t, j.UnmarshalJSON([]byte(fmt.Sprintf(`{"%v":%v}`, 456, 456)))) + chk.Column(0).AppendJSON(*j) + require.Equal(t, `{"123": 123}`, vgc.firstRowDatums[0].GetMysqlJSON().String()) + require.Equal(t, `{"123": 123}`, vgc.lastRowDatums[0].GetMysqlJSON().String()) + } + } +} + +func genTestChunk4VecGroupChecker(chkRows []int, sameNum int) (expr []expression.Expression, inputs []*chunk.Chunk) { + chkNum := len(chkRows) + numRows := 0 + inputs = make([]*chunk.Chunk, chkNum) + fts := make([]*types.FieldType, 1) + fts[0] = types.NewFieldType(mysql.TypeLonglong) + for i := 0; i < chkNum; i++ { + inputs[i] = chunk.New(fts, chkRows[i], chkRows[i]) + numRows += chkRows[i] + } + var numGroups int + if numRows%sameNum == 0 { + numGroups = numRows / sameNum + } else { + numGroups = numRows/sameNum + 1 + } + + rand.Seed(time.Now().Unix()) + nullPos := rand.Intn(numGroups) + cnt := 0 + val := rand.Int63() + for i := 0; i < chkNum; i++ { + col := inputs[i].Column(0) + col.ResizeInt64(chkRows[i], false) + i64s := col.Int64s() + for j := 0; j < chkRows[i]; j++ { + if cnt == sameNum { + val = rand.Int63() + cnt = 0 + nullPos-- + } + if nullPos == 0 { + col.SetNull(j, true) + } else { + i64s[j] = val + } + cnt++ + } + } + + expr = make([]expression.Expression, 1) + expr[0] = &expression.Column{ + RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetFlen(mysql.MaxIntWidth).BuildP(), + Index: 0, + } + return +} + +func TestVecGroupChecker4GroupCount(t *testing.T) { + testCases := []struct { + chunkRows []int + expectedGroups int + expectedFlag []bool + sameNum int + }{ + { + chunkRows: []int{1024, 1}, + expectedGroups: 1025, + expectedFlag: []bool{false, false}, + sameNum: 1, + }, + { + chunkRows: []int{1024, 1}, + expectedGroups: 1, + expectedFlag: []bool{false, true}, + sameNum: 1025, + }, + { + chunkRows: []int{1, 1}, + expectedGroups: 1, + expectedFlag: []bool{false, true}, + sameNum: 2, + }, + { + chunkRows: []int{1, 1}, + expectedGroups: 2, + expectedFlag: []bool{false, false}, + sameNum: 1, + }, + { + chunkRows: []int{2, 2}, + expectedGroups: 2, + expectedFlag: []bool{false, false}, + sameNum: 2, + }, + { + chunkRows: []int{2, 2}, + expectedGroups: 1, + expectedFlag: []bool{false, true}, + sameNum: 4, + }, + } + + ctx := mock.NewContext() + for _, testCase := range testCases { + expr, inputChks := genTestChunk4VecGroupChecker(testCase.chunkRows, testCase.sameNum) + groupChecker := NewVecGroupChecker(ctx, expr) + groupNum := 0 + for i, inputChk := range inputChks { + flag, err := groupChecker.SplitIntoGroups(inputChk) + require.NoError(t, err) + require.Equal(t, testCase.expectedFlag[i], flag) + if flag { + groupNum += groupChecker.GroupCount() - 1 + } else { + groupNum += groupChecker.GroupCount() + } + } + require.Equal(t, testCase.expectedGroups, groupNum) + } +} + +func TestVecGroupChecker(t *testing.T) { + tp := types.NewFieldTypeBuilder().SetType(mysql.TypeVarchar).BuildP() + col0 := &expression.Column{ + RetType: tp, + Index: 0, + } + ctx := mock.NewContext() + groupChecker := NewVecGroupChecker(ctx, []expression.Expression{col0}) + + chk := chunk.New([]*types.FieldType{tp}, 6, 6) + chk.Reset() + chk.Column(0).AppendString("aaa") + chk.Column(0).AppendString("AAA") + chk.Column(0).AppendString("😜") + chk.Column(0).AppendString("😃") + chk.Column(0).AppendString("À") + chk.Column(0).AppendString("A") + + tp.SetCollate("bin") + groupChecker.Reset() + _, err := groupChecker.SplitIntoGroups(chk) + require.NoError(t, err) + for i := 0; i < 6; i++ { + b, e := groupChecker.GetNextGroup() + require.Equal(t, b, i) + require.Equal(t, e, i+1) + } + require.True(t, groupChecker.IsExhausted()) + + tp.SetCollate("utf8_general_ci") + groupChecker.Reset() + _, err = groupChecker.SplitIntoGroups(chk) + require.NoError(t, err) + for i := 0; i < 3; i++ { + b, e := groupChecker.GetNextGroup() + require.Equal(t, b, i*2) + require.Equal(t, e, i*2+2) + } + require.True(t, groupChecker.IsExhausted()) + + tp.SetCollate("utf8_unicode_ci") + groupChecker.Reset() + _, err = groupChecker.SplitIntoGroups(chk) + require.NoError(t, err) + for i := 0; i < 3; i++ { + b, e := groupChecker.GetNextGroup() + require.Equal(t, b, i*2) + require.Equal(t, e, i*2+2) + } + require.True(t, groupChecker.IsExhausted()) + + // test padding + tp.SetCollate("utf8_bin") + tp.SetFlen(6) + chk.Reset() + chk.Column(0).AppendString("a") + chk.Column(0).AppendString("a ") + chk.Column(0).AppendString("a ") + groupChecker.Reset() + _, err = groupChecker.SplitIntoGroups(chk) + require.NoError(t, err) + b, e := groupChecker.GetNextGroup() + require.Equal(t, b, 0) + require.Equal(t, e, 3) + require.True(t, groupChecker.IsExhausted()) +} diff --git a/executor/merge_join.go b/executor/merge_join.go index ae2a113ceb217..9ec4934ef217f 100644 --- a/executor/merge_join.go +++ b/executor/merge_join.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/tidb/executor/internal/exec" + "github.com/pingcap/tidb/executor/internal/vecgroupchecker" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -63,7 +64,7 @@ type mergeJoinTable struct { executed bool childChunk *chunk.Chunk childChunkIter *chunk.Iterator4Chunk - groupChecker *vecGroupChecker + groupChecker *vecgroupchecker.VecGroupChecker groupRowsSelected []int groupRowsIter chunk.Iterator @@ -85,7 +86,7 @@ func (t *mergeJoinTable) init(exec *MergeJoinExec) { for _, col := range t.joinKeys { items = append(items, col) } - t.groupChecker = newVecGroupChecker(exec.Ctx(), items) + t.groupChecker = vecgroupchecker.NewVecGroupChecker(exec.Ctx(), items) t.groupRowsIter = chunk.NewIterator4Chunk(t.childChunk) if t.isInner { @@ -146,7 +147,7 @@ func (t *mergeJoinTable) finish() error { func (t *mergeJoinTable) selectNextGroup() { t.groupRowsSelected = t.groupRowsSelected[:0] - begin, end := t.groupChecker.getNextGroup() + begin, end := t.groupChecker.GetNextGroup() if t.isInner && t.hasNullInJoinKey(t.childChunk.GetRow(begin)) { return } @@ -175,7 +176,7 @@ func (t *mergeJoinTable) fetchNextInnerGroup(ctx context.Context, exec *MergeJoi } fetchNext: - if t.executed && t.groupChecker.isExhausted() { + if t.executed && t.groupChecker.IsExhausted() { // Ensure iter at the end, since sel of childChunk has been cleared. t.groupRowsIter.ReachEnd() return nil @@ -183,13 +184,13 @@ fetchNext: isEmpty := true // For inner table, rows have null in join keys should be skip by selectNextGroup. - for isEmpty && !t.groupChecker.isExhausted() { + for isEmpty && !t.groupChecker.IsExhausted() { t.selectNextGroup() isEmpty = len(t.groupRowsSelected) == 0 } // For inner table, all the rows have the same join keys should be put into one group. - for !t.executed && t.groupChecker.isExhausted() { + for !t.executed && t.groupChecker.IsExhausted() { if !isEmpty { // Group is not empty, hand over the management of childChunk to t.rowContainer. if err := t.rowContainer.Add(t.childChunk); err != nil { @@ -210,7 +211,7 @@ fetchNext: break } - isFirstGroupSameAsPrev, err := t.groupChecker.splitIntoGroups(t.childChunk) + isFirstGroupSameAsPrev, err := t.groupChecker.SplitIntoGroups(t.childChunk) if err != nil { return err } @@ -240,11 +241,11 @@ fetchNext: } func (t *mergeJoinTable) fetchNextOuterGroup(ctx context.Context, exec *MergeJoinExec, requiredRows int) error { - if t.executed && t.groupChecker.isExhausted() { + if t.executed && t.groupChecker.IsExhausted() { return nil } - if !t.executed && t.groupChecker.isExhausted() { + if !t.executed && t.groupChecker.IsExhausted() { // It's hard to calculate selectivity if there is any filter or it's inner join, // so we just push the requiredRows down when it's outer join and has no filter. if exec.isOuterJoin && len(t.filters) == 0 { @@ -261,7 +262,7 @@ func (t *mergeJoinTable) fetchNextOuterGroup(ctx context.Context, exec *MergeJoi return err } - _, err = t.groupChecker.splitIntoGroups(t.childChunk) + _, err = t.groupChecker.SplitIntoGroups(t.childChunk) if err != nil { return err } diff --git a/executor/pipelined_window.go b/executor/pipelined_window.go index 4811e313c3b23..705c0e476b3f1 100644 --- a/executor/pipelined_window.go +++ b/executor/pipelined_window.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/executor/aggfuncs" "github.com/pingcap/tidb/executor/internal/exec" + "github.com/pingcap/tidb/executor/internal/vecgroupchecker" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/planner/core" @@ -43,7 +44,7 @@ type PipelinedWindowExec struct { partialResults []aggfuncs.PartialResult start *core.FrameBound end *core.FrameBound - groupChecker *vecGroupChecker + groupChecker *vecgroupchecker.VecGroupChecker // childResult stores the child chunk. Note that even if remaining is 0, e.rows might still references rows in data[0].chk after returned it to upper executor, since there is no guarantee what the upper executor will do to the returned chunk, it might destroy the data (as in the benchmark test, it reused the chunk to pull data, and it will be chk.Reset(), causing panicking). So dataIdx, accumulated and dropped are added to ensure that chunk will only be returned if there is no row reference. childResult *chunk.Chunk @@ -176,7 +177,7 @@ func (e *PipelinedWindowExec) getRowsInPartition(ctx context.Context) (err error e.newPartition = false } - if e.groupChecker.isExhausted() { + if e.groupChecker.IsExhausted() { var drained, samePartition bool drained, err = e.fetchChild(ctx) if err != nil { @@ -187,7 +188,7 @@ func (e *PipelinedWindowExec) getRowsInPartition(ctx context.Context) (err error e.done = true return nil } - samePartition, err = e.groupChecker.splitIntoGroups(e.childResult) + samePartition, err = e.groupChecker.SplitIntoGroups(e.childResult) if samePartition { // the only case that when getRowsInPartition gets called, it is not a new partition. e.newPartition = false @@ -196,7 +197,7 @@ func (e *PipelinedWindowExec) getRowsInPartition(ctx context.Context) (err error return errors.Trace(err) } } - begin, end := e.groupChecker.getNextGroup() + begin, end := e.groupChecker.GetNextGroup() e.rowToConsume += uint64(end - begin) for i := begin; i < end; i++ { e.rows = append(e.rows, e.childResult.GetRow(i)) diff --git a/executor/shuffle.go b/executor/shuffle.go index 777b83642d927..8bccd2d5a83c9 100644 --- a/executor/shuffle.go +++ b/executor/shuffle.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/executor/internal/exec" + "github.com/pingcap/tidb/executor/internal/vecgroupchecker" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/channel" @@ -444,7 +445,7 @@ func buildPartitionHashSplitter(concurrency int, byItems []expression.Expression type partitionRangeSplitter struct { byItems []expression.Expression numWorkers int - groupChecker *vecGroupChecker + groupChecker *vecgroupchecker.VecGroupChecker idx int } @@ -452,7 +453,7 @@ func buildPartitionRangeSplitter(ctx sessionctx.Context, concurrency int, byItem return &partitionRangeSplitter{ byItems: byItems, numWorkers: concurrency, - groupChecker: newVecGroupChecker(ctx, byItems), + groupChecker: vecgroupchecker.NewVecGroupChecker(ctx, byItems), idx: 0, } } @@ -461,14 +462,14 @@ func buildPartitionRangeSplitter(ctx sessionctx.Context, concurrency int, byItem // the caller of this method should guarantee that `input` is grouped, // which means that rows with the same byItems should be continuous, the order does not matter. func (s *partitionRangeSplitter) split(_ sessionctx.Context, input *chunk.Chunk, workerIndices []int) ([]int, error) { - _, err := s.groupChecker.splitIntoGroups(input) + _, err := s.groupChecker.SplitIntoGroups(input) if err != nil { return workerIndices, err } workerIndices = workerIndices[:0] - for !s.groupChecker.isExhausted() { - begin, end := s.groupChecker.getNextGroup() + for !s.groupChecker.IsExhausted() { + begin, end := s.groupChecker.GetNextGroup() for i := begin; i < end; i++ { workerIndices = append(workerIndices, s.idx) } diff --git a/executor/window.go b/executor/window.go index 3c155b3218ede..91319d6ef0065 100644 --- a/executor/window.go +++ b/executor/window.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/executor/aggfuncs" "github.com/pingcap/tidb/executor/internal/exec" + "github.com/pingcap/tidb/executor/internal/vecgroupchecker" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/planner/core" @@ -32,7 +33,7 @@ import ( type WindowExec struct { exec.BaseExecutor - groupChecker *vecGroupChecker + groupChecker *vecgroupchecker.VecGroupChecker // childResult stores the child chunk childResult *chunk.Chunk // executed indicates the child executor is drained or something unexpected happened. @@ -76,7 +77,7 @@ func (e *WindowExec) preparedChunkAvailable() bool { func (e *WindowExec) consumeOneGroup(ctx context.Context) error { var groupRows []chunk.Row - if e.groupChecker.isExhausted() { + if e.groupChecker.IsExhausted() { eof, err := e.fetchChild(ctx) if err != nil { return errors.Trace(err) @@ -85,12 +86,12 @@ func (e *WindowExec) consumeOneGroup(ctx context.Context) error { e.executed = true return e.consumeGroupRows(groupRows) } - _, err = e.groupChecker.splitIntoGroups(e.childResult) + _, err = e.groupChecker.SplitIntoGroups(e.childResult) if err != nil { return errors.Trace(err) } } - begin, end := e.groupChecker.getNextGroup() + begin, end := e.groupChecker.GetNextGroup() for i := begin; i < end; i++ { groupRows = append(groupRows, e.childResult.GetRow(i)) } @@ -106,13 +107,13 @@ func (e *WindowExec) consumeOneGroup(ctx context.Context) error { return e.consumeGroupRows(groupRows) } - isFirstGroupSameAsPrev, err := e.groupChecker.splitIntoGroups(e.childResult) + isFirstGroupSameAsPrev, err := e.groupChecker.SplitIntoGroups(e.childResult) if err != nil { return errors.Trace(err) } if isFirstGroupSameAsPrev { - begin, end = e.groupChecker.getNextGroup() + begin, end = e.groupChecker.GetNextGroup() for i := begin; i < end; i++ { groupRows = append(groupRows, e.childResult.GetRow(i)) }