diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index aff3cd4e375a3..9dd3ba3daa8f9 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -31,7 +31,9 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/stringutil" ) var ( @@ -39,11 +41,12 @@ var ( ) type mockDataSourceParameters struct { - schema *expression.Schema - ndvs []int // number of distinct values on columns[i] and zero represents no limit - orders []bool // columns[i] should be ordered if orders[i] is true - rows int // number of rows the DataSource should output - ctx sessionctx.Context + schema *expression.Schema + genDataFunc func(row int, typ *types.FieldType) interface{} + ndvs []int // number of distinct values on columns[i] and zero represents no limit + orders []bool // columns[i] should be ordered if orders[i] is true + rows int // number of rows the DataSource should output + ctx sessionctx.Context } type mockDataSource struct { @@ -56,11 +59,21 @@ type mockDataSource struct { func (mds *mockDataSource) genColDatums(col int) (results []interface{}) { typ := mds.retFieldTypes[col] - order := mds.p.orders[col] + order := false + if col < len(mds.p.orders) { + order = mds.p.orders[col] + } rows := mds.p.rows - NDV := mds.p.ndvs[col] + NDV := 0 + if col < len(mds.p.ndvs) { + NDV = mds.p.ndvs[col] + } results = make([]interface{}, 0, rows) - if NDV == 0 { + if mds.p.genDataFunc != nil { + for i := 0; i < rows; i++ { + results = append(results, mds.p.genDataFunc(i, typ)) + } + } else if NDV == 0 { for i := 0; i < rows; i++ { results = append(results, mds.randDatum(typ)) } @@ -184,7 +197,7 @@ func (a aggTestCase) columns() []*expression.Column { } func (a aggTestCase) String() string { - return fmt.Sprintf("(execType:%v, aggFunc:%v, ndv:%v, hasDistinct:%v, rows:%v, concruuency:%v)", + return fmt.Sprintf("(execType:%v, aggFunc:%v, ndv:%v, hasDistinct:%v, rows:%v, concurrency:%v)", a.execType, a.aggFunc, a.groupByNDV, a.hasDistinct, a.rows, a.concurrency) } @@ -503,3 +516,178 @@ func BenchmarkWindowFunctions(b *testing.B) { }) } } + +type hashJoinTestCase struct { + rows int + concurrency int + ctx sessionctx.Context + keyIdx []int +} + +func (tc hashJoinTestCase) columns() []*expression.Column { + return []*expression.Column{ + {Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong)}, + {Index: 1, RetType: types.NewFieldType(mysql.TypeVarString)}, + } +} + +func (tc hashJoinTestCase) String() string { + return fmt.Sprintf("(rows:%v, concurency:%v, joinKeyIdx: %v)", + tc.rows, tc.concurrency, tc.keyIdx) +} + +func defaultHashJoinTestCase() *hashJoinTestCase { + ctx := mock.NewContext() + ctx.GetSessionVars().InitChunkSize = variable.DefInitChunkSize + ctx.GetSessionVars().MaxChunkSize = variable.DefMaxChunkSize + ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(nil, -1) + tc := &hashJoinTestCase{rows: 100000, concurrency: 4, ctx: ctx, keyIdx: []int{0, 1}} + return tc +} + +func prepare4Join(testCase *hashJoinTestCase, innerExec, outerExec Executor) *HashJoinExec { + cols0 := testCase.columns() + cols1 := testCase.columns() + joinSchema := expression.NewSchema(cols0...) + joinSchema.Append(cols1...) + joinKeys := make([]*expression.Column, 0, len(testCase.keyIdx)) + for _, keyIdx := range testCase.keyIdx { + joinKeys = append(joinKeys, cols0[keyIdx]) + } + e := &HashJoinExec{ + baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, stringutil.StringerStr("HashJoin"), innerExec, outerExec), + concurrency: uint(testCase.concurrency), + joinType: 0, // InnerJoin + isOuterJoin: false, + innerKeys: joinKeys, + outerKeys: joinKeys, + innerExec: innerExec, + outerExec: outerExec, + } + defaultValues := make([]types.Datum, e.innerExec.Schema().Len()) + lhsTypes, rhsTypes := retTypes(innerExec), retTypes(outerExec) + e.joiners = make([]joiner, e.concurrency) + for i := uint(0); i < e.concurrency; i++ { + e.joiners[i] = newJoiner(testCase.ctx, e.joinType, true, defaultValues, + nil, lhsTypes, rhsTypes) + } + return e +} + +func benchmarkHashJoinExecWithCase(b *testing.B, casTest *hashJoinTestCase) { + opt := mockDataSourceParameters{ + schema: expression.NewSchema(casTest.columns()...), + rows: casTest.rows, + ctx: casTest.ctx, + genDataFunc: func(row int, typ *types.FieldType) interface{} { + switch typ.Tp { + case mysql.TypeLong, mysql.TypeLonglong: + return int64(row) + case mysql.TypeVarString: + return rawData + default: + panic("not implement") + } + }, + } + dataSource1 := buildMockDataSource(opt) + dataSource2 := buildMockDataSource(opt) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + exec := prepare4Join(casTest, dataSource1, dataSource2) + tmpCtx := context.Background() + chk := newFirstChunk(exec) + dataSource1.prepareChunks() + dataSource2.prepareChunks() + + b.StartTimer() + if err := exec.Open(tmpCtx); err != nil { + b.Fatal(err) + } + for { + if err := exec.Next(tmpCtx, chk); err != nil { + b.Fatal(err) + } + if chk.NumRows() == 0 { + break + } + } + + if err := exec.Close(); err != nil { + b.Fatal(err) + } + b.StopTimer() + } +} + +func BenchmarkHashJoinExec(b *testing.B) { + b.ReportAllocs() + cas := defaultHashJoinTestCase() + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkHashJoinExecWithCase(b, cas) + }) + + cas.keyIdx = []int{0} + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkHashJoinExecWithCase(b, cas) + }) +} + +func benchmarkBuildHashTableForList(b *testing.B, casTest *hashJoinTestCase) { + opt := mockDataSourceParameters{ + schema: expression.NewSchema(casTest.columns()...), + rows: casTest.rows, + ctx: casTest.ctx, + genDataFunc: func(row int, typ *types.FieldType) interface{} { + switch typ.Tp { + case mysql.TypeLong, mysql.TypeLonglong: + return int64(row) + case mysql.TypeVarString: + return rawData + default: + panic("not implement") + } + }, + } + dataSource1 := buildMockDataSource(opt) + dataSource2 := buildMockDataSource(opt) + + dataSource1.prepareChunks() + exec := prepare4Join(casTest, dataSource1, dataSource2) + tmpCtx := context.Background() + if err := exec.Open(tmpCtx); err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + innerResultCh := make(chan *chunk.Chunk, 1) + go func() { + for _, chk := range dataSource1.genData { + innerResultCh <- chk + } + close(innerResultCh) + }() + + b.StartTimer() + if err := exec.buildHashTableForList(innerResultCh); err != nil { + b.Fatal(err) + } + b.StopTimer() + } +} + +func BenchmarkBuildHashTableForList(b *testing.B) { + b.ReportAllocs() + cas := defaultHashJoinTestCase() + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkBuildHashTableForList(b, cas) + }) + + cas.keyIdx = []int{0} + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkBuildHashTableForList(b, cas) + }) +} diff --git a/executor/hash_table.go b/executor/hash_table.go new file mode 100644 index 0000000000000..af5693f236df2 --- /dev/null +++ b/executor/hash_table.go @@ -0,0 +1,106 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "github.com/pingcap/tidb/util/chunk" +) + +const maxEntrySliceLen = 8 * 1024 + +type entry struct { + ptr chunk.RowPtr + next entryAddr +} + +type entryStore struct { + slices [][]entry + sliceIdx uint32 + sliceLen uint32 +} + +func (es *entryStore) put(e entry) entryAddr { + if es.sliceLen == maxEntrySliceLen { + es.slices = append(es.slices, make([]entry, 0, maxEntrySliceLen)) + es.sliceLen = 0 + es.sliceIdx++ + } + addr := entryAddr{sliceIdx: es.sliceIdx, offset: es.sliceLen} + es.slices[es.sliceIdx] = append(es.slices[es.sliceIdx], e) + es.sliceLen++ + return addr +} + +func (es *entryStore) get(addr entryAddr) entry { + return es.slices[addr.sliceIdx][addr.offset] +} + +type entryAddr struct { + sliceIdx uint32 + offset uint32 +} + +var nullEntryAddr = entryAddr{} + +// rowHashMap stores multiple rowPtr of rows for a given key with minimum GC overhead. +// A given key can store multiple values. +// It is not thread-safe, should only be used in one goroutine. +type rowHashMap struct { + entryStore entryStore + hashTable map[uint64]entryAddr + length int +} + +// newRowHashMap creates a new rowHashMap. +func newRowHashMap() *rowHashMap { + m := new(rowHashMap) + // TODO(fengliyuan): initialize the size of map from the estimated row count for better performance. + m.hashTable = make(map[uint64]entryAddr) + m.entryStore.slices = [][]entry{make([]entry, 0, 64)} + // Reserve the first empty entry, so entryAddr{} can represent nullEntryAddr. + m.entryStore.put(entry{}) + return m +} + +// Put puts the key/rowPtr pairs to the rowHashMap, multiple rowPtrs are stored in a list. +func (m *rowHashMap) Put(hashKey uint64, rowPtr chunk.RowPtr) { + oldEntryAddr := m.hashTable[hashKey] + e := entry{ + ptr: rowPtr, + next: oldEntryAddr, + } + newEntryAddr := m.entryStore.put(e) + m.hashTable[hashKey] = newEntryAddr + m.length++ +} + +// Get gets the values of the "key" and appends them to "values". +func (m *rowHashMap) Get(hashKey uint64) (rowPtrs []chunk.RowPtr) { + entryAddr := m.hashTable[hashKey] + for entryAddr != nullEntryAddr { + e := m.entryStore.get(entryAddr) + entryAddr = e.next + rowPtrs = append(rowPtrs, e.ptr) + } + // Keep the order of input. + for i := 0; i < len(rowPtrs)/2; i++ { + j := len(rowPtrs) - 1 - i + rowPtrs[i], rowPtrs[j] = rowPtrs[j], rowPtrs[i] + } + return +} + +// Len returns the number of rowPtrs in the rowHashMap, the number of keys may be less than Len +// if the same key is put more than once. +func (m *rowHashMap) Len() int { return m.length } diff --git a/executor/join.go b/executor/join.go index 2b3eae1d5d06e..6c1d4b3056c19 100644 --- a/executor/join.go +++ b/executor/join.go @@ -16,9 +16,10 @@ package executor import ( "context" "fmt" + "hash" + "hash/fnv" "sync" "sync/atomic" - "unsafe" "github.com/pingcap/errors" "github.com/pingcap/parser/terror" @@ -29,7 +30,6 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/memory" - "github.com/pingcap/tidb/util/mvmap" "github.com/pingcap/tidb/util/stringutil" ) @@ -49,10 +49,10 @@ type HashJoinExec struct { innerKeys []*expression.Column // concurrency is the number of partition, build and join workers. - concurrency uint - hashTable *mvmap.MVMap - innerFinished chan error - hashJoinBuffers []*hashJoinBuffer + concurrency uint + hashTable *rowHashMap + joinKeyBuf [][]byte + innerFinished chan error // joinWorkerWaitGroup is for sync multiple join workers. joinWorkerWaitGroup sync.WaitGroup finished atomic.Value @@ -72,7 +72,6 @@ type HashJoinExec struct { outerResultChs []chan *chunk.Chunk joinChkResourceCh []chan *chunk.Chunk joinResultCh chan *hashjoinWorkerResult - hashTableValBufs [][][]byte memTracker *memory.Tracker // track memory usage. prepared bool @@ -97,11 +96,6 @@ type hashjoinWorkerResult struct { src chan<- *chunk.Chunk } -type hashJoinBuffer struct { - data []types.Datum - bytes []byte -} - // Close implements the Executor Close interface. func (e *HashJoinExec) Close() error { close(e.closeCh) @@ -147,15 +141,9 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.prepared = false e.memTracker = memory.NewTracker(e.id, e.ctx.GetSessionVars().MemQuotaHashJoin) e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) - - e.hashTableValBufs = make([][][]byte, e.concurrency) - e.hashJoinBuffers = make([]*hashJoinBuffer, 0, e.concurrency) - for i := uint(0); i < e.concurrency; i++ { - buffer := &hashJoinBuffer{ - data: make([]types.Datum, len(e.outerKeys)), - bytes: make([]byte, 0, 10000), - } - e.hashJoinBuffers = append(e.hashJoinBuffers, buffer) + e.joinKeyBuf = make([][]byte, e.concurrency) + for i := range e.joinKeyBuf { + e.joinKeyBuf[i] = make([]byte, 1) } e.closeCh = make(chan struct{}) @@ -164,7 +152,7 @@ func (e *HashJoinExec) Open(ctx context.Context) error { return nil } -func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, keyBuf []byte) (hasNull bool, _ []byte, err error) { +func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, h hash.Hash64, buf []byte) (hasNull bool, key uint64, err error) { var keyColIdx []int var allTypes []*types.FieldType if isOuterKey { @@ -177,12 +165,19 @@ func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, keyB for _, i := range keyColIdx { if row.IsNull(i) { - return true, keyBuf, nil + return true, 0, nil } } - keyBuf = keyBuf[:0] - keyBuf, err = codec.HashChunkRow(e.ctx.GetSessionVars().StmtCtx, keyBuf, row, allTypes, keyColIdx) - return false, keyBuf, err + h.Reset() + err = codec.HashChunkRow(e.ctx.GetSessionVars().StmtCtx, h, row, allTypes, keyColIdx, buf) + return false, h.Sum64(), err +} + +func (e *HashJoinExec) matchJoinKey(inner, outer chunk.Row) (ok bool, err error) { + innerAllTypes, outerAllTypes := retTypes(e.innerExec), retTypes(e.outerExec) + return codec.EqualChunkRow(e.ctx.GetSessionVars().StmtCtx, + inner, innerAllTypes, e.innerKeyColIdx, + outer, outerAllTypes, e.outerKeyColIdx) } // fetchOuterChunks get chunks from fetches chunks from the big table in a background goroutine @@ -368,6 +363,7 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { var ( outerResult *chunk.Chunk selected = make([]bool, 0, chunk.InitialCapacity) + h = fnv.New64() ) ok, joinResult := e.getNewJoinResult(workerID) if !ok { @@ -390,7 +386,7 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { if !ok { break } - ok, joinResult = e.join2Chunk(workerID, outerResult, joinResult, selected) + ok, joinResult = e.join2Chunk(workerID, outerResult, joinResult, selected, h) if !ok { break } @@ -406,9 +402,8 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { } func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.Row, - joinResult *hashjoinWorkerResult) (bool, *hashjoinWorkerResult) { - buffer := e.hashJoinBuffers[workerID] - hasNull, joinKey, err := e.getJoinKeyFromChkRow(true, outerRow, buffer.bytes) + joinResult *hashjoinWorkerResult, h hash.Hash64) (bool, *hashjoinWorkerResult) { + hasNull, joinKey, err := e.getJoinKeyFromChkRow(true, outerRow, h, e.joinKeyBuf[workerID]) if err != nil { joinResult.err = err return false, joinResult @@ -417,18 +412,28 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } - e.hashTableValBufs[workerID] = e.hashTable.Get(joinKey, e.hashTableValBufs[workerID][:0]) - innerPtrs := e.hashTableValBufs[workerID] + innerPtrs := e.hashTable.Get(joinKey) if len(innerPtrs) == 0 { e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } innerRows := make([]chunk.Row, 0, len(innerPtrs)) - for _, b := range innerPtrs { - ptr := *(*chunk.RowPtr)(unsafe.Pointer(&b[0])) + for _, ptr := range innerPtrs { matchedInner := e.innerResult.GetRow(ptr) + ok, err := e.matchJoinKey(matchedInner, outerRow) + if err != nil { + joinResult.err = err + return false, joinResult + } + if !ok { + continue + } innerRows = append(innerRows, matchedInner) } + if len(innerRows) == 0 { // TODO(fengliyuan): add test case + e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) + return true, joinResult + } iter := chunk.NewIterator4Slice(innerRows) hasMatch, hasNull := false, false for iter.Begin(); iter.Current() != iter.End(); { @@ -468,7 +473,7 @@ func (e *HashJoinExec) getNewJoinResult(workerID uint) (bool, *hashjoinWorkerRes } func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, joinResult *hashjoinWorkerResult, - selected []bool) (ok bool, _ *hashjoinWorkerResult) { + selected []bool, h hash.Hash64) (ok bool, _ *hashjoinWorkerResult) { var err error selected, err = expression.VectorizedFilter(e.ctx, e.outerFilter, chunk.NewIterator4Chunk(outerChk), selected) if err != nil { @@ -479,7 +484,7 @@ func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, joinResu if !selected[i] { // process unmatched outer rows e.joiners[workerID].onMissMatch(false, outerChk.GetRow(i), joinResult.chk) } else { // process matched outer rows - ok, joinResult = e.joinMatchedOuterRow2Chunk(workerID, outerChk.GetRow(i), joinResult) + ok, joinResult = e.joinMatchedOuterRow2Chunk(workerID, outerChk.GetRow(i), joinResult, h) if !ok { return false, joinResult } @@ -537,7 +542,7 @@ func (e *HashJoinExec) fetchInnerAndBuildHashTable(ctx context.Context) { doneCh := make(chan struct{}) go util.WithRecovery(func() { e.fetchInnerRows(ctx, innerResultCh, doneCh) }, nil) - // TODO: Parallel build hash table. Currently not support because `mvmap` is not thread-safe. + // TODO: Parallel build hash table. Currently not support because `rowHashMap` is not thread-safe. err := e.buildHashTableForList(innerResultCh) if err != nil { e.innerFinished <- errors.Trace(err) @@ -554,7 +559,7 @@ func (e *HashJoinExec) fetchInnerAndBuildHashTable(ctx context.Context) { // key of hash table: hash value of key columns // value of hash table: RowPtr of the corresponded row func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) error { - e.hashTable = mvmap.NewMVMap() + e.hashTable = newRowHashMap() e.innerKeyColIdx = make([]int, len(e.innerKeys)) for i := range e.innerKeys { e.innerKeyColIdx[i] = e.innerKeys[i].Index @@ -562,10 +567,11 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) var ( hasNull bool err error - keyBuf = make([]byte, 0, 64) - valBuf = make([]byte, 8) + key uint64 + buf = make([]byte, 1) ) + h := fnv.New64() chkIdx := uint32(0) for chk := range innerResultCh { if e.finished.Load().(bool) { @@ -573,7 +579,7 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) } numRows := chk.NumRows() for j := 0; j < numRows; j++ { - hasNull, keyBuf, err = e.getJoinKeyFromChkRow(false, chk.GetRow(j), keyBuf) + hasNull, key, err = e.getJoinKeyFromChkRow(false, chk.GetRow(j), h, buf) if err != nil { return errors.Trace(err) } @@ -581,8 +587,7 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) continue } rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(j)} - *(*chunk.RowPtr)(unsafe.Pointer(&valBuf[0])) = rowPtr - e.hashTable.Put(keyBuf, valBuf) + e.hashTable.Put(key, rowPtr) } chkIdx++ } diff --git a/executor/join_test.go b/executor/join_test.go index 5fdcfce836ac8..dd77430d5f83d 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -273,6 +273,7 @@ func (s *testSuite2) TestJoin(c *C) { func (s *testSuite2) TestJoinCast(c *C) { tk := testkit.NewTestKit(c, s.store) + var result *testkit.Result tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -281,9 +282,49 @@ func (s *testSuite2) TestJoinCast(c *C) { tk.MustExec("create table t1(c1 int unsigned)") tk.MustExec("insert into t values (1)") tk.MustExec("insert into t1 values (1)") - result := tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") + result = tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") result.Check(testkit.Rows("1")) + // int64(-1) != uint64(18446744073709551615) + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 bigint)") + tk.MustExec("create table t1(c1 bigint unsigned)") + tk.MustExec("insert into t values (-1)") + tk.MustExec("insert into t1 values (18446744073709551615)") + result = tk.MustQuery("select * from t , t1 where t.c1 = t1.c1") + result.Check(testkit.Rows()) + + // float(1) == double(1) + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 float)") + tk.MustExec("create table t1(c1 double)") + tk.MustExec("insert into t values (1.0)") + tk.MustExec("insert into t1 values (1.00)") + result = tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") + result.Check(testkit.Rows("1")) + + // varchar("x") == char("x") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 varchar(1))") + tk.MustExec("create table t1(c1 char(1))") + tk.MustExec(`insert into t values ("x")`) + tk.MustExec(`insert into t1 values ("x")`) + result = tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") + result.Check(testkit.Rows("x")) + + // varchar("x") != char("y") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 varchar(1))") + tk.MustExec("create table t1(c1 char(1))") + tk.MustExec(`insert into t values ("x")`) + tk.MustExec(`insert into t1 values ("y")`) + result = tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") + result.Check(testkit.Rows()) + tk.MustExec("drop table if exists t") tk.MustExec("drop table if exists t1") tk.MustExec("create table t(c1 int,c2 double)") @@ -293,6 +334,73 @@ func (s *testSuite2) TestJoinCast(c *C) { result = tk.MustQuery("select * from t a , t1 b where (a.c1, a.c2) = (b.c1, b.c2);") result.Check(testkit.Rows("1 2 1 2")) + /* Enable & fix this test after https://github.com/pingcap/tidb/issues/11895 is fixed. + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t(c1 bigint unsigned);") + tk.MustExec("create table t1(c1 bit(64));") + tk.MustExec("insert into t value(18446744073709551615);") + tk.MustExec("insert into t1 value(-1);") + result = tk.MustQuery("select * from t, t1 where t.c1 = t1.c1;") + c.Check(len(result.Rows()), Equals, 1) + */ + + /* https://github.com/pingcap/tidb/issues/11896 + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t(c1 bigint);") + tk.MustExec("create table t1(c1 bit(64));") + tk.MustExec("insert into t value(1);") + tk.MustExec("insert into t1 value(1);") + result = tk.MustQuery("select * from t, t1 where t.c1 = t1.c1;") + c.Check(len(result.Rows()), Equals, 1) + */ + + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t(c1 bigint);") + tk.MustExec("create table t1(c1 bit(64));") + tk.MustExec("insert into t value(-1);") + tk.MustExec("insert into t1 value(18446744073709551615);") + result = tk.MustQuery("select * from t, t1 where t.c1 = t1.c1;") + c.Check(len(result.Rows()), Equals, 0) + + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t(c1 bigint)") + tk.MustExec("create table t1(c1 bigint unsigned)") + tk.MustExec("create table t2(c1 Date)") + tk.MustExec("insert into t value(20191111)") + tk.MustExec("insert into t1 value(20191111)") + tk.MustExec("insert into t2 value('2019-11-11')") + result = tk.MustQuery("select * from t, t1, t2 where t.c1 = t2.c1 and t1.c1 = t2.c1") + result.Check(testkit.Rows("20191111 20191111 2019-11-11")) + + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2;") + tk.MustExec("create table t(c1 bigint);") + tk.MustExec("create table t1(c1 bigint unsigned);") + tk.MustExec("create table t2(c1 enum('a', 'b', 'c', 'd'));") + tk.MustExec("insert into t value(3);") + tk.MustExec("insert into t1 value(3);") + tk.MustExec("insert into t2 value('c');") + result = tk.MustQuery("select * from t, t1, t2 where t.c1 = t2.c1 and t1.c1 = t2.c1;") + result.Check(testkit.Rows("3 3 c")) + + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("drop table if exists t2;") + tk.MustExec("create table t(c1 bigint);") + tk.MustExec("create table t1(c1 bigint unsigned);") + tk.MustExec("create table t2 (c1 SET('a', 'b', 'c', 'd'));") + tk.MustExec("insert into t value(9);") + tk.MustExec("insert into t1 value(9);") + tk.MustExec("insert into t2 value('a,d');") + result = tk.MustQuery("select * from t, t1, t2 where t.c1 = t2.c1 and t1.c1 = t2.c1;") + result.Check(testkit.Rows("9 9 a,d")) + tk.MustExec("drop table if exists t") tk.MustExec("drop table if exists t1") tk.MustExec("create table t(c1 int)") @@ -323,6 +431,7 @@ func (s *testSuite2) TestJoinCast(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") tk.MustExec("create table t(c1 char(10))") tk.MustExec("create table t1(c1 char(10))") tk.MustExec("create table t2(c1 char(10))") diff --git a/executor/pkg_test.go b/executor/pkg_test.go index 637661ed3e40b..bdd7a91772bed 100644 --- a/executor/pkg_test.go +++ b/executor/pkg_test.go @@ -3,16 +3,15 @@ package executor import ( "context" "fmt" + . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" plannercore "github.com/pingcap/tidb/planner/core" - "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/mock" - "github.com/pingcap/tidb/util/stringutil" ) var _ = Suite(&pkgTestSuite{}) @@ -20,36 +19,6 @@ var _ = Suite(&pkgTestSuite{}) type pkgTestSuite struct { } -type MockExec struct { - baseExecutor - - Rows []chunk.MutRow - curRowIdx int -} - -func (m *MockExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - colTypes := retTypes(m) - for ; m.curRowIdx < len(m.Rows) && req.NumRows() < req.Capacity(); m.curRowIdx++ { - curRow := m.Rows[m.curRowIdx] - for i := 0; i < curRow.Len(); i++ { - curDatum := curRow.ToRow().GetDatum(i, colTypes[i]) - req.AppendDatum(i, &curDatum) - } - } - return nil -} - -func (m *MockExec) Close() error { - m.curRowIdx = 0 - return nil -} - -func (m *MockExec) Open(ctx context.Context) error { - m.curRowIdx = 0 - return nil -} - func (s *pkgTestSuite) TestNestedLoopApply(c *C) { ctx := context.Background() sctx := mock.NewContext() @@ -57,27 +26,27 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { col1 := &expression.Column{Index: 1, RetType: types.NewFieldType(mysql.TypeLong)} con := &expression.Constant{Value: types.NewDatum(6), RetType: types.NewFieldType(mysql.TypeLong)} outerSchema := expression.NewSchema(col0) - outerExec := &MockExec{ - baseExecutor: newBaseExecutor(sctx, outerSchema, nil), - Rows: []chunk.MutRow{ - chunk.MutRowFromDatums(types.MakeDatums(1)), - chunk.MutRowFromDatums(types.MakeDatums(2)), - chunk.MutRowFromDatums(types.MakeDatums(3)), - chunk.MutRowFromDatums(types.MakeDatums(4)), - chunk.MutRowFromDatums(types.MakeDatums(5)), - chunk.MutRowFromDatums(types.MakeDatums(6)), - }} + outerExec := buildMockDataSource(mockDataSourceParameters{ + schema: outerSchema, + rows: 6, + ctx: sctx, + genDataFunc: func(row int, typ *types.FieldType) interface{} { + return int64(row + 1) + }, + }) + outerExec.prepareChunks() + innerSchema := expression.NewSchema(col1) - innerExec := &MockExec{ - baseExecutor: newBaseExecutor(sctx, innerSchema, nil), - Rows: []chunk.MutRow{ - chunk.MutRowFromDatums(types.MakeDatums(1)), - chunk.MutRowFromDatums(types.MakeDatums(2)), - chunk.MutRowFromDatums(types.MakeDatums(3)), - chunk.MutRowFromDatums(types.MakeDatums(4)), - chunk.MutRowFromDatums(types.MakeDatums(5)), - chunk.MutRowFromDatums(types.MakeDatums(6)), - }} + innerExec := buildMockDataSource(mockDataSourceParameters{ + schema: innerSchema, + rows: 6, + ctx: sctx, + genDataFunc: func(row int, typ *types.FieldType) interface{} { + return int64(row + 1) + }, + }) + innerExec.prepareChunks() + outerFilter := expression.NewFunctionInternal(sctx, ast.LT, types.NewFieldType(mysql.TypeTiny), col0, con) innerFilter := outerFilter.Clone() otherFilter := expression.NewFunctionInternal(sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), col0, col1) @@ -112,39 +81,6 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { } } -func prepareOneColChildExec(sctx sessionctx.Context, rowCount int) Executor { - col0 := &expression.Column{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)} - schema := expression.NewSchema(col0) - exec := &MockExec{ - baseExecutor: newBaseExecutor(sctx, schema, nil), - Rows: make([]chunk.MutRow, rowCount)} - for i := 0; i < len(exec.Rows); i++ { - exec.Rows[i] = chunk.MutRowFromDatums(types.MakeDatums(i % 10)) - } - return exec -} - -func prepare4RadixPartition(sctx sessionctx.Context, rowCount int) *HashJoinExec { - childExec0 := prepareOneColChildExec(sctx, rowCount) - childExec1 := prepareOneColChildExec(sctx, rowCount) - - col0 := &expression.Column{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)} - col1 := &expression.Column{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)} - joinSchema := expression.NewSchema(col0, col1) - hashJoinExec := &HashJoinExec{ - baseExecutor: newBaseExecutor(sctx, joinSchema, stringutil.StringerStr("HashJoin"), childExec0, childExec1), - concurrency: 4, - joinType: 0, // InnerJoin - innerKeys: []*expression.Column{{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)}}, - innerKeyColIdx: []int{0}, - outerKeys: []*expression.Column{{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)}}, - outerKeyColIdx: []int{0}, - innerExec: childExec0, - outerExec: childExec1, - } - return hashJoinExec -} - func (s *pkgTestSuite) TestMoveInfoSchemaToFront(c *C) { dbss := [][]string{ {}, diff --git a/expression/constant_test.go b/expression/constant_test.go index f3b24dec09e8d..d495f963d5643 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -389,7 +389,7 @@ func (*testExpressionSuite) TestDeferredParamNotNull(c *C) { c.Assert(mysql.TypeBlob, Equals, cstBytes.GetType().Tp) c.Assert(mysql.TypeBit, Equals, cstBinary.GetType().Tp) c.Assert(mysql.TypeBit, Equals, cstBit.GetType().Tp) - c.Assert(mysql.TypeVarString, Equals, cstFloat32.GetType().Tp) + c.Assert(mysql.TypeFloat, Equals, cstFloat32.GetType().Tp) c.Assert(mysql.TypeDouble, Equals, cstFloat64.GetType().Tp) c.Assert(mysql.TypeEnum, Equals, cstEnum.GetType().Tp) diff --git a/types/field_type.go b/types/field_type.go index 51905169dafa8..f8d51d18a101e 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -183,6 +183,12 @@ func DefaultTypeForValue(value interface{}, tp *FieldType) { tp.Flen = len(x) tp.Decimal = UnspecifiedLength tp.Charset, tp.Collate = charset.GetDefaultCharsetAndCollate() + case float32: + tp.Tp = mysql.TypeFloat + s := strconv.FormatFloat(float64(x), 'f', -1, 32) + tp.Flen = len(s) + tp.Decimal = UnspecifiedLength + SetBinChsClnFlag(tp) case float64: tp.Tp = mysql.TypeDouble s := strconv.FormatFloat(x, 'f', -1, 64) diff --git a/util/chunk/column.go b/util/chunk/column.go index e4f1d6c54b0e3..aad386207b601 100644 --- a/util/chunk/column.go +++ b/util/chunk/column.go @@ -526,6 +526,18 @@ func (c *Column) getNameValue(rowID int) (string, uint64) { return string(hack.String(c.data[start+8 : end])), *(*uint64)(unsafe.Pointer(&c.data[start])) } +// GetRaw returns the underlying raw bytes in the specific row. +func (c *Column) GetRaw(rowID int) []byte { + var data []byte + if c.isFixed() { + elemLen := len(c.elemBuf) + data = c.data[rowID*elemLen : rowID*elemLen+elemLen] + } else { + data = c.data[c.offsets[rowID]:c.offsets[rowID+1]] + } + return data +} + // reconstruct reconstructs this Column by removing all filtered rows in it according to sel. func (c *Column) reconstruct(sel []int) { if sel == nil { diff --git a/util/chunk/column_test.go b/util/chunk/column_test.go index bc47fc828e170..2a105440b067e 100644 --- a/util/chunk/column_test.go +++ b/util/chunk/column_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "testing" "time" + "unsafe" "github.com/pingcap/check" "github.com/pingcap/parser/mysql" @@ -749,6 +750,36 @@ func (s *testChunkSuite) TestResizeReserve(c *check.C) { c.Assert(cStrs.length, check.Equals, 0) } +func (s *testChunkSuite) TestGetRaw(c *check.C) { + chk := NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeFloat)}, 1024) + col := chk.Column(0) + for i := 0; i < 1024; i++ { + col.AppendFloat32(float32(i)) + } + it := NewIterator4Chunk(chk) + var i int64 + for row := it.Begin(); row != it.End(); row = it.Next() { + f := float32(i) + b := (*[unsafe.Sizeof(f)]byte)(unsafe.Pointer(&f))[:] + c.Assert(row.GetRaw(0), check.DeepEquals, b) + c.Assert(col.GetRaw(int(i)), check.DeepEquals, b) + i++ + } + + chk = NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeVarString)}, 1024) + col = chk.Column(0) + for i := 0; i < 1024; i++ { + col.AppendString(fmt.Sprint(i)) + } + it = NewIterator4Chunk(chk) + i = 0 + for row := it.Begin(); row != it.End(); row = it.Next() { + c.Assert(row.GetRaw(0), check.DeepEquals, []byte(fmt.Sprint(i))) + c.Assert(col.GetRaw(int(i)), check.DeepEquals, []byte(fmt.Sprint(i))) + i++ + } +} + func BenchmarkDurationRow(b *testing.B) { chk1 := NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDuration)}, 1024) col1 := chk1.Column(0) diff --git a/util/chunk/row.go b/util/chunk/row.go index 0d83c4dec5525..620373a076c48 100644 --- a/util/chunk/row.go +++ b/util/chunk/row.go @@ -204,6 +204,11 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum { return d } +// GetRaw returns the underlying raw bytes with the colIdx. +func (r Row) GetRaw(colIdx int) []byte { + return r.c.columns[colIdx].GetRaw(r.idx) +} + // IsNull returns if the datum in the chunk.Row is null. func (r Row) IsNull(colIdx int) bool { return r.c.columns[colIdx].IsNull(r.idx) diff --git a/util/codec/codec.go b/util/codec/codec.go index 8d6d733feb48d..9ba94cc3dac1c 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -14,8 +14,11 @@ package codec import ( + "bytes" "encoding/binary" + "io" "time" + "unsafe" "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" @@ -67,23 +70,14 @@ func preRealloc(b []byte, vals []types.Datum, comparable bool) []byte { // encode will encode a datum and append it to a byte slice. If comparable is true, the encoded bytes can be sorted as it's original order. // If hash is true, the encoded bytes can be checked equal as it's original value. -func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparable bool, hash bool) (_ []byte, err error) { +func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparable bool) (_ []byte, err error) { b = preRealloc(b, vals, comparable) for i, length := 0, len(vals); i < length; i++ { switch vals[i].Kind() { case types.KindInt64: b = encodeSignedInt(b, vals[i].GetInt64(), comparable) case types.KindUint64: - if hash { - integer := vals[i].GetInt64() - if integer < 0 { - b = encodeUnsignedInt(b, uint64(integer), comparable) - } else { - b = encodeSignedInt(b, integer, comparable) - } - } else { - b = encodeUnsignedInt(b, vals[i].GetUint64(), comparable) - } + b = encodeUnsignedInt(b, vals[i].GetUint64(), comparable) case types.KindFloat32, types.KindFloat64: b = append(b, floatFlag) b = EncodeFloat(b, vals[i].GetFloat64()) @@ -101,22 +95,11 @@ func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparab b = EncodeInt(b, int64(vals[i].GetMysqlDuration().Duration)) case types.KindMysqlDecimal: b = append(b, decimalFlag) - if hash { - // If hash is true, we only consider the original value of this decimal and ignore it's precision. - dec := vals[i].GetMysqlDecimal() - var bin []byte - bin, err = dec.ToHashKey() - if err != nil { - return b, errors.Trace(err) - } - b = append(b, bin...) - } else { - b, err = EncodeDecimal(b, vals[i].GetMysqlDecimal(), vals[i].Length(), vals[i].Frac()) - if terror.ErrorEqual(err, types.ErrTruncated) { - err = sc.HandleTruncate(err) - } else if terror.ErrorEqual(err, types.ErrOverflow) { - err = sc.HandleOverflow(err, err) - } + b, err = EncodeDecimal(b, vals[i].GetMysqlDecimal(), vals[i].Length(), vals[i].Frac()) + if terror.ErrorEqual(err, types.ErrTruncated) { + err = sc.HandleTruncate(err) + } else if terror.ErrorEqual(err, types.ErrOverflow) { + err = sc.HandleOverflow(err, err) } case types.KindMysqlEnum: b = encodeUnsignedInt(b, uint64(vals[i].GetMysqlEnum().ToNumber()), comparable) @@ -284,105 +267,133 @@ func sizeInt(comparable bool) int { // slice. It guarantees the encoded value is in ascending order for comparison. // For Decimal type, datum must set datum's length and frac. func EncodeKey(sc *stmtctx.StatementContext, b []byte, v ...types.Datum) ([]byte, error) { - return encode(sc, b, v, true, false) + return encode(sc, b, v, true) } // EncodeValue appends the encoded values to byte slice b, returning the appended // slice. It does not guarantee the order for comparison. func EncodeValue(sc *stmtctx.StatementContext, b []byte, v ...types.Datum) ([]byte, error) { - return encode(sc, b, v, false, false) + return encode(sc, b, v, false) } -func encodeHashChunkRow(sc *stmtctx.StatementContext, b []byte, row chunk.Row, allTypes []*types.FieldType, colIdx []int) (_ []byte, err error) { - const comparable = false - for _, i := range colIdx { - if row.IsNull(i) { - b = append(b, NilFlag) - continue - } - switch allTypes[i].Tp { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: - if !mysql.HasUnsignedFlag(allTypes[i].Flag) { - b = encodeSignedInt(b, row.GetInt64(i), comparable) - break - } - // encode unsigned integers. - integer := row.GetInt64(i) - if integer < 0 { - b = encodeUnsignedInt(b, uint64(integer), comparable) - } else { - b = encodeSignedInt(b, integer, comparable) - } - case mysql.TypeFloat: - b = append(b, floatFlag) - b = EncodeFloat(b, float64(row.GetFloat32(i))) - case mysql.TypeDouble: - b = append(b, floatFlag) - b = EncodeFloat(b, row.GetFloat64(i)) - case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - b = encodeBytes(b, row.GetBytes(i), comparable) - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: - b = append(b, uintFlag) - t := row.GetTime(i) - // Encoding timestamp need to consider timezone. - // If it's not in UTC, transform to UTC first. - if t.Type == mysql.TypeTimestamp && sc.TimeZone != time.UTC { - err = t.ConvertTimeZone(sc.TimeZone, time.UTC) - if err != nil { - return nil, errors.Trace(err) - } - } - var v uint64 - v, err = t.ToPackedUint() - if err != nil { - return nil, errors.Trace(err) +func encodeHashChunkRowIdx(sc *stmtctx.StatementContext, row chunk.Row, tp *types.FieldType, idx int) (flag byte, b []byte, err error) { + if row.IsNull(idx) { + flag = NilFlag + return + } + switch tp.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: + flag = varintFlag + if mysql.HasUnsignedFlag(tp.Flag) { + if integer := row.GetInt64(idx); integer < 0 { + flag = uvarintFlag } - b = EncodeUint(b, v) - case mysql.TypeDuration: - // duration may have negative value, so we cannot use String to encode directly. - b = append(b, durationFlag) - b = EncodeInt(b, int64(row.GetDuration(i, 0).Duration)) - case mysql.TypeNewDecimal: - b = append(b, decimalFlag) - // If hash is true, we only consider the original value of this decimal and ignore it's precision. - dec := row.GetMyDecimal(i) - bin, err := dec.ToHashKey() + } + b = row.GetRaw(idx) + case mysql.TypeFloat: + flag = floatFlag + f := float64(row.GetFloat32(idx)) + b = (*[unsafe.Sizeof(f)]byte)(unsafe.Pointer(&f))[:] + case mysql.TypeDouble: + flag = floatFlag + b = row.GetRaw(idx) + case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + flag = compactBytesFlag + b = row.GetBytes(idx) + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + flag = uintFlag + t := row.GetTime(idx) + // Encoding timestamp need to consider timezone. + // If it's not in UTC, transform to UTC first. + if t.Type == mysql.TypeTimestamp && sc.TimeZone != time.UTC { + err = t.ConvertTimeZone(sc.TimeZone, time.UTC) if err != nil { - return nil, errors.Trace(err) + return } - b = append(b, bin...) - case mysql.TypeEnum: - b = encodeUnsignedInt(b, uint64(row.GetEnum(i).ToNumber()), comparable) - case mysql.TypeSet: - b = encodeUnsignedInt(b, uint64(row.GetSet(i).ToNumber()), comparable) - case mysql.TypeBit: - // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. - var val uint64 - val, err = types.BinaryLiteral(row.GetBytes(i)).ToInt(sc) - terror.Log(errors.Trace(err)) - b = encodeUnsignedInt(b, val, comparable) - case mysql.TypeJSON: - b = append(b, jsonFlag) - j := row.GetJSON(i) - b = append(b, j.TypeCode) - b = append(b, j.Value...) - default: - return nil, errors.Errorf("unsupport column type for encode %d", allTypes[i].Tp) } + var v uint64 + v, err = t.ToPackedUint() + if err != nil { + return + } + b = (*[unsafe.Sizeof(v)]byte)(unsafe.Pointer(&v))[:] + case mysql.TypeDuration: + flag = durationFlag + // duration may have negative value, so we cannot use String to encode directly. + b = row.GetRaw(idx) + case mysql.TypeNewDecimal: + flag = decimalFlag + // If hash is true, we only consider the original value of this decimal and ignore it's precision. + dec := row.GetMyDecimal(idx) + b, err = dec.ToHashKey() + if err != nil { + return + } + case mysql.TypeEnum: + flag = uvarintFlag + v := uint64(row.GetEnum(idx).ToNumber()) + b = (*[8]byte)(unsafe.Pointer(&v))[:] + case mysql.TypeSet: + flag = uvarintFlag + v := uint64(row.GetSet(idx).ToNumber()) + b = (*[unsafe.Sizeof(v)]byte)(unsafe.Pointer(&v))[:] + case mysql.TypeBit: + // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. + flag = uvarintFlag + v, err1 := types.BinaryLiteral(row.GetBytes(idx)).ToInt(sc) + terror.Log(errors.Trace(err1)) + b = (*[unsafe.Sizeof(v)]byte)(unsafe.Pointer(&v))[:] + case mysql.TypeJSON: + flag = jsonFlag + b = row.GetBytes(idx) + default: + return 0, nil, errors.Errorf("unsupport column type for encode %d", tp.Tp) } - return b, errors.Trace(err) + return } -// HashValues appends the encoded values to byte slice b, returning the appended -// slice. If two datums are equal, they will generate the same bytes. -func HashValues(sc *stmtctx.StatementContext, b []byte, v ...types.Datum) ([]byte, error) { - return encode(sc, b, v, false, true) +// HashChunkRow writes the encoded values to w. +// If two rows are logically equal, it will generate the same bytes. +func HashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int, buf []byte) (err error) { + var b []byte + for _, idx := range colIdx { + buf[0], b, err = encodeHashChunkRowIdx(sc, row, allTypes[idx], idx) + if err != nil { + return errors.Trace(err) + } + _, err = w.Write(buf) + if err != nil { + return + } + _, err = w.Write(b) + if err != nil { + return + } + } + return err } -// HashChunkRow appends the encoded values to byte slice "b", returning the appended slice. -// If two rows are equal, it will generate the same bytes. -func HashChunkRow(sc *stmtctx.StatementContext, b []byte, row chunk.Row, allTypes []*types.FieldType, colIdx []int) ([]byte, error) { - return encodeHashChunkRow(sc, b, row, allTypes, colIdx) +// EqualChunkRow returns a boolean reporting whether row1 and row2 +// with their types and column index are logically equal. +func EqualChunkRow(sc *stmtctx.StatementContext, + row1 chunk.Row, allTypes1 []*types.FieldType, colIdx1 []int, + row2 chunk.Row, allTypes2 []*types.FieldType, colIdx2 []int, +) (bool, error) { + for i := range colIdx1 { + idx1, idx2 := colIdx1[i], colIdx2[i] + flag1, b1, err := encodeHashChunkRowIdx(sc, row1, allTypes1[idx1], idx1) + if err != nil { + return false, errors.Trace(err) + } + flag2, b2, err := encodeHashChunkRowIdx(sc, row2, allTypes2[idx2], idx2) + if err != nil { + return false, errors.Trace(err) + } + if !(flag1 == flag2 && bytes.Equal(b1, b2)) { + return false, nil + } + } + return true, nil } // Decode decodes values from a byte slice generated with EncodeKey or EncodeValue diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index 343046e4ed189..f23b23e7f3774 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -15,6 +15,7 @@ package codec import ( "bytes" + "hash/crc32" "math" "testing" "time" @@ -812,7 +813,7 @@ func (s *testCodecSuite) TestJSON(c *C) { } buf := make([]byte, 0, 4096) - buf, err := encode(nil, buf, datums, false, false) + buf, err := encode(nil, buf, datums, false) c.Assert(err, IsNil) datums1, err := Decode(buf, 2) @@ -1033,8 +1034,52 @@ func (s *testCodecSuite) TestDecodeRange(c *C) { } } +func testHashChunkRowEqual(c *C, a, b interface{}, equal bool) { + sc := &stmtctx.StatementContext{TimeZone: time.Local} + buf1 := make([]byte, 1) + buf2 := make([]byte, 1) + + tp1 := new(types.FieldType) + types.DefaultTypeForValue(a, tp1) + chk1 := chunk.New([]*types.FieldType{tp1}, 1, 1) + d := types.Datum{} + d.SetValue(a) + chk1.AppendDatum(0, &d) + + tp2 := new(types.FieldType) + types.DefaultTypeForValue(b, tp2) + chk2 := chunk.New([]*types.FieldType{tp2}, 1, 1) + d = types.Datum{} + d.SetValue(b) + chk2.AppendDatum(0, &d) + + h := crc32.NewIEEE() + err1 := HashChunkRow(sc, h, chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}, buf1) + sum1 := h.Sum32() + h.Reset() + err2 := HashChunkRow(sc, h, chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}, buf2) + sum2 := h.Sum32() + c.Assert(err1, IsNil) + c.Assert(err2, IsNil) + if equal { + c.Assert(sum1, Equals, sum2) + } else { + c.Assert(sum1, Not(Equals), sum2) + } + e, err := EqualChunkRow(sc, + chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}, + chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}) + c.Assert(err, IsNil) + if equal { + c.Assert(e, IsTrue) + } else { + c.Assert(e, IsFalse) + } +} + func (s *testCodecSuite) TestHashChunkRow(c *C) { sc := &stmtctx.StatementContext{TimeZone: time.Local} + buf := make([]byte, 1) datums, tps := datumsForTest(sc) chk := chunkForTest(c, sc, datums, tps, 1) @@ -1042,12 +1087,37 @@ func (s *testCodecSuite) TestHashChunkRow(c *C) { for i := 0; i < len(tps); i++ { colIdx[i] = i } - b1, err1 := HashChunkRow(sc, nil, chk.GetRow(0), tps, colIdx) - b2, err2 := HashValues(sc, nil, datums...) + h := crc32.NewIEEE() + err1 := HashChunkRow(sc, h, chk.GetRow(0), tps, colIdx, buf) + sum1 := h.Sum32() + h.Reset() + err2 := HashChunkRow(sc, h, chk.GetRow(0), tps, colIdx, buf) + sum2 := h.Sum32() c.Assert(err1, IsNil) c.Assert(err2, IsNil) - c.Assert(b1, BytesEquals, b2) + c.Assert(sum1, Equals, sum2) + e, err := EqualChunkRow(sc, + chk.GetRow(0), tps, colIdx, + chk.GetRow(0), tps, colIdx) + c.Assert(err, IsNil) + c.Assert(e, IsTrue) + + testHashChunkRowEqual(c, uint64(1), int64(1), true) + testHashChunkRowEqual(c, uint64(18446744073709551615), int64(-1), false) + + dec1 := types.NewDecFromStringForTest("1.1") + dec2 := types.NewDecFromStringForTest("01.100") + testHashChunkRowEqual(c, dec1, dec2, true) + dec1 = types.NewDecFromStringForTest("1.1") + dec2 = types.NewDecFromStringForTest("01.200") + testHashChunkRowEqual(c, dec1, dec2, false) + + testHashChunkRowEqual(c, float32(1.0), float64(1.0), true) + testHashChunkRowEqual(c, float32(1.0), float64(1.1), false) + + testHashChunkRowEqual(c, "x", []byte("x"), true) + testHashChunkRowEqual(c, "x", []byte("y"), false) } func (s *testCodecSuite) TestValueSizeOfSignedInt(c *C) {