From c8951081d08cee71f817fd0e6fee584c2ef6564d Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Thu, 22 Aug 2019 23:41:59 +0800 Subject: [PATCH 01/11] executor: decrease the memory usage of hashTable in HashJoinExec --- executor/hash_table.go | 36 +++++++++++ executor/join.go | 72 +++++++++++---------- executor/join_test.go | 10 +++ util/chunk/column.go | 14 ++++ util/chunk/compare.go | 40 ++++++++++++ util/chunk/row.go | 6 ++ util/codec/codec.go | 135 +++++++++++---------------------------- util/codec/codec_test.go | 72 +++++++++++++++++++-- 8 files changed, 249 insertions(+), 136 deletions(-) create mode 100644 executor/hash_table.go diff --git a/executor/hash_table.go b/executor/hash_table.go new file mode 100644 index 0000000000000..2de90c3246ae4 --- /dev/null +++ b/executor/hash_table.go @@ -0,0 +1,36 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "github.com/pingcap/tidb/util/chunk" +) + +// TODO: Consider using a fix-sized hash table. +type rowHashTable map[uint64][]chunk.RowPtr + +func newRowHashTable() rowHashTable { + t := make(rowHashTable) + return t +} + +func (t rowHashTable) Put(key uint64, rowPtr chunk.RowPtr) { + t[key] = append(t[key], rowPtr) +} + +func (t rowHashTable) Get(key uint64) []chunk.RowPtr { + return t[key] +} + +func (t rowHashTable) Len() int { return len(t) } diff --git a/executor/join.go b/executor/join.go index 2b3eae1d5d06e..53c33880baf94 100644 --- a/executor/join.go +++ b/executor/join.go @@ -16,9 +16,10 @@ package executor import ( "context" "fmt" + "hash" + "hash/fnv" "sync" "sync/atomic" - "unsafe" "github.com/pingcap/errors" "github.com/pingcap/parser/terror" @@ -29,7 +30,6 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/memory" - "github.com/pingcap/tidb/util/mvmap" "github.com/pingcap/tidb/util/stringutil" ) @@ -50,7 +50,7 @@ type HashJoinExec struct { // concurrency is the number of partition, build and join workers. concurrency uint - hashTable *mvmap.MVMap + hashTable rowHashTable innerFinished chan error hashJoinBuffers []*hashJoinBuffer // joinWorkerWaitGroup is for sync multiple join workers. @@ -98,7 +98,6 @@ type hashjoinWorkerResult struct { } type hashJoinBuffer struct { - data []types.Datum bytes []byte } @@ -149,14 +148,6 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) e.hashTableValBufs = make([][][]byte, e.concurrency) - e.hashJoinBuffers = make([]*hashJoinBuffer, 0, e.concurrency) - for i := uint(0); i < e.concurrency; i++ { - buffer := &hashJoinBuffer{ - data: make([]types.Datum, len(e.outerKeys)), - bytes: make([]byte, 0, 10000), - } - e.hashJoinBuffers = append(e.hashJoinBuffers, buffer) - } e.closeCh = make(chan struct{}) e.finished.Store(false) @@ -164,7 +155,7 @@ func (e *HashJoinExec) Open(ctx context.Context) error { return nil } -func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, keyBuf []byte) (hasNull bool, _ []byte, err error) { +func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, h hash.Hash64) (hasNull bool, key uint64, err error) { var keyColIdx []int var allTypes []*types.FieldType if isOuterKey { @@ -177,12 +168,27 @@ func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, keyB for _, i := range keyColIdx { if row.IsNull(i) { - return true, keyBuf, nil + return true, 0, nil + } + } + h.Reset() + err = codec.HashChunkRow(e.ctx.GetSessionVars().StmtCtx, h, row, allTypes, keyColIdx) + return false, h.Sum64(), err +} + +func (e *HashJoinExec) matchJoinKey(inner, outer chunk.Row) bool { + innerAllTypes := retTypes(e.innerExec) + outerAllTypes := retTypes(e.outerExec) + for i := range e.innerKeyColIdx { + innerIdx := e.innerKeyColIdx[i] + outerIdx := e.outerKeyColIdx[i] + innerTp := innerAllTypes[innerIdx] + outerTp := outerAllTypes[outerIdx] + if cmp := chunk.GetCompareFuncWithTypes(innerTp, outerTp); cmp(inner, innerIdx, outer, outerIdx) != 0 { + return false } } - keyBuf = keyBuf[:0] - keyBuf, err = codec.HashChunkRow(e.ctx.GetSessionVars().StmtCtx, keyBuf, row, allTypes, keyColIdx) - return false, keyBuf, err + return true } // fetchOuterChunks get chunks from fetches chunks from the big table in a background goroutine @@ -368,6 +374,7 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { var ( outerResult *chunk.Chunk selected = make([]bool, 0, chunk.InitialCapacity) + h = fnv.New64() ) ok, joinResult := e.getNewJoinResult(workerID) if !ok { @@ -390,7 +397,7 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { if !ok { break } - ok, joinResult = e.join2Chunk(workerID, outerResult, joinResult, selected) + ok, joinResult = e.join2Chunk(workerID, outerResult, joinResult, selected, h) if !ok { break } @@ -406,9 +413,8 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { } func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.Row, - joinResult *hashjoinWorkerResult) (bool, *hashjoinWorkerResult) { - buffer := e.hashJoinBuffers[workerID] - hasNull, joinKey, err := e.getJoinKeyFromChkRow(true, outerRow, buffer.bytes) + joinResult *hashjoinWorkerResult, h hash.Hash64) (bool, *hashjoinWorkerResult) { + hasNull, joinKey, err := e.getJoinKeyFromChkRow(true, outerRow, h) if err != nil { joinResult.err = err return false, joinResult @@ -417,16 +423,17 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } - e.hashTableValBufs[workerID] = e.hashTable.Get(joinKey, e.hashTableValBufs[workerID][:0]) - innerPtrs := e.hashTableValBufs[workerID] + innerPtrs := e.hashTable.Get(joinKey) if len(innerPtrs) == 0 { e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } innerRows := make([]chunk.Row, 0, len(innerPtrs)) - for _, b := range innerPtrs { - ptr := *(*chunk.RowPtr)(unsafe.Pointer(&b[0])) + for _, ptr := range innerPtrs { matchedInner := e.innerResult.GetRow(ptr) + if !e.matchJoinKey(matchedInner, outerRow) { + continue + } innerRows = append(innerRows, matchedInner) } iter := chunk.NewIterator4Slice(innerRows) @@ -468,7 +475,7 @@ func (e *HashJoinExec) getNewJoinResult(workerID uint) (bool, *hashjoinWorkerRes } func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, joinResult *hashjoinWorkerResult, - selected []bool) (ok bool, _ *hashjoinWorkerResult) { + selected []bool, h hash.Hash64) (ok bool, _ *hashjoinWorkerResult) { var err error selected, err = expression.VectorizedFilter(e.ctx, e.outerFilter, chunk.NewIterator4Chunk(outerChk), selected) if err != nil { @@ -479,7 +486,7 @@ func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, joinResu if !selected[i] { // process unmatched outer rows e.joiners[workerID].onMissMatch(false, outerChk.GetRow(i), joinResult.chk) } else { // process matched outer rows - ok, joinResult = e.joinMatchedOuterRow2Chunk(workerID, outerChk.GetRow(i), joinResult) + ok, joinResult = e.joinMatchedOuterRow2Chunk(workerID, outerChk.GetRow(i), joinResult, h) if !ok { return false, joinResult } @@ -554,7 +561,7 @@ func (e *HashJoinExec) fetchInnerAndBuildHashTable(ctx context.Context) { // key of hash table: hash value of key columns // value of hash table: RowPtr of the corresponded row func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) error { - e.hashTable = mvmap.NewMVMap() + e.hashTable = newRowHashTable() e.innerKeyColIdx = make([]int, len(e.innerKeys)) for i := range e.innerKeys { e.innerKeyColIdx[i] = e.innerKeys[i].Index @@ -562,10 +569,10 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) var ( hasNull bool err error - keyBuf = make([]byte, 0, 64) - valBuf = make([]byte, 8) + key uint64 ) + h := fnv.New64() chkIdx := uint32(0) for chk := range innerResultCh { if e.finished.Load().(bool) { @@ -573,7 +580,7 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) } numRows := chk.NumRows() for j := 0; j < numRows; j++ { - hasNull, keyBuf, err = e.getJoinKeyFromChkRow(false, chk.GetRow(j), keyBuf) + hasNull, key, err = e.getJoinKeyFromChkRow(false, chk.GetRow(j), h) if err != nil { return errors.Trace(err) } @@ -581,8 +588,7 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) continue } rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(j)} - *(*chunk.RowPtr)(unsafe.Pointer(&valBuf[0])) = rowPtr - e.hashTable.Put(keyBuf, valBuf) + e.hashTable.Put(key, rowPtr) } chkIdx++ } diff --git a/executor/join_test.go b/executor/join_test.go index 5fdcfce836ac8..9a19cb1d4a79b 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -284,6 +284,16 @@ func (s *testSuite2) TestJoinCast(c *C) { result := tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") result.Check(testkit.Rows("1")) + // int64(-1) != uint64(18446744073709551615) + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 bigint)") + tk.MustExec("create table t1(c1 bigint unsigned)") + tk.MustExec("insert into t values (-1)") + tk.MustExec("insert into t1 values (18446744073709551615)") + result = tk.MustQuery("select * from t , t1 where t.c1 = t1.c1") + result.Check(testkit.Rows()) + tk.MustExec("drop table if exists t") tk.MustExec("drop table if exists t1") tk.MustExec("create table t(c1 int,c2 double)") diff --git a/util/chunk/column.go b/util/chunk/column.go index e4f1d6c54b0e3..c35f43596c51c 100644 --- a/util/chunk/column.go +++ b/util/chunk/column.go @@ -14,6 +14,7 @@ package chunk import ( + "io" "math/bits" "reflect" "time" @@ -605,3 +606,16 @@ func (c *Column) CopyReconstruct(sel []int, dst *Column) *Column { } return dst } + +// WriteTo writes the raw data in the specific row to w. +func (c *Column) WriteTo(rowID int, w io.Writer) (int, error) { + var data []byte + if c.isFixed() { + elemLen := len(c.elemBuf) + data = c.data[rowID*elemLen : rowID*elemLen+elemLen] + } else { + start, end := c.offsets[rowID], c.offsets[rowID] + data = c.data[start:end] + } + return w.Write(data) +} diff --git a/util/chunk/compare.go b/util/chunk/compare.go index 2ecaf196ee0c9..c754e72891b3a 100644 --- a/util/chunk/compare.go +++ b/util/chunk/compare.go @@ -55,6 +55,20 @@ func GetCompareFunc(tp *types.FieldType) CompareFunc { return nil } +// GetCompareFuncWithTypes gets a compare function for the two field types. +func GetCompareFuncWithTypes(tp1, tp2 *types.FieldType) CompareFunc { + switch tp1.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: + f1, f2 := mysql.HasUnsignedFlag(tp1.Flag), mysql.HasUnsignedFlag(tp2.Flag) + if f1 && !f2 { + return cmpUint64AndInt64 + } else if !f1 && f2 { + return cmpInt64AndUint64 + } + } + return GetCompareFunc(tp1) +} + func cmpNull(lNull, rNull bool) int { if lNull && rNull { return 0 @@ -81,6 +95,32 @@ func cmpUint64(l Row, lCol int, r Row, rCol int) int { return types.CompareUint64(l.GetUint64(lCol), r.GetUint64(rCol)) } +func cmpInt64AndUint64(l Row, lCol int, r Row, rCol int) int { + lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) + if lNull || rNull { + return cmpNull(lNull, rNull) + } + lVal := l.GetInt64(lCol) + rVal := r.GetUint64(rCol) + if lVal < 0 { + return -1 + } + return types.CompareUint64(uint64(lVal), rVal) +} + +func cmpUint64AndInt64(l Row, lCol int, r Row, rCol int) int { + lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) + if lNull || rNull { + return cmpNull(lNull, rNull) + } + lVal := l.GetUint64(lCol) + rVal := r.GetInt64(rCol) + if rVal < 0 { + return 1 + } + return types.CompareUint64(lVal, uint64(rVal)) +} + func cmpString(l Row, lCol int, r Row, rCol int) int { lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) if lNull || rNull { diff --git a/util/chunk/row.go b/util/chunk/row.go index 0d83c4dec5525..284412d3703f3 100644 --- a/util/chunk/row.go +++ b/util/chunk/row.go @@ -14,6 +14,7 @@ package chunk import ( + "io" "unsafe" "github.com/pingcap/parser/mysql" @@ -204,6 +205,11 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum { return d } +// WriteTo writes the raw data with the colIdx to w. +func (r Row) WriteTo(colIdx int, w io.Writer) (int, error) { + return r.c.columns[colIdx].WriteTo(r.idx, w) +} + // IsNull returns if the datum in the chunk.Row is null. func (r Row) IsNull(colIdx int) bool { return r.c.columns[colIdx].IsNull(r.idx) diff --git a/util/codec/codec.go b/util/codec/codec.go index 8d6d733feb48d..678d4a54231eb 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -15,6 +15,7 @@ package codec import ( "encoding/binary" + "io" "time" "github.com/pingcap/errors" @@ -67,23 +68,14 @@ func preRealloc(b []byte, vals []types.Datum, comparable bool) []byte { // encode will encode a datum and append it to a byte slice. If comparable is true, the encoded bytes can be sorted as it's original order. // If hash is true, the encoded bytes can be checked equal as it's original value. -func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparable bool, hash bool) (_ []byte, err error) { +func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparable bool) (_ []byte, err error) { b = preRealloc(b, vals, comparable) for i, length := 0, len(vals); i < length; i++ { switch vals[i].Kind() { case types.KindInt64: b = encodeSignedInt(b, vals[i].GetInt64(), comparable) case types.KindUint64: - if hash { - integer := vals[i].GetInt64() - if integer < 0 { - b = encodeUnsignedInt(b, uint64(integer), comparable) - } else { - b = encodeSignedInt(b, integer, comparable) - } - } else { - b = encodeUnsignedInt(b, vals[i].GetUint64(), comparable) - } + b = encodeUnsignedInt(b, vals[i].GetUint64(), comparable) case types.KindFloat32, types.KindFloat64: b = append(b, floatFlag) b = EncodeFloat(b, vals[i].GetFloat64()) @@ -101,22 +93,11 @@ func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparab b = EncodeInt(b, int64(vals[i].GetMysqlDuration().Duration)) case types.KindMysqlDecimal: b = append(b, decimalFlag) - if hash { - // If hash is true, we only consider the original value of this decimal and ignore it's precision. - dec := vals[i].GetMysqlDecimal() - var bin []byte - bin, err = dec.ToHashKey() - if err != nil { - return b, errors.Trace(err) - } - b = append(b, bin...) - } else { - b, err = EncodeDecimal(b, vals[i].GetMysqlDecimal(), vals[i].Length(), vals[i].Frac()) - if terror.ErrorEqual(err, types.ErrTruncated) { - err = sc.HandleTruncate(err) - } else if terror.ErrorEqual(err, types.ErrOverflow) { - err = sc.HandleOverflow(err, err) - } + b, err = EncodeDecimal(b, vals[i].GetMysqlDecimal(), vals[i].Length(), vals[i].Frac()) + if terror.ErrorEqual(err, types.ErrTruncated) { + err = sc.HandleTruncate(err) + } else if terror.ErrorEqual(err, types.ErrOverflow) { + err = sc.HandleOverflow(err, err) } case types.KindMysqlEnum: b = encodeUnsignedInt(b, uint64(vals[i].GetMysqlEnum().ToNumber()), comparable) @@ -284,105 +265,61 @@ func sizeInt(comparable bool) int { // slice. It guarantees the encoded value is in ascending order for comparison. // For Decimal type, datum must set datum's length and frac. func EncodeKey(sc *stmtctx.StatementContext, b []byte, v ...types.Datum) ([]byte, error) { - return encode(sc, b, v, true, false) + return encode(sc, b, v, true) } // EncodeValue appends the encoded values to byte slice b, returning the appended // slice. It does not guarantee the order for comparison. func EncodeValue(sc *stmtctx.StatementContext, b []byte, v ...types.Datum) ([]byte, error) { - return encode(sc, b, v, false, false) + return encode(sc, b, v, false) } -func encodeHashChunkRow(sc *stmtctx.StatementContext, b []byte, row chunk.Row, allTypes []*types.FieldType, colIdx []int) (_ []byte, err error) { - const comparable = false +func encodeHashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int) (err error) { for _, i := range colIdx { if row.IsNull(i) { - b = append(b, NilFlag) + _, err = w.Write([]byte{NilFlag}) + if err != nil { + return errors.Trace(err) + } continue } - switch allTypes[i].Tp { + ft := allTypes[i] + switch ft.Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: - if !mysql.HasUnsignedFlag(allTypes[i].Flag) { - b = encodeSignedInt(b, row.GetInt64(i), comparable) - break - } - // encode unsigned integers. - integer := row.GetInt64(i) - if integer < 0 { - b = encodeUnsignedInt(b, uint64(integer), comparable) - } else { - b = encodeSignedInt(b, integer, comparable) - } - case mysql.TypeFloat: - b = append(b, floatFlag) - b = EncodeFloat(b, float64(row.GetFloat32(i))) - case mysql.TypeDouble: - b = append(b, floatFlag) - b = EncodeFloat(b, row.GetFloat64(i)) - case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - b = encodeBytes(b, row.GetBytes(i), comparable) - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: - b = append(b, uintFlag) - t := row.GetTime(i) - // Encoding timestamp need to consider timezone. - // If it's not in UTC, transform to UTC first. - if t.Type == mysql.TypeTimestamp && sc.TimeZone != time.UTC { - err = t.ConvertTimeZone(sc.TimeZone, time.UTC) - if err != nil { - return nil, errors.Trace(err) + flag := varintFlag + if mysql.HasUnsignedFlag(allTypes[i].Flag) { + if integer := row.GetInt64(i); integer < 0 { + flag = uvarintFlag } } - var v uint64 - v, err = t.ToPackedUint() + _, err = w.Write([]byte{flag}) if err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } - b = EncodeUint(b, v) - case mysql.TypeDuration: - // duration may have negative value, so we cannot use String to encode directly. - b = append(b, durationFlag) - b = EncodeInt(b, int64(row.GetDuration(i, 0).Duration)) + _, err = row.WriteTo(i, w) case mysql.TypeNewDecimal: - b = append(b, decimalFlag) // If hash is true, we only consider the original value of this decimal and ignore it's precision. dec := row.GetMyDecimal(i) - bin, err := dec.ToHashKey() + var bin []byte + bin, err = dec.ToHashKey() if err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } - b = append(b, bin...) - case mysql.TypeEnum: - b = encodeUnsignedInt(b, uint64(row.GetEnum(i).ToNumber()), comparable) - case mysql.TypeSet: - b = encodeUnsignedInt(b, uint64(row.GetSet(i).ToNumber()), comparable) - case mysql.TypeBit: - // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. - var val uint64 - val, err = types.BinaryLiteral(row.GetBytes(i)).ToInt(sc) - terror.Log(errors.Trace(err)) - b = encodeUnsignedInt(b, val, comparable) - case mysql.TypeJSON: - b = append(b, jsonFlag) - j := row.GetJSON(i) - b = append(b, j.TypeCode) - b = append(b, j.Value...) + _, err = w.Write(bin) default: - return nil, errors.Errorf("unsupport column type for encode %d", allTypes[i].Tp) + _, err = row.WriteTo(i, w) + } + if err != nil { + return errors.Trace(err) } } - return b, errors.Trace(err) -} - -// HashValues appends the encoded values to byte slice b, returning the appended -// slice. If two datums are equal, they will generate the same bytes. -func HashValues(sc *stmtctx.StatementContext, b []byte, v ...types.Datum) ([]byte, error) { - return encode(sc, b, v, false, true) + return } -// HashChunkRow appends the encoded values to byte slice "b", returning the appended slice. +// HashChunkRow writes the encoded values to w. // If two rows are equal, it will generate the same bytes. -func HashChunkRow(sc *stmtctx.StatementContext, b []byte, row chunk.Row, allTypes []*types.FieldType, colIdx []int) ([]byte, error) { - return encodeHashChunkRow(sc, b, row, allTypes, colIdx) +func HashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int) error { + return encodeHashChunkRow(sc, w, row, allTypes, colIdx) } // Decode decodes values from a byte slice generated with EncodeKey or EncodeValue diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index 343046e4ed189..2c67e14e5d7f9 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -15,6 +15,8 @@ package codec import ( "bytes" + "fmt" + "hash/crc32" "math" "testing" "time" @@ -812,7 +814,7 @@ func (s *testCodecSuite) TestJSON(c *C) { } buf := make([]byte, 0, 4096) - buf, err := encode(nil, buf, datums, false, false) + buf, err := encode(nil, buf, datums, false) c.Assert(err, IsNil) datums1, err := Decode(buf, 2) @@ -1042,12 +1044,74 @@ func (s *testCodecSuite) TestHashChunkRow(c *C) { for i := 0; i < len(tps); i++ { colIdx[i] = i } - b1, err1 := HashChunkRow(sc, nil, chk.GetRow(0), tps, colIdx) - b2, err2 := HashValues(sc, nil, datums...) + h := crc32.NewIEEE() + err1 := HashChunkRow(sc, h, chk.GetRow(0), tps, colIdx) + sum1 := h.Sum32() + h.Reset() + err2 := HashChunkRow(sc, h, chk.GetRow(0), tps, colIdx) + sum2 := h.Sum32() c.Assert(err1, IsNil) c.Assert(err2, IsNil) - c.Assert(b1, BytesEquals, b2) + c.Assert(sum1, Equals, sum2) + + // uint64(18446744073709551615) != int64(-1) + // uint64(1) == int64(1) + tp1 := new(types.FieldType) + types.DefaultTypeForValue(uint64(18446744073709551615), tp1) + chk1 := chunk.New([]*types.FieldType{tp1}, 1, 1) + chk1.AppendUint64(0, 18446744073709551615) + chk1.AppendUint64(0, 1) + + tp2 := new(types.FieldType) + types.DefaultTypeForValue(int64(-1), tp2) + chk2 := chunk.New([]*types.FieldType{tp2}, 1, 1) + chk2.AppendInt64(0, -1) + chk2.AppendInt64(0, 1) + + h.Reset() + err1 = HashChunkRow(sc, h, chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}) + sum1 = h.Sum32() + h.Reset() + err2 = HashChunkRow(sc, h, chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}) + sum2 = h.Sum32() + fmt.Println(sum1, sum2) + c.Assert(err1, IsNil) + c.Assert(err2, IsNil) + c.Assert(sum1, Not(Equals), sum2) + + h.Reset() + err1 = HashChunkRow(sc, h, chk1.GetRow(1), []*types.FieldType{tp1}, []int{0}) + sum1 = h.Sum32() + h.Reset() + err2 = HashChunkRow(sc, h, chk2.GetRow(1), []*types.FieldType{tp2}, []int{0}) + sum2 = h.Sum32() + c.Assert(err1, IsNil) + c.Assert(err2, IsNil) + c.Assert(sum1, Equals, sum2) + + // Decimal(1.00) == Decimal(1.000) + tp1 = new(types.FieldType) + dec := types.NewDecFromStringForTest("1.10") + types.DefaultTypeForValue(dec, tp1) + chk1 = chunk.New([]*types.FieldType{tp1}, 1, 1) + chk1.AppendMyDecimal(0, dec) + + tp2 = new(types.FieldType) + dec = types.NewDecFromStringForTest("01.100") + types.DefaultTypeForValue(dec, tp2) + chk2 = chunk.New([]*types.FieldType{tp2}, 1, 1) + chk2.AppendMyDecimal(0, dec) + + h.Reset() + err1 = HashChunkRow(sc, h, chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}) + sum1 = h.Sum32() + h.Reset() + err2 = HashChunkRow(sc, h, chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}) + sum2 = h.Sum32() + c.Assert(err1, IsNil) + c.Assert(err2, IsNil) + c.Assert(sum1, Not(Equals), sum2) } func (s *testCodecSuite) TestValueSizeOfSignedInt(c *C) { From bcdb7191b6f627a5a1f9cc021bd59a251ca8b4f4 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Thu, 22 Aug 2019 23:42:10 +0800 Subject: [PATCH 02/11] benchmark --- executor/benchmark_test.go | 151 ++++++++++++++++++++++++++++++++++--- executor/pkg_test.go | 106 ++++++-------------------- 2 files changed, 163 insertions(+), 94 deletions(-) diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index aff3cd4e375a3..f0d184144d5cb 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -31,7 +31,9 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/stringutil" ) var ( @@ -39,11 +41,12 @@ var ( ) type mockDataSourceParameters struct { - schema *expression.Schema - ndvs []int // number of distinct values on columns[i] and zero represents no limit - orders []bool // columns[i] should be ordered if orders[i] is true - rows int // number of rows the DataSource should output - ctx sessionctx.Context + schema *expression.Schema + genDataFunc func(row int, typ *types.FieldType) interface{} + ndvs []int // number of distinct values on columns[i] and zero represents no limit + orders []bool // columns[i] should be ordered if orders[i] is true + rows int // number of rows the DataSource should output + ctx sessionctx.Context } type mockDataSource struct { @@ -56,11 +59,21 @@ type mockDataSource struct { func (mds *mockDataSource) genColDatums(col int) (results []interface{}) { typ := mds.retFieldTypes[col] - order := mds.p.orders[col] + order := false + if col < len(mds.p.orders) { + order = mds.p.orders[col] + } rows := mds.p.rows - NDV := mds.p.ndvs[col] + NDV := 0 + if col < len(mds.p.ndvs) { + NDV = mds.p.ndvs[col] + } results = make([]interface{}, 0, rows) - if NDV == 0 { + if mds.p.genDataFunc != nil { + for i := 0; i < rows; i++ { + results = append(results, mds.p.genDataFunc(i, typ)) + } + } else if NDV == 0 { for i := 0; i < rows; i++ { results = append(results, mds.randDatum(typ)) } @@ -184,7 +197,7 @@ func (a aggTestCase) columns() []*expression.Column { } func (a aggTestCase) String() string { - return fmt.Sprintf("(execType:%v, aggFunc:%v, ndv:%v, hasDistinct:%v, rows:%v, concruuency:%v)", + return fmt.Sprintf("(execType:%v, aggFunc:%v, ndv:%v, hasDistinct:%v, rows:%v, concurrency:%v)", a.execType, a.aggFunc, a.groupByNDV, a.hasDistinct, a.rows, a.concurrency) } @@ -503,3 +516,123 @@ func BenchmarkWindowFunctions(b *testing.B) { }) } } + +type hashJoinTestCase struct { + rows int + concurrency int + ctx sessionctx.Context + keyIdx []int +} + +func (tc hashJoinTestCase) columns() []*expression.Column { + return []*expression.Column{ + {Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong)}, + {Index: 1, RetType: types.NewFieldType(mysql.TypeVarString)}, + } +} + +func (tc hashJoinTestCase) String() string { + return fmt.Sprintf("(rows:%v, concurency:%v, joinKeyIdx: %v)", + tc.rows, tc.concurrency, tc.keyIdx) +} + +func defaultHashJoinTestCase() *hashJoinTestCase { + ctx := mock.NewContext() + ctx.GetSessionVars().InitChunkSize = variable.DefInitChunkSize + ctx.GetSessionVars().MaxChunkSize = variable.DefMaxChunkSize + ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(nil, -1) + tc := &hashJoinTestCase{rows: 100000, concurrency: 4, ctx: ctx, keyIdx: []int{0, 1}} + return tc +} + +func prepare4Join(testCase *hashJoinTestCase, innerExec, outerExec Executor) *HashJoinExec { + cols0 := testCase.columns() + cols1 := testCase.columns() + joinSchema := expression.NewSchema(cols0...) + joinSchema.Append(cols1...) + joinKeys := make([]*expression.Column, 0, len(testCase.keyIdx)) + for _, keyIdx := range testCase.keyIdx { + joinKeys = append(joinKeys, cols0[keyIdx]) + } + e := &HashJoinExec{ + baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, stringutil.StringerStr("HashJoin"), innerExec, outerExec), + concurrency: uint(testCase.concurrency), + joinType: 0, // InnerJoin + isOuterJoin: false, + innerKeys: joinKeys, + outerKeys: joinKeys, + innerExec: innerExec, + outerExec: outerExec, + } + defaultValues := make([]types.Datum, e.innerExec.Schema().Len()) + lhsTypes, rhsTypes := retTypes(innerExec), retTypes(outerExec) + e.joiners = make([]joiner, e.concurrency) + for i := uint(0); i < e.concurrency; i++ { + e.joiners[i] = newJoiner(testCase.ctx, e.joinType, true, defaultValues, + nil, lhsTypes, rhsTypes) + } + return e +} + +func benchmarkHashJoinExecWithCase(b *testing.B, casTest *hashJoinTestCase) { + opt := mockDataSourceParameters{ + schema: expression.NewSchema(casTest.columns()...), + rows: casTest.rows, + ctx: casTest.ctx, + genDataFunc: func(row int, typ *types.FieldType) interface{} { + switch typ.Tp { + case mysql.TypeLong, mysql.TypeLonglong: + return int64(row) + case mysql.TypeVarString: + return rawData + default: + panic("not implement") + } + }, + } + dataSource1 := buildMockDataSource(opt) + dataSource2 := buildMockDataSource(opt) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + exec := prepare4Join(casTest, dataSource1, dataSource2) + tmpCtx := context.Background() + chk := newFirstChunk(exec) + dataSource1.prepareChunks() + dataSource2.prepareChunks() + + total := 0 + b.StartTimer() + if err := exec.Open(tmpCtx); err != nil { + b.Fatal(err) + } + for { + if err := exec.Next(tmpCtx, chk); err != nil { + b.Fatal(err) + } + total += chk.NumRows() + if chk.NumRows() == 0 { + break + } + } + + if err := exec.Close(); err != nil { + b.Fatal(err) + } + b.StopTimer() + } +} + +func BenchmarkHashJoinExec(b *testing.B) { + b.ReportAllocs() + cas := defaultHashJoinTestCase() + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkHashJoinExecWithCase(b, cas) + }) + + cas.keyIdx = []int{0} + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkHashJoinExecWithCase(b, cas) + }) +} diff --git a/executor/pkg_test.go b/executor/pkg_test.go index 637661ed3e40b..bdd7a91772bed 100644 --- a/executor/pkg_test.go +++ b/executor/pkg_test.go @@ -3,16 +3,15 @@ package executor import ( "context" "fmt" + . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" plannercore "github.com/pingcap/tidb/planner/core" - "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/mock" - "github.com/pingcap/tidb/util/stringutil" ) var _ = Suite(&pkgTestSuite{}) @@ -20,36 +19,6 @@ var _ = Suite(&pkgTestSuite{}) type pkgTestSuite struct { } -type MockExec struct { - baseExecutor - - Rows []chunk.MutRow - curRowIdx int -} - -func (m *MockExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - colTypes := retTypes(m) - for ; m.curRowIdx < len(m.Rows) && req.NumRows() < req.Capacity(); m.curRowIdx++ { - curRow := m.Rows[m.curRowIdx] - for i := 0; i < curRow.Len(); i++ { - curDatum := curRow.ToRow().GetDatum(i, colTypes[i]) - req.AppendDatum(i, &curDatum) - } - } - return nil -} - -func (m *MockExec) Close() error { - m.curRowIdx = 0 - return nil -} - -func (m *MockExec) Open(ctx context.Context) error { - m.curRowIdx = 0 - return nil -} - func (s *pkgTestSuite) TestNestedLoopApply(c *C) { ctx := context.Background() sctx := mock.NewContext() @@ -57,27 +26,27 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { col1 := &expression.Column{Index: 1, RetType: types.NewFieldType(mysql.TypeLong)} con := &expression.Constant{Value: types.NewDatum(6), RetType: types.NewFieldType(mysql.TypeLong)} outerSchema := expression.NewSchema(col0) - outerExec := &MockExec{ - baseExecutor: newBaseExecutor(sctx, outerSchema, nil), - Rows: []chunk.MutRow{ - chunk.MutRowFromDatums(types.MakeDatums(1)), - chunk.MutRowFromDatums(types.MakeDatums(2)), - chunk.MutRowFromDatums(types.MakeDatums(3)), - chunk.MutRowFromDatums(types.MakeDatums(4)), - chunk.MutRowFromDatums(types.MakeDatums(5)), - chunk.MutRowFromDatums(types.MakeDatums(6)), - }} + outerExec := buildMockDataSource(mockDataSourceParameters{ + schema: outerSchema, + rows: 6, + ctx: sctx, + genDataFunc: func(row int, typ *types.FieldType) interface{} { + return int64(row + 1) + }, + }) + outerExec.prepareChunks() + innerSchema := expression.NewSchema(col1) - innerExec := &MockExec{ - baseExecutor: newBaseExecutor(sctx, innerSchema, nil), - Rows: []chunk.MutRow{ - chunk.MutRowFromDatums(types.MakeDatums(1)), - chunk.MutRowFromDatums(types.MakeDatums(2)), - chunk.MutRowFromDatums(types.MakeDatums(3)), - chunk.MutRowFromDatums(types.MakeDatums(4)), - chunk.MutRowFromDatums(types.MakeDatums(5)), - chunk.MutRowFromDatums(types.MakeDatums(6)), - }} + innerExec := buildMockDataSource(mockDataSourceParameters{ + schema: innerSchema, + rows: 6, + ctx: sctx, + genDataFunc: func(row int, typ *types.FieldType) interface{} { + return int64(row + 1) + }, + }) + innerExec.prepareChunks() + outerFilter := expression.NewFunctionInternal(sctx, ast.LT, types.NewFieldType(mysql.TypeTiny), col0, con) innerFilter := outerFilter.Clone() otherFilter := expression.NewFunctionInternal(sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), col0, col1) @@ -112,39 +81,6 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { } } -func prepareOneColChildExec(sctx sessionctx.Context, rowCount int) Executor { - col0 := &expression.Column{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)} - schema := expression.NewSchema(col0) - exec := &MockExec{ - baseExecutor: newBaseExecutor(sctx, schema, nil), - Rows: make([]chunk.MutRow, rowCount)} - for i := 0; i < len(exec.Rows); i++ { - exec.Rows[i] = chunk.MutRowFromDatums(types.MakeDatums(i % 10)) - } - return exec -} - -func prepare4RadixPartition(sctx sessionctx.Context, rowCount int) *HashJoinExec { - childExec0 := prepareOneColChildExec(sctx, rowCount) - childExec1 := prepareOneColChildExec(sctx, rowCount) - - col0 := &expression.Column{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)} - col1 := &expression.Column{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)} - joinSchema := expression.NewSchema(col0, col1) - hashJoinExec := &HashJoinExec{ - baseExecutor: newBaseExecutor(sctx, joinSchema, stringutil.StringerStr("HashJoin"), childExec0, childExec1), - concurrency: 4, - joinType: 0, // InnerJoin - innerKeys: []*expression.Column{{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)}}, - innerKeyColIdx: []int{0}, - outerKeys: []*expression.Column{{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)}}, - outerKeyColIdx: []int{0}, - innerExec: childExec0, - outerExec: childExec1, - } - return hashJoinExec -} - func (s *pkgTestSuite) TestMoveInfoSchemaToFront(c *C) { dbss := [][]string{ {}, From 60ccc83f4cc3c07b2f8dd7a3e15fa95ee4462745 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Fri, 23 Aug 2019 00:38:46 +0800 Subject: [PATCH 03/11] fix licence --- executor/hash_table.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/executor/hash_table.go b/executor/hash_table.go index 2de90c3246ae4..72e79ed849869 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -1,4 +1,4 @@ -// Copyright 2016 PingCAP, Inc. +// Copyright 2019 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From 04526e26b2372598757883cab223ef1499af2226 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Fri, 23 Aug 2019 22:19:47 +0800 Subject: [PATCH 04/11] address comments --- executor/benchmark_test.go | 2 - executor/join.go | 40 +++++----- executor/join_test.go | 33 +++++++- types/field_type.go | 6 ++ util/chunk/column.go | 14 ---- util/chunk/compare.go | 40 ---------- util/chunk/row.go | 6 -- util/codec/codec.go | 154 ++++++++++++++++++++++++++++--------- util/codec/codec_test.go | 113 ++++++++++++++------------- 9 files changed, 233 insertions(+), 175 deletions(-) diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index f0d184144d5cb..1dfe4c334d055 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -602,7 +602,6 @@ func benchmarkHashJoinExecWithCase(b *testing.B, casTest *hashJoinTestCase) { dataSource1.prepareChunks() dataSource2.prepareChunks() - total := 0 b.StartTimer() if err := exec.Open(tmpCtx); err != nil { b.Fatal(err) @@ -611,7 +610,6 @@ func benchmarkHashJoinExecWithCase(b *testing.B, casTest *hashJoinTestCase) { if err := exec.Next(tmpCtx, chk); err != nil { b.Fatal(err) } - total += chk.NumRows() if chk.NumRows() == 0 { break } diff --git a/executor/join.go b/executor/join.go index 53c33880baf94..16ca35710b750 100644 --- a/executor/join.go +++ b/executor/join.go @@ -49,10 +49,9 @@ type HashJoinExec struct { innerKeys []*expression.Column // concurrency is the number of partition, build and join workers. - concurrency uint - hashTable rowHashTable - innerFinished chan error - hashJoinBuffers []*hashJoinBuffer + concurrency uint + hashTable rowHashTable + innerFinished chan error // joinWorkerWaitGroup is for sync multiple join workers. joinWorkerWaitGroup sync.WaitGroup finished atomic.Value @@ -97,10 +96,6 @@ type hashjoinWorkerResult struct { src chan<- *chunk.Chunk } -type hashJoinBuffer struct { - bytes []byte -} - // Close implements the Executor Close interface. func (e *HashJoinExec) Close() error { close(e.closeCh) @@ -176,19 +171,11 @@ func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, h ha return false, h.Sum64(), err } -func (e *HashJoinExec) matchJoinKey(inner, outer chunk.Row) bool { - innerAllTypes := retTypes(e.innerExec) - outerAllTypes := retTypes(e.outerExec) - for i := range e.innerKeyColIdx { - innerIdx := e.innerKeyColIdx[i] - outerIdx := e.outerKeyColIdx[i] - innerTp := innerAllTypes[innerIdx] - outerTp := outerAllTypes[outerIdx] - if cmp := chunk.GetCompareFuncWithTypes(innerTp, outerTp); cmp(inner, innerIdx, outer, outerIdx) != 0 { - return false - } - } - return true +func (e *HashJoinExec) matchJoinKey(inner, outer chunk.Row) (ok bool, err error) { + innerAllTypes, outerAllTypes := retTypes(e.innerExec), retTypes(e.outerExec) + return codec.EqualChunkRow(e.ctx.GetSessionVars().StmtCtx, + inner, innerAllTypes, e.innerKeyColIdx, + outer, outerAllTypes, e.outerKeyColIdx) } // fetchOuterChunks get chunks from fetches chunks from the big table in a background goroutine @@ -431,11 +418,20 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R innerRows := make([]chunk.Row, 0, len(innerPtrs)) for _, ptr := range innerPtrs { matchedInner := e.innerResult.GetRow(ptr) - if !e.matchJoinKey(matchedInner, outerRow) { + ok, err := e.matchJoinKey(matchedInner, outerRow) + if err != nil { + joinResult.err = err + return false, joinResult + } + if !ok { continue } innerRows = append(innerRows, matchedInner) } + if len(innerRows) == 0 { // TODO(fengliyuan): add test case + e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) + return true, joinResult + } iter := chunk.NewIterator4Slice(innerRows) hasMatch, hasNull := false, false for iter.Begin(); iter.Current() != iter.End(); { diff --git a/executor/join_test.go b/executor/join_test.go index 9a19cb1d4a79b..02bab0e19b674 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -273,6 +273,7 @@ func (s *testSuite2) TestJoin(c *C) { func (s *testSuite2) TestJoinCast(c *C) { tk := testkit.NewTestKit(c, s.store) + var result *testkit.Result tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -281,7 +282,7 @@ func (s *testSuite2) TestJoinCast(c *C) { tk.MustExec("create table t1(c1 int unsigned)") tk.MustExec("insert into t values (1)") tk.MustExec("insert into t1 values (1)") - result := tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") + result = tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") result.Check(testkit.Rows("1")) // int64(-1) != uint64(18446744073709551615) @@ -294,6 +295,36 @@ func (s *testSuite2) TestJoinCast(c *C) { result = tk.MustQuery("select * from t , t1 where t.c1 = t1.c1") result.Check(testkit.Rows()) + // float(1) == double(1) + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 float)") + tk.MustExec("create table t1(c1 double)") + tk.MustExec("insert into t values (1.0)") + tk.MustExec("insert into t1 values (1.00)") + result = tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") + result.Check(testkit.Rows("1")) + + // varchar("x") == char("x") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 varchar(1))") + tk.MustExec("create table t1(c1 char(1))") + tk.MustExec(`insert into t values ("x")`) + tk.MustExec(`insert into t1 values ("x")`) + result = tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") + result.Check(testkit.Rows("x")) + + // varchar("x") != char("y") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 varchar(1))") + tk.MustExec("create table t1(c1 char(1))") + tk.MustExec(`insert into t values ("x")`) + tk.MustExec(`insert into t1 values ("y")`) + result = tk.MustQuery("select t.c1 from t , t1 where t.c1 = t1.c1") + result.Check(testkit.Rows()) + tk.MustExec("drop table if exists t") tk.MustExec("drop table if exists t1") tk.MustExec("create table t(c1 int,c2 double)") diff --git a/types/field_type.go b/types/field_type.go index 51905169dafa8..f8d51d18a101e 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -183,6 +183,12 @@ func DefaultTypeForValue(value interface{}, tp *FieldType) { tp.Flen = len(x) tp.Decimal = UnspecifiedLength tp.Charset, tp.Collate = charset.GetDefaultCharsetAndCollate() + case float32: + tp.Tp = mysql.TypeFloat + s := strconv.FormatFloat(float64(x), 'f', -1, 32) + tp.Flen = len(s) + tp.Decimal = UnspecifiedLength + SetBinChsClnFlag(tp) case float64: tp.Tp = mysql.TypeDouble s := strconv.FormatFloat(x, 'f', -1, 64) diff --git a/util/chunk/column.go b/util/chunk/column.go index c35f43596c51c..e4f1d6c54b0e3 100644 --- a/util/chunk/column.go +++ b/util/chunk/column.go @@ -14,7 +14,6 @@ package chunk import ( - "io" "math/bits" "reflect" "time" @@ -606,16 +605,3 @@ func (c *Column) CopyReconstruct(sel []int, dst *Column) *Column { } return dst } - -// WriteTo writes the raw data in the specific row to w. -func (c *Column) WriteTo(rowID int, w io.Writer) (int, error) { - var data []byte - if c.isFixed() { - elemLen := len(c.elemBuf) - data = c.data[rowID*elemLen : rowID*elemLen+elemLen] - } else { - start, end := c.offsets[rowID], c.offsets[rowID] - data = c.data[start:end] - } - return w.Write(data) -} diff --git a/util/chunk/compare.go b/util/chunk/compare.go index c754e72891b3a..2ecaf196ee0c9 100644 --- a/util/chunk/compare.go +++ b/util/chunk/compare.go @@ -55,20 +55,6 @@ func GetCompareFunc(tp *types.FieldType) CompareFunc { return nil } -// GetCompareFuncWithTypes gets a compare function for the two field types. -func GetCompareFuncWithTypes(tp1, tp2 *types.FieldType) CompareFunc { - switch tp1.Tp { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: - f1, f2 := mysql.HasUnsignedFlag(tp1.Flag), mysql.HasUnsignedFlag(tp2.Flag) - if f1 && !f2 { - return cmpUint64AndInt64 - } else if !f1 && f2 { - return cmpInt64AndUint64 - } - } - return GetCompareFunc(tp1) -} - func cmpNull(lNull, rNull bool) int { if lNull && rNull { return 0 @@ -95,32 +81,6 @@ func cmpUint64(l Row, lCol int, r Row, rCol int) int { return types.CompareUint64(l.GetUint64(lCol), r.GetUint64(rCol)) } -func cmpInt64AndUint64(l Row, lCol int, r Row, rCol int) int { - lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) - if lNull || rNull { - return cmpNull(lNull, rNull) - } - lVal := l.GetInt64(lCol) - rVal := r.GetUint64(rCol) - if lVal < 0 { - return -1 - } - return types.CompareUint64(uint64(lVal), rVal) -} - -func cmpUint64AndInt64(l Row, lCol int, r Row, rCol int) int { - lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) - if lNull || rNull { - return cmpNull(lNull, rNull) - } - lVal := l.GetUint64(lCol) - rVal := r.GetInt64(rCol) - if rVal < 0 { - return 1 - } - return types.CompareUint64(lVal, uint64(rVal)) -} - func cmpString(l Row, lCol int, r Row, rCol int) int { lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) if lNull || rNull { diff --git a/util/chunk/row.go b/util/chunk/row.go index 284412d3703f3..0d83c4dec5525 100644 --- a/util/chunk/row.go +++ b/util/chunk/row.go @@ -14,7 +14,6 @@ package chunk import ( - "io" "unsafe" "github.com/pingcap/parser/mysql" @@ -205,11 +204,6 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum { return d } -// WriteTo writes the raw data with the colIdx to w. -func (r Row) WriteTo(colIdx int, w io.Writer) (int, error) { - return r.c.columns[colIdx].WriteTo(r.idx, w) -} - // IsNull returns if the datum in the chunk.Row is null. func (r Row) IsNull(colIdx int) bool { return r.c.columns[colIdx].IsNull(r.idx) diff --git a/util/codec/codec.go b/util/codec/codec.go index 678d4a54231eb..d63cfda03f56c 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -14,6 +14,7 @@ package codec import ( + "bytes" "encoding/binary" "io" "time" @@ -274,52 +275,135 @@ func EncodeValue(sc *stmtctx.StatementContext, b []byte, v ...types.Datum) ([]by return encode(sc, b, v, false) } -func encodeHashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int) (err error) { - for _, i := range colIdx { - if row.IsNull(i) { - _, err = w.Write([]byte{NilFlag}) - if err != nil { - return errors.Trace(err) - } - continue - } - ft := allTypes[i] - switch ft.Tp { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: - flag := varintFlag - if mysql.HasUnsignedFlag(allTypes[i].Flag) { - if integer := row.GetInt64(i); integer < 0 { - flag = uvarintFlag - } - } - _, err = w.Write([]byte{flag}) - if err != nil { - return errors.Trace(err) - } - _, err = row.WriteTo(i, w) - case mysql.TypeNewDecimal: - // If hash is true, we only consider the original value of this decimal and ignore it's precision. - dec := row.GetMyDecimal(i) - var bin []byte - bin, err = dec.ToHashKey() +func encodeHashChunkRowIdx(sc *stmtctx.StatementContext, b []byte, row chunk.Row, tp *types.FieldType, idx int) (_ [][]byte, err error) { + if row.IsNull(idx) { + return [][]byte{{NilFlag}}, nil + } + const comparable = false + b = b[:0] + switch tp.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: + if !mysql.HasUnsignedFlag(tp.Flag) { + b = encodeSignedInt(b, row.GetInt64(idx), comparable) + break + } + // encode unsigned integers. + integer := row.GetInt64(idx) + if integer < 0 { + b = encodeUnsignedInt(b, uint64(integer), comparable) + } else { + b = encodeSignedInt(b, integer, comparable) + } + case mysql.TypeFloat: + b = append(b, floatFlag) + b = EncodeFloat(b, float64(row.GetFloat32(idx))) + case mysql.TypeDouble: + b = append(b, floatFlag) + b = EncodeFloat(b, row.GetFloat64(idx)) + case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + return [][]byte{ + {compactBytesFlag}, + row.GetBytes(idx), + }, nil + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + b = append(b, uintFlag) + t := row.GetTime(idx) + // Encoding timestamp need to consider timezone. + // If it's not in UTC, transform to UTC first. + if t.Type == mysql.TypeTimestamp && sc.TimeZone != time.UTC { + err = t.ConvertTimeZone(sc.TimeZone, time.UTC) if err != nil { - return errors.Trace(err) + return nil, err } - _, err = w.Write(bin) - default: - _, err = row.WriteTo(i, w) } + var v uint64 + v, err = t.ToPackedUint() if err != nil { - return errors.Trace(err) + return nil, err } + b = EncodeUint(b, v) + case mysql.TypeDuration: + // duration may have negative value, so we cannot use String to encode directly. + b = append(b, durationFlag) + b = EncodeInt(b, int64(row.GetDuration(idx, 0).Duration)) + case mysql.TypeNewDecimal: + // If hash is true, we only consider the original value of this decimal and ignore it's precision. + dec := row.GetMyDecimal(idx) + bin, err := dec.ToHashKey() + if err != nil { + return nil, err + } + return [][]byte{ + {decimalFlag}, + bin, + }, nil + case mysql.TypeEnum: + b = encodeUnsignedInt(b, uint64(row.GetEnum(idx).ToNumber()), comparable) + case mysql.TypeSet: + b = encodeUnsignedInt(b, uint64(row.GetSet(idx).ToNumber()), comparable) + case mysql.TypeBit: + // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. + var val uint64 + val, err = types.BinaryLiteral(row.GetBytes(idx)).ToInt(sc) + terror.Log(errors.Trace(err)) + b = encodeUnsignedInt(b, val, comparable) + case mysql.TypeJSON: + return [][]byte{ + {jsonFlag}, + row.GetBytes(idx), + }, nil + default: + return nil, errors.Errorf("unsupport column type for encode %d", tp.Tp) } - return + return [][]byte{b}, nil } // HashChunkRow writes the encoded values to w. // If two rows are equal, it will generate the same bytes. func HashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int) error { - return encodeHashChunkRow(sc, w, row, allTypes, colIdx) + var b []byte + for _, idx := range colIdx { + rets, err := encodeHashChunkRowIdx(sc, b, row, allTypes[idx], idx) + if err != nil { + return errors.Trace(err) + } + for _, ret := range rets { + _, err = w.Write(ret) + if err != nil { + return errors.Trace(err) + } + } + } + return nil +} + +// EqualChunkRow returns a boolean reporting whether row1 and row2 +// with their types and column index are logically equal. +func EqualChunkRow(sc *stmtctx.StatementContext, + row1 chunk.Row, allTypes1 []*types.FieldType, colIdx1 []int, + row2 chunk.Row, allTypes2 []*types.FieldType, colIdx2 []int, +) (bool, error) { + var b1, b2 []byte + for i := range colIdx1 { + idx1, idx2 := colIdx1[i], colIdx2[i] + rets1, err := encodeHashChunkRowIdx(sc, b1, row1, allTypes1[idx1], idx1) + if err != nil { + return false, errors.Trace(err) + } + rets2, err := encodeHashChunkRowIdx(sc, b2, row2, allTypes2[idx2], idx2) + if err != nil { + return false, errors.Trace(err) + } + if len(rets1) != len(rets2) { + return false, nil + } + for i := range rets1 { + if !bytes.Equal(rets1[i], rets2[i]) { + return false, nil + } + } + } + return true, nil } // Decode decodes values from a byte slice generated with EncodeKey or EncodeValue diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index 2c67e14e5d7f9..3b32a8acb86ce 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -15,7 +15,6 @@ package codec import ( "bytes" - "fmt" "hash/crc32" "math" "testing" @@ -1035,6 +1034,47 @@ func (s *testCodecSuite) TestDecodeRange(c *C) { } } +func testHashChunkRowEqual(c *C, a, b interface{}, equal bool) { + sc := &stmtctx.StatementContext{TimeZone: time.Local} + + tp1 := new(types.FieldType) + types.DefaultTypeForValue(a, tp1) + chk1 := chunk.New([]*types.FieldType{tp1}, 1, 1) + d := types.Datum{} + d.SetValue(a) + chk1.AppendDatum(0, &d) + + tp2 := new(types.FieldType) + types.DefaultTypeForValue(b, tp2) + chk2 := chunk.New([]*types.FieldType{tp2}, 1, 1) + d = types.Datum{} + d.SetValue(b) + chk2.AppendDatum(0, &d) + + h := crc32.NewIEEE() + err1 := HashChunkRow(sc, h, chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}) + sum1 := h.Sum32() + h.Reset() + err2 := HashChunkRow(sc, h, chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}) + sum2 := h.Sum32() + c.Assert(err1, IsNil) + c.Assert(err2, IsNil) + if equal { + c.Assert(sum1, Equals, sum2) + } else { + c.Assert(sum1, Not(Equals), sum2) + } + e, err := EqualChunkRow(sc, + chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}, + chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}) + c.Assert(err, IsNil) + if equal { + c.Assert(e, IsTrue) + } else { + c.Assert(e, IsFalse) + } +} + func (s *testCodecSuite) TestHashChunkRow(c *C) { sc := &stmtctx.StatementContext{TimeZone: time.Local} datums, tps := datumsForTest(sc) @@ -1054,64 +1094,27 @@ func (s *testCodecSuite) TestHashChunkRow(c *C) { c.Assert(err1, IsNil) c.Assert(err2, IsNil) c.Assert(sum1, Equals, sum2) + e, err := EqualChunkRow(sc, + chk.GetRow(0), tps, colIdx, + chk.GetRow(0), tps, colIdx) + c.Assert(err, IsNil) + c.Assert(e, IsTrue) - // uint64(18446744073709551615) != int64(-1) - // uint64(1) == int64(1) - tp1 := new(types.FieldType) - types.DefaultTypeForValue(uint64(18446744073709551615), tp1) - chk1 := chunk.New([]*types.FieldType{tp1}, 1, 1) - chk1.AppendUint64(0, 18446744073709551615) - chk1.AppendUint64(0, 1) - - tp2 := new(types.FieldType) - types.DefaultTypeForValue(int64(-1), tp2) - chk2 := chunk.New([]*types.FieldType{tp2}, 1, 1) - chk2.AppendInt64(0, -1) - chk2.AppendInt64(0, 1) - - h.Reset() - err1 = HashChunkRow(sc, h, chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}) - sum1 = h.Sum32() - h.Reset() - err2 = HashChunkRow(sc, h, chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}) - sum2 = h.Sum32() - fmt.Println(sum1, sum2) - c.Assert(err1, IsNil) - c.Assert(err2, IsNil) - c.Assert(sum1, Not(Equals), sum2) - - h.Reset() - err1 = HashChunkRow(sc, h, chk1.GetRow(1), []*types.FieldType{tp1}, []int{0}) - sum1 = h.Sum32() - h.Reset() - err2 = HashChunkRow(sc, h, chk2.GetRow(1), []*types.FieldType{tp2}, []int{0}) - sum2 = h.Sum32() - c.Assert(err1, IsNil) - c.Assert(err2, IsNil) - c.Assert(sum1, Equals, sum2) + testHashChunkRowEqual(c, uint64(1), int64(1), true) + testHashChunkRowEqual(c, uint64(18446744073709551615), int64(-1), false) - // Decimal(1.00) == Decimal(1.000) - tp1 = new(types.FieldType) - dec := types.NewDecFromStringForTest("1.10") - types.DefaultTypeForValue(dec, tp1) - chk1 = chunk.New([]*types.FieldType{tp1}, 1, 1) - chk1.AppendMyDecimal(0, dec) + dec1 := types.NewDecFromStringForTest("1.1") + dec2 := types.NewDecFromStringForTest("01.100") + testHashChunkRowEqual(c, dec1, dec2, true) + dec1 = types.NewDecFromStringForTest("1.1") + dec2 = types.NewDecFromStringForTest("01.200") + testHashChunkRowEqual(c, dec1, dec2, false) - tp2 = new(types.FieldType) - dec = types.NewDecFromStringForTest("01.100") - types.DefaultTypeForValue(dec, tp2) - chk2 = chunk.New([]*types.FieldType{tp2}, 1, 1) - chk2.AppendMyDecimal(0, dec) + testHashChunkRowEqual(c, float32(1.0), float64(1.0), true) + testHashChunkRowEqual(c, float32(1.0), float64(1.1), false) - h.Reset() - err1 = HashChunkRow(sc, h, chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}) - sum1 = h.Sum32() - h.Reset() - err2 = HashChunkRow(sc, h, chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}) - sum2 = h.Sum32() - c.Assert(err1, IsNil) - c.Assert(err2, IsNil) - c.Assert(sum1, Not(Equals), sum2) + testHashChunkRowEqual(c, "x", []byte("x"), true) + testHashChunkRowEqual(c, "x", []byte("y"), false) } func (s *testCodecSuite) TestValueSizeOfSignedInt(c *C) { From 09a53e4b5dd14155763521b66747df0196ec48d6 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Tue, 27 Aug 2019 18:36:34 +0800 Subject: [PATCH 05/11] rewrite encodeHashChunk & hashTable --- executor/benchmark_test.go | 57 ++++++++++++++++++ executor/hash_table.go | 83 ++++++++++++++++++++++---- executor/join.go | 23 ++++---- util/chunk/column.go | 12 ++++ util/chunk/column_test.go | 31 ++++++++++ util/chunk/row.go | 5 ++ util/codec/codec.go | 117 +++++++++++++++++-------------------- util/codec/codec_test.go | 11 ++-- 8 files changed, 250 insertions(+), 89 deletions(-) diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 1dfe4c334d055..9dd3ba3daa8f9 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -634,3 +634,60 @@ func BenchmarkHashJoinExec(b *testing.B) { benchmarkHashJoinExecWithCase(b, cas) }) } + +func benchmarkBuildHashTableForList(b *testing.B, casTest *hashJoinTestCase) { + opt := mockDataSourceParameters{ + schema: expression.NewSchema(casTest.columns()...), + rows: casTest.rows, + ctx: casTest.ctx, + genDataFunc: func(row int, typ *types.FieldType) interface{} { + switch typ.Tp { + case mysql.TypeLong, mysql.TypeLonglong: + return int64(row) + case mysql.TypeVarString: + return rawData + default: + panic("not implement") + } + }, + } + dataSource1 := buildMockDataSource(opt) + dataSource2 := buildMockDataSource(opt) + + dataSource1.prepareChunks() + exec := prepare4Join(casTest, dataSource1, dataSource2) + tmpCtx := context.Background() + if err := exec.Open(tmpCtx); err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + innerResultCh := make(chan *chunk.Chunk, 1) + go func() { + for _, chk := range dataSource1.genData { + innerResultCh <- chk + } + close(innerResultCh) + }() + + b.StartTimer() + if err := exec.buildHashTableForList(innerResultCh); err != nil { + b.Fatal(err) + } + b.StopTimer() + } +} + +func BenchmarkBuildHashTableForList(b *testing.B) { + b.ReportAllocs() + cas := defaultHashJoinTestCase() + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkBuildHashTableForList(b, cas) + }) + + cas.keyIdx = []int{0} + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkBuildHashTableForList(b, cas) + }) +} diff --git a/executor/hash_table.go b/executor/hash_table.go index 72e79ed849869..562bd1cd2524a 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -17,20 +17,83 @@ import ( "github.com/pingcap/tidb/util/chunk" ) -// TODO: Consider using a fix-sized hash table. -type rowHashTable map[uint64][]chunk.RowPtr +const maxEntrySliceLen = 8 * 1024 -func newRowHashTable() rowHashTable { - t := make(rowHashTable) - return t +type entry struct { + ptr chunk.RowPtr + next entryAddr } -func (t rowHashTable) Put(key uint64, rowPtr chunk.RowPtr) { - t[key] = append(t[key], rowPtr) +type entryStore struct { + slices [][]entry + sliceIdx uint32 + sliceLen uint32 } -func (t rowHashTable) Get(key uint64) []chunk.RowPtr { - return t[key] +func (es *entryStore) put(e entry) entryAddr { + if es.sliceLen == maxEntrySliceLen { + es.slices = append(es.slices, make([]entry, 0, maxEntrySliceLen)) + es.sliceLen = 0 + es.sliceIdx++ + } + addr := entryAddr{sliceIdx: es.sliceIdx, offset: es.sliceLen} + slice := es.slices[es.sliceIdx] + slice = append(slice, e) + es.slices[es.sliceIdx] = slice + es.sliceLen++ + return addr } -func (t rowHashTable) Len() int { return len(t) } +func (es *entryStore) get(addr entryAddr) entry { + return es.slices[addr.sliceIdx][addr.offset] +} + +type entryAddr struct { + sliceIdx uint32 + offset uint32 +} + +var nullEntryAddr = entryAddr{} + +type rowHashMap struct { + entryStore entryStore + hashTable map[uint64]entryAddr + length int +} + +func newRowHashMap() *rowHashMap { + m := new(rowHashMap) + m.hashTable = make(map[uint64]entryAddr) + m.entryStore.slices = [][]entry{make([]entry, 0, 64)} + // Reserve the first empty entry, so entryAddr{} can represent nullEntryAddr. + m.entryStore.put(entry{}) + return m +} + +func (m *rowHashMap) Put(hashKey uint64, rowPtr chunk.RowPtr) { + oldEntryAddr := m.hashTable[hashKey] + e := entry{ + ptr: rowPtr, + next: oldEntryAddr, + } + newEntryAddr := m.entryStore.put(e) + m.hashTable[hashKey] = newEntryAddr + m.length++ +} + +func (m *rowHashMap) Get(hashKey uint64) (rowPtrs []chunk.RowPtr) { + entryAddr := m.hashTable[hashKey] + for entryAddr != nullEntryAddr { + e := m.entryStore.get(entryAddr) + entryAddr = e.next + rowPtrs = append(rowPtrs, e.ptr) + } + // Keep the order of input. + for i := 0; i < len(rowPtrs)/2; i++ { + j := len(rowPtrs) - 1 - i + rowPtrs[i], rowPtrs[j] = rowPtrs[j], rowPtrs[i] + } + return +} + +func (m *rowHashMap) Len() int { return m.length } diff --git a/executor/join.go b/executor/join.go index 16ca35710b750..e964da58229b9 100644 --- a/executor/join.go +++ b/executor/join.go @@ -50,7 +50,8 @@ type HashJoinExec struct { // concurrency is the number of partition, build and join workers. concurrency uint - hashTable rowHashTable + hashTable *rowHashMap + joinKeyBuf [][]byte innerFinished chan error // joinWorkerWaitGroup is for sync multiple join workers. joinWorkerWaitGroup sync.WaitGroup @@ -71,7 +72,6 @@ type HashJoinExec struct { outerResultChs []chan *chunk.Chunk joinChkResourceCh []chan *chunk.Chunk joinResultCh chan *hashjoinWorkerResult - hashTableValBufs [][][]byte memTracker *memory.Tracker // track memory usage. prepared bool @@ -141,8 +141,10 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.prepared = false e.memTracker = memory.NewTracker(e.id, e.ctx.GetSessionVars().MemQuotaHashJoin) e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) - - e.hashTableValBufs = make([][][]byte, e.concurrency) + e.joinKeyBuf = make([][]byte, e.concurrency) + for i := range e.joinKeyBuf { + e.joinKeyBuf[i] = make([]byte, 0, 1000) + } e.closeCh = make(chan struct{}) e.finished.Store(false) @@ -150,7 +152,7 @@ func (e *HashJoinExec) Open(ctx context.Context) error { return nil } -func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, h hash.Hash64) (hasNull bool, key uint64, err error) { +func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, h hash.Hash64, buf []byte) (hasNull bool, key uint64, err error) { var keyColIdx []int var allTypes []*types.FieldType if isOuterKey { @@ -167,7 +169,7 @@ func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, h ha } } h.Reset() - err = codec.HashChunkRow(e.ctx.GetSessionVars().StmtCtx, h, row, allTypes, keyColIdx) + err = codec.HashChunkRow(e.ctx.GetSessionVars().StmtCtx, h, row, allTypes, keyColIdx, buf) return false, h.Sum64(), err } @@ -401,7 +403,7 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.Row, joinResult *hashjoinWorkerResult, h hash.Hash64) (bool, *hashjoinWorkerResult) { - hasNull, joinKey, err := e.getJoinKeyFromChkRow(true, outerRow, h) + hasNull, joinKey, err := e.getJoinKeyFromChkRow(true, outerRow, h, e.joinKeyBuf[workerID]) if err != nil { joinResult.err = err return false, joinResult @@ -540,7 +542,7 @@ func (e *HashJoinExec) fetchInnerAndBuildHashTable(ctx context.Context) { doneCh := make(chan struct{}) go util.WithRecovery(func() { e.fetchInnerRows(ctx, innerResultCh, doneCh) }, nil) - // TODO: Parallel build hash table. Currently not support because `mvmap` is not thread-safe. + // TODO: Parallel build hash table. Currently not support because `rowHashMap` is not thread-safe. err := e.buildHashTableForList(innerResultCh) if err != nil { e.innerFinished <- errors.Trace(err) @@ -557,7 +559,7 @@ func (e *HashJoinExec) fetchInnerAndBuildHashTable(ctx context.Context) { // key of hash table: hash value of key columns // value of hash table: RowPtr of the corresponded row func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) error { - e.hashTable = newRowHashTable() + e.hashTable = newRowHashMap() e.innerKeyColIdx = make([]int, len(e.innerKeys)) for i := range e.innerKeys { e.innerKeyColIdx[i] = e.innerKeys[i].Index @@ -566,6 +568,7 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) hasNull bool err error key uint64 + buf = make([]byte, 0, 64) ) h := fnv.New64() @@ -576,7 +579,7 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) } numRows := chk.NumRows() for j := 0; j < numRows; j++ { - hasNull, key, err = e.getJoinKeyFromChkRow(false, chk.GetRow(j), h) + hasNull, key, err = e.getJoinKeyFromChkRow(false, chk.GetRow(j), h, buf) if err != nil { return errors.Trace(err) } diff --git a/util/chunk/column.go b/util/chunk/column.go index e4f1d6c54b0e3..aad386207b601 100644 --- a/util/chunk/column.go +++ b/util/chunk/column.go @@ -526,6 +526,18 @@ func (c *Column) getNameValue(rowID int) (string, uint64) { return string(hack.String(c.data[start+8 : end])), *(*uint64)(unsafe.Pointer(&c.data[start])) } +// GetRaw returns the underlying raw bytes in the specific row. +func (c *Column) GetRaw(rowID int) []byte { + var data []byte + if c.isFixed() { + elemLen := len(c.elemBuf) + data = c.data[rowID*elemLen : rowID*elemLen+elemLen] + } else { + data = c.data[c.offsets[rowID]:c.offsets[rowID+1]] + } + return data +} + // reconstruct reconstructs this Column by removing all filtered rows in it according to sel. func (c *Column) reconstruct(sel []int) { if sel == nil { diff --git a/util/chunk/column_test.go b/util/chunk/column_test.go index bc47fc828e170..2a105440b067e 100644 --- a/util/chunk/column_test.go +++ b/util/chunk/column_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "testing" "time" + "unsafe" "github.com/pingcap/check" "github.com/pingcap/parser/mysql" @@ -749,6 +750,36 @@ func (s *testChunkSuite) TestResizeReserve(c *check.C) { c.Assert(cStrs.length, check.Equals, 0) } +func (s *testChunkSuite) TestGetRaw(c *check.C) { + chk := NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeFloat)}, 1024) + col := chk.Column(0) + for i := 0; i < 1024; i++ { + col.AppendFloat32(float32(i)) + } + it := NewIterator4Chunk(chk) + var i int64 + for row := it.Begin(); row != it.End(); row = it.Next() { + f := float32(i) + b := (*[unsafe.Sizeof(f)]byte)(unsafe.Pointer(&f))[:] + c.Assert(row.GetRaw(0), check.DeepEquals, b) + c.Assert(col.GetRaw(int(i)), check.DeepEquals, b) + i++ + } + + chk = NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeVarString)}, 1024) + col = chk.Column(0) + for i := 0; i < 1024; i++ { + col.AppendString(fmt.Sprint(i)) + } + it = NewIterator4Chunk(chk) + i = 0 + for row := it.Begin(); row != it.End(); row = it.Next() { + c.Assert(row.GetRaw(0), check.DeepEquals, []byte(fmt.Sprint(i))) + c.Assert(col.GetRaw(int(i)), check.DeepEquals, []byte(fmt.Sprint(i))) + i++ + } +} + func BenchmarkDurationRow(b *testing.B) { chk1 := NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDuration)}, 1024) col1 := chk1.Column(0) diff --git a/util/chunk/row.go b/util/chunk/row.go index 0d83c4dec5525..620373a076c48 100644 --- a/util/chunk/row.go +++ b/util/chunk/row.go @@ -204,6 +204,11 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum { return d } +// GetRaw returns the underlying raw bytes with the colIdx. +func (r Row) GetRaw(colIdx int) []byte { + return r.c.columns[colIdx].GetRaw(r.idx) +} + // IsNull returns if the datum in the chunk.Row is null. func (r Row) IsNull(colIdx int) bool { return r.c.columns[colIdx].IsNull(r.idx) diff --git a/util/codec/codec.go b/util/codec/codec.go index d63cfda03f56c..4dcadfd9601c9 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "io" "time" + "unsafe" "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" @@ -275,106 +276,98 @@ func EncodeValue(sc *stmtctx.StatementContext, b []byte, v ...types.Datum) ([]by return encode(sc, b, v, false) } -func encodeHashChunkRowIdx(sc *stmtctx.StatementContext, b []byte, row chunk.Row, tp *types.FieldType, idx int) (_ [][]byte, err error) { +func encodeHashChunkRowIdx(sc *stmtctx.StatementContext, row chunk.Row, tp *types.FieldType, idx int) (flag byte, b []byte, err error) { if row.IsNull(idx) { - return [][]byte{{NilFlag}}, nil + flag = NilFlag + return } - const comparable = false - b = b[:0] switch tp.Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: - if !mysql.HasUnsignedFlag(tp.Flag) { - b = encodeSignedInt(b, row.GetInt64(idx), comparable) - break - } - // encode unsigned integers. - integer := row.GetInt64(idx) - if integer < 0 { - b = encodeUnsignedInt(b, uint64(integer), comparable) - } else { - b = encodeSignedInt(b, integer, comparable) + flag = varintFlag + if mysql.HasUnsignedFlag(tp.Flag) { + if integer := row.GetInt64(idx); integer < 0 { + flag = uvarintFlag + } } + b = row.GetRaw(idx) case mysql.TypeFloat: - b = append(b, floatFlag) - b = EncodeFloat(b, float64(row.GetFloat32(idx))) + flag = floatFlag + f := float64(row.GetFloat32(idx)) + b = (*[unsafe.Sizeof(f)]byte)(unsafe.Pointer(&f))[:] case mysql.TypeDouble: - b = append(b, floatFlag) - b = EncodeFloat(b, row.GetFloat64(idx)) + flag = floatFlag + b = row.GetRaw(idx) case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - return [][]byte{ - {compactBytesFlag}, - row.GetBytes(idx), - }, nil + flag = compactBytesFlag + b = row.GetBytes(idx) case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: - b = append(b, uintFlag) + flag = uintFlag t := row.GetTime(idx) // Encoding timestamp need to consider timezone. // If it's not in UTC, transform to UTC first. if t.Type == mysql.TypeTimestamp && sc.TimeZone != time.UTC { err = t.ConvertTimeZone(sc.TimeZone, time.UTC) if err != nil { - return nil, err + return } } var v uint64 v, err = t.ToPackedUint() if err != nil { - return nil, err + return } - b = EncodeUint(b, v) + b = (*[unsafe.Sizeof(v)]byte)(unsafe.Pointer(&v))[:] case mysql.TypeDuration: + flag = durationFlag // duration may have negative value, so we cannot use String to encode directly. - b = append(b, durationFlag) - b = EncodeInt(b, int64(row.GetDuration(idx, 0).Duration)) + b = row.GetRaw(idx) case mysql.TypeNewDecimal: + flag = decimalFlag // If hash is true, we only consider the original value of this decimal and ignore it's precision. dec := row.GetMyDecimal(idx) - bin, err := dec.ToHashKey() + b, err = dec.ToHashKey() if err != nil { - return nil, err + return } - return [][]byte{ - {decimalFlag}, - bin, - }, nil case mysql.TypeEnum: - b = encodeUnsignedInt(b, uint64(row.GetEnum(idx).ToNumber()), comparable) + flag = uvarintFlag + v := uint64(row.GetEnum(idx).ToNumber()) + b = (*[8]byte)(unsafe.Pointer(&v))[:] case mysql.TypeSet: - b = encodeUnsignedInt(b, uint64(row.GetSet(idx).ToNumber()), comparable) + flag = uvarintFlag + v := uint64(row.GetSet(idx).ToNumber()) + b = (*[unsafe.Sizeof(v)]byte)(unsafe.Pointer(&v))[:] case mysql.TypeBit: // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. - var val uint64 - val, err = types.BinaryLiteral(row.GetBytes(idx)).ToInt(sc) - terror.Log(errors.Trace(err)) - b = encodeUnsignedInt(b, val, comparable) + flag = uvarintFlag + v, err1 := types.BinaryLiteral(row.GetBytes(idx)).ToInt(sc) + terror.Log(errors.Trace(err1)) + b = (*[unsafe.Sizeof(v)]byte)(unsafe.Pointer(&v))[:] case mysql.TypeJSON: - return [][]byte{ - {jsonFlag}, - row.GetBytes(idx), - }, nil + flag = jsonFlag + b = row.GetBytes(idx) default: - return nil, errors.Errorf("unsupport column type for encode %d", tp.Tp) + return 0, nil, errors.Errorf("unsupport column type for encode %d", tp.Tp) } - return [][]byte{b}, nil + return } // HashChunkRow writes the encoded values to w. -// If two rows are equal, it will generate the same bytes. -func HashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int) error { +// If two rows are logically equal, it will generate the same bytes. +func HashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int, buf []byte) (err error) { + buf = buf[:0] + var flag byte var b []byte for _, idx := range colIdx { - rets, err := encodeHashChunkRowIdx(sc, b, row, allTypes[idx], idx) + flag, b, err = encodeHashChunkRowIdx(sc, row, allTypes[idx], idx) if err != nil { return errors.Trace(err) } - for _, ret := range rets { - _, err = w.Write(ret) - if err != nil { - return errors.Trace(err) - } - } + buf = append(buf, flag) + buf = append(buf, b...) } - return nil + _, err = w.Write(buf) + return err } // EqualChunkRow returns a boolean reporting whether row1 and row2 @@ -383,25 +376,19 @@ func EqualChunkRow(sc *stmtctx.StatementContext, row1 chunk.Row, allTypes1 []*types.FieldType, colIdx1 []int, row2 chunk.Row, allTypes2 []*types.FieldType, colIdx2 []int, ) (bool, error) { - var b1, b2 []byte for i := range colIdx1 { idx1, idx2 := colIdx1[i], colIdx2[i] - rets1, err := encodeHashChunkRowIdx(sc, b1, row1, allTypes1[idx1], idx1) + flag1, b1, err := encodeHashChunkRowIdx(sc, row1, allTypes1[idx1], idx1) if err != nil { return false, errors.Trace(err) } - rets2, err := encodeHashChunkRowIdx(sc, b2, row2, allTypes2[idx2], idx2) + flag2, b2, err := encodeHashChunkRowIdx(sc, row2, allTypes2[idx2], idx2) if err != nil { return false, errors.Trace(err) } - if len(rets1) != len(rets2) { + if !(flag1 == flag2 && bytes.Equal(b1, b2)) { return false, nil } - for i := range rets1 { - if !bytes.Equal(rets1[i], rets2[i]) { - return false, nil - } - } } return true, nil } diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index 3b32a8acb86ce..b07dca0a8bd01 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -1036,6 +1036,8 @@ func (s *testCodecSuite) TestDecodeRange(c *C) { func testHashChunkRowEqual(c *C, a, b interface{}, equal bool) { sc := &stmtctx.StatementContext{TimeZone: time.Local} + buf1 := make([]byte, 0, 64) + buf2 := make([]byte, 0, 64) tp1 := new(types.FieldType) types.DefaultTypeForValue(a, tp1) @@ -1052,10 +1054,10 @@ func testHashChunkRowEqual(c *C, a, b interface{}, equal bool) { chk2.AppendDatum(0, &d) h := crc32.NewIEEE() - err1 := HashChunkRow(sc, h, chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}) + err1 := HashChunkRow(sc, h, chk1.GetRow(0), []*types.FieldType{tp1}, []int{0}, buf1) sum1 := h.Sum32() h.Reset() - err2 := HashChunkRow(sc, h, chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}) + err2 := HashChunkRow(sc, h, chk2.GetRow(0), []*types.FieldType{tp2}, []int{0}, buf2) sum2 := h.Sum32() c.Assert(err1, IsNil) c.Assert(err2, IsNil) @@ -1077,6 +1079,7 @@ func testHashChunkRowEqual(c *C, a, b interface{}, equal bool) { func (s *testCodecSuite) TestHashChunkRow(c *C) { sc := &stmtctx.StatementContext{TimeZone: time.Local} + buf := make([]byte, 0, 64) datums, tps := datumsForTest(sc) chk := chunkForTest(c, sc, datums, tps, 1) @@ -1085,10 +1088,10 @@ func (s *testCodecSuite) TestHashChunkRow(c *C) { colIdx[i] = i } h := crc32.NewIEEE() - err1 := HashChunkRow(sc, h, chk.GetRow(0), tps, colIdx) + err1 := HashChunkRow(sc, h, chk.GetRow(0), tps, colIdx, buf) sum1 := h.Sum32() h.Reset() - err2 := HashChunkRow(sc, h, chk.GetRow(0), tps, colIdx) + err2 := HashChunkRow(sc, h, chk.GetRow(0), tps, colIdx, buf) sum2 := h.Sum32() c.Assert(err1, IsNil) From a54f5397539856ebc488c38a6bdebdb1afa15d70 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Tue, 27 Aug 2019 18:59:22 +0800 Subject: [PATCH 06/11] fixup --- executor/join.go | 4 ++-- util/codec/codec.go | 15 +++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/executor/join.go b/executor/join.go index e964da58229b9..6c1d4b3056c19 100644 --- a/executor/join.go +++ b/executor/join.go @@ -143,7 +143,7 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) e.joinKeyBuf = make([][]byte, e.concurrency) for i := range e.joinKeyBuf { - e.joinKeyBuf[i] = make([]byte, 0, 1000) + e.joinKeyBuf[i] = make([]byte, 1) } e.closeCh = make(chan struct{}) @@ -568,7 +568,7 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) hasNull bool err error key uint64 - buf = make([]byte, 0, 64) + buf = make([]byte, 1) ) h := fnv.New64() diff --git a/util/codec/codec.go b/util/codec/codec.go index 4dcadfd9601c9..9ba94cc3dac1c 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -355,18 +355,21 @@ func encodeHashChunkRowIdx(sc *stmtctx.StatementContext, row chunk.Row, tp *type // HashChunkRow writes the encoded values to w. // If two rows are logically equal, it will generate the same bytes. func HashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int, buf []byte) (err error) { - buf = buf[:0] - var flag byte var b []byte for _, idx := range colIdx { - flag, b, err = encodeHashChunkRowIdx(sc, row, allTypes[idx], idx) + buf[0], b, err = encodeHashChunkRowIdx(sc, row, allTypes[idx], idx) if err != nil { return errors.Trace(err) } - buf = append(buf, flag) - buf = append(buf, b...) + _, err = w.Write(buf) + if err != nil { + return + } + _, err = w.Write(b) + if err != nil { + return + } } - _, err = w.Write(buf) return err } From 2da2ebbd9b228a69dfee0f37c5486f4c7e43f0c5 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Tue, 27 Aug 2019 19:58:05 +0800 Subject: [PATCH 07/11] fixup --- util/codec/codec_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index b07dca0a8bd01..f23b23e7f3774 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -1036,8 +1036,8 @@ func (s *testCodecSuite) TestDecodeRange(c *C) { func testHashChunkRowEqual(c *C, a, b interface{}, equal bool) { sc := &stmtctx.StatementContext{TimeZone: time.Local} - buf1 := make([]byte, 0, 64) - buf2 := make([]byte, 0, 64) + buf1 := make([]byte, 1) + buf2 := make([]byte, 1) tp1 := new(types.FieldType) types.DefaultTypeForValue(a, tp1) @@ -1079,7 +1079,7 @@ func testHashChunkRowEqual(c *C, a, b interface{}, equal bool) { func (s *testCodecSuite) TestHashChunkRow(c *C) { sc := &stmtctx.StatementContext{TimeZone: time.Local} - buf := make([]byte, 0, 64) + buf := make([]byte, 1) datums, tps := datumsForTest(sc) chk := chunkForTest(c, sc, datums, tps, 1) From 60f33871405709e4d1704c5ea85a7d153421a158 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Tue, 27 Aug 2019 21:19:56 +0800 Subject: [PATCH 08/11] add more tests --- executor/join_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/executor/join_test.go b/executor/join_test.go index 02bab0e19b674..dd77430d5f83d 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -334,6 +334,73 @@ func (s *testSuite2) TestJoinCast(c *C) { result = tk.MustQuery("select * from t a , t1 b where (a.c1, a.c2) = (b.c1, b.c2);") result.Check(testkit.Rows("1 2 1 2")) + /* Enable & fix this test after https://github.com/pingcap/tidb/issues/11895 is fixed. + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t(c1 bigint unsigned);") + tk.MustExec("create table t1(c1 bit(64));") + tk.MustExec("insert into t value(18446744073709551615);") + tk.MustExec("insert into t1 value(-1);") + result = tk.MustQuery("select * from t, t1 where t.c1 = t1.c1;") + c.Check(len(result.Rows()), Equals, 1) + */ + + /* https://github.com/pingcap/tidb/issues/11896 + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t(c1 bigint);") + tk.MustExec("create table t1(c1 bit(64));") + tk.MustExec("insert into t value(1);") + tk.MustExec("insert into t1 value(1);") + result = tk.MustQuery("select * from t, t1 where t.c1 = t1.c1;") + c.Check(len(result.Rows()), Equals, 1) + */ + + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t(c1 bigint);") + tk.MustExec("create table t1(c1 bit(64));") + tk.MustExec("insert into t value(-1);") + tk.MustExec("insert into t1 value(18446744073709551615);") + result = tk.MustQuery("select * from t, t1 where t.c1 = t1.c1;") + c.Check(len(result.Rows()), Equals, 0) + + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t(c1 bigint)") + tk.MustExec("create table t1(c1 bigint unsigned)") + tk.MustExec("create table t2(c1 Date)") + tk.MustExec("insert into t value(20191111)") + tk.MustExec("insert into t1 value(20191111)") + tk.MustExec("insert into t2 value('2019-11-11')") + result = tk.MustQuery("select * from t, t1, t2 where t.c1 = t2.c1 and t1.c1 = t2.c1") + result.Check(testkit.Rows("20191111 20191111 2019-11-11")) + + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2;") + tk.MustExec("create table t(c1 bigint);") + tk.MustExec("create table t1(c1 bigint unsigned);") + tk.MustExec("create table t2(c1 enum('a', 'b', 'c', 'd'));") + tk.MustExec("insert into t value(3);") + tk.MustExec("insert into t1 value(3);") + tk.MustExec("insert into t2 value('c');") + result = tk.MustQuery("select * from t, t1, t2 where t.c1 = t2.c1 and t1.c1 = t2.c1;") + result.Check(testkit.Rows("3 3 c")) + + tk.MustExec("drop table if exists t;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("drop table if exists t2;") + tk.MustExec("create table t(c1 bigint);") + tk.MustExec("create table t1(c1 bigint unsigned);") + tk.MustExec("create table t2 (c1 SET('a', 'b', 'c', 'd'));") + tk.MustExec("insert into t value(9);") + tk.MustExec("insert into t1 value(9);") + tk.MustExec("insert into t2 value('a,d');") + result = tk.MustQuery("select * from t, t1, t2 where t.c1 = t2.c1 and t1.c1 = t2.c1;") + result.Check(testkit.Rows("9 9 a,d")) + tk.MustExec("drop table if exists t") tk.MustExec("drop table if exists t1") tk.MustExec("create table t(c1 int)") @@ -364,6 +431,7 @@ func (s *testSuite2) TestJoinCast(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") tk.MustExec("create table t(c1 char(10))") tk.MustExec("create table t1(c1 char(10))") tk.MustExec("create table t2(c1 char(10))") From aabe4726dded8115da3791bff5013d66b9f5b551 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Wed, 28 Aug 2019 13:13:25 +0800 Subject: [PATCH 09/11] add comment for rowHashMap --- executor/hash_table.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/executor/hash_table.go b/executor/hash_table.go index 562bd1cd2524a..aa6d8d58ff98c 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -55,14 +55,19 @@ type entryAddr struct { var nullEntryAddr = entryAddr{} +// rowHashMap stores multiple rowPtr of rows for a given key with minimum GC overhead. +// A given key can store multiple values. +// It is not thread-safe, should only be used in one goroutine. type rowHashMap struct { entryStore entryStore hashTable map[uint64]entryAddr length int } +// newRowHashMap creates a new rowHashMap. func newRowHashMap() *rowHashMap { m := new(rowHashMap) + // TODO(fengliyuan): initialize the size of map from the estimated row count for better performance. m.hashTable = make(map[uint64]entryAddr) m.entryStore.slices = [][]entry{make([]entry, 0, 64)} // Reserve the first empty entry, so entryAddr{} can represent nullEntryAddr. @@ -70,6 +75,7 @@ func newRowHashMap() *rowHashMap { return m } +// Put puts the key/rowPtr pairs to the rowHashMap, multiple rowPtrs are stored in a list. func (m *rowHashMap) Put(hashKey uint64, rowPtr chunk.RowPtr) { oldEntryAddr := m.hashTable[hashKey] e := entry{ @@ -81,6 +87,7 @@ func (m *rowHashMap) Put(hashKey uint64, rowPtr chunk.RowPtr) { m.length++ } +// Get gets the values of the "key" and appends them to "values". func (m *rowHashMap) Get(hashKey uint64) (rowPtrs []chunk.RowPtr) { entryAddr := m.hashTable[hashKey] for entryAddr != nullEntryAddr { @@ -96,4 +103,6 @@ func (m *rowHashMap) Get(hashKey uint64) (rowPtrs []chunk.RowPtr) { return } +// Len returns the number of rowPtrs in the rowHashMap, the number of keys may be less than Len +// if the same key is put more than once. func (m *rowHashMap) Len() int { return m.length } From be51694267bee1ffb00d44351d8e0c389fc8c7a8 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Wed, 28 Aug 2019 20:40:59 +0800 Subject: [PATCH 10/11] address commment --- executor/hash_table.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/executor/hash_table.go b/executor/hash_table.go index aa6d8d58ff98c..af5693f236df2 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -37,9 +37,7 @@ func (es *entryStore) put(e entry) entryAddr { es.sliceIdx++ } addr := entryAddr{sliceIdx: es.sliceIdx, offset: es.sliceLen} - slice := es.slices[es.sliceIdx] - slice = append(slice, e) - es.slices[es.sliceIdx] = slice + es.slices[es.sliceIdx] = append(es.slices[es.sliceIdx], e) es.sliceLen++ return addr } From 134780aa8b4413cee01c3e9ee718dca6d0e67876 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Wed, 28 Aug 2019 22:03:33 +0800 Subject: [PATCH 11/11] fix unit test in master branch --- expression/constant_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expression/constant_test.go b/expression/constant_test.go index f3b24dec09e8d..d495f963d5643 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -389,7 +389,7 @@ func (*testExpressionSuite) TestDeferredParamNotNull(c *C) { c.Assert(mysql.TypeBlob, Equals, cstBytes.GetType().Tp) c.Assert(mysql.TypeBit, Equals, cstBinary.GetType().Tp) c.Assert(mysql.TypeBit, Equals, cstBit.GetType().Tp) - c.Assert(mysql.TypeVarString, Equals, cstFloat32.GetType().Tp) + c.Assert(mysql.TypeFloat, Equals, cstFloat32.GetType().Tp) c.Assert(mysql.TypeDouble, Equals, cstFloat64.GetType().Tp) c.Assert(mysql.TypeEnum, Equals, cstEnum.GetType().Tp)