diff --git a/executor/adapter.go b/executor/adapter.go index aaed323245541..e6c0d2af64d2f 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -106,7 +106,7 @@ func (a *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { defer span1.Finish() } - err := a.executor.Next(ctx, req) + err := Next(ctx, a.executor, req) if err != nil { a.lastErr = err return err @@ -126,7 +126,7 @@ func (a *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { // NewRecordBatch create a recordBatch base on top-level executor's newFirstChunk(). func (a *recordSet) NewRecordBatch() *chunk.RecordBatch { - return chunk.NewRecordBatch(a.executor.newFirstChunk()) + return chunk.NewRecordBatch(newFirstChunk(a.executor)) } func (a *recordSet) Close() error { @@ -307,7 +307,7 @@ func (c *chunkRowRecordSet) Next(ctx context.Context, req *chunk.RecordBatch) er } func (c *chunkRowRecordSet) NewRecordBatch() *chunk.RecordBatch { - return chunk.NewRecordBatch(c.e.newFirstChunk()) + return chunk.NewRecordBatch(newFirstChunk(c.e)) } func (c *chunkRowRecordSet) Close() error { @@ -385,7 +385,7 @@ func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, e Executor) (sqlex a.logAudit() }() - err = e.Next(ctx, chunk.NewRecordBatch(e.newFirstChunk())) + err = Next(ctx, e, chunk.NewRecordBatch(newFirstChunk(e))) if err != nil { return nil, err } diff --git a/executor/admin.go b/executor/admin.go index e6721dd976f89..f4944ba17f916 100644 --- a/executor/admin.go +++ b/executor/admin.go @@ -102,7 +102,7 @@ func (e *CheckIndexRangeExec) Open(ctx context.Context) error { FieldType: *colTypeForHandle, }) - e.srcChunk = e.newFirstChunk() + e.srcChunk = newFirstChunk(e) dagPB, err := e.buildDAGPB() if err != nil { return err diff --git a/executor/aggregate.go b/executor/aggregate.go index e8fa560797dab..ad3d224b7c6a9 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -239,7 +239,7 @@ func (e *HashAggExec) initForUnparallelExec() { e.partialResultMap = make(aggPartialResultMapper) e.groupKeyBuffer = make([]byte, 0, 8) e.groupValDatums = make([]types.Datum, 0, len(e.groupKeyBuffer)) - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) } func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { @@ -275,12 +275,12 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { partialResultsMap: make(aggPartialResultMapper), groupByItems: e.GroupByItems, groupValDatums: make([]types.Datum, 0, len(e.GroupByItems)), - chk: e.children[0].newFirstChunk(), + chk: newFirstChunk(e.children[0]), } e.partialWorkers[i] = w e.inputCh <- &HashAggInput{ - chk: e.children[0].newFirstChunk(), + chk: newFirstChunk(e.children[0]), giveBackCh: w.inputCh, } } @@ -295,7 +295,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { outputCh: e.finalOutputCh, finalResultHolderCh: e.finalInputCh, rowBuffer: make([]types.Datum, 0, e.Schema().Len()), - mutableRow: chunk.MutRowFromTypes(e.retTypes()), + mutableRow: chunk.MutRowFromTypes(retTypes(e)), } } } @@ -555,7 +555,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context) { } chk = input.chk } - err = e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err = Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { e.finalOutputCh <- &AfFinalResult{err: err} return @@ -681,7 +681,7 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro func (e *HashAggExec) execute(ctx context.Context) (err error) { inputIter := chunk.NewIterator4Chunk(e.childResult) for { - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } @@ -772,7 +772,7 @@ func (e *StreamAggExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) e.executed = false e.isChildReturnEmpty = true e.inputIter = chunk.NewIterator4Chunk(e.childResult) @@ -870,7 +870,7 @@ func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Ch return err } - err = e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err = Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index e13e0eaf6f137..4d38af79e887f 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -129,7 +129,7 @@ func (mds *mockDataSource) Next(ctx context.Context, req *chunk.RecordBatch) err func buildMockDataSource(opt mockDataSourceParameters) *mockDataSource { baseExec := newBaseExecutor(opt.ctx, opt.schema, nil) m := &mockDataSource{baseExec, opt, nil, nil, 0} - types := m.retTypes() + types := retTypes(m) colData := make([][]interface{}, len(types)) for i := 0; i < len(types); i++ { colData[i] = m.genColDatums(i) @@ -137,12 +137,12 @@ func buildMockDataSource(opt mockDataSourceParameters) *mockDataSource { m.genData = make([]*chunk.Chunk, (m.p.rows+m.initCap-1)/m.initCap) for i := range m.genData { - m.genData[i] = chunk.NewChunkWithCapacity(m.retTypes(), m.ctx.GetSessionVars().MaxChunkSize) + m.genData[i] = chunk.NewChunkWithCapacity(retTypes(m), m.ctx.GetSessionVars().MaxChunkSize) } for i := 0; i < m.p.rows; i++ { idx := i / m.maxChunkSize - retTypes := m.retTypes() + retTypes := retTypes(m) for colIdx := 0; colIdx < len(types); colIdx++ { switch retTypes[colIdx].Tp { case mysql.TypeLong, mysql.TypeLonglong: @@ -259,7 +259,7 @@ func benchmarkAggExecWithCase(b *testing.B, casTest *aggTestCase) { b.StopTimer() // prepare a new agg-executor aggExec := buildAggExecutor(b, casTest, dataSource) tmpCtx := context.Background() - chk := aggExec.newFirstChunk() + chk := newFirstChunk(aggExec) dataSource.prepareChunks() b.StartTimer() diff --git a/executor/builder.go b/executor/builder.go index 9d424d86ea0d9..25e340ff361fe 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -872,8 +872,8 @@ func (b *executorBuilder) buildMergeJoin(v *plannercore.PhysicalMergeJoin) Execu v.JoinType == plannercore.RightOuterJoin, defaultValues, v.OtherConditions, - leftExec.retTypes(), - rightExec.retTypes(), + retTypes(leftExec), + retTypes(rightExec), ), isOuterJoin: v.JoinType.IsOuterJoin(), } @@ -946,7 +946,7 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo } defaultValues := v.DefaultValues - lhsTypes, rhsTypes := leftExec.retTypes(), rightExec.retTypes() + lhsTypes, rhsTypes := retTypes(leftExec), retTypes(rightExec) if v.InnerChildIdx == 0 { if len(v.LeftConditions) > 0 { b.err = errors.Annotate(ErrBuildExecutor, "join's inner condition should be empty") @@ -1020,7 +1020,7 @@ func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) Executor if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { e.defaultVal = nil } else { - e.defaultVal = chunk.NewChunkWithCapacity(e.retTypes(), 1) + e.defaultVal = chunk.NewChunkWithCapacity(retTypes(e), 1) } for _, aggDesc := range v.AggFuncs { if aggDesc.HasDistinct { @@ -1079,7 +1079,7 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) Execu if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { e.defaultVal = nil } else { - e.defaultVal = chunk.NewChunkWithCapacity(e.retTypes(), 1) + e.defaultVal = chunk.NewChunkWithCapacity(retTypes(e), 1) } for i, aggDesc := range v.AggFuncs { aggFunc := aggfuncs.Build(b.ctx, aggDesc, i) @@ -1220,7 +1220,7 @@ func (b *executorBuilder) buildApply(v *plannercore.PhysicalApply) *NestedLoopAp defaultValues = make([]types.Datum, v.Children()[v.InnerChildIdx].Schema().Len()) } tupleJoiner := newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, - defaultValues, otherConditions, leftChild.retTypes(), rightChild.retTypes()) + defaultValues, otherConditions, retTypes(leftChild), retTypes(rightChild)) outerExec, innerExec := leftChild, rightChild outerFilter, innerFilter := v.LeftConditions, v.RightConditions if v.InnerChildIdx == 0 { @@ -1703,7 +1703,7 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) if b.err != nil { return nil } - outerTypes := outerExec.retTypes() + outerTypes := retTypes(outerExec) innerPlan := v.Children()[1-v.OuterIndex] innerTypes := make([]*types.FieldType, innerPlan.Schema().Len()) for i, col := range innerPlan.Schema().Columns { @@ -1761,7 +1761,7 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) innerKeyCols[i] = v.InnerJoinKeys[i].Index } e.innerCtx.keyCols = innerKeyCols - e.joinResult = e.newFirstChunk() + e.joinResult = newFirstChunk(e) executorCounterIndexLookUpJoin.Inc() return e } @@ -2015,7 +2015,7 @@ func (builder *dataReaderBuilder) buildUnionScanForIndexJoin(ctx context.Context return nil, err } us := e.(*UnionScanExec) - us.snapshotChunkBuffer = us.newFirstChunk() + us.snapshotChunkBuffer = newFirstChunk(us) return us, nil } @@ -2050,7 +2050,7 @@ func (builder *dataReaderBuilder) buildTableReaderFromHandles(ctx context.Contex return nil, err } e.resultHandler = &tableResultHandler{} - result, err := builder.SelectResult(ctx, builder.ctx, kvReq, e.retTypes(), e.feedback, getPhysicalPlanIDs(e.plans)) + result, err := builder.SelectResult(ctx, builder.ctx, kvReq, retTypes(e), e.feedback, getPhysicalPlanIDs(e.plans)) if err != nil { return nil, err } diff --git a/executor/delete.go b/executor/delete.go index c3c52363dff2b..f198861dfe99d 100644 --- a/executor/delete.go +++ b/executor/delete.go @@ -100,12 +100,12 @@ func (e *DeleteExec) deleteSingleTableByChunk(ctx context.Context) error { // If tidb_batch_delete is ON and not in a transaction, we could use BatchDelete mode. batchDelete := e.ctx.GetSessionVars().BatchDelete && !e.ctx.GetSessionVars().InTxn() batchDMLSize := e.ctx.GetSessionVars().DMLBatchSize - fields := e.children[0].retTypes() - chk := e.children[0].newFirstChunk() + fields := retTypes(e.children[0]) + chk := newFirstChunk(e.children[0]) for { iter := chunk.NewIterator4Chunk(chk) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { return err } @@ -183,11 +183,11 @@ func (e *DeleteExec) deleteMultiTablesByChunk(ctx context.Context) error { e.initialMultiTableTblMap() colPosInfos := e.getColPosInfos(e.children[0].Schema()) tblRowMap := make(tableRowMapType) - fields := e.children[0].retTypes() - chk := e.children[0].newFirstChunk() + fields := retTypes(e.children[0]) + chk := newFirstChunk(e.children[0]) for { iter := chunk.NewIterator4Chunk(chk) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { return err } diff --git a/executor/distsql.go b/executor/distsql.go index 5bd77376292c1..2b380313355d0 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -330,7 +330,7 @@ func (e *IndexReaderExecutor) open(ctx context.Context, kvRanges []kv.KeyRange) e.feedback.Invalidate() return err } - e.result, err = e.SelectResult(ctx, e.ctx, kvReq, e.retTypes(), e.feedback, getPhysicalPlanIDs(e.plans)) + e.result, err = e.SelectResult(ctx, e.ctx, kvReq, retTypes(e), e.feedback, getPhysicalPlanIDs(e.plans)) if err != nil { e.feedback.Invalidate() return err @@ -794,7 +794,7 @@ func (w *tableWorker) executeTask(ctx context.Context, task *lookupTableTask) er handleCnt := len(task.handles) task.rows = make([]chunk.Row, 0, handleCnt) for { - chk := tableReader.newFirstChunk() + chk := newFirstChunk(tableReader) err = tableReader.Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { logutil.Logger(ctx).Error("table reader fetch next chunk failed", zap.Error(err)) diff --git a/executor/errors.go b/executor/errors.go index c1d426ef71535..a48152f0acdfe 100644 --- a/executor/errors.go +++ b/executor/errors.go @@ -52,6 +52,7 @@ var ( ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject]) ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted]) ErrDeadlock = terror.ClassExecutor.New(mysql.ErrLockDeadlock, mysql.MySQLErrName[mysql.ErrLockDeadlock]) + ErrQueryInterrupted = terror.ClassExecutor.New(mysql.ErrQueryInterrupted, mysql.MySQLErrName[mysql.ErrQueryInterrupted]) ) func init() { @@ -69,6 +70,7 @@ func init() { mysql.ErrBadDB: mysql.ErrBadDB, mysql.ErrWrongObject: mysql.ErrWrongObject, mysql.ErrLockDeadlock: mysql.ErrLockDeadlock, + mysql.ErrQueryInterrupted: mysql.ErrQueryInterrupted, } terror.ErrClassToMySQLCodes[terror.ClassExecutor] = tableMySQLErrCodes } diff --git a/executor/executor.go b/executor/executor.go index 5a26b246f3adc..4481657de1443 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -86,6 +86,11 @@ type baseExecutor struct { runtimeStats *execdetails.RuntimeStats } +// base returns the baseExecutor of an executor, don't override this method! +func (e *baseExecutor) base() *baseExecutor { + return e +} + // Open initializes children recursively and "childrenResults" according to children's schemas. func (e *baseExecutor) Open(ctx context.Context) error { for _, child := range e.children { @@ -117,13 +122,15 @@ func (e *baseExecutor) Schema() *expression.Schema { } // newFirstChunk creates a new chunk to buffer current executor's result. -func (e *baseExecutor) newFirstChunk() *chunk.Chunk { - return chunk.New(e.retTypes(), e.initCap, e.maxChunkSize) +func newFirstChunk(e Executor) *chunk.Chunk { + base := e.base() + return chunk.New(base.retFieldTypes, base.initCap, base.maxChunkSize) } // retTypes returns all output column types. -func (e *baseExecutor) retTypes() []*types.FieldType { - return e.retFieldTypes +func retTypes(e Executor) []*types.FieldType { + base := e.base() + return base.retFieldTypes } // Next fills mutiple rows into a chunk. @@ -166,13 +173,21 @@ func newBaseExecutor(ctx sessionctx.Context, schema *expression.Schema, id fmt.S // return a batch of rows, other than a single row in Volcano. // NOTE: Executors must call "chk.Reset()" before appending their results to it. type Executor interface { + base() *baseExecutor Open(context.Context) error Next(ctx context.Context, req *chunk.RecordBatch) error Close() error Schema() *expression.Schema +} + +// Next is a wrapper function on e.Next(), it handles some common codes. +func Next(ctx context.Context, e Executor, req *chunk.RecordBatch) error { + sessVars := e.base().ctx.GetSessionVars() + if atomic.CompareAndSwapUint32(&sessVars.Killed, 1, 0) { + return ErrQueryInterrupted + } - retTypes() []*types.FieldType - newFirstChunk() *chunk.Chunk + return e.Next(ctx, req) } // CancelDDLJobsExec represents a cancel DDL jobs executor. @@ -552,9 +567,9 @@ func (e *CheckIndexExec) Next(ctx context.Context, req *chunk.RecordBatch) error if err != nil { return err } - chk := e.src.newFirstChunk() + chk := newFirstChunk(e.src) for { - err := e.src.Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.src, chunk.NewRecordBatch(chk)) if err != nil { return err } @@ -663,7 +678,7 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.RecordBatch) error } req.GrowAndReset(e.maxChunkSize) - err := e.children[0].Next(ctx, req) + err := Next(ctx, e.children[0], req) if err != nil { return err } @@ -723,7 +738,7 @@ func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error { for !e.meetFirstBatch { // transfer req's requiredRows to childResult and then adjust it in childResult e.childResult = e.childResult.SetRequiredRows(req.RequiredRows(), e.maxChunkSize) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.adjustRequiredRows(e.childResult))) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.adjustRequiredRows(e.childResult))) if err != nil { return err } @@ -748,7 +763,7 @@ func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error { e.cursor += batchSize } e.adjustRequiredRows(req.Chunk) - err := e.children[0].Next(ctx, req) + err := Next(ctx, e.children[0], req) if err != nil { return err } @@ -770,7 +785,7 @@ func (e *LimitExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) e.cursor = 0 e.meetFirstBatch = e.begin == 0 return nil @@ -816,9 +831,9 @@ func init() { if err != nil { return rows, err } - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for { - err = exec.Next(ctx, chunk.NewRecordBatch(chk)) + err = Next(ctx, exec, chunk.NewRecordBatch(chk)) if err != nil { return rows, err } @@ -827,7 +842,7 @@ func init() { } iter := chunk.NewIterator4Chunk(chk) for r := iter.Begin(); r != iter.End(); r = iter.Next() { - row := r.GetDatumRow(exec.retTypes()) + row := r.GetDatumRow(retTypes(exec)) rows = append(rows, row) } chk = chunk.Renew(chk, sctx.GetSessionVars().MaxChunkSize) @@ -892,7 +907,7 @@ func (e *SelectionExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) e.batched = expression.Vectorizable(e.filters) if e.batched { e.selected = make([]bool, 0, chunk.InitialCapacity) @@ -935,7 +950,7 @@ func (e *SelectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error } req.AppendRow(e.inputRow) } - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } @@ -967,7 +982,7 @@ func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) err return nil } } - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } @@ -1011,7 +1026,7 @@ func (e *TableScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error return err } - mutableRow := chunk.MutRowFromTypes(e.retTypes()) + mutableRow := chunk.MutRowFromTypes(retTypes(e)) for req.NumRows() < req.Capacity() { row, err := e.getRow(handle) if err != nil { @@ -1027,12 +1042,12 @@ func (e *TableScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error func (e *TableScanExec) nextChunk4InfoSchema(ctx context.Context, chk *chunk.Chunk) error { chk.GrowAndReset(e.maxChunkSize) if e.virtualTableChunkList == nil { - e.virtualTableChunkList = chunk.NewList(e.retTypes(), e.initCap, e.maxChunkSize) + e.virtualTableChunkList = chunk.NewList(retTypes(e), e.initCap, e.maxChunkSize) columns := make([]*table.Column, e.schema.Len()) for i, colInfo := range e.columns { columns[i] = table.ToColumn(colInfo) } - mutableRow := chunk.MutRowFromTypes(e.retTypes()) + mutableRow := chunk.MutRowFromTypes(retTypes(e)) err := e.t.IterRecords(e.ctx, nil, columns, func(h int64, rec []types.Datum, cols []*table.Column) (bool, error) { mutableRow.SetDatums(rec...) e.virtualTableChunkList.AppendRow(mutableRow.ToRow()) @@ -1115,7 +1130,7 @@ func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error return nil } e.evaluated = true - err := e.children[0].Next(ctx, req) + err := Next(ctx, e.children[0], req) if err != nil { return err } @@ -1129,8 +1144,8 @@ func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error return errors.New("subquery returns more than 1 row") } - childChunk := e.children[0].newFirstChunk() - err = e.children[0].Next(ctx, chunk.NewRecordBatch(childChunk)) + childChunk := newFirstChunk(e.children[0]) + err = Next(ctx, e.children[0], chunk.NewRecordBatch(childChunk)) if err != nil { return err } @@ -1194,7 +1209,7 @@ func (e *UnionExec) Open(ctx context.Context) error { return err } for _, child := range e.children { - e.childrenResults = append(e.childrenResults, child.newFirstChunk()) + e.childrenResults = append(e.childrenResults, newFirstChunk(child)) } e.stopFetchData.Store(false) e.initialized = false @@ -1241,7 +1256,7 @@ func (e *UnionExec) resultPuller(ctx context.Context, childID int) { return case result.chk = <-e.resourcePools[childID]: } - result.err = e.children[childID].Next(ctx, chunk.NewRecordBatch(result.chk)) + result.err = Next(ctx, e.children[childID], chunk.NewRecordBatch(result.chk)) if result.err == nil && result.chk.NumRows() == 0 { return } diff --git a/executor/executor_pkg_test.go b/executor/executor_pkg_test.go index 23ba938192a08..47c66e4eb5798 100644 --- a/executor/executor_pkg_test.go +++ b/executor/executor_pkg_test.go @@ -98,7 +98,7 @@ func (s *testExecSuite) TestShowProcessList(c *C) { err := e.Open(ctx) c.Assert(err, IsNil) - chk := e.newFirstChunk() + chk := newFirstChunk(e) it := chunk.NewIterator4Chunk(chk) // Run test and check results. for _, p := range ps { diff --git a/executor/executor_required_rows_test.go b/executor/executor_required_rows_test.go index 5cdd3cfe2898e..70cf56031e79e 100644 --- a/executor/executor_required_rows_test.go +++ b/executor/executor_required_rows_test.go @@ -93,9 +93,9 @@ func (r *requiredRowsDataSource) Next(ctx context.Context, req *chunk.RecordBatc } func (r *requiredRowsDataSource) genOneRow() chunk.Row { - row := chunk.MutRowFromTypes(r.retTypes()) - for i := range r.retTypes() { - row.SetValue(i, r.generator(r.retTypes()[i])) + row := chunk.MutRowFromTypes(retTypes(r)) + for i, tp := range retTypes(r) { + row.SetValue(i, r.generator(tp)) } return row.ToRow() } @@ -177,7 +177,7 @@ func (s *testExecSuite) TestLimitRequiredRows(c *C) { ds := newRequiredRowsDataSource(sctx, testCase.totalRows, testCase.expectedRowsDS) exec := buildLimitExec(sctx, ds, testCase.limitOffset, testCase.limitCount) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], sctx.GetSessionVars().MaxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) @@ -260,7 +260,7 @@ func (s *testExecSuite) TestSortRequiredRows(c *C) { } exec := buildSortExec(sctx, byItems, ds) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) @@ -367,7 +367,7 @@ func (s *testExecSuite) TestTopNRequiredRows(c *C) { } exec := buildTopNExec(sctx, testCase.topNOffset, testCase.topNCount, byItems, ds) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) @@ -460,7 +460,7 @@ func (s *testExecSuite) TestSelectionRequiredRows(c *C) { } exec := buildSelectionExec(sctx, filters, ds) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) @@ -518,7 +518,7 @@ func (s *testExecSuite) TestProjectionUnparallelRequiredRows(c *C) { } exec := buildProjectionExec(sctx, exprs, ds, 0) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) @@ -574,7 +574,7 @@ func (s *testExecSuite) TestProjectionParallelRequiredRows(c *C) { } exec := buildProjectionExec(sctx, exprs, ds, testCase.numWorkers) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) @@ -663,7 +663,7 @@ func (s *testExecSuite) TestStreamAggRequiredRows(c *C) { aggFuncs := []*aggregation.AggFuncDesc{aggFunc} exec := buildStreamAggExecutor(sctx, ds, schema, aggFuncs, groupBy) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) @@ -722,7 +722,7 @@ func (s *testExecSuite) TestHashAggParallelRequiredRows(c *C) { aggFuncs := []*aggregation.AggFuncDesc{aggFunc} exec := buildHashAggExecutor(sctx, ds, schema, aggFuncs, groupBy) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) @@ -758,7 +758,7 @@ func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) { exec := buildMergeJoinExec(ctx, joinType, innerSrc, outerSrc) c.Assert(exec.Open(context.Background()), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range required { chk.SetRequiredRows(required[i], ctx.GetSessionVars().MaxChunkSize) c.Assert(exec.Next(context.Background(), chunk.NewRecordBatch(chk)), IsNil) diff --git a/executor/explain.go b/executor/explain.go index 61ced6d564b62..37ef8273a55c2 100644 --- a/executor/explain.go +++ b/executor/explain.go @@ -72,7 +72,7 @@ func (e *ExplainExec) Next(ctx context.Context, req *chunk.RecordBatch) error { func (e *ExplainExec) generateExplainInfo(ctx context.Context) ([][]string, error) { if e.analyzeExec != nil { - chk := e.analyzeExec.newFirstChunk() + chk := newFirstChunk(e.analyzeExec) for { err := e.analyzeExec.Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index 3d9371d14d396..f658522dccb0f 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -365,11 +365,11 @@ func (ow *outerWorker) pushToChan(ctx context.Context, task *lookUpJoinTask, dst // buildTask builds a lookUpJoinTask and read outer rows. // When err is not nil, task must not be nil to send the error to the main thread via task. func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { - ow.executor.newFirstChunk() + newFirstChunk(ow.executor) task := &lookUpJoinTask{ doneCh: make(chan error, 1), - outerResult: ow.executor.newFirstChunk(), + outerResult: newFirstChunk(ow.executor), encodedLookUpKeys: chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeBlob)}, ow.ctx.GetSessionVars().MaxChunkSize), lookupMap: mvmap.NewMVMap(), } @@ -386,7 +386,7 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { task.memTracker.Consume(task.outerResult.MemoryUsage()) for !task.outerResult.IsFull() { - err := ow.executor.Next(ctx, chunk.NewRecordBatch(ow.executorChk)) + err := Next(ctx, ow.executor, chunk.NewRecordBatch(ow.executorChk)) if err != nil { return task, err } @@ -582,11 +582,11 @@ func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTa return err } defer terror.Call(innerExec.Close) - innerResult := chunk.NewList(innerExec.retTypes(), iw.ctx.GetSessionVars().MaxChunkSize, iw.ctx.GetSessionVars().MaxChunkSize) + innerResult := chunk.NewList(retTypes(innerExec), iw.ctx.GetSessionVars().MaxChunkSize, iw.ctx.GetSessionVars().MaxChunkSize) innerResult.GetMemTracker().SetLabel(innerResultLabel) innerResult.GetMemTracker().AttachTo(task.memTracker) for { - err := innerExec.Next(ctx, chunk.NewRecordBatch(iw.executorChk)) + err := Next(ctx, innerExec, chunk.NewRecordBatch(iw.executorChk)) if err != nil { return err } @@ -594,7 +594,7 @@ func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTa break } innerResult.Add(iw.executorChk) - iw.executorChk = innerExec.newFirstChunk() + iw.executorChk = newFirstChunk(innerExec) } task.innerResult = innerResult return nil diff --git a/executor/insert_common.go b/executor/insert_common.go index 49ac8953b5319..2939d946ee0c5 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -295,8 +295,8 @@ func (e *InsertValues) setValueForRefColumn(row []types.Datum, hasValue []bool) func (e *InsertValues) insertRowsFromSelect(ctx context.Context, exec func(ctx context.Context, rows [][]types.Datum) error) error { // process `insert|replace into ... select ... from ...` selectExec := e.children[0] - fields := selectExec.retTypes() - chk := selectExec.newFirstChunk() + fields := retTypes(selectExec) + chk := newFirstChunk(selectExec) iter := chunk.NewIterator4Chunk(chk) rows := make([][]types.Datum, 0, chk.Capacity()) diff --git a/executor/join.go b/executor/join.go index 3352a3307161a..47148c833acc1 100644 --- a/executor/join.go +++ b/executor/join.go @@ -178,10 +178,10 @@ func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, keyB var allTypes []*types.FieldType if isOuterKey { keyColIdx = e.outerKeyColIdx - allTypes = e.outerExec.retTypes() + allTypes = retTypes(e.outerExec) } else { keyColIdx = e.innerKeyColIdx - allTypes = e.innerExec.retTypes() + allTypes = retTypes(e.innerExec) } for _, i := range keyColIdx { @@ -202,6 +202,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { if e.finished.Load().(bool) { return } + var outerResource *outerChkResource var ok bool select { @@ -217,7 +218,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { required := int(atomic.LoadInt64(&e.requiredRows)) outerResult.SetRequiredRows(required, e.maxChunkSize) } - err := e.outerExec.Next(ctx, chunk.NewRecordBatch(outerResult)) + err := Next(ctx, e.outerExec, chunk.NewRecordBatch(outerResult)) if err != nil { e.joinResultCh <- &hashjoinWorkerResult{ err: err, @@ -244,6 +245,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { if outerResult.NumRows() == 0 { return } + outerResource.dest <- outerResult } } @@ -268,7 +270,7 @@ var innerResultLabel fmt.Stringer = stringutil.StringerStr("innerResult") // fetchInnerRows fetches all rows from inner executor, and append them to // e.innerResult. func (e *HashJoinExec) fetchInnerRows(ctx context.Context) error { - e.innerResult = chunk.NewList(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) + e.innerResult = chunk.NewList(retTypes(e.innerExec), e.initCap, e.maxChunkSize) e.innerResult.GetMemTracker().AttachTo(e.memTracker) e.innerResult.GetMemTracker().SetLabel(innerResultLabel) var err error @@ -276,8 +278,9 @@ func (e *HashJoinExec) fetchInnerRows(ctx context.Context) error { if e.finished.Load().(bool) { return nil } - chk := e.children[e.innerIdx].newFirstChunk() - err = e.innerExec.Next(ctx, chunk.NewRecordBatch(chk)) + + chk := newFirstChunk(e.children[e.innerIdx]) + err = Next(ctx, e.innerExec, chunk.NewRecordBatch(chk)) if err != nil || chk.NumRows() == 0 { return err } @@ -299,7 +302,7 @@ func (e *HashJoinExec) initializeForProbe() { e.outerChkResourceCh = make(chan *outerChkResource, e.concurrency) for i := uint(0); i < e.concurrency; i++ { e.outerChkResourceCh <- &outerChkResource{ - chk: e.outerExec.newFirstChunk(), + chk: newFirstChunk(e.outerExec), dest: e.outerResultChs[i], } } @@ -309,7 +312,7 @@ func (e *HashJoinExec) initializeForProbe() { e.joinChkResourceCh = make([]chan *chunk.Chunk, e.concurrency) for i := uint(0); i < e.concurrency; i++ { e.joinChkResourceCh[i] = make(chan *chunk.Chunk, 1) - e.joinChkResourceCh[i] <- e.newFirstChunk() + e.joinChkResourceCh[i] <- newFirstChunk(e) } // e.joinResultCh is for transmitting the join result chunks to the main @@ -512,6 +515,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) (err er if e.joinResultCh == nil { return nil } + result, ok := <-e.joinResultCh if !ok { return nil @@ -625,9 +629,9 @@ func (e *NestedLoopApplyExec) Open(ctx context.Context) error { } e.cursor = 0 e.innerRows = e.innerRows[:0] - e.outerChunk = e.outerExec.newFirstChunk() - e.innerChunk = e.innerExec.newFirstChunk() - e.innerList = chunk.NewList(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) + e.outerChunk = newFirstChunk(e.outerExec) + e.innerChunk = newFirstChunk(e.innerExec) + e.innerList = chunk.NewList(retTypes(e.innerExec), e.initCap, e.maxChunkSize) e.memTracker = memory.NewTracker(e.id, e.ctx.GetSessionVars().MemQuotaNestedLoopApply) e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) @@ -642,7 +646,7 @@ func (e *NestedLoopApplyExec) fetchSelectedOuterRow(ctx context.Context, chk *ch outerIter := chunk.NewIterator4Chunk(e.outerChunk) for { if e.outerChunkCursor >= e.outerChunk.NumRows() { - err := e.outerExec.Next(ctx, chunk.NewRecordBatch(e.outerChunk)) + err := Next(ctx, e.outerExec, chunk.NewRecordBatch(e.outerChunk)) if err != nil { return nil, err } @@ -679,7 +683,7 @@ func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error { e.innerList.Reset() innerIter := chunk.NewIterator4Chunk(e.innerChunk) for { - err := e.innerExec.Next(ctx, chunk.NewRecordBatch(e.innerChunk)) + err := Next(ctx, e.innerExec, chunk.NewRecordBatch(e.innerChunk)) if err != nil { return err } diff --git a/executor/merge_join.go b/executor/merge_join.go index 2eca140bed902..bc6f597a9325d 100644 --- a/executor/merge_join.go +++ b/executor/merge_join.go @@ -142,7 +142,7 @@ func (t *mergeJoinInnerTable) nextRow() (chunk.Row, error) { if t.curRow == t.curIter.End() { t.reallocReaderResult() oldMemUsage := t.curResult.MemoryUsage() - err := t.reader.Next(t.ctx, chunk.NewRecordBatch(t.curResult)) + err := Next(t.ctx, t.reader, chunk.NewRecordBatch(t.curResult)) // error happens or no more data. if err != nil || t.curResult.NumRows() == 0 { t.curRow = t.curIter.End() @@ -185,7 +185,7 @@ func (t *mergeJoinInnerTable) reallocReaderResult() { // Create a new Chunk and append it to "resourceQueue" if there is no more // available chunk in "resourceQueue". if len(t.resourceQueue) == 0 { - newChunk := t.reader.newFirstChunk() + newChunk := newFirstChunk(t.reader) t.memTracker.Consume(newChunk.MemoryUsage()) t.resourceQueue = append(t.resourceQueue, newChunk) } @@ -222,7 +222,7 @@ func (e *MergeJoinExec) Open(ctx context.Context) error { e.childrenResults = make([]*chunk.Chunk, 0, len(e.children)) for _, child := range e.children { - e.childrenResults = append(e.childrenResults, child.newFirstChunk()) + e.childrenResults = append(e.childrenResults, newFirstChunk(child)) } e.innerTable.memTracker = memory.NewTracker(innerTableLabel, -1) @@ -389,7 +389,7 @@ func (e *MergeJoinExec) fetchNextOuterRows(ctx context.Context, requiredRows int e.outerTable.chk.SetRequiredRows(requiredRows, e.maxChunkSize) } - err = e.outerTable.reader.Next(ctx, chunk.NewRecordBatch(e.outerTable.chk)) + err = Next(ctx, e.outerTable.reader, chunk.NewRecordBatch(e.outerTable.chk)) if err != nil { return err } diff --git a/executor/pkg_test.go b/executor/pkg_test.go index 74a478aadce48..f7b73e09aad04 100644 --- a/executor/pkg_test.go +++ b/executor/pkg_test.go @@ -35,7 +35,7 @@ type MockExec struct { func (m *MockExec) Next(ctx context.Context, req *chunk.RecordBatch) error { req.Reset() - colTypes := m.retTypes() + colTypes := retTypes(m) for ; m.curRowIdx < len(m.Rows) && req.NumRows() < req.Capacity(); m.curRowIdx++ { curRow := m.Rows[m.curRowIdx] for i := 0; i < curRow.Len(); i++ { @@ -88,7 +88,7 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { innerFilter := outerFilter.Clone() otherFilter := expression.NewFunctionInternal(sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), col0, col1) joiner := newJoiner(sctx, plannercore.InnerJoin, false, - make([]types.Datum, innerExec.Schema().Len()), []expression.Expression{otherFilter}, outerExec.retTypes(), innerExec.retTypes()) + make([]types.Datum, innerExec.Schema().Len()), []expression.Expression{otherFilter}, retTypes(outerExec), retTypes(innerExec)) joinSchema := expression.NewSchema(col0, col1) join := &NestedLoopApplyExec{ baseExecutor: newBaseExecutor(sctx, joinSchema, nil), @@ -98,10 +98,10 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { innerFilter: []expression.Expression{innerFilter}, joiner: joiner, } - join.innerList = chunk.NewList(innerExec.retTypes(), innerExec.initCap, innerExec.maxChunkSize) - join.innerChunk = innerExec.newFirstChunk() - join.outerChunk = outerExec.newFirstChunk() - joinChk := join.newFirstChunk() + join.innerList = chunk.NewList(retTypes(innerExec), innerExec.initCap, innerExec.maxChunkSize) + join.innerChunk = newFirstChunk(innerExec) + join.outerChunk = newFirstChunk(outerExec) + joinChk := newFirstChunk(join) it := chunk.NewIterator4Chunk(joinChk) for rowIdx := 1; ; { err := join.Next(ctx, chunk.NewRecordBatch(joinChk)) diff --git a/executor/point_get.go b/executor/point_get.go index e9063386fbe0c..1d1d9881869c2 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -19,10 +19,8 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" plannercore "github.com/pingcap/tidb/planner/core" - "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" @@ -38,22 +36,23 @@ func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) Executor { b.err = err return nil } - return &PointGetExecutor{ - ctx: b.ctx, - schema: p.Schema(), - tblInfo: p.TblInfo, - idxInfo: p.IndexInfo, - idxVals: p.IndexValues, - handle: p.Handle, - startTS: startTS, - } + e := &PointGetExecutor{ + baseExecutor: newBaseExecutor(b.ctx, p.Schema(), p.ExplainID()), + tblInfo: p.TblInfo, + idxInfo: p.IndexInfo, + idxVals: p.IndexValues, + handle: p.Handle, + startTS: startTS, + } + e.base().initCap = 1 + e.base().maxChunkSize = 1 + return e } // PointGetExecutor executes point select query. type PointGetExecutor struct { - ctx sessionctx.Context - schema *expression.Schema - tps []*types.FieldType + baseExecutor + tblInfo *model.TableInfo handle int64 idxInfo *model.IndexInfo @@ -241,22 +240,3 @@ func getColInfoByID(tbl *model.TableInfo, colID int64) *model.ColumnInfo { } return nil } - -// Schema implements the Executor interface. -func (e *PointGetExecutor) Schema() *expression.Schema { - return e.schema -} - -func (e *PointGetExecutor) retTypes() []*types.FieldType { - if e.tps == nil { - e.tps = make([]*types.FieldType, e.schema.Len()) - for i := range e.schema.Columns { - e.tps[i] = e.schema.Columns[i].RetType - } - } - return e.tps -} - -func (e *PointGetExecutor) newFirstChunk() *chunk.Chunk { - return chunk.New(e.retTypes(), 1, 1) -} diff --git a/executor/projection.go b/executor/projection.go index 01d3aff16b7b9..f316558524044 100644 --- a/executor/projection.go +++ b/executor/projection.go @@ -91,7 +91,7 @@ func (e *ProjectionExec) Open(ctx context.Context) error { } if e.isUnparallelExec() { - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) } return nil @@ -179,7 +179,7 @@ func (e *ProjectionExec) isUnparallelExec() bool { func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk) error { // transmit the requiredRows e.childResult.SetRequiredRows(chk.RequiredRows(), e.maxChunkSize) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } @@ -236,11 +236,11 @@ func (e *ProjectionExec) prepare(ctx context.Context) { }) e.fetcher.inputCh <- &projectionInput{ - chk: e.children[0].newFirstChunk(), + chk: newFirstChunk(e.children[0]), targetWorker: e.workers[i], } e.fetcher.outputCh <- &projectionOutput{ - chk: e.newFirstChunk(), + chk: newFirstChunk(e), done: make(chan error, 1), } } @@ -312,7 +312,7 @@ func (f *projectionInputFetcher) run(ctx context.Context) { requiredRows := atomic.LoadInt64(&f.proj.parentReqRows) input.chk.SetRequiredRows(int(requiredRows), f.proj.maxChunkSize) - err := f.child.Next(ctx, chunk.NewRecordBatch(input.chk)) + err := Next(ctx, f.child, chunk.NewRecordBatch(input.chk)) if err != nil || input.chk.NumRows() == 0 { output.done <- err return diff --git a/executor/radix_hash_join.go b/executor/radix_hash_join.go index cc0633f391bb8..c32e229bfa2cc 100644 --- a/executor/radix_hash_join.go +++ b/executor/radix_hash_join.go @@ -186,7 +186,7 @@ func (e *RadixHashJoinExec) preAlloc4InnerParts() (err error) { func (e *RadixHashJoinExec) getPartition(idx uint32) partition { if e.innerParts[idx] == nil { e.numNonEmptyPart++ - e.innerParts[idx] = chunk.New(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) + e.innerParts[idx] = chunk.New(retTypes(e.innerExec), e.initCap, e.maxChunkSize) } return e.innerParts[idx] } diff --git a/executor/show.go b/executor/show.go index f25dcd603fdbe..b61941c203856 100644 --- a/executor/show.go +++ b/executor/show.go @@ -79,7 +79,7 @@ type ShowExec struct { func (e *ShowExec) Next(ctx context.Context, req *chunk.RecordBatch) error { req.GrowAndReset(e.maxChunkSize) if e.result == nil { - e.result = e.newFirstChunk() + e.result = newFirstChunk(e) err := e.fetchAll() if err != nil { return errors.Trace(err) diff --git a/executor/sort.go b/executor/sort.go index 8e2e221a6828a..4d4ce8a3d22d6 100644 --- a/executor/sort.go +++ b/executor/sort.go @@ -105,13 +105,13 @@ func (e *SortExec) Next(ctx context.Context, req *chunk.RecordBatch) error { } func (e *SortExec) fetchRowChunks(ctx context.Context) error { - fields := e.retTypes() + fields := retTypes(e) e.rowChunks = chunk.NewList(fields, e.initCap, e.maxChunkSize) e.rowChunks.GetMemTracker().AttachTo(e.memTracker) e.rowChunks.GetMemTracker().SetLabel(rowChunksLabel) for { - chk := e.children[0].newFirstChunk() - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + chk := newFirstChunk(e.children[0]) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { return err } @@ -275,14 +275,14 @@ func (e *TopNExec) Next(ctx context.Context, req *chunk.RecordBatch) error { func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error { e.chkHeap = &topNChunkHeap{e} - e.rowChunks = chunk.NewList(e.retTypes(), e.initCap, e.maxChunkSize) + e.rowChunks = chunk.NewList(retTypes(e), e.initCap, e.maxChunkSize) e.rowChunks.GetMemTracker().AttachTo(e.memTracker) e.rowChunks.GetMemTracker().SetLabel(rowChunksLabel) for uint64(e.rowChunks.Len()) < e.totalLimit { - srcChk := e.children[0].newFirstChunk() + srcChk := newFirstChunk(e.children[0]) // adjust required rows by total limit srcChk.SetRequiredRows(int(e.totalLimit-uint64(e.rowChunks.Len())), e.maxChunkSize) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(srcChk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(srcChk)) if err != nil { return err } @@ -305,9 +305,9 @@ func (e *TopNExec) executeTopN(ctx context.Context) error { // The number of rows we loaded may exceeds total limit, remove greatest rows by Pop. heap.Pop(e.chkHeap) } - childRowChk := e.children[0].newFirstChunk() + childRowChk := newFirstChunk(e.children[0]) for { - err := e.children[0].Next(ctx, chunk.NewRecordBatch(childRowChk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(childRowChk)) if err != nil { return err } @@ -349,7 +349,7 @@ func (e *TopNExec) processChildChk(childRowChk *chunk.Chunk) error { // but we want descending top N, then we will keep all data in memory. // But if data is distributed randomly, this function will be called log(n) times. func (e *TopNExec) doCompaction() error { - newRowChunks := chunk.NewList(e.retTypes(), e.initCap, e.maxChunkSize) + newRowChunks := chunk.NewList(retTypes(e), e.initCap, e.maxChunkSize) newRowPtrs := make([]chunk.RowPtr, 0, e.rowChunks.Len()) for _, rowPtr := range e.rowPtrs { newRowPtr := newRowChunks.AppendRow(e.rowChunks.GetRow(rowPtr)) diff --git a/executor/table_reader.go b/executor/table_reader.go index 327b148b02de7..6ad3eec52e918 100644 --- a/executor/table_reader.go +++ b/executor/table_reader.go @@ -175,7 +175,7 @@ func (e *TableReaderExecutor) buildResp(ctx context.Context, ranges []*ranger.Ra if err != nil { return nil, err } - result, err := e.SelectResult(ctx, e.ctx, kvReq, e.retTypes(), e.feedback, getPhysicalPlanIDs(e.plans)) + result, err := e.SelectResult(ctx, e.ctx, kvReq, retTypes(e), e.feedback, getPhysicalPlanIDs(e.plans)) if err != nil { return nil, err } diff --git a/executor/table_readers_required_rows_test.go b/executor/table_readers_required_rows_test.go index 21819329d6a82..0d7163d431c0c 100644 --- a/executor/table_readers_required_rows_test.go +++ b/executor/table_readers_required_rows_test.go @@ -178,7 +178,7 @@ func (s *testExecSuite) TestTableReaderRequiredRows(c *C) { ctx := mockDistsqlSelectCtxSet(testCase.totalRows, testCase.expectedRowsDS) exec := buildTableReader(sctx) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) @@ -230,7 +230,7 @@ func (s *testExecSuite) TestIndexReaderRequiredRows(c *C) { ctx := mockDistsqlSelectCtxSet(testCase.totalRows, testCase.expectedRowsDS) exec := buildIndexReader(sctx) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) diff --git a/executor/union_scan.go b/executor/union_scan.go index 9f95c88d075e0..2967953149f53 100644 --- a/executor/union_scan.go +++ b/executor/union_scan.go @@ -117,7 +117,7 @@ func (us *UnionScanExec) Open(ctx context.Context) error { if err := us.baseExecutor.Open(ctx); err != nil { return err } - us.snapshotChunkBuffer = us.newFirstChunk() + us.snapshotChunkBuffer = newFirstChunk(us) return nil } @@ -133,7 +133,7 @@ func (us *UnionScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error defer func() { us.runtimeStats.Record(time.Since(start), req.NumRows()) }() } req.GrowAndReset(us.maxChunkSize) - mutableRow := chunk.MutRowFromTypes(us.retTypes()) + mutableRow := chunk.MutRowFromTypes(retTypes(us)) for i, batchSize := 0, req.Capacity(); i < batchSize; i++ { row, err := us.getOneRow(ctx) if err != nil { @@ -199,7 +199,7 @@ func (us *UnionScanExec) getSnapshotRow(ctx context.Context) ([]types.Datum, err us.cursor4SnapshotRows = 0 us.snapshotRows = us.snapshotRows[:0] for len(us.snapshotRows) == 0 { - err = us.children[0].Next(ctx, chunk.NewRecordBatch(us.snapshotChunkBuffer)) + err = Next(ctx, us.children[0], chunk.NewRecordBatch(us.snapshotChunkBuffer)) if err != nil || us.snapshotChunkBuffer.NumRows() == 0 { return nil, err } @@ -214,7 +214,7 @@ func (us *UnionScanExec) getSnapshotRow(ctx context.Context) ([]types.Datum, err // commit, but for simplicity, we don't handle it here. continue } - us.snapshotRows = append(us.snapshotRows, row.GetDatumRow(us.children[0].retTypes())) + us.snapshotRows = append(us.snapshotRows, row.GetDatumRow(retTypes(us.children[0]))) } } return us.snapshotRows[0], nil @@ -295,7 +295,7 @@ func (us *UnionScanExec) rowWithColsInTxn(t table.Table, h int64, cols []*table. func (us *UnionScanExec) buildAndSortAddedRows(t table.Table) error { us.addedRows = make([][]types.Datum, 0, len(us.dirty.addedRows)) - mutableRow := chunk.MutRowFromTypes(us.retTypes()) + mutableRow := chunk.MutRowFromTypes(retTypes(us)) cols := t.WritableCols() for h := range us.dirty.addedRows { newData := make([]types.Datum, 0, us.schema.Len()) diff --git a/executor/update.go b/executor/update.go index c2cdbaf9f9d55..9b478cc80d968 100644 --- a/executor/update.go +++ b/executor/update.go @@ -165,7 +165,7 @@ func (e *UpdateExec) Next(ctx context.Context, req *chunk.RecordBatch) error { } func (e *UpdateExec) fetchChunkRows(ctx context.Context) error { - fields := e.children[0].retTypes() + fields := retTypes(e.children[0]) schema := e.children[0].Schema() colsInfo := make([]*table.Column, len(fields)) for id, cols := range schema.TblID2Handle { @@ -178,10 +178,10 @@ func (e *UpdateExec) fetchChunkRows(ctx context.Context) error { } } globalRowIdx := 0 - chk := e.children[0].newFirstChunk() + chk := newFirstChunk(e.children[0]) e.evalBuffer = chunk.MutRowFromTypes(fields) for { - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { return err } diff --git a/executor/window.go b/executor/window.go index bf4e5a2dab0b1..2ba119736e564 100644 --- a/executor/window.go +++ b/executor/window.go @@ -130,8 +130,8 @@ func (e *WindowExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk return errors.Trace(err) } - childResult := e.children[0].newFirstChunk() - err = e.children[0].Next(ctx, &chunk.RecordBatch{Chunk: childResult}) + childResult := newFirstChunk(e.children[0]) + err = Next(ctx, e.children[0], &chunk.RecordBatch{Chunk: childResult}) if err != nil { return errors.Trace(err) } diff --git a/server/conn.go b/server/conn.go index 24d16c0af9b50..290821f758c84 100644 --- a/server/conn.go +++ b/server/conn.go @@ -45,7 +45,6 @@ import ( "runtime" "strconv" "strings" - "sync" "sync/atomic" "time" @@ -150,13 +149,6 @@ type clientConn struct { peerHost string // peer host peerPort string // peer port lastCode uint16 // last error code - - // mu is used for cancelling the execution of current transaction. - mu struct { - sync.RWMutex - cancelFunc context.CancelFunc - resultSets []ResultSet - } } func (cc *clientConn) String() string { @@ -847,11 +839,6 @@ func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { span := opentracing.StartSpan("server.dispatch") - ctx1, cancelFunc := context.WithCancel(ctx) - cc.mu.Lock() - cc.mu.cancelFunc = cancelFunc - cc.mu.Unlock() - t := time.Now() cmd := data[0] data = data[1:] @@ -863,6 +850,8 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { span.Finish() }() + vars := cc.ctx.GetSessionVars() + atomic.StoreUint32(&vars.Killed, 0) if cmd < mysql.ComEnd { cc.ctx.SetCommandValue(cmd) } @@ -893,11 +882,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { data = data[:len(data)-1] dataStr = string(hack.String(data)) } - return cc.handleQuery(ctx1, dataStr) + return cc.handleQuery(ctx, dataStr) case mysql.ComPing: return cc.writeOK() case mysql.ComInitDB: - if err := cc.useDB(ctx1, dataStr); err != nil { + if err := cc.useDB(ctx, dataStr); err != nil { return err } return cc.writeOK() @@ -906,9 +895,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { case mysql.ComStmtPrepare: return cc.handleStmtPrepare(dataStr) case mysql.ComStmtExecute: - return cc.handleStmtExecute(ctx1, data) + return cc.handleStmtExecute(ctx, data) case mysql.ComStmtFetch: - return cc.handleStmtFetch(ctx1, data) + return cc.handleStmtFetch(ctx, data) case mysql.ComStmtClose: return cc.handleStmtClose(data) case mysql.ComStmtSendLongData: @@ -918,7 +907,7 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { case mysql.ComSetOption: return cc.handleSetOption(data) case mysql.ComChangeUser: - return cc.handleChangeUser(ctx1, data) + return cc.handleChangeUser(ctx, data) default: return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd) } @@ -1171,15 +1160,11 @@ func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) { metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc() return err } - cc.mu.Lock() - cc.mu.resultSets = rs status := atomic.LoadInt32(&cc.status) if status == connStatusShutdown || status == connStatusWaitShutdown { - cc.mu.Unlock() killConn(cc) return errors.New("killed by another connection") } - cc.mu.Unlock() if rs != nil { if len(rs) == 1 { err = cc.writeResultset(ctx, rs[0], false, 0, 0) diff --git a/server/server.go b/server/server.go index f45cb65e6c7a7..c258f074d0ed0 100644 --- a/server/server.go +++ b/server/server.go @@ -530,19 +530,8 @@ func (s *Server) Kill(connectionID uint64, query bool) { } func killConn(conn *clientConn) { - conn.mu.RLock() - resultSets := conn.mu.resultSets - cancelFunc := conn.mu.cancelFunc - conn.mu.RUnlock() - for _, resultSet := range resultSets { - // resultSet.Close() is reentrant so it's safe to kill a same connID multiple times - if err := resultSet.Close(); err != nil { - logutil.Logger(context.Background()).Error("close result set error", zap.Uint32("connID", conn.connectionID), zap.Error(err)) - } - } - if cancelFunc != nil { - cancelFunc() - } + sessVars := conn.ctx.GetSessionVars() + atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1) } // KillAllConnections kills all connections when server is not gracefully shutdown. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index cae67c4ed4558..83d060a814f93 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -379,6 +379,9 @@ type SessionVars struct { // LowResolutionTSO is used for reading data with low resolution TSO which is updated once every two seconds. LowResolutionTSO bool + + // Killed is a flag to indicate that this query is killed. + Killed uint32 } // ConnectionInfo present connection used by audit.