Skip to content

Commit

Permalink
Merge branch 'master' into issue#29993
Browse files Browse the repository at this point in the history
  • Loading branch information
qw4990 authored Nov 23, 2021
2 parents 5314597 + 7fdafb4 commit 554649d
Show file tree
Hide file tree
Showing 17 changed files with 1,181 additions and 892 deletions.
38 changes: 22 additions & 16 deletions executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/set"
Expand Down Expand Up @@ -317,6 +318,7 @@ func testMergePartialResult(t *testing.T, p aggTest) {
iter := chunk.NewIterator4Chunk(srcChk)

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
ctor := collate.GetCollator(p.dataType.Collate)
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
Expand Down Expand Up @@ -359,7 +361,7 @@ func testMergePartialResult(t *testing.T, p aggTest) {
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0])

Expand All @@ -386,7 +388,7 @@ func testMergePartialResult(t *testing.T, p aggTest) {
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1])
_, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr)
Expand All @@ -409,7 +411,7 @@ func testMergePartialResult(t *testing.T, p aggTest) {
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[2])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[2])
}
Expand Down Expand Up @@ -447,6 +449,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul
{Expr: args[0], Desc: true},
}
}
ctor := collate.GetCollator(args[0].GetType().Collate)
partialDesc, finalDesc := desc.Split([]int{0, 1})

// build partial func for partial phase.
Expand All @@ -467,7 +470,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul
err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, p.retType)
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zero(t, result)

Expand All @@ -488,7 +491,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul
err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zero(t, result)
_, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr)
Expand All @@ -499,7 +502,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul
require.NoError(t, err)

dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[2])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor)
require.NoError(t, err)
require.Zero(t, result)
}
Expand Down Expand Up @@ -570,6 +573,7 @@ func testAggFunc(t *testing.T, p aggTest) {
ctx := mock.NewContext()

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
ctor := collate.GetCollator(p.dataType.Collate)
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
Expand All @@ -596,7 +600,7 @@ func testAggFunc(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -606,7 +610,7 @@ func testAggFunc(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0])

Expand Down Expand Up @@ -639,7 +643,7 @@ func testAggFunc(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -649,7 +653,7 @@ func testAggFunc(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0])
}
Expand All @@ -658,6 +662,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) {
srcChk := p.genSrcChk()

args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
ctor := collate.GetCollator(p.dataType.Collate)
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
Expand Down Expand Up @@ -685,7 +690,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -695,7 +700,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[0])
}
Expand Down Expand Up @@ -749,6 +754,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
{Expr: args[0], Desc: true},
}
}
ctor := collate.GetCollator(args[0].GetType().Collate)
finalFunc := aggfuncs.Build(ctx, desc, 0)
finalPr, _ := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
Expand All @@ -762,7 +768,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -772,7 +778,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[0])

Expand Down Expand Up @@ -805,7 +811,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[1])

Expand All @@ -815,7 +821,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0])
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zero(t, result)
}
Expand Down
9 changes: 0 additions & 9 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,15 +651,6 @@ func TestGroupConcatAggr(t *testing.T) {
rows.Check(testkit.Rows("01234567", "12345"))
}

func fillData(tk *testkit.TestKit, table string) {
tk.MustExec("use test")
tk.MustExec(fmt.Sprintf("create table %s(id int not null default 1, name varchar(255), PRIMARY KEY(id));", table))

// insert data
tk.MustExec(fmt.Sprintf("insert INTO %s VALUES (1, \"hello\");", table))
tk.MustExec(fmt.Sprintf("insert into %s values (2, \"hello\");", table))
}

