diff --git a/ddl/backfilling_scheduler.go b/ddl/backfilling_scheduler.go index 6596b3e49d611..209bc15277224 100644 --- a/ddl/backfilling_scheduler.go +++ b/ddl/backfilling_scheduler.go @@ -308,11 +308,8 @@ func (b *ingestBackfillScheduler) setupWorkers() error { } b.copReqSenderPool = copReqSenderPool readerCnt, writerCnt := b.expectedWorkerSize() - writerPool, err := workerpool.NewWorkerPool[idxRecResult]("ingest_writer", + writerPool := workerpool.NewWorkerPool[idxRecResult]("ingest_writer", poolutil.DDL, writerCnt, b.createWorker) - if err != nil { - return errors.Trace(err) - } writerPool.Start() b.writerPool = writerPool b.copReqSenderPool.chunkSender = writerPool diff --git a/disttask/operator/BUILD.bazel b/disttask/operator/BUILD.bazel new file mode 100644 index 0000000000000..dc2a180e2da93 --- /dev/null +++ b/disttask/operator/BUILD.bazel @@ -0,0 +1,27 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "operator", + srcs = [ + "compose.go", + "operator.go", + "pipeline.go", + "wrapper.go", + ], + importpath = "github.com/pingcap/tidb/disttask/operator", + visibility = ["//visibility:public"], + deps = [ + "//resourcemanager/pool/workerpool", + "//resourcemanager/util", + "@org_golang_x_sync//errgroup", + ], +) + +go_test( + name = "operator_test", + timeout = "short", + srcs = ["pipeline_test.go"], + embed = [":operator"], + flaky = True, + deps = ["@com_github_stretchr_testify//require"], +) diff --git a/disttask/operator/compose.go b/disttask/operator/compose.go new file mode 100644 index 0000000000000..fb1637c5154e9 --- /dev/null +++ b/disttask/operator/compose.go @@ -0,0 +1,58 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package operator + +// WithSource is an interface that can be used to set the source of an operator. +type WithSource[T any] interface { + SetSource(channel DataChannel[T]) +} + +// WithSink is an interface that can be used to set the sink of an operator. +type WithSink[T any] interface { + SetSink(channel DataChannel[T]) +} + +// Compose sets the sink of op1 and the source of op2. +func Compose[T any](op1 WithSink[T], op2 WithSource[T]) { + ch := NewSimpleDataChannel(make(chan T)) + op1.SetSink(ch) + op2.SetSource(ch) +} + +// DataChannel is a channel that can be used to transfer data between operators. +type DataChannel[T any] interface { + Channel() chan T + Finish() +} + +// SimpleDataChannel is a simple implementation of DataChannel. +type SimpleDataChannel[T any] struct { + channel chan T +} + +// NewSimpleDataChannel creates a new SimpleDataChannel. +func NewSimpleDataChannel[T any](ch chan T) *SimpleDataChannel[T] { + return &SimpleDataChannel[T]{channel: ch} +} + +// Channel returns the underlying channel of the SimpleDataChannel. +func (s *SimpleDataChannel[T]) Channel() chan T { + return s.channel +} + +// Finish closes the underlying channel of the SimpleDataChannel. +func (s *SimpleDataChannel[T]) Finish() { + close(s.channel) +} diff --git a/disttask/operator/operator.go b/disttask/operator/operator.go new file mode 100644 index 0000000000000..c1baa8f45b1b8 --- /dev/null +++ b/disttask/operator/operator.go @@ -0,0 +1,102 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package operator + +import ( + "fmt" + + "github.com/pingcap/tidb/resourcemanager/pool/workerpool" + "github.com/pingcap/tidb/resourcemanager/util" +) + +// Operator is the basic operation unit in the task execution. +type Operator interface { + Open() error + Close() error + String() string +} + +// AsyncOperator process the data in async way. +// +// Eg: The sink of AsyncOperator op1 and the source of op2 +// use the same channel, Then op2's worker will handle +// the result from op1. +type AsyncOperator[T, R any] struct { + pool *workerpool.WorkerPool[T, R] +} + +// NewAsyncOperatorWithTransform create an AsyncOperator with a transform function. +func NewAsyncOperatorWithTransform[T, R any](name string, workerNum int, transform func(T) R) *AsyncOperator[T, R] { + pool := workerpool.NewWorkerPool(name, util.DistTask, workerNum, newAsyncWorkerCtor(transform)) + return NewAsyncOperator(pool) +} + +// NewAsyncOperator create an AsyncOperator. +func NewAsyncOperator[T, R any](pool *workerpool.WorkerPool[T, R]) *AsyncOperator[T, R] { + return &AsyncOperator[T, R]{ + pool: pool, + } +} + +// Open implements the Operator's Open interface. +func (c *AsyncOperator[T, R]) Open() error { + c.pool.Start() + return nil +} + +// Close implements the Operator's Close interface. +func (c *AsyncOperator[T, R]) Close() error { + // Wait all tasks done. + // We don't need to close the task channel because + // it is closed by the workerpool. + c.pool.Wait() + c.pool.Release() + return nil +} + +// String show the name. +func (*AsyncOperator[T, R]) String() string { + var zT T + var zR R + return fmt.Sprintf("AsyncOp[%T, %T]", zT, zR) +} + +// SetSource set the source channel. +func (c *AsyncOperator[T, R]) SetSource(ch DataChannel[T]) { + c.pool.SetTaskReceiver(ch.Channel()) +} + +// SetSink set the sink channel. +func (c *AsyncOperator[T, R]) SetSink(ch DataChannel[R]) { + c.pool.SetResultSender(ch.Channel()) +} + +type asyncWorker[T, R any] struct { + transform func(T) R +} + +func newAsyncWorkerCtor[T, R any](transform func(T) R) func() workerpool.Worker[T, R] { + return func() workerpool.Worker[T, R] { + return &asyncWorker[T, R]{ + transform: transform, + } + } +} + +func (s *asyncWorker[T, R]) HandleTask(task T) R { + return s.transform(task) +} + +func (*asyncWorker[T, R]) Close() {} diff --git a/disttask/operator/pipeline.go b/disttask/operator/pipeline.go new file mode 100644 index 0000000000000..3c9f32e2a9d23 --- /dev/null +++ b/disttask/operator/pipeline.go @@ -0,0 +1,67 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package operator + +import "strings" + +// AsyncPipeline wraps a list of Operators. +// The dataflow is from the first operator to the last operator. +type AsyncPipeline struct { + ops []Operator +} + +// Execute starts all operators waiting to handle tasks. +func (p *AsyncPipeline) Execute() error { + // Start running each operator. + for i, op := range p.ops { + err := op.Open() + if err != nil { + // Close all operators that have been opened. + for j := i - 1; j >= 0; j-- { + _ = p.ops[j].Close() + } + return err + } + } + return nil +} + +// Close waits all tasks done. +func (p *AsyncPipeline) Close() error { + var firstErr error + for _, op := range p.ops { + err := op.Close() + if firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// NewAsyncPipeline creates a new AsyncPipeline. +func NewAsyncPipeline(ops ...Operator) *AsyncPipeline { + return &AsyncPipeline{ + ops: ops, + } +} + +// String shows the pipeline. +func (p *AsyncPipeline) String() string { + opStrs := make([]string, len(p.ops)) + for i, op := range p.ops { + opStrs[i] = op.String() + } + return "AsyncPipeline[" + strings.Join(opStrs, " -> ") + "]" +} diff --git a/disttask/operator/pipeline_test.go b/disttask/operator/pipeline_test.go new file mode 100644 index 0000000000000..bba5e4d4b4379 --- /dev/null +++ b/disttask/operator/pipeline_test.go @@ -0,0 +1,101 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package operator + +import ( + "regexp" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPipelineAsyncMultiOperators(t *testing.T) { + words := `Bob hiT a ball, the hIt BALL flew far after it was hit.` + var mostCommonWord string + splitter := makeSplitter(words) + lower := makeLower() + trimmer := makeTrimmer() + counter := makeCounter() + collector := makeCollector(&mostCommonWord) + + Compose[string](splitter, lower) + Compose[string](lower, trimmer) + Compose[string](trimmer, counter) + Compose[strCnt](counter, collector) + + pipeline := NewAsyncPipeline(splitter, lower, trimmer, counter, collector) + require.Equal(t, pipeline.String(), "AsyncPipeline[simpleSource -> simpleOperator(AsyncOp[string, string]) -> simpleOperator(AsyncOp[string, string]) -> simpleOperator(AsyncOp[string, operator.strCnt]) -> simpleSink]") + err := pipeline.Execute() + require.NoError(t, err) + err = pipeline.Close() + require.NoError(t, err) + require.Equal(t, mostCommonWord, "hit") +} + +type strCnt struct { + str string + cnt int +} + +func makeSplitter(s string) *simpleSource[string] { + ss := strings.Split(s, " ") + src := newSimpleSource(func() string { + if len(ss) == 0 { + return "" + } + ret := ss[0] + ss = ss[1:] + return ret + }) + return src +} + +func makeLower() *simpleOperator[string, string] { + return newSimpleOperator(strings.ToLower, 3) +} + +func makeTrimmer() *simpleOperator[string, string] { + var nonAlphaRegex = regexp.MustCompile(`[^a-zA-Z0-9]+`) + return newSimpleOperator(func(s string) string { + return nonAlphaRegex.ReplaceAllString(s, "") + }, 3) +} + +func makeCounter() *simpleOperator[string, strCnt] { + strCntMap := make(map[string]int) + strCntMapMu := sync.Mutex{} + return newSimpleOperator(func(s string) strCnt { + strCntMapMu.Lock() + old := strCntMap[s] + strCntMap[s] = old + 1 + strCntMapMu.Unlock() + return strCnt{s, old + 1} + }, 3) +} + +func makeCollector(v *string) *simpleSink[strCnt] { + maxCnt := 0 + maxMu := sync.Mutex{} + return newSimpleSink(func(sc strCnt) { + maxMu.Lock() + if sc.cnt > maxCnt { + maxCnt = sc.cnt + *v = sc.str + } + maxMu.Unlock() + }) +} diff --git a/disttask/operator/wrapper.go b/disttask/operator/wrapper.go new file mode 100644 index 0000000000000..d77c8abbc82e8 --- /dev/null +++ b/disttask/operator/wrapper.go @@ -0,0 +1,111 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package operator + +import ( + "fmt" + + "golang.org/x/sync/errgroup" +) + +type simpleSource[T comparable] struct { + errGroup errgroup.Group + generator func() T + sink DataChannel[T] +} + +func newSimpleSource[T comparable](generator func() T) *simpleSource[T] { + return &simpleSource[T]{generator: generator} +} + +func (s *simpleSource[T]) Open() error { + s.errGroup.Go(func() error { + var zT T + for { + res := s.generator() + if res == zT { + break + } + s.sink.Channel() <- res + } + s.sink.Finish() + return nil + }) + return nil +} + +func (s *simpleSource[T]) Close() error { + return s.errGroup.Wait() +} + +func (s *simpleSource[T]) SetSink(ch DataChannel[T]) { + s.sink = ch +} + +func (*simpleSource[T]) String() string { + return "simpleSource" +} + +type simpleSink[R any] struct { + errGroup errgroup.Group + drainer func(R) + source DataChannel[R] +} + +func newSimpleSink[R any](drainer func(R)) *simpleSink[R] { + return &simpleSink[R]{ + drainer: drainer, + } +} + +func (s *simpleSink[R]) Open() error { + s.errGroup.Go(func() error { + for { + data, ok := <-s.source.Channel() + if !ok { + return nil + } + s.drainer(data) + } + }) + return nil +} + +func (s *simpleSink[R]) Close() error { + return s.errGroup.Wait() +} + +func (s *simpleSink[T]) SetSource(ch DataChannel[T]) { + s.source = ch +} + +func (*simpleSink[R]) String() string { + return "simpleSink" +} + +type simpleOperator[T, R any] struct { + *AsyncOperator[T, R] +} + +func (s *simpleOperator[T, R]) String() string { + return fmt.Sprintf("simpleOperator(%s)", s.AsyncOperator.String()) +} + +func newSimpleOperator[T, R any](transform func(task T) R, concurrency int) *simpleOperator[T, R] { + asyncOp := NewAsyncOperatorWithTransform("simple", concurrency, transform) + return &simpleOperator[T, R]{ + AsyncOperator: asyncOp, + } +} diff --git a/executor/executor.go b/executor/executor.go index 5749f74a2e550..d9a8487bf346f 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -2712,11 +2712,8 @@ func (e *FastCheckTableExec) Next(context.Context, *chunk.Chunk) error { e.Ctx().GetSessionVars().OptimizerUseInvisibleIndexes = false }() - workerPool, err := workerpool.NewWorkerPool[checkIndexTask]("checkIndex", + workerPool := workerpool.NewWorkerPool[checkIndexTask]("checkIndex", poolutil.CheckTable, 3, e.createWorker) - if err != nil { - return errors.Trace(err) - } workerPool.Start() e.wg.Add(len(e.indexInfos)) diff --git a/resourcemanager/pool/workerpool/workerpool.go b/resourcemanager/pool/workerpool/workerpool.go index dee92351b674d..ef8ed41d28690 100644 --- a/resourcemanager/pool/workerpool/workerpool.go +++ b/resourcemanager/pool/workerpool/workerpool.go @@ -55,7 +55,7 @@ type None struct{} // NewWorkerPool creates a new worker pool. func NewWorkerPool[T, R any](name string, _ util.Component, numWorkers int, - createWorker func() Worker[T, R], opts ...Option[T, R]) (*WorkerPool[T, R], error) { + createWorker func() Worker[T, R], opts ...Option[T, R]) *WorkerPool[T, R] { if numWorkers <= 0 { numWorkers = 1 } @@ -72,7 +72,7 @@ func NewWorkerPool[T, R any](name string, _ util.Component, numWorkers int, } p.createWorker = createWorker - return p, nil + return p } // SetTaskReceiver sets the task receiver for the pool. @@ -125,7 +125,11 @@ func (p *WorkerPool[T, R]) runAWorker() { p.wg.Run(func() { for { select { - case task := <-p.taskChan: + case task, ok := <-p.taskChan: + if !ok { + w.Close() + return + } p.handleTaskWithRecover(w, task) case <-p.quitChan: w.Close() @@ -199,6 +203,18 @@ func (p *WorkerPool[T, R]) ReleaseAndWait() { } } +// Wait waits for all workers to complete. +func (p *WorkerPool[T, R]) Wait() { + p.wg.Wait() +} + +// Release releases the pool. +func (p *WorkerPool[T, R]) Release() { + if p.resChan != nil { + close(p.resChan) + } +} + // GetOriginConcurrency return the concurrency of the pool at the init. func (p *WorkerPool[T, R]) GetOriginConcurrency() int32 { return p.originWorkers diff --git a/resourcemanager/pool/workerpool/workpool_test.go b/resourcemanager/pool/workerpool/workpool_test.go index 0d21795603d30..0de9cab63f810 100644 --- a/resourcemanager/pool/workerpool/workpool_test.go +++ b/resourcemanager/pool/workerpool/workpool_test.go @@ -50,8 +50,7 @@ func createMyWorker() Worker[int64, struct{}] { func TestWorkerPool(t *testing.T) { // Create a worker pool with 3 workers. - pool, err := NewWorkerPool[int64]("test", util.UNKNOWN, 3, createMyWorker) - require.NoError(t, err) + pool := NewWorkerPool[int64]("test", util.UNKNOWN, 3, createMyWorker) pool.Start() globalCnt.Store(0) @@ -116,45 +115,41 @@ func (d dummyWorker[T, R]) HandleTask(task T) R { func (d dummyWorker[T, R]) Close() {} func TestWorkerPoolNoneResult(t *testing.T) { - pool, err := NewWorkerPool[int64, None]( + pool := NewWorkerPool[int64, None]( "test", util.UNKNOWN, 3, func() Worker[int64, None] { return dummyWorker[int64, None]{} }) - require.NoError(t, err) pool.Start() ch := pool.GetResultChan() require.Nil(t, ch) pool.ReleaseAndWait() - pool2, err := NewWorkerPool[int64, int64]( + pool2 := NewWorkerPool[int64, int64]( "test", util.UNKNOWN, 3, func() Worker[int64, int64] { return dummyWorker[int64, int64]{} }) - require.NoError(t, err) pool2.Start() require.NotNil(t, pool2.GetResultChan()) pool2.ReleaseAndWait() - pool3, err := NewWorkerPool[int64, struct{}]( + pool3 := NewWorkerPool[int64, struct{}]( "test", util.UNKNOWN, 3, func() Worker[int64, struct{}] { return dummyWorker[int64, struct{}]{} }) - require.NoError(t, err) pool3.Start() require.NotNil(t, pool3.GetResultChan()) pool3.ReleaseAndWait() } func TestWorkerPoolCustomChan(t *testing.T) { - pool, err := NewWorkerPool[int64, int64]( + pool := NewWorkerPool[int64, int64]( "test", util.UNKNOWN, 3, func() Worker[int64, int64] { return dummyWorker[int64, int64]{} }) - require.NoError(t, err) taskCh := make(chan int64) pool.SetTaskReceiver(taskCh)