diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 3f64164332ce7..542ba5d5f963c 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -913,7 +913,6 @@ func prepare4HashJoin(testCase *hashJoinTestCase, innerExec, outerExec Executor) e := &HashJoinExec{ baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, 5, innerExec, outerExec), hashJoinCtx: &hashJoinCtx{ - sessCtx: testCase.ctx, joinType: testCase.joinType, // 0 for InnerJoin, 1 for LeftOutersJoin, 2 for RightOuterJoin isOuterJoin: false, useOuterToBuild: testCase.useOuterToBuild, @@ -937,13 +936,13 @@ func prepare4HashJoin(testCase *hashJoinTestCase, innerExec, outerExec Executor) for i := uint(0); i < e.concurrency; i++ { e.probeWorkers[i] = &probeWorker{ workerID: i, + sessCtx: e.ctx, hashJoinCtx: e.hashJoinCtx, joiner: newJoiner(testCase.ctx, e.joinType, true, defaultValues, nil, lhsTypes, rhsTypes, childrenUsedSchema, false), probeKeyColIdx: probeKeysColIdx, } } - e.buildWorker.hashJoinCtx = e.hashJoinCtx memLimit := int64(-1) if testCase.disk { memLimit = 1 @@ -1201,7 +1200,7 @@ func benchmarkBuildHashTable(b *testing.B, casTest *hashJoinTestCase, dataSource close(innerResultCh) b.StartTimer() - if err := exec.buildWorker.buildHashTableForList(innerResultCh); err != nil { + if err := exec.buildHashTableForList(innerResultCh); err != nil { b.Fatal(err) } diff --git a/executor/builder.go b/executor/builder.go index d4270397eecd0..7b3ca2d192983 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1417,14 +1417,12 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo probeWorkers: make([]*probeWorker, v.Concurrency), buildWorker: &buildWorker{}, hashJoinCtx: &hashJoinCtx{ - sessCtx: b.ctx, isOuterJoin: v.JoinType.IsOuterJoin(), useOuterToBuild: v.UseOuterToBuild, joinType: v.JoinType, concurrency: v.Concurrency, }, } - e.hashJoinCtx.allocPool = e.AllocPool defaultValues := v.DefaultValues lhsTypes, rhsTypes := retTypes(leftExec), retTypes(rightExec) if v.InnerChildIdx == 1 { @@ -1496,12 +1494,13 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo e.probeWorkers[i] = &probeWorker{ hashJoinCtx: e.hashJoinCtx, workerID: i, + sessCtx: e.ctx, joiner: newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, lhsTypes, rhsTypes, childrenUsedSchema, isNAJoin), probeKeyColIdx: probeKeyColIdx, probeNAKeyColIdx: probeNAKeColIdx, } } - e.buildWorker.buildKeyColIdx, e.buildWorker.buildNAKeyColIdx, e.buildWorker.buildSideExec, e.buildWorker.hashJoinCtx = buildKeyColIdx, buildNAKeyColIdx, buildSideExec, e.hashJoinCtx + e.buildWorker.buildKeyColIdx, e.buildWorker.buildNAKeyColIdx, e.buildWorker.buildSideExec = buildKeyColIdx, buildNAKeyColIdx, buildSideExec e.hashJoinCtx.isNullAware = isNAJoin executorCountHashJoinExec.Inc() diff --git a/executor/join.go b/executor/join.go index 91533be2259cf..214e2edb1d440 100644 --- a/executor/join.go +++ b/executor/join.go @@ -47,8 +47,6 @@ var ( ) type hashJoinCtx struct { - sessCtx sessionctx.Context - allocPool chunk.Allocator // concurrency is the number of partition, build and join workers. concurrency uint joinResultCh chan *hashjoinWorkerResult @@ -67,8 +65,6 @@ type hashJoinCtx struct { buildTypes []*types.FieldType outerFilter expression.CNFExprs isNullAware bool - memTracker *memory.Tracker // track memory usage. - diskTracker *disk.Tracker // track disk usage. } // probeSideTupleFetcher reads tuples from probeSideExec and send them to probeWorkers. @@ -83,6 +79,7 @@ type probeSideTupleFetcher struct { type probeWorker struct { hashJoinCtx *hashJoinCtx + sessCtx sessionctx.Context workerID uint probeKeyColIdx []int @@ -105,7 +102,6 @@ type probeWorker struct { } type buildWorker struct { - hashJoinCtx *hashJoinCtx buildSideExec Executor buildKeyColIdx []int buildNAKeyColIdx []int @@ -120,8 +116,11 @@ type HashJoinExec struct { probeWorkers []*probeWorker buildWorker *buildWorker - workerWg util.WaitGroupWrapper - waiterWg util.WaitGroupWrapper + worker util.WaitGroupWrapper + waiter util.WaitGroupWrapper + + memTracker *memory.Tracker // track memory usage. + diskTracker *disk.Tracker // track disk usage. prepared bool } @@ -170,7 +169,7 @@ func (e *HashJoinExec) Close() error { } e.probeSideTupleFetcher.probeChkResourceCh = nil terror.Call(e.rowContainer.Close) - e.waiterWg.Wait() + e.waiter.Wait() } e.outerMatchedStatus = e.outerMatchedStatus[:0] for _, w := range e.probeWorkers { @@ -199,14 +198,14 @@ func (e *HashJoinExec) Open(ctx context.Context) error { return err } e.prepared = false - e.hashJoinCtx.memTracker = memory.NewTracker(e.id, -1) - e.hashJoinCtx.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) + e.memTracker = memory.NewTracker(e.id, -1) + e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) e.diskTracker = disk.NewTracker(e.id, -1) e.diskTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.DiskTracker) - e.workerWg = util.WaitGroupWrapper{} - e.waiterWg = util.WaitGroupWrapper{} + e.worker = util.WaitGroupWrapper{} + e.waiter = util.WaitGroupWrapper{} e.closeCh = make(chan struct{}) e.finished.Store(false) @@ -296,7 +295,7 @@ func (fetcher *probeSideTupleFetcher) wait4BuildSide() (emptyBuild bool, err err // fetchBuildSideRows fetches all rows from build side executor, and append them // to e.buildSideResult. -func (w *buildWorker) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) { +func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) { defer close(chkCh) var err error failpoint.Inject("issue30289", func(val failpoint.Value) { @@ -306,13 +305,12 @@ func (w *buildWorker) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chun return } }) - sessVars := w.hashJoinCtx.sessCtx.GetSessionVars() for { - if w.hashJoinCtx.finished.Load() { + if e.finished.Load() { return } - chk := sessVars.GetNewChunkWithCapacity(w.buildSideExec.base().retFieldTypes, sessVars.MaxChunkSize, sessVars.MaxChunkSize, w.hashJoinCtx.allocPool) - err = Next(ctx, w.buildSideExec, chk) + chk := e.ctx.GetSessionVars().GetNewChunkWithCapacity(e.buildWorker.buildSideExec.base().retFieldTypes, e.ctx.GetSessionVars().MaxChunkSize, e.ctx.GetSessionVars().MaxChunkSize, e.AllocPool) + err = Next(ctx, e.buildWorker.buildSideExec, chk) if err != nil { errCh <- errors.Trace(err) return @@ -325,7 +323,7 @@ func (w *buildWorker) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chun select { case <-doneCh: return - case <-w.hashJoinCtx.closeCh: + case <-e.closeCh: return case chkCh <- chk: } @@ -368,19 +366,19 @@ func (e *HashJoinExec) initializeForProbe() { func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { e.initializeForProbe() - e.workerWg.RunWithRecover(func() { + e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End() e.probeSideTupleFetcher.fetchProbeSideChunks(ctx, e.maxChunkSize) }, e.probeSideTupleFetcher.handleProbeSideFetcherPanic) for i := uint(0); i < e.concurrency; i++ { workerID := i - e.workerWg.RunWithRecover(func() { + e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinWorker").End() e.probeWorkers[workerID].runJoinWorker() }, e.probeWorkers[workerID].handleProbeWorkerPanic) } - e.waiterWg.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil) + e.waiter.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil) } func (fetcher *probeSideTupleFetcher) handleProbeSideFetcherPanic(r interface{}) { @@ -441,14 +439,14 @@ func (w *probeWorker) handleUnmatchedRowsFromHashTable() { } func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() { - e.workerWg.Wait() + e.worker.Wait() if e.useOuterToBuild { // Concurrently handling unmatched rows from the hash table at the tail for i := uint(0); i < e.concurrency; i++ { var workerID = i - e.workerWg.RunWithRecover(func() { e.probeWorkers[workerID].handleUnmatchedRowsFromHashTable() }, e.handleJoinWorkerPanic) + e.worker.RunWithRecover(func() { e.probeWorkers[workerID].handleUnmatchedRowsFromHashTable() }, e.handleJoinWorkerPanic) } - e.workerWg.Wait() + e.worker.Wait() } close(e.joinResultCh) } @@ -956,7 +954,7 @@ func (w *probeWorker) getNewJoinResult() (bool, *hashjoinWorkerResult) { func (w *probeWorker) join2Chunk(probeSideChk *chunk.Chunk, hCtx *hashContext, joinResult *hashjoinWorkerResult, selected []bool) (ok bool, _ *hashjoinWorkerResult) { var err error - selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(probeSideChk), selected) + selected, err = expression.VectorizedFilter(w.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(probeSideChk), selected) if err != nil { joinResult.err = err return false, joinResult @@ -996,7 +994,7 @@ func (w *probeWorker) join2Chunk(probeSideChk *chunk.Chunk, hCtx *hashContext, j } for i := range selected { - killed := atomic.LoadUint32(&w.hashJoinCtx.sessCtx.GetSessionVars().Killed) == 1 + killed := atomic.LoadUint32(&w.sessCtx.GetSessionVars().Killed) == 1 failpoint.Inject("killedInJoin2Chunk", func(val failpoint.Value) { if val.(bool) { killed = true @@ -1062,7 +1060,7 @@ func (w *probeWorker) join2ChunkForOuterHashJoin(probeSideChk *chunk.Chunk, hCtx } } for i := 0; i < probeSideChk.NumRows(); i++ { - killed := atomic.LoadUint32(&w.hashJoinCtx.sessCtx.GetSessionVars().Killed) == 1 + killed := atomic.LoadUint32(&w.sessCtx.GetSessionVars().Killed) == 1 failpoint.Inject("killedInJoin2ChunkForOuterHashJoin", func(val failpoint.Value) { if val.(bool) { killed = true @@ -1112,7 +1110,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { for i := uint(0); i < e.concurrency; i++ { e.probeWorkers[i].rowIters = chunk.NewIterator4Slice([]chunk.Row{}).(*chunk.Iterator4Slice) } - e.workerWg.RunWithRecover(func() { + e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End() e.fetchAndBuildHashTable(ctx) }, e.handleFetchAndBuildHashTablePanic) @@ -1155,10 +1153,10 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) { buildSideResultCh := make(chan *chunk.Chunk, 1) doneCh := make(chan struct{}) fetchBuildSideRowsOk := make(chan error, 1) - e.workerWg.RunWithRecover( + e.worker.RunWithRecover( func() { defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End() - e.buildWorker.fetchBuildSideRows(ctx, buildSideResultCh, fetchBuildSideRowsOk, doneCh) + e.fetchBuildSideRows(ctx, buildSideResultCh, fetchBuildSideRowsOk, doneCh) }, func(r interface{}) { if r != nil { @@ -1169,7 +1167,7 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) { ) // TODO: Parallel build hash table. Currently not support because `unsafeHashTable` is not thread-safe. - err := e.buildWorker.buildHashTableForList(buildSideResultCh) + err := e.buildHashTableForList(buildSideResultCh) if err != nil { e.buildFinished <- errors.Trace(err) close(doneCh) @@ -1187,42 +1185,41 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) { } // buildHashTableForList builds hash table from `list`. -func (w *buildWorker) buildHashTableForList(buildSideResultCh <-chan *chunk.Chunk) error { +func (e *HashJoinExec) buildHashTableForList(buildSideResultCh <-chan *chunk.Chunk) error { var err error var selected []bool - rowContainer := w.hashJoinCtx.rowContainer - rowContainer.GetMemTracker().AttachTo(w.hashJoinCtx.memTracker) - rowContainer.GetMemTracker().SetLabel(memory.LabelForBuildSideResult) - rowContainer.GetDiskTracker().AttachTo(w.hashJoinCtx.diskTracker) - rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult) + e.rowContainer.GetMemTracker().AttachTo(e.memTracker) + e.rowContainer.GetMemTracker().SetLabel(memory.LabelForBuildSideResult) + e.rowContainer.GetDiskTracker().AttachTo(e.diskTracker) + e.rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult) if variable.EnableTmpStorageOnOOM.Load() { - actionSpill := rowContainer.ActionSpill() + actionSpill := e.rowContainer.ActionSpill() failpoint.Inject("testRowContainerSpill", func(val failpoint.Value) { if val.(bool) { - actionSpill = rowContainer.rowContainer.ActionSpillForTest() + actionSpill = e.rowContainer.rowContainer.ActionSpillForTest() defer actionSpill.(*chunk.SpillDiskAction).WaitForTest() } }) - w.hashJoinCtx.sessCtx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) + e.ctx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) } for chk := range buildSideResultCh { - if w.hashJoinCtx.finished.Load() { + if e.finished.Load() { return nil } - if !w.hashJoinCtx.useOuterToBuild { - err = rowContainer.PutChunk(chk, w.hashJoinCtx.isNullEQ) + if !e.useOuterToBuild { + err = e.rowContainer.PutChunk(chk, e.isNullEQ) } else { var bitMap = bitmap.NewConcurrentBitmap(chk.NumRows()) - w.hashJoinCtx.outerMatchedStatus = append(w.hashJoinCtx.outerMatchedStatus, bitMap) - w.hashJoinCtx.memTracker.Consume(bitMap.BytesConsumed()) - if len(w.hashJoinCtx.outerFilter) == 0 { - err = w.hashJoinCtx.rowContainer.PutChunk(chk, w.hashJoinCtx.isNullEQ) + e.outerMatchedStatus = append(e.outerMatchedStatus, bitMap) + e.memTracker.Consume(bitMap.BytesConsumed()) + if len(e.outerFilter) == 0 { + err = e.rowContainer.PutChunk(chk, e.isNullEQ) } else { - selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(chk), selected) + selected, err = expression.VectorizedFilter(e.ctx, e.outerFilter, chunk.NewIterator4Chunk(chk), selected) if err != nil { return err } - err = rowContainer.PutChunkSelected(chk, selected, w.hashJoinCtx.isNullEQ) + err = e.rowContainer.PutChunkSelected(chk, selected, e.isNullEQ) } } failpoint.Inject("ConsumeRandomPanic", nil)