func TestSelectDistinct(t *testing.T) {
t.Parallel()
store, clean := testkit.CreateMockStore(t)
Expand Down
6 changes: 3 additions & 3 deletions executor/analyze_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func (s *testFastAnalyze) TestAnalyzeFastSample(c *C) {
samples := mockExec.Collectors[i].Samples
c.Assert(len(samples), Equals, 20)
for j := 1; j < 20; j++ {
cmp, err := samples[j].Value.CompareDatum(tk.Se.GetSessionVars().StmtCtx, &samples[j-1].Value)
cmp, err := samples[j].Value.Compare(tk.Se.GetSessionVars().StmtCtx, &samples[j-1].Value, collate.GetBinaryCollator())
c.Assert(err, IsNil)
c.Assert(cmp, Greater, 0)
}
Expand All @@ -381,15 +381,15 @@ func (s *testFastAnalyze) TestAnalyzeFastSample(c *C) {
func checkHistogram(sc *stmtctx.StatementContext, hg *statistics.Histogram) (bool, error) {
for i := 0; i < len(hg.Buckets); i++ {
lower, upper := hg.GetLower(i), hg.GetUpper(i)
cmp, err := upper.CompareDatum(sc, lower)
cmp, err := upper.Compare(sc, lower, collate.GetBinaryCollator())
if cmp < 0 || err != nil {
return false, err
}
if i == 0 {
continue
}
previousUpper := hg.GetUpper(i - 1)
cmp, err = lower.CompareDatum(sc, previousUpper)
cmp, err = lower.Compare(sc, previousUpper, collate.GetBinaryCollator())
if cmp <= 0 || err != nil {
return false, err
}
Expand Down
3 changes: 2 additions & 1 deletion executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4260,7 +4260,8 @@ func (b *executorBuilder) buildShuffle(v *plannercore.PhysicalShuffle) *ShuffleE

for j, dataSource := range v.DataSources {
stub := plannercore.PhysicalShuffleReceiverStub{
Receiver: (unsafe.Pointer)(receivers[j]),
Receiver: (unsafe.Pointer)(receivers[j]),
DataSource: dataSource,
}.Init(b.ctx, dataSource.Stats(), dataSource.SelectBlockOffset(), nil)
stub.SetSchema(dataSource.Schema())
v.Tails[j].SetChildren(stub)
Expand Down
25 changes: 12 additions & 13 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import (
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/tablecodec"
testkit2 "github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/admin"
Expand All @@ -81,6 +82,7 @@ import (
"github.com/pingcap/tidb/util/testutil"
"github.com/pingcap/tidb/util/timeutil"
"github.com/pingcap/tipb/go-tipb"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/testutils"
"github.com/tikv/client-go/v2/tikv"
Expand Down Expand Up @@ -112,7 +114,6 @@ var _ = SerialSuites(&testSuiteJoinSerial{&baseTestSuite{}})
var _ = Suite(&testSuite6{&baseTestSuite{}})
var _ = Suite(&testSuite7{&baseTestSuite{}})
var _ = Suite(&testSuite8{&baseTestSuite{}})
var _ = Suite(&testBypassSuite{})
var _ = Suite(&testUpdateSuite{})
var _ = Suite(&testPointGetSuite{})
var _ = SerialSuites(&testRecoverTable{})
Expand Down Expand Up @@ -617,35 +618,33 @@ type testCase struct {
}

func checkCases(tests []testCase, ld *executor.LoadDataInfo,
c *C, tk *testkit.TestKit, ctx sessionctx.Context, selectSQL, deleteSQL string) {
t *testing.T, tk *testkit2.TestKit, ctx sessionctx.Context, selectSQL, deleteSQL string) {
origin := ld.IgnoreLines
for _, tt := range tests {
ld.IgnoreLines = origin
c.Assert(ctx.NewTxn(context.Background()), IsNil)
require.Nil(t, ctx.NewTxn(context.Background()))
ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true
ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true
ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true
ctx.GetSessionVars().StmtCtx.InDeleteStmt = false
data, reachLimit, err1 := ld.InsertData(context.Background(), tt.data1, tt.data2)
c.Assert(err1, IsNil)
c.Assert(reachLimit, IsFalse)
require.NoError(t, err1)
require.False(t, reachLimit)
err1 = ld.CheckAndInsertOneBatch(context.Background(), ld.GetRows(), ld.GetCurBatchCnt())
c.Assert(err1, IsNil)
require.NoError(t, err1)
ld.SetMaxRowsInBatch(20000)
if tt.restData == nil {
c.Assert(data, HasLen, 0,
Commentf("data1:%v, data2:%v, data:%v", string(tt.data1), string(tt.data2), string(data)))
require.Len(t, data, 0, "data1:%v, data2:%v, data:%v", string(tt.data1), string(tt.data2), string(data))
} else {
c.Assert(data, DeepEquals, tt.restData,
Commentf("data1:%v, data2:%v, data:%v", string(tt.data1), string(tt.data2), string(data)))
require.Equal(t, tt.restData, data, "data1:%v, data2:%v, data:%v", string(tt.data1), string(tt.data2), string(data))
}
ld.SetMessage()
tk.CheckLastMessage(tt.expectedMsg)
require.Equal(t, tt.expectedMsg, tk.Session().LastMessage())
ctx.StmtCommit()
txn, err := ctx.Txn(true)
c.Assert(err, IsNil)
require.NoError(t, err)
err = txn.Commit(context.Background())
c.Assert(err, IsNil)
require.NoError(t, err)
r := tk.MustQuery(selectSQL)
r.Check(testutil.RowsWithSep("|", tt.expected...))
tk.MustExec(deleteSQL)
Expand Down
11 changes: 11 additions & 0 deletions executor/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
package executor_test

import (
"fmt"
"os"
"testing"

"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/testkit/testdata"
"github.com/pingcap/tidb/testkit/testmain"
"github.com/pingcap/tidb/util/testbridge"
Expand Down Expand Up @@ -64,3 +66,12 @@ func TestMain(m *testing.M) {

goleak.VerifyTestMain(testmain.WrapTestingM(m, callback), opts...)
}

func fillData(tk *testkit.TestKit, table string) {
tk.MustExec("use test")
tk.MustExec(fmt.Sprintf("create table %s(id int not null default 1, name varchar(255), PRIMARY KEY(id));", table))

// insert data
tk.MustExec(fmt.Sprintf("insert INTO %s VALUES (1, \"hello\");", table))
tk.MustExec(fmt.Sprintf("insert into %s values (2, \"hello\");", table))
}
Loading

0 comments on commit 554649d

Please sign in to comment.