From c685c6cdfc90f6ad9603bfcfe48ed752e27b223a Mon Sep 17 00:00:00 2001 From: Chengpeng Yan <41809508+Reminiscent@users.noreply.github.com> Date: Fri, 17 Apr 2020 22:43:32 +0800 Subject: [PATCH] cherry pick #16014 to release-2.1 Signed-off-by: sre-bot --- expression/bench_test.go | 1806 ++++++++++++++++++++++++++++++++ expression/expression.go | 350 +++++++ expression/integration_test.go | 614 +++++++++++ types/datum.go | 2 +- types/datum_test.go | 4 +- 5 files changed, 2773 insertions(+), 3 deletions(-) diff --git a/expression/bench_test.go b/expression/bench_test.go index bc4a58ed81b45..6aebfdef3e1d6 100644 --- a/expression/bench_test.go +++ b/expression/bench_test.go @@ -188,3 +188,1809 @@ func BenchmarkScalarFunctionClone(b *testing.B) { } b.ReportAllocs() } +<<<<<<< HEAD +======= + +func getRandomTime(r *rand.Rand) types.CoreTime { + return types.FromDate(r.Intn(2200), r.Intn(10)+1, r.Intn(20)+1, + r.Intn(12), r.Intn(60), r.Intn(60), r.Intn(1000000)) + +} + +// dataGenerator is used to generate data for test. +type dataGenerator interface { + gen() interface{} +} + +type defaultRandGen struct { + *rand.Rand +} + +func newDefaultRandGen() *defaultRandGen { + return &defaultRandGen{rand.New(rand.NewSource(int64(rand.Uint64())))} +} + +type defaultGener struct { + nullRation float64 + eType types.EvalType + randGen *defaultRandGen +} + +func newDefaultGener(nullRation float64, eType types.EvalType) *defaultGener { + return &defaultGener{ + nullRation: nullRation, + eType: eType, + randGen: newDefaultRandGen(), + } +} + +func (g *defaultGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + switch g.eType { + case types.ETInt: + if g.randGen.Float64() < 0.5 { + return -g.randGen.Int63() + } + return g.randGen.Int63() + case types.ETReal: + if g.randGen.Float64() < 0.5 { + return -g.randGen.Float64() * 1000000 + } + return g.randGen.Float64() * 1000000 + case types.ETDecimal: + d := new(types.MyDecimal) + var f float64 + if g.randGen.Float64() < 0.5 { + f = g.randGen.Float64() * 100000 + } else { + f = -g.randGen.Float64() * 100000 + } + if err := d.FromFloat64(f); err != nil { + panic(err) + } + return d + case types.ETDatetime, types.ETTimestamp: + gt := getRandomTime(g.randGen.Rand) + t := types.NewTime(gt, convertETType(g.eType), 0) + return t + case types.ETDuration: + d := types.Duration{ + // use rand.Int32() to make it not overflow when AddDuration + Duration: time.Duration(g.randGen.Int31()), + } + return d + case types.ETJson: + j := new(json.BinaryJSON) + if err := j.UnmarshalJSON([]byte(fmt.Sprintf(`{"key":%v}`, g.randGen.Int()))); err != nil { + panic(err) + } + return *j + case types.ETString: + return randString(g.randGen.Rand) + } + return nil +} + +// charInt64Gener is used to generate int which is equal to char's ascii +type charInt64Gener struct{} + +func (g *charInt64Gener) gen() interface{} { + rand := time.Now().Nanosecond() + rand = rand % 1024 + return int64(rand) +} + +// charsetStringGener is used to generate "ascii" or "gbk" +type charsetStringGener struct{} + +func (g *charsetStringGener) gen() interface{} { + rand := time.Now().Nanosecond() % 3 + if rand == 0 { + return "ascii" + } + if rand == 1 { + return "utf8" + } + return "gbk" +} + +// selectStringGener select one string randomly from the candidates array +type selectStringGener struct { + candidates []string + randGen *defaultRandGen +} + +func newSelectStringGener(candidates []string) *selectStringGener { + return &selectStringGener{candidates, newDefaultRandGen()} +} + +func (g *selectStringGener) gen() interface{} { + if len(g.candidates) == 0 { + return nil + } + return g.candidates[g.randGen.Intn(len(g.candidates))] +} + +// selectRealGener select one real number randomly from the candidates array +type selectRealGener struct { + candidates []float64 + randGen *defaultRandGen +} + +func newSelectRealGener(candidates []float64) *selectRealGener { + return &selectRealGener{candidates, newDefaultRandGen()} +} + +func (g *selectRealGener) gen() interface{} { + if len(g.candidates) == 0 { + return nil + } + return g.candidates[g.randGen.Intn(len(g.candidates))] +} + +type constJSONGener struct { + jsonStr string +} + +func (g *constJSONGener) gen() interface{} { + j := new(json.BinaryJSON) + if err := j.UnmarshalJSON([]byte(g.jsonStr)); err != nil { + panic(err) + } + return *j +} + +type decimalJSONGener struct { + nullRation float64 + randGen *defaultRandGen +} + +func newDecimalJSONGener(nullRation float64) *decimalJSONGener { + return &decimalJSONGener{nullRation, newDefaultRandGen()} +} + +func (g *decimalJSONGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + + var f float64 + if g.randGen.Float64() < 0.5 { + f = g.randGen.Float64() * 100000 + } else { + f = -g.randGen.Float64() * 100000 + } + if err := (&types.MyDecimal{}).FromFloat64(f); err != nil { + panic(err) + } + return json.CreateBinary(f) +} + +type jsonStringGener struct { + randGen *defaultRandGen +} + +func newJSONStringGener() *jsonStringGener { + return &jsonStringGener{newDefaultRandGen()} +} + +func (g *jsonStringGener) gen() interface{} { + j := new(json.BinaryJSON) + if err := j.UnmarshalJSON([]byte(fmt.Sprintf(`{"key":%v}`, g.randGen.Int()))); err != nil { + panic(err) + } + return j.String() +} + +type decimalStringGener struct { + randGen *defaultRandGen +} + +func newDecimalStringGener() *decimalStringGener { + return &decimalStringGener{newDefaultRandGen()} +} + +func (g *decimalStringGener) gen() interface{} { + tempDecimal := new(types.MyDecimal) + if err := tempDecimal.FromFloat64(g.randGen.Float64()); err != nil { + panic(err) + } + return tempDecimal.String() +} + +type realStringGener struct { + randGen *defaultRandGen +} + +func newRealStringGener() *realStringGener { + return &realStringGener{newDefaultRandGen()} +} + +func (g *realStringGener) gen() interface{} { + return fmt.Sprintf("%f", g.randGen.Float64()) +} + +type jsonTimeGener struct { + randGen *defaultRandGen +} + +func newJSONTimeGener() *jsonTimeGener { + return &jsonTimeGener{newDefaultRandGen()} +} + +func (g *jsonTimeGener) gen() interface{} { + tm := types.NewTime(getRandomTime(g.randGen.Rand), mysql.TypeDatetime, types.DefaultFsp) + return json.CreateBinary(tm.String()) +} + +type rangeDurationGener struct { + nullRation float64 + randGen *defaultRandGen +} + +func newRangeDurationGener(nullRation float64) *rangeDurationGener { + return &rangeDurationGener{nullRation, newDefaultRandGen()} +} + +func (g *rangeDurationGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + tm := (math.Abs(g.randGen.Int63n(12))*3600 + math.Abs(g.randGen.Int63n(60))*60 + math.Abs(g.randGen.Int63n(60))) * 1000 + tu := (tm + math.Abs(g.randGen.Int63n(1000))) * 1000 + return types.Duration{ + Duration: time.Duration(tu * 1000)} +} + +type timeFormatGener struct { + nullRation float64 + randGen *defaultRandGen +} + +func newTimeFormatGener(nullRation float64) *timeFormatGener { + return &timeFormatGener{nullRation, newDefaultRandGen()} +} + +func (g *timeFormatGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + switch g.randGen.Uint32() % 4 { + case 0: + return "%H %i %S" + case 1: + return "%l %i %s" + case 2: + return "%p %i %s" + case 3: + return "%I %i %S %f" + case 4: + return "%T" + default: + return nil + } +} + +// rangeRealGener is used to generate float64 items in [begin, end]. +type rangeRealGener struct { + begin float64 + end float64 + + nullRation float64 + randGen *defaultRandGen +} + +func newRangeRealGener(begin, end, nullRation float64) *rangeRealGener { + return &rangeRealGener{begin, end, nullRation, newDefaultRandGen()} +} + +func (g *rangeRealGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + if g.end < g.begin { + g.begin = -100 + g.end = 100 + } + return g.randGen.Float64()*(g.end-g.begin) + g.begin +} + +// rangeDecimalGener is used to generate decimal items in [begin, end]. +type rangeDecimalGener struct { + begin float64 + end float64 + + nullRation float64 + randGen *defaultRandGen +} + +func newRangeDecimalGener(begin, end, nullRation float64) *rangeDecimalGener { + return &rangeDecimalGener{begin, end, nullRation, newDefaultRandGen()} +} + +func (g *rangeDecimalGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + if g.end < g.begin { + g.begin = -100000 + g.end = 100000 + } + d := new(types.MyDecimal) + f := g.randGen.Float64()*(g.end-g.begin) + g.begin + if err := d.FromFloat64(f); err != nil { + panic(err) + } + return d +} + +// rangeInt64Gener is used to generate int64 items in [begin, end). +type rangeInt64Gener struct { + begin int + end int + randGen *defaultRandGen +} + +func newRangeInt64Gener(begin, end int) *rangeInt64Gener { + return &rangeInt64Gener{begin, end, newDefaultRandGen()} +} + +func (rig *rangeInt64Gener) gen() interface{} { + return int64(rig.randGen.Intn(rig.end-rig.begin) + rig.begin) +} + +// numStrGener is used to generate number strings. +type numStrGener struct { + rangeInt64Gener +} + +func (g *numStrGener) gen() interface{} { + return fmt.Sprintf("%v", g.rangeInt64Gener.gen()) +} + +// ipv6StrGener is used to generate ipv6 strings. +type ipv6StrGener struct { + randGen *defaultRandGen +} + +func (g *ipv6StrGener) gen() interface{} { + var ip net.IP = make([]byte, net.IPv6len) + for i := range ip { + ip[i] = uint8(g.randGen.Intn(256)) + } + return ip.String() +} + +// ipv4StrGener is used to generate ipv4 strings. For example 111.111.111.111 +type ipv4StrGener struct { + randGen *defaultRandGen +} + +func (g *ipv4StrGener) gen() interface{} { + var ip net.IP = make([]byte, net.IPv4len) + for i := range ip { + ip[i] = uint8(g.randGen.Intn(256)) + } + return ip.String() +} + +// ipv6ByteGener is used to generate ipv6 address in 16 bytes string. +type ipv6ByteGener struct { + randGen *defaultRandGen +} + +func (g *ipv6ByteGener) gen() interface{} { + var ip = make([]byte, net.IPv6len) + for i := range ip { + ip[i] = uint8(g.randGen.Intn(256)) + } + return string(ip[:net.IPv6len]) +} + +// ipv4ByteGener is used to generate ipv4 address in 4 bytes string. +type ipv4ByteGener struct { + randGen *defaultRandGen +} + +func (g *ipv4ByteGener) gen() interface{} { + var ip = make([]byte, net.IPv4len) + for i := range ip { + ip[i] = uint8(g.randGen.Intn(256)) + } + return string(ip[:net.IPv4len]) +} + +// ipv4Compat is used to generate ipv4 compatible ipv6 strings +type ipv4CompatByteGener struct { + randGen *defaultRandGen +} + +func (g *ipv4CompatByteGener) gen() interface{} { + var ip = make([]byte, net.IPv6len) + for i := range ip { + if i < 12 { + ip[i] = 0 + } else { + ip[i] = uint8(g.randGen.Intn(256)) + } + } + return string(ip[:net.IPv6len]) +} + +// ipv4MappedByteGener is used to generate ipv4-mapped ipv6 bytes. +type ipv4MappedByteGener struct { + randGen *defaultRandGen +} + +func (g *ipv4MappedByteGener) gen() interface{} { + var ip = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0, 0, 0, 0} + for i := 12; i < 16; i++ { + ip[i] = uint8(g.randGen.Intn(256)) // reset the last 4 bytes + } + return string(ip[:net.IPv6len]) +} + +// randLenStrGener is used to generate strings whose lengths are in [lenBegin, lenEnd). +type randLenStrGener struct { + lenBegin int + lenEnd int + randGen *defaultRandGen +} + +func newRandLenStrGener(lenBegin, lenEnd int) *randLenStrGener { + return &randLenStrGener{lenBegin, lenEnd, newDefaultRandGen()} +} + +func (g *randLenStrGener) gen() interface{} { + n := g.randGen.Intn(g.lenEnd-g.lenBegin) + g.lenBegin + buf := make([]byte, n) + for i := range buf { + x := g.randGen.Intn(62) + if x < 10 { + buf[i] = byte('0' + x) + } else if x-10 < 26 { + buf[i] = byte('a' + x - 10) + } else { + buf[i] = byte('A' + x - 10 - 26) + } + } + return string(buf) +} + +type randHexStrGener struct { + lenBegin int + lenEnd int + randGen *defaultRandGen +} + +func newRandHexStrGener(lenBegin, lenEnd int) *randHexStrGener { + return &randHexStrGener{lenBegin, lenEnd, newDefaultRandGen()} +} + +func (g *randHexStrGener) gen() interface{} { + n := g.randGen.Intn(g.lenEnd-g.lenBegin) + g.lenBegin + buf := make([]byte, n) + for i := range buf { + x := g.randGen.Intn(16) + if x < 10 { + buf[i] = byte('0' + x) + } else { + if x%2 == 0 { + buf[i] = byte('a' + x - 10) + } else { + buf[i] = byte('A' + x - 10) + } + } + } + return string(buf) +} + +// dateTimeGener is used to generate a dataTime +type dateTimeGener struct { + Fsp int + Year int + Month int + Day int + randGen *defaultRandGen +} + +func newDateTimeGener(fsp, year, month, day int) *dateTimeGener { + return &dateTimeGener{fsp, year, month, day, newDefaultRandGen()} +} + +func (g *dateTimeGener) gen() interface{} { + if g.Year == 0 { + g.Year = 1970 + g.randGen.Intn(100) + } + if g.Month == 0 { + g.Month = g.randGen.Intn(10) + 1 + } + if g.Day == 0 { + g.Day = g.randGen.Intn(20) + 1 + } + var gt types.CoreTime + if g.Fsp > 0 && g.Fsp <= 6 { + gt = types.FromDate(g.Year, g.Month, g.Day, g.randGen.Intn(12), g.randGen.Intn(60), g.randGen.Intn(60), g.randGen.Intn(1000000)) + } else { + gt = types.FromDate(g.Year, g.Month, g.Day, g.randGen.Intn(12), g.randGen.Intn(60), g.randGen.Intn(60), 0) + } + t := types.NewTime(gt, mysql.TypeDatetime, types.DefaultFsp) + return t +} + +// dateTimeStrGener is used to generate strings which are dataTime format +type dateTimeStrGener struct { + Fsp int + Year int + Month int + Day int + randGen *defaultRandGen +} + +func (g *dateTimeStrGener) gen() interface{} { + if g.Year == 0 { + g.Year = 1970 + g.randGen.Intn(100) + } + if g.Month == 0 { + g.Month = g.randGen.Intn(10) + 1 + } + if g.Day == 0 { + g.Day = g.randGen.Intn(20) + 1 + } + hour := g.randGen.Intn(12) + minute := g.randGen.Intn(60) + second := g.randGen.Intn(60) + dataTimeStr := fmt.Sprintf("%d-%d-%d %d:%d:%d", + g.Year, g.Month, g.Day, hour, minute, second) + if g.Fsp > 0 && g.Fsp <= 6 { + microFmt := fmt.Sprintf(".%%0%dd", g.Fsp) + return dataTimeStr + fmt.Sprintf(microFmt, g.randGen.Int()%(10^g.Fsp)) + } + + return dataTimeStr +} + +// dateStrGener is used to generate strings which are date format +type dateStrGener struct { + Year int + Month int + Day int + NullRation float64 + randGen *defaultRandGen +} + +func (g *dateStrGener) gen() interface{} { + if g.NullRation > 1e-6 && g.randGen.Float64() < g.NullRation { + return nil + } + + if g.Year == 0 { + g.Year = 1970 + g.randGen.Intn(100) + } + if g.Month == 0 { + g.Month = g.randGen.Intn(10) + 1 + } + if g.Day == 0 { + g.Day = g.randGen.Intn(20) + 1 + } + + return fmt.Sprintf("%d-%d-%d", g.Year, g.Month, g.Day) +} + +// timeStrGener is used to generate strings which are time format +type timeStrGener struct { + nullRation float64 + randGen *defaultRandGen +} + +func (g *timeStrGener) gen() interface{} { + if g.nullRation > 1e-6 && g.randGen.Float64() < g.nullRation { + return nil + } + hour := g.randGen.Intn(12) + minute := g.randGen.Intn(60) + second := g.randGen.Intn(60) + + return fmt.Sprintf("%d:%d:%d", hour, minute, second) +} + +type dateTimeIntGener struct { + dateTimeGener + nullRation float64 +} + +func (g *dateTimeIntGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + + t := g.dateTimeGener.gen().(types.Time) + num, err := t.ToNumber().ToInt() + if err != nil { + panic(err) + } + return num +} + +// constStrGener always returns the given string +type constStrGener struct { + s string +} + +func (g *constStrGener) gen() interface{} { + return g.s +} + +type randDurInt struct { + randGen *defaultRandGen +} + +func newRandDurInt() *randDurInt { + return &randDurInt{newDefaultRandGen()} +} + +func (g *randDurInt) gen() interface{} { + return int64(g.randGen.Intn(types.TimeMaxHour)*10000 + g.randGen.Intn(60)*100 + g.randGen.Intn(60)) +} + +type randDurReal struct { + randGen *defaultRandGen +} + +func newRandDurReal() *randDurReal { + return &randDurReal{newDefaultRandGen()} +} + +func (g *randDurReal) gen() interface{} { + return float64(g.randGen.Intn(types.TimeMaxHour)*10000 + g.randGen.Intn(60)*100 + g.randGen.Intn(60)) +} + +type randDurDecimal struct { + randGen *defaultRandGen +} + +func newRandDurDecimal() *randDurDecimal { + return &randDurDecimal{newDefaultRandGen()} +} + +func (g *randDurDecimal) gen() interface{} { + d := new(types.MyDecimal) + return d.FromFloat64(float64(g.randGen.Intn(types.TimeMaxHour)*10000 + g.randGen.Intn(60)*100 + g.randGen.Intn(60))) +} + +type randDurString struct{} + +func (g *randDurString) gen() interface{} { + return strconv.Itoa(rand.Intn(types.TimeMaxHour)*10000 + rand.Intn(60)*100 + rand.Intn(60)) +} + +// locationGener is used to generate location for the built-in function GetFormat. +type locationGener struct { + nullRation float64 + randGen *defaultRandGen +} + +func newLocationGener(nullRation float64) *locationGener { + return &locationGener{nullRation, newDefaultRandGen()} +} + +func (g *locationGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + switch g.randGen.Uint32() % 5 { + case 0: + return usaLocation + case 1: + return jisLocation + case 2: + return isoLocation + case 3: + return eurLocation + case 4: + return internalLocation + default: + return nil + } +} + +// formatGener is used to generate a format for the built-in function GetFormat. +type formatGener struct { + nullRation float64 + randGen *defaultRandGen +} + +func newFormatGener(nullRation float64) *formatGener { + return &formatGener{nullRation, newDefaultRandGen()} +} + +func (g *formatGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + switch g.randGen.Uint32() % 4 { + case 0: + return dateFormat + case 1: + return datetimeFormat + case 2: + return timestampFormat + case 3: + return timeFormat + default: + return nil + } +} + +type nullWrappedGener struct { + nullRation float64 + inner dataGenerator + randGen *defaultRandGen +} + +func newNullWrappedGener(nullRation float64, inner dataGenerator) *nullWrappedGener { + return &nullWrappedGener{nullRation, inner, newDefaultRandGen()} +} + +func (g *nullWrappedGener) gen() interface{} { + if g.randGen.Float64() < g.nullRation { + return nil + } + return g.inner.gen() +} + +type vecExprBenchCase struct { + // retEvalType is the EvalType of the expression result. + // This field is required. + retEvalType types.EvalType + // childrenTypes is the EvalTypes of the expression children(arguments). + // This field is required. + childrenTypes []types.EvalType + // childrenFieldTypes is the field types of the expression children(arguments). + // If childrenFieldTypes is not set, it will be converted from childrenTypes. + // This field is optional. + childrenFieldTypes []*types.FieldType + // geners are used to generate data for children and geners[i] generates data for children[i]. + // If geners[i] is nil, the default dataGenerator will be used for its corresponding child. + // The geners slice can be shorter than the children slice, if it has 3 children, then + // geners[gen1, gen2] will be regarded as geners[gen1, gen2, nil]. + // This field is optional. + geners []dataGenerator + // aesModeAttr information, needed by encryption functions + aesModes string + // constants are used to generate constant data for children[i]. + constants []*Constant + // chunkSize is used to specify the chunk size of children, the maximum is 1024. + // This field is optional, 1024 by default. + chunkSize int +} + +type vecExprBenchCases map[string][]vecExprBenchCase + +func fillColumn(eType types.EvalType, chk *chunk.Chunk, colIdx int, testCase vecExprBenchCase) { + var gen dataGenerator + if len(testCase.geners) > colIdx && testCase.geners[colIdx] != nil { + gen = testCase.geners[colIdx] + } + fillColumnWithGener(eType, chk, colIdx, gen) +} + +func fillColumnWithGener(eType types.EvalType, chk *chunk.Chunk, colIdx int, gen dataGenerator) { + batchSize := chk.Capacity() + if gen == nil { + gen = newDefaultGener(0.2, eType) + } + + col := chk.Column(colIdx) + col.Reset(eType) + for i := 0; i < batchSize; i++ { + v := gen.gen() + if v == nil { + col.AppendNull() + continue + } + switch eType { + case types.ETInt: + col.AppendInt64(v.(int64)) + case types.ETReal: + col.AppendFloat64(v.(float64)) + case types.ETDecimal: + col.AppendMyDecimal(v.(*types.MyDecimal)) + case types.ETDatetime, types.ETTimestamp: + col.AppendTime(v.(types.Time)) + case types.ETDuration: + col.AppendDuration(v.(types.Duration)) + case types.ETJson: + col.AppendJSON(v.(json.BinaryJSON)) + case types.ETString: + col.AppendString(v.(string)) + } + } +} + +func randString(r *rand.Rand) string { + n := 10 + r.Intn(10) + buf := make([]byte, n) + for i := range buf { + x := r.Intn(62) + if x < 10 { + buf[i] = byte('0' + x) + } else if x-10 < 26 { + buf[i] = byte('a' + x - 10) + } else { + buf[i] = byte('A' + x - 10 - 26) + } + } + return string(buf) +} + +func eType2FieldType(eType types.EvalType) *types.FieldType { + switch eType { + case types.ETInt: + return types.NewFieldType(mysql.TypeLonglong) + case types.ETReal: + return types.NewFieldType(mysql.TypeDouble) + case types.ETDecimal: + return types.NewFieldType(mysql.TypeNewDecimal) + case types.ETDatetime, types.ETTimestamp: + return types.NewFieldType(mysql.TypeDatetime) + case types.ETDuration: + return types.NewFieldType(mysql.TypeDuration) + case types.ETJson: + return types.NewFieldType(mysql.TypeJSON) + case types.ETString: + return types.NewFieldType(mysql.TypeVarString) + default: + panic(fmt.Sprintf("EvalType=%v is not supported.", eType)) + } +} + +func genVecExprBenchCase(ctx sessionctx.Context, funcName string, testCase vecExprBenchCase) (expr Expression, fts []*types.FieldType, input *chunk.Chunk, output *chunk.Chunk) { + fts = make([]*types.FieldType, len(testCase.childrenTypes)) + for i := range fts { + if i < len(testCase.childrenFieldTypes) && testCase.childrenFieldTypes[i] != nil { + fts[i] = testCase.childrenFieldTypes[i] + } else { + fts[i] = eType2FieldType(testCase.childrenTypes[i]) + } + } + if testCase.chunkSize <= 0 || testCase.chunkSize > 1024 { + testCase.chunkSize = 1024 + } + cols := make([]Expression, len(testCase.childrenTypes)) + input = chunk.New(fts, testCase.chunkSize, testCase.chunkSize) + input.NumRows() + for i, eType := range testCase.childrenTypes { + fillColumn(eType, input, i, testCase) + if i < len(testCase.constants) && testCase.constants[i] != nil { + cols[i] = testCase.constants[i] + } else { + cols[i] = &Column{Index: i, RetType: fts[i]} + } + } + + expr, err := NewFunction(ctx, funcName, eType2FieldType(testCase.retEvalType), cols...) + if err != nil { + panic(err) + } + + output = chunk.New([]*types.FieldType{eType2FieldType(expr.GetType().EvalType())}, testCase.chunkSize, testCase.chunkSize) + return expr, fts, input, output +} + +// testVectorizedEvalOneVec is used to verify that the vectorized +// expression is evaluated correctly during projection +func testVectorizedEvalOneVec(c *C, vecExprCases vecExprBenchCases) { + ctx := mock.NewContext() + for funcName, testCases := range vecExprCases { + for _, testCase := range testCases { + expr, fts, input, output := genVecExprBenchCase(ctx, funcName, testCase) + commentf := func(row int) CommentInterface { + return Commentf("func: %v, case %+v, row: %v, rowData: %v", funcName, testCase, row, input.GetRow(row).GetDatumRow(fts)) + } + output2 := output.CopyConstruct() + c.Assert(evalOneVec(ctx, expr, input, output, 0), IsNil, Commentf("func: %v, case: %+v", funcName, testCase)) + it := chunk.NewIterator4Chunk(input) + c.Assert(evalOneColumn(ctx, expr, it, output2, 0), IsNil, Commentf("func: %v, case: %+v", funcName, testCase)) + + c1, c2 := output.Column(0), output2.Column(0) + switch expr.GetType().EvalType() { + case types.ETInt: + for i := 0; i < input.NumRows(); i++ { + c.Assert(c1.IsNull(i), Equals, c2.IsNull(i), commentf(i)) + if !c1.IsNull(i) { + c.Assert(c1.GetInt64(i), Equals, c2.GetInt64(i), commentf(i)) + } + } + case types.ETReal: + for i := 0; i < input.NumRows(); i++ { + c.Assert(c1.IsNull(i), Equals, c2.IsNull(i), commentf(i)) + if !c1.IsNull(i) { + c.Assert(c1.GetFloat64(i), Equals, c2.GetFloat64(i), commentf(i)) + } + } + case types.ETDecimal: + for i := 0; i < input.NumRows(); i++ { + c.Assert(c1.IsNull(i), Equals, c2.IsNull(i), commentf(i)) + if !c1.IsNull(i) { + c.Assert(c1.GetDecimal(i), DeepEquals, c2.GetDecimal(i), commentf(i)) + } + } + case types.ETDatetime, types.ETTimestamp: + for i := 0; i < input.NumRows(); i++ { + c.Assert(c1.IsNull(i), Equals, c2.IsNull(i), commentf(i)) + if !c1.IsNull(i) { + c.Assert(c1.GetTime(i), DeepEquals, c2.GetTime(i), commentf(i)) + } + } + case types.ETDuration: + for i := 0; i < input.NumRows(); i++ { + c.Assert(c1.IsNull(i), Equals, c2.IsNull(i), commentf(i)) + if !c1.IsNull(i) { + c.Assert(c1.GetDuration(i, 0), Equals, c2.GetDuration(i, 0), commentf(i)) + } + } + case types.ETJson: + for i := 0; i < input.NumRows(); i++ { + c.Assert(c1.IsNull(i), Equals, c2.IsNull(i), commentf(i)) + if !c1.IsNull(i) { + c.Assert(c1.GetJSON(i), DeepEquals, c2.GetJSON(i), commentf(i)) + } + } + case types.ETString: + for i := 0; i < input.NumRows(); i++ { + c.Assert(c1.IsNull(i), Equals, c2.IsNull(i), commentf(i)) + if !c1.IsNull(i) { + c.Assert(c1.GetString(i), Equals, c2.GetString(i), commentf(i)) + } + } + } + } + } +} + +// benchmarkVectorizedEvalOneVec is used to get the effect of +// using the vectorized expression evaluations during projection +func benchmarkVectorizedEvalOneVec(b *testing.B, vecExprCases vecExprBenchCases) { + ctx := mock.NewContext() + for funcName, testCases := range vecExprCases { + for _, testCase := range testCases { + expr, _, input, output := genVecExprBenchCase(ctx, funcName, testCase) + exprName := expr.String() + if sf, ok := expr.(*ScalarFunction); ok { + exprName = fmt.Sprintf("%v", reflect.TypeOf(sf.Function)) + tmp := strings.Split(exprName, ".") + exprName = tmp[len(tmp)-1] + } + + b.Run(exprName+"-EvalOneVec", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := evalOneVec(ctx, expr, input, output, 0); err != nil { + b.Fatal(err) + } + } + }) + b.Run(exprName+"-EvalOneCol", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + it := chunk.NewIterator4Chunk(input) + if err := evalOneColumn(ctx, expr, it, output, 0); err != nil { + b.Fatal(err) + } + } + }) + } + } +} + +func genVecBuiltinFuncBenchCase(ctx sessionctx.Context, funcName string, testCase vecExprBenchCase) (baseFunc builtinFunc, fts []*types.FieldType, input *chunk.Chunk, result *chunk.Column) { + childrenNumber := len(testCase.childrenTypes) + fts = make([]*types.FieldType, childrenNumber) + for i := range fts { + if i < len(testCase.childrenFieldTypes) && testCase.childrenFieldTypes[i] != nil { + fts[i] = testCase.childrenFieldTypes[i] + } else { + fts[i] = eType2FieldType(testCase.childrenTypes[i]) + } + } + cols := make([]Expression, childrenNumber) + if testCase.chunkSize <= 0 || testCase.chunkSize > 1024 { + testCase.chunkSize = 1024 + } + input = chunk.New(fts, testCase.chunkSize, testCase.chunkSize) + for i, eType := range testCase.childrenTypes { + fillColumn(eType, input, i, testCase) + if i < len(testCase.constants) && testCase.constants[i] != nil { + cols[i] = testCase.constants[i] + } else { + cols[i] = &Column{Index: i, RetType: fts[i]} + } + } + if len(cols) == 0 { + input.SetNumVirtualRows(testCase.chunkSize) + } + + var err error + if funcName == ast.Cast { + var fc functionClass + tp := eType2FieldType(testCase.retEvalType) + switch testCase.retEvalType { + case types.ETInt: + fc = &castAsIntFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + case types.ETDecimal: + fc = &castAsDecimalFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + case types.ETReal: + fc = &castAsRealFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + case types.ETDatetime, types.ETTimestamp: + fc = &castAsTimeFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + case types.ETDuration: + fc = &castAsDurationFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + case types.ETJson: + fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + case types.ETString: + fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + } + baseFunc, err = fc.getFunction(ctx, cols) + } else { + baseFunc, err = funcs[funcName].getFunction(ctx, cols) + } + if err != nil { + panic(err) + } + result = chunk.NewColumn(eType2FieldType(testCase.retEvalType), testCase.chunkSize) + // Mess up the output to make sure vecEvalXXX to call ResizeXXX/ReserveXXX itself. + result.AppendNull() + return baseFunc, fts, input, result +} + +// a hack way to calculate length of a chunk.Column. +func getColumnLen(col *chunk.Column, eType types.EvalType) int { + chk := chunk.New([]*types.FieldType{eType2FieldType(eType)}, 1024, 1024) + chk.SetCol(0, col) + return chk.NumRows() +} + +// removeTestOptions removes all not needed options like '-test.timeout=' from argument list +func removeTestOptions(args []string) []string { + argList := args[:0] + + // args contains '-test.timeout=' option for example + // excluding it to be able to run all tests + for _, arg := range args { + if strings.HasPrefix(arg, "builtin") || IsFunctionSupported(arg) { + argList = append(argList, arg) + } + } + return argList +} + +// testVectorizedBuiltinFunc is used to verify that the vectorized +// expression is evaluated correctly +func testVectorizedBuiltinFunc(c *C, vecExprCases vecExprBenchCases) { + testFunc := make(map[string]bool) + argList := removeTestOptions(flag.Args()) + testAll := len(argList) == 0 + for _, arg := range argList { + testFunc[arg] = true + } + for funcName, testCases := range vecExprCases { + for _, testCase := range testCases { + ctx := mock.NewContext() + err := ctx.GetSessionVars().SetSystemVar(variable.BlockEncryptionMode, testCase.aesModes) + c.Assert(err, IsNil) + if funcName == ast.CurrentUser || funcName == ast.User { + ctx.GetSessionVars().User = &auth.UserIdentity{ + Username: "tidb", + Hostname: "localhost", + CurrentUser: true, + AuthHostname: "localhost", + AuthUsername: "tidb", + } + } + if funcName == ast.GetParam { + testTime := time.Now() + ctx.GetSessionVars().PreparedParams = []types.Datum{ + types.NewIntDatum(1), + types.NewDecimalDatum(types.NewDecFromStringForTest("20170118123950.123")), + types.NewTimeDatum(types.NewTime(types.FromGoTime(testTime), mysql.TypeTimestamp, 6)), + types.NewDurationDatum(types.ZeroDuration), + types.NewStringDatum("{}"), + types.NewBinaryLiteralDatum(types.BinaryLiteral([]byte{1})), + types.NewBytesDatum([]byte{'b'}), + types.NewFloat32Datum(1.1), + types.NewFloat64Datum(2.1), + types.NewUintDatum(100), + types.NewMysqlBitDatum(types.BinaryLiteral([]byte{1})), + types.NewMysqlEnumDatum(types.Enum{Name: "n", Value: 2}), + } + } + baseFunc, fts, input, output := genVecBuiltinFuncBenchCase(ctx, funcName, testCase) + baseFuncName := fmt.Sprintf("%v", reflect.TypeOf(baseFunc)) + tmp := strings.Split(baseFuncName, ".") + baseFuncName = tmp[len(tmp)-1] + + if !testAll && (testFunc[baseFuncName] != true && testFunc[funcName] != true) { + continue + } + // do not forget to implement the vectorized method. + c.Assert(baseFunc.vectorized(), IsTrue, Commentf("func: %v, case: %+v", baseFuncName, testCase)) + commentf := func(row int) CommentInterface { + return Commentf("func: %v, case %+v, row: %v, rowData: %v", baseFuncName, testCase, row, input.GetRow(row).GetDatumRow(fts)) + } + it := chunk.NewIterator4Chunk(input) + i := 0 + var vecWarnCnt uint16 + switch testCase.retEvalType { + case types.ETInt: + err := baseFunc.vecEvalInt(input, output) + c.Assert(err, IsNil, Commentf("func: %v, case: %+v", baseFuncName, testCase)) + // do not forget to call ResizeXXX/ReserveXXX + c.Assert(getColumnLen(output, testCase.retEvalType), Equals, input.NumRows()) + vecWarnCnt = ctx.GetSessionVars().StmtCtx.WarningCount() + i64s := output.Int64s() + for row := it.Begin(); row != it.End(); row = it.Next() { + val, isNull, err := baseFunc.evalInt(row) + c.Assert(err, IsNil, commentf(i)) + c.Assert(isNull, Equals, output.IsNull(i), commentf(i)) + if !isNull { + c.Assert(val, Equals, i64s[i], commentf(i)) + } + i++ + } + case types.ETReal: + err := baseFunc.vecEvalReal(input, output) + c.Assert(err, IsNil, Commentf("func: %v, case: %+v", baseFuncName, testCase)) + // do not forget to call ResizeXXX/ReserveXXX + c.Assert(getColumnLen(output, testCase.retEvalType), Equals, input.NumRows()) + vecWarnCnt = ctx.GetSessionVars().StmtCtx.WarningCount() + f64s := output.Float64s() + for row := it.Begin(); row != it.End(); row = it.Next() { + val, isNull, err := baseFunc.evalReal(row) + c.Assert(err, IsNil, commentf(i)) + c.Assert(isNull, Equals, output.IsNull(i), commentf(i)) + if !isNull { + c.Assert(val, Equals, f64s[i], commentf(i)) + } + i++ + } + case types.ETDecimal: + err := baseFunc.vecEvalDecimal(input, output) + c.Assert(err, IsNil, Commentf("func: %v, case: %+v", baseFuncName, testCase)) + // do not forget to call ResizeXXX/ReserveXXX + c.Assert(getColumnLen(output, testCase.retEvalType), Equals, input.NumRows()) + vecWarnCnt = ctx.GetSessionVars().StmtCtx.WarningCount() + d64s := output.Decimals() + for row := it.Begin(); row != it.End(); row = it.Next() { + val, isNull, err := baseFunc.evalDecimal(row) + c.Assert(err, IsNil, commentf(i)) + c.Assert(isNull, Equals, output.IsNull(i), commentf(i)) + if !isNull { + c.Assert(*val, Equals, d64s[i], commentf(i)) + } + i++ + } + case types.ETDatetime, types.ETTimestamp: + err := baseFunc.vecEvalTime(input, output) + c.Assert(err, IsNil, Commentf("func: %v, case: %+v", baseFuncName, testCase)) + // do not forget to call ResizeXXX/ReserveXXX + c.Assert(getColumnLen(output, testCase.retEvalType), Equals, input.NumRows()) + vecWarnCnt = ctx.GetSessionVars().StmtCtx.WarningCount() + t64s := output.Times() + for row := it.Begin(); row != it.End(); row = it.Next() { + val, isNull, err := baseFunc.evalTime(row) + c.Assert(err, IsNil, commentf(i)) + c.Assert(isNull, Equals, output.IsNull(i), commentf(i)) + if !isNull { + c.Assert(val, Equals, t64s[i], commentf(i)) + } + i++ + } + case types.ETDuration: + err := baseFunc.vecEvalDuration(input, output) + c.Assert(err, IsNil, Commentf("func: %v, case: %+v", baseFuncName, testCase)) + // do not forget to call ResizeXXX/ReserveXXX + c.Assert(getColumnLen(output, testCase.retEvalType), Equals, input.NumRows()) + vecWarnCnt = ctx.GetSessionVars().StmtCtx.WarningCount() + d64s := output.GoDurations() + for row := it.Begin(); row != it.End(); row = it.Next() { + val, isNull, err := baseFunc.evalDuration(row) + c.Assert(err, IsNil, commentf(i)) + c.Assert(isNull, Equals, output.IsNull(i), commentf(i)) + if !isNull { + c.Assert(val.Duration, Equals, d64s[i], commentf(i)) + } + i++ + } + case types.ETJson: + err := baseFunc.vecEvalJSON(input, output) + c.Assert(err, IsNil, Commentf("func: %v, case: %+v", baseFuncName, testCase)) + // do not forget to call ResizeXXX/ReserveXXX + c.Assert(getColumnLen(output, testCase.retEvalType), Equals, input.NumRows()) + vecWarnCnt = ctx.GetSessionVars().StmtCtx.WarningCount() + for row := it.Begin(); row != it.End(); row = it.Next() { + val, isNull, err := baseFunc.evalJSON(row) + c.Assert(err, IsNil, commentf(i)) + c.Assert(isNull, Equals, output.IsNull(i), commentf(i)) + if !isNull { + cmp := json.CompareBinary(val, output.GetJSON(i)) + c.Assert(cmp, Equals, 0, commentf(i)) + } + i++ + } + case types.ETString: + err := baseFunc.vecEvalString(input, output) + c.Assert(err, IsNil, Commentf("func: %v, case: %+v", baseFuncName, testCase)) + // do not forget to call ResizeXXX/ReserveXXX + c.Assert(getColumnLen(output, testCase.retEvalType), Equals, input.NumRows()) + vecWarnCnt = ctx.GetSessionVars().StmtCtx.WarningCount() + for row := it.Begin(); row != it.End(); row = it.Next() { + val, isNull, err := baseFunc.evalString(row) + c.Assert(err, IsNil, commentf(i)) + c.Assert(isNull, Equals, output.IsNull(i), commentf(i)) + if !isNull { + c.Assert(val, Equals, output.GetString(i), commentf(i)) + } + i++ + } + default: + c.Fatal(fmt.Sprintf("evalType=%v is not supported", testCase.retEvalType)) + } + + // check warnings + totalWarns := ctx.GetSessionVars().StmtCtx.WarningCount() + c.Assert(2*vecWarnCnt, Equals, totalWarns) + warns := ctx.GetSessionVars().StmtCtx.GetWarnings() + for i := 0; i < int(vecWarnCnt); i++ { + c.Assert(terror.ErrorEqual(warns[i].Err, warns[i+int(vecWarnCnt)].Err), IsTrue) + } + } + } +} + +// testVectorizedBuiltinFuncForRand is used to verify that the vectorized +// expression is evaluated correctly +func testVectorizedBuiltinFuncForRand(c *C, vecExprCases vecExprBenchCases) { + for funcName, testCases := range vecExprCases { + c.Assert(strings.EqualFold("rand", funcName), Equals, true) + + for _, testCase := range testCases { + c.Assert(len(testCase.childrenTypes), Equals, 0) + + ctx := mock.NewContext() + baseFunc, _, input, output := genVecBuiltinFuncBenchCase(ctx, funcName, testCase) + baseFuncName := fmt.Sprintf("%v", reflect.TypeOf(baseFunc)) + tmp := strings.Split(baseFuncName, ".") + baseFuncName = tmp[len(tmp)-1] + // do not forget to implement the vectorized method. + c.Assert(baseFunc.vectorized(), IsTrue, Commentf("func: %v", baseFuncName)) + switch testCase.retEvalType { + case types.ETReal: + err := baseFunc.vecEvalReal(input, output) + c.Assert(err, IsNil) + // do not forget to call ResizeXXX/ReserveXXX + c.Assert(getColumnLen(output, testCase.retEvalType), Equals, input.NumRows()) + // check result + res := output.Float64s() + for _, v := range res { + c.Assert((0 <= v) && (v < 1), Equals, true) + } + default: + c.Fatal(fmt.Sprintf("evalType=%v is not supported", testCase.retEvalType)) + } + } + } +} + +// benchmarkVectorizedBuiltinFunc is used to get the effect of +// using the vectorized expression evaluations +func benchmarkVectorizedBuiltinFunc(b *testing.B, vecExprCases vecExprBenchCases) { + ctx := mock.NewContext() + testFunc := make(map[string]bool) + argList := removeTestOptions(flag.Args()) + testAll := len(argList) == 0 + for _, arg := range argList { + testFunc[arg] = true + } + for funcName, testCases := range vecExprCases { + for _, testCase := range testCases { + err := ctx.GetSessionVars().SetSystemVar(variable.BlockEncryptionMode, testCase.aesModes) + if err != nil { + panic(err) + } + if funcName == ast.CurrentUser || funcName == ast.User { + ctx.GetSessionVars().User = &auth.UserIdentity{ + Username: "tidb", + Hostname: "localhost", + CurrentUser: true, + AuthHostname: "localhost", + AuthUsername: "tidb", + } + } + if funcName == ast.GetParam { + testTime := time.Now() + ctx.GetSessionVars().PreparedParams = []types.Datum{ + types.NewIntDatum(1), + types.NewDecimalDatum(types.NewDecFromStringForTest("20170118123950.123")), + types.NewTimeDatum(types.NewTime(types.FromGoTime(testTime), mysql.TypeTimestamp, 6)), + types.NewDurationDatum(types.ZeroDuration), + types.NewStringDatum("{}"), + types.NewBinaryLiteralDatum(types.BinaryLiteral([]byte{1})), + types.NewBytesDatum([]byte{'b'}), + types.NewFloat32Datum(1.1), + types.NewFloat64Datum(2.1), + types.NewUintDatum(100), + types.NewMysqlBitDatum(types.BinaryLiteral([]byte{1})), + types.NewMysqlEnumDatum(types.Enum{Name: "n", Value: 2}), + } + } + baseFunc, _, input, output := genVecBuiltinFuncBenchCase(ctx, funcName, testCase) + baseFuncName := fmt.Sprintf("%v", reflect.TypeOf(baseFunc)) + tmp := strings.Split(baseFuncName, ".") + baseFuncName = tmp[len(tmp)-1] + + if !testAll && testFunc[baseFuncName] != true && testFunc[funcName] != true { + continue + } + + b.Run(baseFuncName+"-VecBuiltinFunc", func(b *testing.B) { + b.ResetTimer() + switch testCase.retEvalType { + case types.ETInt: + for i := 0; i < b.N; i++ { + if err := baseFunc.vecEvalInt(input, output); err != nil { + b.Fatal(err) + } + } + case types.ETReal: + for i := 0; i < b.N; i++ { + if err := baseFunc.vecEvalReal(input, output); err != nil { + b.Fatal(err) + } + } + case types.ETDecimal: + for i := 0; i < b.N; i++ { + if err := baseFunc.vecEvalDecimal(input, output); err != nil { + b.Fatal(err) + } + } + case types.ETDatetime, types.ETTimestamp: + for i := 0; i < b.N; i++ { + if err := baseFunc.vecEvalTime(input, output); err != nil { + b.Fatal(err) + } + } + case types.ETDuration: + for i := 0; i < b.N; i++ { + if err := baseFunc.vecEvalDuration(input, output); err != nil { + b.Fatal(err) + } + } + case types.ETJson: + for i := 0; i < b.N; i++ { + if err := baseFunc.vecEvalJSON(input, output); err != nil { + b.Fatal(err) + } + } + case types.ETString: + for i := 0; i < b.N; i++ { + if err := baseFunc.vecEvalString(input, output); err != nil { + b.Fatal(err) + } + } + default: + b.Fatal(fmt.Sprintf("evalType=%v is not supported", testCase.retEvalType)) + } + }) + b.Run(baseFuncName+"-NonVecBuiltinFunc", func(b *testing.B) { + b.ResetTimer() + it := chunk.NewIterator4Chunk(input) + switch testCase.retEvalType { + case types.ETInt: + for i := 0; i < b.N; i++ { + output.Reset(testCase.retEvalType) + for row := it.Begin(); row != it.End(); row = it.Next() { + v, isNull, err := baseFunc.evalInt(row) + if err != nil { + b.Fatal(err) + } + if isNull { + output.AppendNull() + } else { + output.AppendInt64(v) + } + } + } + case types.ETReal: + for i := 0; i < b.N; i++ { + output.Reset(testCase.retEvalType) + for row := it.Begin(); row != it.End(); row = it.Next() { + v, isNull, err := baseFunc.evalReal(row) + if err != nil { + b.Fatal(err) + } + if isNull { + output.AppendNull() + } else { + output.AppendFloat64(v) + } + } + } + case types.ETDecimal: + for i := 0; i < b.N; i++ { + output.Reset(testCase.retEvalType) + for row := it.Begin(); row != it.End(); row = it.Next() { + v, isNull, err := baseFunc.evalDecimal(row) + if err != nil { + b.Fatal(err) + } + if isNull { + output.AppendNull() + } else { + output.AppendMyDecimal(v) + } + } + } + case types.ETDatetime, types.ETTimestamp: + for i := 0; i < b.N; i++ { + output.Reset(testCase.retEvalType) + for row := it.Begin(); row != it.End(); row = it.Next() { + v, isNull, err := baseFunc.evalTime(row) + if err != nil { + b.Fatal(err) + } + if isNull { + output.AppendNull() + } else { + output.AppendTime(v) + } + } + } + case types.ETDuration: + for i := 0; i < b.N; i++ { + output.Reset(testCase.retEvalType) + for row := it.Begin(); row != it.End(); row = it.Next() { + v, isNull, err := baseFunc.evalDuration(row) + if err != nil { + b.Fatal(err) + } + if isNull { + output.AppendNull() + } else { + output.AppendDuration(v) + } + } + } + case types.ETJson: + for i := 0; i < b.N; i++ { + output.Reset(testCase.retEvalType) + for row := it.Begin(); row != it.End(); row = it.Next() { + v, isNull, err := baseFunc.evalJSON(row) + if err != nil { + b.Fatal(err) + } + if isNull { + output.AppendNull() + } else { + output.AppendJSON(v) + } + } + } + case types.ETString: + for i := 0; i < b.N; i++ { + output.Reset(testCase.retEvalType) + for row := it.Begin(); row != it.End(); row = it.Next() { + v, isNull, err := baseFunc.evalString(row) + if err != nil { + b.Fatal(err) + } + if isNull { + output.AppendNull() + } else { + output.AppendString(v) + } + } + } + default: + b.Fatal(fmt.Sprintf("evalType=%v is not supported", testCase.retEvalType)) + } + }) + } + } +} + +func genVecEvalBool(numCols int, colTypes, eTypes []types.EvalType) (CNFExprs, *chunk.Chunk) { + gens := make([]dataGenerator, 0, len(eTypes)) + for _, eType := range eTypes { + if eType == types.ETString { + gens = append(gens, &numStrGener{*newRangeInt64Gener(0, 10)}) + } else { + gens = append(gens, newDefaultGener(0.05, eType)) + } + } + + ts := make([]types.EvalType, 0, numCols) + gs := make([]dataGenerator, 0, numCols) + fts := make([]*types.FieldType, 0, numCols) + randGen := newDefaultRandGen() + for i := 0; i < numCols; i++ { + idx := randGen.Intn(len(eTypes)) + if colTypes != nil { + for j := range eTypes { + if colTypes[i] == eTypes[j] { + idx = j + break + } + } + } + ts = append(ts, eTypes[idx]) + gs = append(gs, gens[idx]) + fts = append(fts, eType2FieldType(eTypes[idx])) + } + + input := chunk.New(fts, 1024, 1024) + exprs := make(CNFExprs, 0, numCols) + for i := 0; i < numCols; i++ { + fillColumn(ts[i], input, i, vecExprBenchCase{geners: gs}) + exprs = append(exprs, &Column{Index: i, RetType: fts[i]}) + } + return exprs, input +} + +func generateRandomSel() []int { + randGen := newDefaultRandGen() + randGen.Seed(time.Now().UnixNano()) + var sel []int + count := 0 + // Use constant 256 to make it faster to generate randomly arranged sel slices + num := randGen.Intn(256) + 1 + existed := make([]bool, 1024) + for i := 0; i < 1024; i++ { + existed[i] = false + } + for count < num { + val := randGen.Intn(1024) + if !existed[val] { + existed[val] = true + count++ + } + } + for i := 0; i < 1024; i++ { + if existed[i] { + sel = append(sel, i) + } + } + return sel +} + +func (s *testVectorizeSuite2) TestVecEvalBool(c *C) { + ctx := mock.NewContext() + eTypes := []types.EvalType{types.ETReal, types.ETDecimal, types.ETString, types.ETTimestamp, types.ETDatetime, types.ETDuration} + for numCols := 1; numCols <= 5; numCols++ { + for round := 0; round < 16; round++ { + exprs, input := genVecEvalBool(numCols, nil, eTypes) + selected, nulls, err := VecEvalBool(ctx, exprs, input, nil, nil) + c.Assert(err, IsNil) + it := chunk.NewIterator4Chunk(input) + i := 0 + for row := it.Begin(); row != it.End(); row = it.Next() { + ok, null, err := EvalBool(mock.NewContext(), exprs, row) + c.Assert(err, IsNil) + c.Assert(null, Equals, nulls[i]) + c.Assert(ok, Equals, selected[i]) + i++ + } + } + } +} + +func BenchmarkVecEvalBool(b *testing.B) { + ctx := mock.NewContext() + selected := make([]bool, 0, 1024) + nulls := make([]bool, 0, 1024) + eTypes := []types.EvalType{types.ETInt, types.ETReal, types.ETDecimal, types.ETString, types.ETTimestamp, types.ETDatetime, types.ETDuration} + tNames := []string{"int", "real", "decimal", "string", "timestamp", "datetime", "duration"} + for numCols := 1; numCols <= 2; numCols++ { + typeCombination := make([]types.EvalType, numCols) + var combFunc func(nCols int) + combFunc = func(nCols int) { + if nCols == 0 { + name := "" + for _, t := range typeCombination { + for i := range eTypes { + if t == eTypes[i] { + name += tNames[t] + "/" + } + } + } + exprs, input := genVecEvalBool(numCols, typeCombination, eTypes) + b.Run("Vec-"+name, func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := VecEvalBool(ctx, exprs, input, selected, nulls) + if err != nil { + b.Fatal(err) + } + } + }) + b.Run("Row-"+name, func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + it := chunk.NewIterator4Chunk(input) + for row := it.Begin(); row != it.End(); row = it.Next() { + _, _, err := EvalBool(ctx, exprs, row) + if err != nil { + b.Fatal(err) + } + } + } + }) + return + } + for _, eType := range eTypes { + typeCombination[nCols-1] = eType + combFunc(nCols - 1) + } + } + + combFunc(numCols) + } +} + +func (s *testVectorizeSuite2) TestRowBasedFilterAndVectorizedFilter(c *C) { + ctx := mock.NewContext() + eTypes := []types.EvalType{types.ETInt, types.ETReal, types.ETDecimal, types.ETString, types.ETTimestamp, types.ETDatetime, types.ETDuration} + for numCols := 1; numCols <= 5; numCols++ { + for round := 0; round < 16; round++ { + exprs, input := genVecEvalBool(numCols, nil, eTypes) + it := chunk.NewIterator4Chunk(input) + isNull := make([]bool, it.Len()) + selected, nulls, err := rowBasedFilter(ctx, exprs, it, nil, isNull) + c.Assert(err, IsNil) + selected2, nulls2, err2 := vectorizedFilter(ctx, exprs, it, nil, isNull) + c.Assert(err2, IsNil) + length := it.Len() + for i := 0; i < length; i++ { + c.Assert(nulls2[i], Equals, nulls[i]) + c.Assert(selected2[i], Equals, selected[i]) + } + } + } +} + +func BenchmarkRowBasedFilterAndVectorizedFilter(b *testing.B) { + ctx := mock.NewContext() + selected := make([]bool, 0, 1024) + nulls := make([]bool, 0, 1024) + eTypes := []types.EvalType{types.ETInt, types.ETReal, types.ETDecimal, types.ETString, types.ETTimestamp, types.ETDatetime, types.ETDuration} + tNames := []string{"int", "real", "decimal", "string", "timestamp", "datetime", "duration"} + for numCols := 1; numCols <= 2; numCols++ { + typeCombination := make([]types.EvalType, numCols) + var combFunc func(nCols int) + combFunc = func(nCols int) { + if nCols == 0 { + name := "" + for _, t := range typeCombination { + for i := range eTypes { + if t == eTypes[i] { + name += tNames[t] + "/" + } + } + } + exprs, input := genVecEvalBool(numCols, typeCombination, eTypes) + it := chunk.NewIterator4Chunk(input) + b.Run("Vec-"+name, func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := vectorizedFilter(ctx, exprs, it, selected, nulls) + if err != nil { + b.Fatal(err) + } + } + }) + b.Run("Row-"+name, func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := rowBasedFilter(ctx, exprs, it, selected, nulls) + if err != nil { + b.Fatal(err) + } + } + }) + return + } + for _, eType := range eTypes { + typeCombination[nCols-1] = eType + combFunc(nCols - 1) + } + } + combFunc(numCols) + } + + // Add special case to prove when some calculations are added, + // the vectorizedFilter for int types will be more faster than rowBasedFilter. + funcName := ast.Least + testCase := vecExprBenchCase{retEvalType: types.ETInt, childrenTypes: []types.EvalType{types.ETInt, types.ETInt}} + expr, _, input, _ := genVecExprBenchCase(ctx, funcName, testCase) + it := chunk.NewIterator4Chunk(input) + + b.Run("Vec-special case", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := vectorizedFilter(ctx, []Expression{expr}, it, selected, nulls) + if err != nil { + panic(err) + } + } + }) + b.Run("Row-special case", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := rowBasedFilter(ctx, []Expression{expr}, it, selected, nulls) + if err != nil { + panic(err) + } + } + }) +} + +func (s *testVectorizeSuite2) TestVectorizedFilterConsiderNull(c *C) { + ctx := mock.NewContext() + dafaultEnableVectorizedExpressionVar := ctx.GetSessionVars().EnableVectorizedExpression + eTypes := []types.EvalType{types.ETInt, types.ETReal, types.ETDecimal, types.ETString, types.ETTimestamp, types.ETDatetime, types.ETDuration} + for numCols := 1; numCols <= 5; numCols++ { + for round := 0; round < 16; round++ { + exprs, input := genVecEvalBool(numCols, nil, eTypes) + it := chunk.NewIterator4Chunk(input) + isNull := make([]bool, it.Len()) + ctx.GetSessionVars().EnableVectorizedExpression = false + selected, nulls, err := VectorizedFilterConsiderNull(ctx, exprs, it, nil, isNull) + c.Assert(err, IsNil) + ctx.GetSessionVars().EnableVectorizedExpression = true + selected2, nulls2, err2 := VectorizedFilterConsiderNull(ctx, exprs, it, nil, isNull) + c.Assert(err2, IsNil) + length := it.Len() + for i := 0; i < length; i++ { + c.Assert(nulls2[i], Equals, nulls[i]) + c.Assert(selected2[i], Equals, selected[i]) + } + + // add test which sel is not nil + randomSel := generateRandomSel() + input.SetSel(randomSel) + it2 := chunk.NewIterator4Chunk(input) + isNull = isNull[:0] + ctx.GetSessionVars().EnableVectorizedExpression = false + selected3, nulls, err := VectorizedFilterConsiderNull(ctx, exprs, it2, nil, isNull) + c.Assert(err, IsNil) + ctx.GetSessionVars().EnableVectorizedExpression = true + selected4, nulls2, err2 := VectorizedFilterConsiderNull(ctx, exprs, it2, nil, isNull) + c.Assert(err2, IsNil) + for i := 0; i < length; i++ { + c.Assert(nulls2[i], Equals, nulls[i]) + c.Assert(selected4[i], Equals, selected3[i]) + } + + unselected := make([]bool, length) + // unselected[i] == false means that the i-th row is selected + for i := 0; i < length; i++ { + unselected[i] = true + } + for _, idx := range randomSel { + unselected[idx] = false + } + for i := range selected2 { + if selected2[i] && unselected[i] { + selected2[i] = false + } + } + for i := 0; i < length; i++ { + c.Assert(selected2[i], Equals, selected4[i]) + } + } + } + ctx.GetSessionVars().EnableVectorizedExpression = dafaultEnableVectorizedExpressionVar +} +>>>>>>> bdbdbae... expression: fix the issue that incorrect result for a predicat… (#16014) diff --git a/expression/expression.go b/expression/expression.go index a543d5690b51f..b6b9cca9abc81 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -174,6 +174,356 @@ func EvalBool(ctx sessionctx.Context, exprList CNFExprs, row chunk.Row) (bool, b return true, false, nil } +<<<<<<< HEAD +======= +var ( + defaultChunkSize = 1024 + selPool = sync.Pool{ + New: func() interface{} { + return make([]int, defaultChunkSize) + }, + } + zeroPool = sync.Pool{ + New: func() interface{} { + return make([]int8, defaultChunkSize) + }, + } +) + +func allocSelSlice(n int) []int { + if n > defaultChunkSize { + return make([]int, n) + } + return selPool.Get().([]int) +} + +func deallocateSelSlice(sel []int) { + if cap(sel) <= defaultChunkSize { + selPool.Put(sel) + } +} + +func allocZeroSlice(n int) []int8 { + if n > defaultChunkSize { + return make([]int8, n) + } + return zeroPool.Get().([]int8) +} + +func deallocateZeroSlice(isZero []int8) { + if cap(isZero) <= defaultChunkSize { + zeroPool.Put(isZero) + } +} + +// VecEvalBool does the same thing as EvalBool but it works in a vectorized manner. +func VecEvalBool(ctx sessionctx.Context, exprList CNFExprs, input *chunk.Chunk, selected, nulls []bool) ([]bool, []bool, error) { + // If input.Sel() != nil, we will call input.SetSel(nil) to clear the sel slice in input chunk. + // After the function finished, then we reset the input.Sel(). + // The caller will handle the input.Sel() and selected slices. + defer input.SetSel(input.Sel()) + input.SetSel(nil) + + n := input.NumRows() + selected = selected[:0] + nulls = nulls[:0] + for i := 0; i < n; i++ { + selected = append(selected, false) + nulls = append(nulls, false) + } + + sel := allocSelSlice(n) + defer deallocateSelSlice(sel) + sel = sel[:0] + for i := 0; i < n; i++ { + sel = append(sel, i) + } + input.SetSel(sel) + + // In isZero slice, -1 means Null, 0 means zero, 1 means not zero + isZero := allocZeroSlice(n) + defer deallocateZeroSlice(isZero) + for _, expr := range exprList { + eType := expr.GetType().EvalType() + buf, err := globalColumnAllocator.get(eType, n) + if err != nil { + return nil, nil, err + } + + if err := EvalExpr(ctx, expr, input, buf); err != nil { + return nil, nil, err + } + + err = toBool(ctx.GetSessionVars().StmtCtx, eType, buf, sel, isZero) + if err != nil { + return nil, nil, err + } + + j := 0 + isEQCondFromIn := IsEQCondFromIn(expr) + for i := range sel { + if isZero[i] == -1 { + if eType != types.ETInt && !isEQCondFromIn { + continue + } + // In this case, we set this row to null and let it pass this filter. + // The null flag may be set to false later by other expressions in some cases. + nulls[sel[i]] = true + sel[j] = sel[i] + j++ + continue + } + + if isZero[i] == 0 { + continue + } + sel[j] = sel[i] // this row passes this filter + j++ + } + sel = sel[:j] + input.SetSel(sel) + globalColumnAllocator.put(buf) + } + + for _, i := range sel { + if !nulls[i] { + selected[i] = true + } + } + + return selected, nulls, nil +} + +func toBool(sc *stmtctx.StatementContext, eType types.EvalType, buf *chunk.Column, sel []int, isZero []int8) error { + switch eType { + case types.ETInt: + i64s := buf.Int64s() + for i := range sel { + if buf.IsNull(i) { + isZero[i] = -1 + } else { + if i64s[i] == 0 { + isZero[i] = 0 + } else { + isZero[i] = 1 + } + } + } + case types.ETReal: + f64s := buf.Float64s() + for i := range sel { + if buf.IsNull(i) { + isZero[i] = -1 + } else { + if types.RoundFloat(f64s[i]) == 0 { + isZero[i] = 0 + } else { + isZero[i] = 1 + } + } + } + case types.ETDuration: + d64s := buf.GoDurations() + for i := range sel { + if buf.IsNull(i) { + isZero[i] = -1 + } else { + if d64s[i] == 0 { + isZero[i] = 0 + } else { + isZero[i] = 1 + } + } + } + case types.ETDatetime, types.ETTimestamp: + t64s := buf.Times() + for i := range sel { + if buf.IsNull(i) { + isZero[i] = -1 + } else { + if t64s[i].IsZero() { + isZero[i] = 0 + } else { + isZero[i] = 1 + } + } + } + case types.ETString: + for i := range sel { + if buf.IsNull(i) { + isZero[i] = -1 + } else { + iVal, err := types.StrToFloat(sc, buf.GetString(i)) + if err != nil { + return err + } + if iVal == 0 { + isZero[i] = 0 + } else { + isZero[i] = 1 + } + } + } + case types.ETDecimal: + d64s := buf.Decimals() + for i := range sel { + if buf.IsNull(i) { + isZero[i] = -1 + } else { + v, err := d64s[i].ToFloat64() + if err != nil { + return err + } + if types.RoundFloat(v) == 0 { + isZero[i] = 0 + } else { + isZero[i] = 1 + } + } + } + case types.ETJson: + return errors.Errorf("cannot convert type json.BinaryJSON to bool") + } + return nil +} + +// EvalExpr evaluates this expr according to its type. +// And it selects the method for evaluating expression based on +// the environment variables and whether the expression can be vectorized. +func EvalExpr(ctx sessionctx.Context, expr Expression, input *chunk.Chunk, result *chunk.Column) (err error) { + evalType := expr.GetType().EvalType() + if expr.Vectorized() && ctx.GetSessionVars().EnableVectorizedExpression { + switch evalType { + case types.ETInt: + err = expr.VecEvalInt(ctx, input, result) + case types.ETReal: + err = expr.VecEvalReal(ctx, input, result) + case types.ETDuration: + err = expr.VecEvalDuration(ctx, input, result) + case types.ETDatetime, types.ETTimestamp: + err = expr.VecEvalTime(ctx, input, result) + case types.ETString: + err = expr.VecEvalString(ctx, input, result) + case types.ETJson: + err = expr.VecEvalJSON(ctx, input, result) + case types.ETDecimal: + err = expr.VecEvalDecimal(ctx, input, result) + default: + err = errors.New(fmt.Sprintf("invalid eval type %v", expr.GetType().EvalType())) + } + } else { + ind, n := 0, input.NumRows() + iter := chunk.NewIterator4Chunk(input) + switch evalType { + case types.ETInt: + result.ResizeInt64(n, false) + i64s := result.Int64s() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalInt(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + i64s[ind] = value + } + ind++ + } + case types.ETReal: + result.ResizeFloat64(n, false) + f64s := result.Float64s() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalReal(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + f64s[ind] = value + } + ind++ + } + case types.ETDuration: + result.ResizeGoDuration(n, false) + d64s := result.GoDurations() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalDuration(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + d64s[ind] = value.Duration + } + ind++ + } + case types.ETDatetime, types.ETTimestamp: + result.ResizeTime(n, false) + t64s := result.Times() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalTime(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + t64s[ind] = value + } + ind++ + } + case types.ETString: + result.ReserveString(n) + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalString(ctx, it) + if err != nil { + return err + } + if isNull { + result.AppendNull() + } else { + result.AppendString(value) + } + } + case types.ETJson: + result.ReserveJSON(n) + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalJSON(ctx, it) + if err != nil { + return err + } + if isNull { + result.AppendNull() + } else { + result.AppendJSON(value) + } + } + case types.ETDecimal: + result.ResizeDecimal(n, false) + d64s := result.Decimals() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalDecimal(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + d64s[ind] = *value + } + ind++ + } + default: + err = errors.New(fmt.Sprintf("invalid eval type %v", expr.GetType().EvalType())) + } + } + return +} + +>>>>>>> bdbdbae... expression: fix the issue that incorrect result for a predicat… (#16014) // composeConditionWithBinaryOp composes condition with binary operator into a balance deep tree, which benefits a lot for pb decoder/encoder. func composeConditionWithBinaryOp(ctx sessionctx.Context, conditions []Expression, funcName string) Expression { length := len(conditions) diff --git a/expression/integration_test.go b/expression/integration_test.go index 0dd12a96e8480..1f9568cef9c6d 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -4609,6 +4609,620 @@ func (s *testIntegrationSuite) TestCacheRefineArgs(c *C) { tk.MustQuery("execute stmt using @p0").Check(testkit.Rows("0")) } +<<<<<<< HEAD +======= +func (s *testIntegrationSuite) TestCollation(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (utf8_bin_c varchar(10) charset utf8 collate utf8_bin, utf8_gen_c varchar(10) charset utf8 collate utf8_general_ci, bin_c binary, num_c int, " + + "abin char collate ascii_bin, lbin char collate latin1_bin, u4bin char collate utf8mb4_bin, u4ci char collate utf8mb4_general_ci)") + tk.MustExec("insert into t values ('a', 'b', 'c', 4, 'a', 'a', 'a', 'a')") + tk.MustQuery("select collation(null)").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(2)").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(2 + 'a')").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(2 + utf8_gen_c) from t").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(2 + utf8_bin_c) from t").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(concat(utf8_bin_c, 2)) from t").Check(testkit.Rows("utf8_bin")) + tk.MustQuery("select collation(concat(utf8_gen_c, 'abc')) from t").Check(testkit.Rows("utf8_general_ci")) + tk.MustQuery("select collation(concat(utf8_gen_c, null)) from t").Check(testkit.Rows("utf8_general_ci")) + tk.MustQuery("select collation(concat(utf8_gen_c, num_c)) from t").Check(testkit.Rows("utf8_general_ci")) + tk.MustQuery("select collation(concat(utf8_bin_c, utf8_gen_c)) from t").Check(testkit.Rows("utf8_bin")) + tk.MustQuery("select collation(upper(utf8_bin_c)) from t").Check(testkit.Rows("utf8_bin")) + tk.MustQuery("select collation(upper(utf8_gen_c)) from t").Check(testkit.Rows("utf8_general_ci")) + tk.MustQuery("select collation(upper(bin_c)) from t").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(concat(abin, bin_c)) from t").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(concat(lbin, bin_c)) from t").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(concat(utf8_bin_c, bin_c)) from t").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(concat(utf8_gen_c, bin_c)) from t").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(concat(u4bin, bin_c)) from t").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(concat(u4ci, bin_c)) from t").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(concat(abin, u4bin)) from t").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select collation(concat(lbin, u4bin)) from t").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select collation(concat(utf8_bin_c, u4bin)) from t").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select collation(concat(utf8_gen_c, u4bin)) from t").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select collation(concat(u4ci, u4bin)) from t").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select collation(concat(abin, u4ci)) from t").Check(testkit.Rows("utf8mb4_general_ci")) + tk.MustQuery("select collation(concat(lbin, u4ci)) from t").Check(testkit.Rows("utf8mb4_general_ci")) + tk.MustQuery("select collation(concat(utf8_bin_c, u4ci)) from t").Check(testkit.Rows("utf8mb4_general_ci")) + tk.MustQuery("select collation(concat(utf8_gen_c, u4ci)) from t").Check(testkit.Rows("utf8mb4_general_ci")) + tk.MustQuery("select collation(concat(abin, utf8_bin_c)) from t").Check(testkit.Rows("utf8_bin")) + tk.MustQuery("select collation(concat(lbin, utf8_bin_c)) from t").Check(testkit.Rows("utf8_bin")) + tk.MustQuery("select collation(concat(utf8_gen_c, utf8_bin_c)) from t").Check(testkit.Rows("utf8_bin")) + tk.MustQuery("select collation(concat(abin, utf8_gen_c)) from t").Check(testkit.Rows("utf8_general_ci")) + tk.MustQuery("select collation(concat(lbin, utf8_gen_c)) from t").Check(testkit.Rows("utf8_general_ci")) + tk.MustQuery("select collation(concat(abin, lbin)) from t").Check(testkit.Rows("latin1_bin")) + + tk.MustExec("set names utf8mb4 collate utf8mb4_bin") + tk.MustQuery("select collation('a')").Check(testkit.Rows("utf8mb4_bin")) + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci") + tk.MustQuery("select collation('a')").Check(testkit.Rows("utf8mb4_general_ci")) + + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci") + tk.MustExec("set @test_collate_var = 'a'") + tk.MustQuery("select collation(@test_collate_var)").Check(testkit.Rows("utf8mb4_general_ci")) + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci") + tk.MustExec("set @test_collate_var = 1") + tk.MustQuery("select collation(@test_collate_var)").Check(testkit.Rows("utf8mb4_general_ci")) + tk.MustExec("set @test_collate_var = concat(\"a\", \"b\" collate utf8mb4_bin)") + tk.MustQuery("select collation(@test_collate_var)").Check(testkit.Rows("utf8mb4_bin")) +} + +func (s *testIntegrationSuite) TestCoercibility(c *C) { + tk := testkit.NewTestKit(c, s.store) + + type testCase struct { + expr string + result int + } + testFunc := func(cases []testCase, suffix string) { + for _, tc := range cases { + tk.MustQuery(fmt.Sprintf("select coercibility(%v) %v", tc.expr, suffix)).Check(testkit.Rows(fmt.Sprintf("%v", tc.result))) + } + } + testFunc([]testCase{ + // constants + {"1", 5}, {"null", 6}, {"'abc'", 4}, + // sys-constants + {"version()", 3}, {"user()", 3}, {"database()", 3}, + {"current_role()", 3}, {"current_user()", 3}, + // scalar functions after constant folding + {"1+null", 5}, {"null+'abcde'", 5}, {"concat(null, 'abcde')", 4}, + // non-deterministic functions + {"rand()", 5}, {"now()", 5}, {"sysdate()", 5}, + }, "") + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (i int, r real, d datetime, t timestamp, c char(10), vc varchar(10), b binary(10), vb binary(10))") + tk.MustExec("insert into t values (null, null, null, null, null, null, null, null)") + testFunc([]testCase{ + {"i", 5}, {"r", 5}, {"d", 5}, {"t", 5}, + {"c", 2}, {"b", 2}, {"vb", 2}, {"vc", 2}, + {"i+r", 5}, {"i*r", 5}, {"cos(r)+sin(i)", 5}, {"d+2", 5}, + {"t*10", 5}, {"concat(c, vc)", 2}, {"replace(c, 'x', 'y')", 2}, + }, "from t") +} + +func (s *testIntegrationSuite) TestCacheConstEval(c *C) { + tk := testkit.NewTestKit(c, s.store) + orgEnable := plannercore.PreparedPlanCacheEnabled() + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + }() + plannercore.SetPreparedPlanCache(true) + var err error + tk.Se, err = session.CreateSession4TestWithOpt(s.store, &session.Opt{ + PreparedPlanCache: kvcache.NewSimpleLRUCache(100, 0.1, math.MaxUint64), + }) + c.Assert(err, IsNil) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(col_double double)") + tk.MustExec("insert into t values (1)") + tk.Se.GetSessionVars().EnableVectorizedExpression = false + tk.MustExec("insert into mysql.expr_pushdown_blacklist values('cast', 'tikv,tiflash,tidb', 'for test')") + tk.MustExec("admin reload expr_pushdown_blacklist") + tk.MustExec("prepare stmt from 'SELECT * FROM (SELECT col_double AS c0 FROM t) t WHERE (ABS((REPEAT(?, ?) OR 5617780767323292672)) < LN(EXP(c0)) + (? ^ ?))'") + tk.MustExec("set @a1 = 'JuvkBX7ykVux20zQlkwDK2DFelgn7'") + tk.MustExec("set @a2 = 1") + tk.MustExec("set @a3 = -112990.35179796701") + tk.MustExec("set @a4 = 87997.92704840179") + // Main purpose here is checking no error is reported. 1 is the result when plan cache is disabled, it is + // incompatible with MySQL actually, update the result after fixing it. + tk.MustQuery("execute stmt using @a1, @a2, @a3, @a4").Check(testkit.Rows("1")) + tk.Se.GetSessionVars().EnableVectorizedExpression = true + tk.MustExec("delete from mysql.expr_pushdown_blacklist where name = 'cast' and store_type = 'tikv,tiflash,tidb' and reason = 'for test'") + tk.MustExec("admin reload expr_pushdown_blacklist") +} + +func (s *testIntegrationSerialSuite) TestCollationBasic(c *C) { + tk := testkit.NewTestKit(c, s.store) + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk.MustExec("use test") + tk.MustExec("create table t_ci(a varchar(10) collate utf8mb4_general_ci, unique key(a))") + tk.MustExec("insert into t_ci values ('a')") + tk.MustQuery("select * from t_ci").Check(testkit.Rows("a")) + tk.MustQuery("select * from t_ci").Check(testkit.Rows("a")) + tk.MustQuery("select * from t_ci where a='a'").Check(testkit.Rows("a")) + tk.MustQuery("select * from t_ci where a='A'").Check(testkit.Rows("a")) + tk.MustQuery("select * from t_ci where a='a '").Check(testkit.Rows("a")) + tk.MustQuery("select * from t_ci where a='a '").Check(testkit.Rows("a")) +} + +func (s *testIntegrationSerialSuite) TestWeightString(c *C) { + tk := testkit.NewTestKit(c, s.store) + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + type testCase struct { + input []string + result []string + resultAsChar1 []string + resultAsChar3 []string + resultAsBinary1 []string + resultAsBinary5 []string + resultExplicitCollateBin []string + } + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (id int, a varchar(20) collate utf8mb4_general_ci)") + cases := testCase{ + input: []string{"aAÁàãăâ", "a", "a ", "中", "中 "}, + result: []string{"\x00A\x00A\x00A\x00A\x00A\x00A\x00A", "\x00A", "\x00A", "\x4E\x2D", "\x4E\x2D"}, + resultAsChar1: []string{"\x00A", "\x00A", "\x00A", "\x4E\x2D", "\x4E\x2D"}, + resultAsChar3: []string{"\x00A\x00A\x00A", "\x00A", "\x00A", "\x4E\x2D", "\x4E\x2D"}, + resultAsBinary1: []string{"a", "a", "a", "\xE4", "\xE4"}, + resultAsBinary5: []string{"aA\xc3\x81\xc3", "a\x00\x00\x00\x00", "a \x00\x00", "中\x00\x00", "中 \x00"}, + resultExplicitCollateBin: []string{"aAÁàãăâ", "a", "a", "中", "中"}, + } + values := make([]string, len(cases.input)) + for i, input := range cases.input { + values[i] = fmt.Sprintf("(%d, '%s')", i, input) + } + tk.MustExec("insert into t values " + strings.Join(values, ",")) + rows := tk.MustQuery("select weight_string(a) from t order by id").Rows() + for i, out := range cases.result { + c.Assert(rows[i][0].(string), Equals, out) + } + rows = tk.MustQuery("select weight_string(a as char(1)) from t order by id").Rows() + for i, out := range cases.resultAsChar1 { + c.Assert(rows[i][0].(string), Equals, out) + } + rows = tk.MustQuery("select weight_string(a as char(3)) from t order by id").Rows() + for i, out := range cases.resultAsChar3 { + c.Assert(rows[i][0].(string), Equals, out) + } + rows = tk.MustQuery("select weight_string(a as binary(1)) from t order by id").Rows() + for i, out := range cases.resultAsBinary1 { + c.Assert(rows[i][0].(string), Equals, out) + } + rows = tk.MustQuery("select weight_string(a as binary(5)) from t order by id").Rows() + for i, out := range cases.resultAsBinary5 { + c.Assert(rows[i][0].(string), Equals, out) + } + c.Assert(tk.MustQuery("select weight_string(NULL);").Rows()[0][0], Equals, "") + c.Assert(tk.MustQuery("select weight_string(7);").Rows()[0][0], Equals, "") + c.Assert(tk.MustQuery("select weight_string(cast(7 as decimal(5)));").Rows()[0][0], Equals, "") + c.Assert(tk.MustQuery("select weight_string(cast(20190821 as date));").Rows()[0][0], Equals, "2019-08-21") + c.Assert(tk.MustQuery("select weight_string(cast(20190821 as date) as binary(5));").Rows()[0][0], Equals, "2019-") + c.Assert(tk.MustQuery("select weight_string(7.0);").Rows()[0][0], Equals, "") + c.Assert(tk.MustQuery("select weight_string(7 AS BINARY(2));").Rows()[0][0], Equals, "7\x00") + // test explicit collation + c.Assert(tk.MustQuery("select weight_string('中 ' collate utf8mb4_general_ci);").Rows()[0][0], Equals, "\x4E\x2D") + c.Assert(tk.MustQuery("select weight_string('中 ' collate utf8mb4_bin);").Rows()[0][0], Equals, "中") + c.Assert(tk.MustQuery("select collation(a collate utf8mb4_general_ci) from t order by id").Rows()[0][0], Equals, "utf8mb4_general_ci") + c.Assert(tk.MustQuery("select collation('中 ' collate utf8mb4_general_ci);").Rows()[0][0], Equals, "utf8mb4_general_ci") + rows = tk.MustQuery("select weight_string(a collate utf8mb4_bin) from t order by id").Rows() + for i, out := range cases.resultExplicitCollateBin { + c.Assert(rows[i][0].(string), Equals, out) + } + tk.MustGetErrMsg("select weight_string(a collate utf8_general_ci) from t order by id", "[ddl:1253]COLLATION 'utf8_general_ci' is not valid for CHARACTER SET 'utf8mb4'") + tk.MustGetErrMsg("select weight_string('中' collate utf8_bin)", "[ddl:1253]COLLATION 'utf8_bin' is not valid for CHARACTER SET 'utf8mb4'") +} + +func (s *testIntegrationSerialSuite) TestCollationCreateIndex(c *C) { + tk := testkit.NewTestKit(c, s.store) + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a varchar(10) collate utf8mb4_general_ci);") + tk.MustExec("insert into t values ('a');") + tk.MustExec("insert into t values ('A');") + tk.MustExec("insert into t values ('b');") + tk.MustExec("insert into t values ('B');") + tk.MustExec("insert into t values ('a');") + tk.MustExec("insert into t values ('A');") + tk.MustExec("create index idx on t(a);") + tk.MustQuery("select * from t order by a").Check(testkit.Rows("a", "A", "a", "A", "b", "B")) +} + +func (s *testIntegrationSerialSuite) TestCollateConstantPropagation(c *C) { + tk := testkit.NewTestKit(c, s.store) + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a char(10) collate utf8mb4_bin, b char(10) collate utf8mb4_general_ci);") + tk.MustExec("insert into t values ('a', 'A');") + tk.MustQuery("select * from t t1, t t2 where t1.a=t2.b and t2.b='a' collate utf8mb4_general_ci;").Check(nil) + tk.MustQuery("select * from t t1, t t2 where t1.a=t2.b and t2.b>='a' collate utf8mb4_general_ci;").Check(nil) + tk.MustExec("drop table t;") + tk.MustExec("create table t (a char(10) collate utf8mb4_general_ci, b char(10) collate utf8mb4_general_ci);") + tk.MustExec("insert into t values ('A', 'a');") + tk.MustQuery("select * from t t1, t t2 where t1.a=t2.b and t2.b='a' collate utf8mb4_bin;").Check(testkit.Rows("A a A a")) + tk.MustQuery("select * from t t1, t t2 where t1.a=t2.b and t2.b>='a' collate utf8mb4_bin;").Check(testkit.Rows("A a A a")) + tk.MustExec("drop table t;") + tk.MustExec("set names utf8mb4") + tk.MustExec("create table t (a char(10) collate utf8mb4_general_ci, b char(10) collate utf8_general_ci);") + tk.MustExec("insert into t values ('a', 'A');") + tk.MustQuery("select * from t t1, t t2 where t1.a=t2.b and t2.b='A'").Check(testkit.Rows("a A a A")) + tk.MustExec("drop table t;") + tk.MustExec("create table t(a char collate utf8_general_ci, b char collate utf8mb4_general_ci, c char collate utf8_bin);") + tk.MustExec("insert into t values ('b', 'B', 'B');") + tk.MustQuery("select * from t t1, t t2 where t1.a=t2.b and t2.b=t2.c;").Check(testkit.Rows("b B B b B B")) + tk.MustExec("drop table t;") + tk.MustExec("create table t(a char collate utf8_bin, b char collate utf8_general_ci);") + tk.MustExec("insert into t values ('a', 'A');") + tk.MustQuery("select * from t t1, t t2 where t1.b=t2.b and t2.b=t1.a collate utf8_general_ci;").Check(testkit.Rows("a A a A")) + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci;") + tk.MustExec("create table t1(a char, b varchar(10)) charset utf8mb4 collate utf8mb4_general_ci;") + tk.MustExec("create table t2(a char, b varchar(10)) charset utf8mb4 collate utf8mb4_bin;") + tk.MustExec("insert into t1 values ('A', 'a');") + tk.MustExec("insert into t2 values ('a', 'a')") + tk.MustQuery("select * from t1 left join t2 on t1.a = t2.a where t1.a = 'a';").Check(testkit.Rows("A a ")) + tk.MustExec("drop table t;") + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci;") + tk.MustExec("create table t(a char collate utf8mb4_bin, b char collate utf8mb4_general_ci);") + tk.MustExec("insert into t values ('a', 'a');") + tk.MustQuery("select * from t t1, t t2 where t2.b = 'A' and lower(concat(t1.a , '' )) = t2.b;").Check(testkit.Rows("a a a a")) +} +func (s *testIntegrationSerialSuite) prepare4Join(c *C) *testkit.TestKit { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t_bin") + tk.MustExec("CREATE TABLE `t` ( `a` int(11) NOT NULL,`b` varchar(5) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL)") + tk.MustExec("CREATE TABLE `t_bin` ( `a` int(11) NOT NULL,`b` varchar(5) CHARACTER SET binary)") + tk.MustExec("insert into t values (1, 'a'), (2, 'À'), (3, 'á'), (4, 'à'), (5, 'b'), (6, 'c'), (7, ' ')") + tk.MustExec("insert into t_bin values (1, 'a'), (2, 'À'), (3, 'á'), (4, 'à'), (5, 'b'), (6, 'c'), (7, ' ')") + return tk +} + +func (s *testIntegrationSerialSuite) prepare4Join2(c *C) *testkit.TestKit { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1 (id int, v varchar(5) character set binary, key(v))") + tk.MustExec("create table t2 (v varchar(5) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci, key(v))") + tk.MustExec("insert into t1 values (1, 'a'), (2, 'À'), (3, 'á'), (4, 'à'), (5, 'b'), (6, 'c'), (7, ' ')") + tk.MustExec("insert into t2 values ('a'), ('À'), ('á'), ('à'), ('b'), ('c'), (' ')") + return tk +} + +func (s *testIntegrationSerialSuite) TestCollateHashJoin(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Join(c) + tk.MustQuery("select /*+ TIDB_HJ(t1, t2) */ t1.a, t1.b from t t1, t t2 where t1.b=t2.b order by t1.a").Check( + testkit.Rows("1 a", "1 a", "1 a", "1 a", "2 À", "2 À", "2 À", "2 À", "3 á", "3 á", "3 á", "3 á", "4 à", "4 à", "4 à", "4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_HJ(t1, t2) */ t1.a, t1.b from t_bin t1, t_bin t2 where t1.b=t2.b order by t1.a").Check( + testkit.Rows("1 a", "2 À", "3 á", "4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_HJ(t1, t2) */ t1.a, t1.b from t t1, t t2 where t1.b=t2.b and t1.a>3 order by t1.a").Check( + testkit.Rows("4 à", "4 à", "4 à", "4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_HJ(t1, t2) */ t1.a, t1.b from t_bin t1, t_bin t2 where t1.b=t2.b and t1.a>3 order by t1.a").Check( + testkit.Rows("4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_HJ(t1, t2) */ t1.a, t1.b from t t1, t t2 where t1.b=t2.b and t1.a>3 order by t1.a").Check( + testkit.Rows("4 à", "4 à", "4 à", "4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_HJ(t1, t2) */ t1.a, t1.b from t_bin t1, t_bin t2 where t1.b=t2.b and t1.a>3 order by t1.a").Check( + testkit.Rows("4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_HJ(t1, t2) */ t1.a, t1.b from t t1, t t2 where t1.b=t2.b and t1.a>t2.a order by t1.a").Check( + testkit.Rows("2 À", "3 á", "3 á", "4 à", "4 à", "4 à")) + tk.MustQuery("select /*+ TIDB_HJ(t1, t2) */ t1.a, t1.b from t_bin t1, t_bin t2 where t1.b=t2.b and t1.a>t2.a order by t1.a").Check( + testkit.Rows()) +} + +func (s *testIntegrationSerialSuite) TestCollateHashJoin2(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Join2(c) + tk.MustQuery("select /*+ TIDB_HJ(t1, t2) */ * from t1, t2 where t1.v=t2.v order by t1.id").Check( + testkit.Rows("1 a a", "2 À À", "3 á á", "4 à à", "5 b b", "6 c c", "7 ")) +} + +func (s *testIntegrationSerialSuite) TestCollateMergeJoin(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Join(c) + tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t1.a, t1.b from t t1, t t2 where t1.b=t2.b order by t1.a").Check( + testkit.Rows("1 a", "1 a", "1 a", "1 a", "2 À", "2 À", "2 À", "2 À", "3 á", "3 á", "3 á", "3 á", "4 à", "4 à", "4 à", "4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t1.a, t1.b from t_bin t1, t_bin t2 where t1.b=t2.b order by t1.a").Check( + testkit.Rows("1 a", "2 À", "3 á", "4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t1.a, t1.b from t t1, t t2 where t1.b=t2.b and t1.a>3 order by t1.a").Check( + testkit.Rows("4 à", "4 à", "4 à", "4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t1.a, t1.b from t_bin t1, t_bin t2 where t1.b=t2.b and t1.a>3 order by t1.a").Check( + testkit.Rows("4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t1.a, t1.b from t t1, t t2 where t1.b=t2.b and t1.a>3 order by t1.a").Check( + testkit.Rows("4 à", "4 à", "4 à", "4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t1.a, t1.b from t_bin t1, t_bin t2 where t1.b=t2.b and t1.a>3 order by t1.a").Check( + testkit.Rows("4 à", "5 b", "6 c", "7 ")) + tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t1.a, t1.b from t t1, t t2 where t1.b=t2.b and t1.a>t2.a order by t1.a").Check( + testkit.Rows("2 À", "3 á", "3 á", "4 à", "4 à", "4 à")) + tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t1.a, t1.b from t_bin t1, t_bin t2 where t1.b=t2.b and t1.a>t2.a order by t1.a").Check( + testkit.Rows()) +} + +func (s *testIntegrationSerialSuite) TestCollateMergeJoin2(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Join2(c) + tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ * from t1, t2 where t1.v=t2.v order by t1.id").Check( + testkit.Rows("1 a a", "2 À À", "3 á á", "4 à à", "5 b b", "6 c c", "7 ")) +} + +func (s *testIntegrationSerialSuite) TestCollateIndexMergeJoin(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a varchar(5) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci, b varchar(5) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci, key(a), key(b))") + tk.MustExec("insert into t values ('a', 'x'), ('x', 'À'), ('á', 'x'), ('à', 'à'), ('à', 'x')") + + tk.MustExec("set tidb_enable_index_merge=1") + tk.MustQuery("select /*+ USE_INDEX_MERGE(t, a, b) */ * from t where a = 'a' or b = 'a'").Sort().Check( + testkit.Rows("a x", "x À", "à x", "à à", "á x")) +} + +func (s *testIntegrationSerialSuite) prepare4Collation(c *C, hasIndex bool) *testkit.TestKit { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t_bin") + idxSQL := ", key(v)" + if !hasIndex { + idxSQL = "" + } + tk.MustExec(fmt.Sprintf("create table t (id int, v varchar(5) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL %v)", idxSQL)) + tk.MustExec(fmt.Sprintf("create table t_bin (id int, v varchar(5) CHARACTER SET binary %v)", idxSQL)) + tk.MustExec("insert into t values (1, 'a'), (2, 'À'), (3, 'á'), (4, 'à'), (5, 'b'), (6, 'c'), (7, ' ')") + tk.MustExec("insert into t_bin values (1, 'a'), (2, 'À'), (3, 'á'), (4, 'à'), (5, 'b'), (6, 'c'), (7, ' ')") + return tk +} + +func (s *testIntegrationSerialSuite) TestCollateSelection(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Collation(c, false) + tk.MustQuery("select v from t where v='a' order by id").Check(testkit.Rows("a", "À", "á", "à")) + tk.MustQuery("select v from t_bin where v='a' order by id").Check(testkit.Rows("a")) + tk.MustQuery("select v from t where v<'b' and id<=3").Check(testkit.Rows("a", "À", "á")) + tk.MustQuery("select v from t_bin where v<'b' and id<=3").Check(testkit.Rows("a")) +} + +func (s *testIntegrationSerialSuite) TestCollateSort(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Collation(c, false) + tk.MustQuery("select id from t order by v, id").Check(testkit.Rows("7", "1", "2", "3", "4", "5", "6")) + tk.MustQuery("select id from t_bin order by v, id").Check(testkit.Rows("7", "1", "5", "6", "2", "4", "3")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a char(10) collate utf8mb4_general_ci, key(a))") + tk.MustExec("insert into t values ('a'), ('A'), ('b')") + tk.MustExec("insert into t values ('a'), ('A'), ('b')") + tk.MustExec("insert into t values ('a'), ('A'), ('b')") + tk.MustQuery("select * from t order by a collate utf8mb4_bin").Check(testkit.Rows("A", "A", "A", "a", "a", "a", "b", "b", "b")) +} + +func (s *testIntegrationSerialSuite) TestCollateHashAgg(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Collation(c, false) + tk.HasPlan("select distinct(v) from t_bin", "HashAgg") + tk.MustQuery("select distinct(v) from t_bin").Sort().Check(testkit.Rows(" ", "a", "b", "c", "À", "à", "á")) + tk.HasPlan("select distinct(v) from t", "HashAgg") + tk.MustQuery("select distinct(v) from t").Sort().Check(testkit.Rows(" ", "a", "b", "c")) + tk.HasPlan("select v, count(*) from t_bin group by v", "HashAgg") + tk.MustQuery("select v, count(*) from t_bin group by v").Sort().Check(testkit.Rows(" 1", "a 1", "b 1", "c 1", "À 1", "à 1", "á 1")) + tk.HasPlan("select v, count(*) from t group by v", "HashAgg") + tk.MustQuery("select v, count(*) from t group by v").Sort().Check(testkit.Rows(" 1", "a 4", "b 1", "c 1")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a char(10) collate utf8mb4_general_ci, key(a))") + tk.MustExec("insert into t values ('a'), ('A'), ('b')") + tk.MustExec("insert into t values ('a'), ('A'), ('b')") + tk.MustExec("insert into t values ('a'), ('A'), ('b')") + tk.MustQuery("select count(1) from t group by a collate utf8mb4_bin").Check(testkit.Rows("3", "3", "3")) +} + +func (s *testIntegrationSerialSuite) TestCollateStreamAgg(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Collation(c, true) + tk.HasPlan("select distinct(v) from t_bin", "StreamAgg") + tk.MustQuery("select distinct(v) from t_bin").Sort().Check(testkit.Rows(" ", "a", "b", "c", "À", "à", "á")) + tk.HasPlan("select distinct(v) from t", "StreamAgg") + tk.MustQuery("select distinct(v) from t").Sort().Check(testkit.Rows(" ", "a", "b", "c")) + tk.HasPlan("select v, count(*) from t_bin group by v", "StreamAgg") + tk.MustQuery("select v, count(*) from t_bin group by v").Sort().Check(testkit.Rows(" 1", "a 1", "b 1", "c 1", "À 1", "à 1", "á 1")) + tk.HasPlan("select v, count(*) from t group by v", "StreamAgg") + tk.MustQuery("select v, count(*) from t group by v").Sort().Check(testkit.Rows(" 1", "a 4", "b 1", "c 1")) +} + +func (s *testIntegrationSerialSuite) TestCollateIndexReader(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Collation(c, true) + tk.HasPlan("select v from t where v < 'b' order by v", "IndexReader") + tk.MustQuery("select v from t where v < 'b' order by v").Check(testkit.Rows(" ", "a", "À", "á", "à")) + tk.HasPlan("select v from t where v < 'b' and v > ' ' order by v", "IndexReader") + tk.MustQuery("select v from t where v < 'b' and v > ' ' order by v").Check(testkit.Rows("a", "À", "á", "à")) + tk.HasPlan("select v from t_bin where v < 'b' order by v", "IndexReader") + tk.MustQuery("select v from t_bin where v < 'b' order by v").Sort().Check(testkit.Rows(" ", "a")) + tk.HasPlan("select v from t_bin where v < 'b' and v > ' ' order by v", "IndexReader") + tk.MustQuery("select v from t_bin where v < 'b' and v > ' ' order by v").Sort().Check(testkit.Rows("a")) +} + +func (s *testIntegrationSerialSuite) TestCollateIndexLookup(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Collation(c, true) + + tk.HasPlan("select id from t where v < 'b'", "IndexLookUp") + tk.MustQuery("select id from t where v < 'b'").Sort().Check(testkit.Rows("1", "2", "3", "4", "7")) + tk.HasPlan("select id from t where v < 'b' and v > ' '", "IndexLookUp") + tk.MustQuery("select id from t where v < 'b' and v > ' '").Sort().Check(testkit.Rows("1", "2", "3", "4")) + tk.HasPlan("select id from t_bin where v < 'b'", "IndexLookUp") + tk.MustQuery("select id from t_bin where v < 'b'").Sort().Check(testkit.Rows("1", "7")) + tk.HasPlan("select id from t_bin where v < 'b' and v > ' '", "IndexLookUp") + tk.MustQuery("select id from t_bin where v < 'b' and v > ' '").Sort().Check(testkit.Rows("1")) +} + +func (s *testIntegrationSerialSuite) TestCollateStringFunction(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := testkit.NewTestKit(c, s.store) + + tk.MustQuery("select field('a', 'b', 'a');").Check(testkit.Rows("2")) + tk.MustQuery("select field('a', 'b', 'A');").Check(testkit.Rows("0")) + tk.MustQuery("select field('a', 'b', 'A' collate utf8mb4_bin);").Check(testkit.Rows("0")) + tk.MustQuery("select field('a', 'b', 'a ' collate utf8mb4_bin);").Check(testkit.Rows("2")) + tk.MustQuery("select field('a', 'b', 'A' collate utf8mb4_general_ci);").Check(testkit.Rows("2")) + tk.MustQuery("select field('a', 'b', 'a ' collate utf8mb4_general_ci);").Check(testkit.Rows("2")) + + tk.MustQuery("select FIND_IN_SET('a','b,a,c,d');").Check(testkit.Rows("2")) + tk.MustQuery("select FIND_IN_SET('a','b,A,c,d');").Check(testkit.Rows("0")) + tk.MustQuery("select FIND_IN_SET('a','b,A,c,d' collate utf8mb4_bin);").Check(testkit.Rows("0")) + tk.MustQuery("select FIND_IN_SET('a','b,a ,c,d' collate utf8mb4_bin);").Check(testkit.Rows("2")) + tk.MustQuery("select FIND_IN_SET('a','b,A,c,d' collate utf8mb4_general_ci);").Check(testkit.Rows("2")) + tk.MustQuery("select FIND_IN_SET('a','b,a ,c,d' collate utf8mb4_general_ci);").Check(testkit.Rows("2")) + + tk.MustExec("select concat('a' collate utf8mb4_bin, 'b' collate utf8mb4_bin);") + tk.MustGetErrMsg("select concat('a' collate utf8mb4_bin, 'b' collate utf8mb4_general_ci);", "[expression:1267]Illegal mix of collations (utf8mb4_bin,EXPLICIT) and (utf8mb4_general_ci,EXPLICIT) for operation 'concat'") + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a char)") + tk.MustGetErrMsg("select * from t t1 join t t2 on t1.a collate utf8mb4_bin = t2.a collate utf8mb4_general_ci;", "[expression:1267]Illegal mix of collations (utf8mb4_bin,EXPLICIT) and (utf8mb4_general_ci,EXPLICIT) for operation 'eq'") +} + +func (s *testIntegrationSerialSuite) TestCollateLike(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci") + tk.MustQuery("select 'a' like 'A'").Check(testkit.Rows("1")) + tk.MustQuery("select 'a' like 'A' collate utf8mb4_general_ci").Check(testkit.Rows("1")) + tk.MustQuery("select 'a' like 'À'").Check(testkit.Rows("1")) + tk.MustQuery("select 'a' like '%À'").Check(testkit.Rows("1")) + tk.MustQuery("select 'a' like '%À '").Check(testkit.Rows("0")) + tk.MustQuery("select 'a' like 'À%'").Check(testkit.Rows("1")) + tk.MustQuery("select 'a' like 'À_'").Check(testkit.Rows("0")) + tk.MustQuery("select 'a' like '%À%'").Check(testkit.Rows("1")) + tk.MustQuery("select 'aaa' like '%ÀAa%'").Check(testkit.Rows("1")) + tk.MustExec("set names utf8mb4 collate utf8mb4_bin") + + tk.MustExec("use test;") + tk.MustExec("drop table if exists t_like;") + tk.MustExec("create table t_like(id int, b varchar(20) collate utf8mb4_general_ci);") + tk.MustExec("insert into t_like values (1, 'aaa'), (2, 'abc'), (3, 'aac');") + tk.MustQuery("select b like 'AaÀ' from t_like order by id;").Check(testkit.Rows("1", "0", "0")) + tk.MustQuery("select b like 'Aa_' from t_like order by id;").Check(testkit.Rows("1", "0", "1")) + tk.MustQuery("select b like '_A_' from t_like order by id;").Check(testkit.Rows("1", "0", "1")) + tk.MustQuery("select b from t_like where b like 'Aa_' order by id;").Check(testkit.Rows("aaa", "aac")) + tk.MustQuery("select b from t_like where b like 'A%' order by id;").Check(testkit.Rows("aaa", "abc", "aac")) + tk.MustQuery("select b from t_like where b like '%A%' order by id;").Check(testkit.Rows("aaa", "abc", "aac")) + tk.MustExec("alter table t_like add index idx_b(b);") + tk.MustQuery("select b from t_like use index(idx_b) where b like 'Aa_' order by id;").Check(testkit.Rows("aaa", "aac")) + tk.MustQuery("select b from t_like use index(idx_b) where b like 'A%' order by id;").Check(testkit.Rows("aaa", "abc", "aac")) + tk.MustQuery("select b from t_like use index(idx_b) where b like '%A%' order by id;").Check(testkit.Rows("aaa", "abc", "aac")) +} + +func (s *testIntegrationSerialSuite) TestCollateSubQuery(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := s.prepare4Collation(c, false) + tk.MustQuery("select id from t where v in (select v from t_bin) order by id").Check(testkit.Rows("1", "2", "3", "4", "5", "6", "7")) + tk.MustQuery("select id from t_bin where v in (select v from t) order by id").Check(testkit.Rows("1", "2", "3", "4", "5", "6", "7")) + tk.MustQuery("select id from t where v not in (select v from t_bin) order by id").Check(testkit.Rows()) + tk.MustQuery("select id from t_bin where v not in (select v from t) order by id").Check(testkit.Rows()) + tk.MustQuery("select id from t where exists (select 1 from t_bin where t_bin.v=t.v) order by id").Check(testkit.Rows("1", "2", "3", "4", "5", "6", "7")) + tk.MustQuery("select id from t_bin where exists (select 1 from t where t_bin.v=t.v) order by id").Check(testkit.Rows("1", "2", "3", "4", "5", "6", "7")) + tk.MustQuery("select id from t where not exists (select 1 from t_bin where t_bin.v=t.v) order by id").Check(testkit.Rows()) + tk.MustQuery("select id from t_bin where not exists (select 1 from t where t_bin.v=t.v) order by id").Check(testkit.Rows()) +} + +func (s *testIntegrationSerialSuite) TestCollateDDL(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("create database t;") + tk.MustExec("use t;") + tk.MustExec("drop database t;") +} + +func (s *testIntegrationSuite) TestIssue15986(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t0") + tk.MustExec("CREATE TABLE t0(c0 int)") + tk.MustExec("INSERT INTO t0 VALUES (0)") + tk.MustQuery("SELECT t0.c0 FROM t0 WHERE CHAR(204355900);").Check(testkit.Rows("0")) + tk.MustQuery("SELECT t0.c0 FROM t0 WHERE not CHAR(204355900);").Check(testkit.Rows()) + tk.MustQuery("SELECT t0.c0 FROM t0 WHERE '.0';").Check(testkit.Rows()) + tk.MustQuery("SELECT t0.c0 FROM t0 WHERE not '.0';").Check(testkit.Rows("0")) + // If the number does not exceed the range of float64 and its value is not 0, it will be converted to true. + tk.MustQuery("select * from t0 where '.000000000000000000000000000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000000009';").Check(testkit.Rows("0")) + tk.MustQuery("select * from t0 where not '.000000000000000000000000000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000000009';").Check(testkit.Rows()) + + // If the number is truncated beyond the range of float64, it will be converted to true when the truncated result is 0. + tk.MustQuery("select * from t0 where '.0000000000000000000000000000000000000000000000000000000" + + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000000009';").Check(testkit.Rows()) + tk.MustQuery("select * from t0 where not '.0000000000000000000000000000000000000000000000000000000" + + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000000009';").Check(testkit.Rows("0")) +} + +func (s *testIntegrationSuite) TestNegativeZeroForHashJoin(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t0, t1") + tk.MustExec("CREATE TABLE t0(c0 float);") + tk.MustExec("CREATE TABLE t1(c0 float);") + tk.MustExec("INSERT INTO t1(c0) VALUES (0);") + tk.MustExec("INSERT INTO t0(c0) VALUES (0);") + tk.MustQuery("SELECT t1.c0 FROM t1, t0 WHERE t0.c0=-t1.c0;").Check(testkit.Rows("0")) + tk.MustExec("drop TABLE t0;") + tk.MustExec("drop table t1;") +} + +func (s *testIntegrationSuite) TestIssue15725(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t values(2)") + tk.MustQuery("select * from t where (not not a) = a").Check(testkit.Rows()) + tk.MustQuery("select * from t where (not not not not a) = a").Check(testkit.Rows()) +} + +>>>>>>> bdbdbae... expression: fix the issue that incorrect result for a predicat… (#16014) func (s *testIntegrationSuite) TestIssue15790(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test;") diff --git a/types/datum.go b/types/datum.go index 9019fd0365834..be0bbfaa940c5 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1328,7 +1328,7 @@ func (d *Datum) ToBool(sc *stmtctx.StatementContext) (int64, error) { case KindFloat64: isZero = RoundFloat(d.GetFloat64()) == 0 case KindString, KindBytes: - iVal, err1 := StrToInt(sc, d.GetString()) + iVal, err1 := StrToFloat(sc, d.GetString()) isZero, err = iVal == 0, err1 case KindMysqlTime: isZero = d.GetMysqlTime().IsZero() diff --git a/types/datum_test.go b/types/datum_test.go index 62f9e1f706e82..8de31006838a3 100644 --- a/types/datum_test.go +++ b/types/datum_test.go @@ -63,9 +63,9 @@ func (ts *testDatumSuite) TestToBool(c *C) { testDatumToBool(c, float32(0.1), 0) testDatumToBool(c, float64(0.1), 0) testDatumToBool(c, "", 0) - testDatumToBool(c, "0.1", 0) + testDatumToBool(c, "0.1", 1) testDatumToBool(c, []byte{}, 0) - testDatumToBool(c, []byte("0.1"), 0) + testDatumToBool(c, []byte("0.1"), 1) testDatumToBool(c, NewBinaryLiteralFromUint(0, -1), 0) testDatumToBool(c, Enum{Name: "a", Value: 1}, 1) testDatumToBool(c, Set{Name: "a", Value: 1}, 1)