diff --git a/README.md b/README.md index 4cd3dde..ed2a9e9 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Here is a [working example in the Go Playground](https://go.dev/play/p/wCZzMDXRU - [Filtering Errors](#filtering-errors) - [Controlling Concurrency](#controlling-concurrency) - [Solo Tasks](#solo-tasks) -- [Collecting Results](#collecting-results) +- [Gathering Results](#gathering-results) ## Rationale @@ -303,67 +303,67 @@ if err != nil { doThingsWith(data) ``` -## Collecting Results +## Gathering Results One common use for a background task is accumulating the results from a batch of concurrent workers. This could be handled by a solo task, as described -above, but it is a common enough pattern that the library provides a -`Collector` type to handle it specifically. +above, but it is a common enough pattern that the library provides a `Gatherer` +type to handle it specifically. -To use it, pass a function to `Collect` to receive the values: +To use it, pass a function to `Gather` to receive the values: ```go +var g taskgroup.Group + var sum int -c := taskgroup.Collect(func(v int) { sum += v }) +c := taskgroup.Gather(g.Go, func(v int) { sum += v }) ``` -The `Call`, `Run`, and `Report` methods of `c` can now be used to wrap -functions that yield values, to deliver those values to `c`: +The `Call`, `Run`, and `Report` methods of `c` can now be used to start tasks +in `g` that yield values, and deliver those values to the accumulator: - `c.Call` takes a `func() (T, error)`, returning a value and an error. -- `c.Run` takes a `func() T`, returning only a value. + If the task reports an error, that error is returned as usual. Otherwise, + its non-error value is gathered by the callback. -If the wrapped function reports an error, that error is returned from the task -as usual. Otherwise, its non-error value is given to the accumulator callback. -As in the above example, calls to the function are serialized so that it is -safe to access state without additional locking: +- `c.Run` takes a `func() T`, returning only a value, which is gathered by the + callback. -```go -var g taskgroup.Group -// ... +- `c.Report` takes a `func(func(T)) error`, which allows a task to report + _multiple_ values to the gatherer via a "report" callback. The task itself + returns only an `error`, but it may call its argument any number of times to + gather values. + +Calls to the callback are serialized so that it is safe to access state without +additional locking: -// Report an error, no value is sent to the collector. -g.Go(c.Call(func() (int, error) { +```go +// Report an error, no value is gathered. +c.Call(func() (int, error) { return -1, errors.New("bad") -})) +}) -// No error, send the value 25 to the collector. -g.Go(c.Call(func() (int, error) { +// No error, send gather the value 25. +c.Call(func() (int, error) { return 25, nil -})) - -// Send a random integer to the collector. -g.Go(c.Run(func() int { return rand.Intn(1000) }) -``` +}) -The `Report` method allows a task to report _multiple_ values to the collector -via a callback. Here, the function returns only an `error`, but it receives a -callback it may invoke any number of times to send values: +// Gather a random integer. +c.Run(func() int { return rand.Intn(1000) }) -```go -// Send the values 10, 20, and 30 to the collector. +// Gather the values 10, 20, and 30. // -// Note that even if the function reports an error, any values it sent to -// the collector before returning are still delivered. -g.Go(c.Report(func(report func(int)) error { +// Note that even if the function reports an error, any values it sent +// before returning are still gathered. +c.Report(func(report func(int)) error { report(10) report(20) report(30) return nil -})) +}) ``` -Once all the tasks derived from the collector are done, it is safe to access +Once all the tasks passed to the gatherer are complete, it is safe to access the values accumulated by the callback: ```go diff --git a/collector.go b/collector.go index ae7dfc2..bd48545 100644 --- a/collector.go +++ b/collector.go @@ -4,6 +4,8 @@ import "sync" // A Collector collects values reported by task functions and delivers them to // an accumulator function. +// +// Deprecated: Use a [Gatherer] instead. type Collector[T any] struct { μ sync.Mutex handle func(T) @@ -22,6 +24,8 @@ func (c *Collector[T]) report(v T) { // // The tasks created from a collector do not return until all the values // reported by the underlying function have been processed by the accumulator. +// +// Deprecated: Use [Gather] instead. func Collect[T any](value func(T)) *Collector[T] { return &Collector[T]{handle: value} } // Call returns a Task wrapping a call to f. If f reports an error, that error @@ -48,5 +52,62 @@ func (c *Collector[T]) Report(f func(report func(T)) error) Task { // Run returns a Task wrapping a call to f. The resulting task reports a nil // error for all calls. func (c *Collector[T]) Run(f func() T) Task { - return NoError(func() { c.report(f()) }) + return noError(func() { c.report(f()) }) +} + +// A Gatherer manages a group of [Task] functions that report values, and +// gathers the values they return. +type Gatherer[T any] struct { + run func(Task) // start the task in a goroutine + + μ sync.Mutex + gather func(T) // handle values reported by tasks +} + +func (g *Gatherer[T]) report(v T) { + g.μ.Lock() + defer g.μ.Unlock() + g.gather(v) +} + +// Gather creates a new empty gatherer that uses run to execute tasks returning +// values of type T. +// +// If gather != nil, values reported by successful tasks are passed to the +// function, otherwise such values are discarded. Calls to gather are +// synchronized to a single goroutine. +// +// If run == nil, Gather will panic. +func Gather[T any](run func(Task), gather func(T)) *Gatherer[T] { + if run == nil { + panic("run function is nil") + } + if gather == nil { + gather = func(T) {} + } + return &Gatherer[T]{run: run, gather: gather} +} + +// Call runs f in g. If f reports an error, the error is propagated to the +// runner; otherwise the non-error value reported by f is gathered. +func (g *Gatherer[T]) Call(f func() (T, error)) { + g.run(func() error { + v, err := f() + if err == nil { + g.report(v) + } + return err + }) +} + +// Run runs f in g, and gathers the value it reports. +func (g *Gatherer[T]) Run(f func() T) { + g.run(func() error { g.report(f()); return nil }) +} + +// Report runs f in g. Any values passed to report are gathered. If f reports +// an error, that error is propagated to the runner. Any values sent before f +// returns are still gathered, even if f reports an error. +func (g *Gatherer[T]) Report(f func(report func(T)) error) { + g.run(func() error { return f(g.report) }) } diff --git a/single.go b/single.go index 7f7cedb..e92aac9 100644 --- a/single.go +++ b/single.go @@ -31,12 +31,8 @@ func Go[T any](task func() T) *Single[T] { } // Run runs task in a new goroutine. The caller must call Wait to wait for the -// task to return and collect its error. This is shorthand for: -// -// taskgroup.Go(taskgroup.NoError(task)) -// -// The error reported by Wait is always nil. -func Run(task func()) *Single[error] { return Go(NoError(task)) } +// task to return. The error reported by Wait is always nil. +func Run(task func()) *Single[error] { return Go(noError(task)) } // Call starts task in a new goroutine. The caller must call Wait to wait for // the task to return and collect its result. diff --git a/taskgroup.go b/taskgroup.go index b88320f..1dc8d1b 100644 --- a/taskgroup.go +++ b/taskgroup.go @@ -106,11 +106,9 @@ func (g *Group) Go(task Task) { }() } -// Run runs task in a new goroutine in g, and returns g to permit chaining. -// This is shorthand for: -// -// g.Go(taskgroup.NoError(task)) -func (g *Group) Run(task func()) { g.Go(NoError(task)) } +// Run runs task in a new goroutine in g. +// The resulting task reports a nil error. +func (g *Group) Run(task func()) { g.Go(noError(task)) } func (g *Group) handleError(err error) { g.μ.Lock() @@ -195,6 +193,8 @@ func Listen(f func(error)) any { return f } // NoError adapts f to a Task that executes f and reports a nil error. func NoError(f func()) Task { return func() error { f(); return nil } } +func noError(f func()) Task { return func() error { f(); return nil } } + // Limit returns g and a "start" function that starts each task passed to it in // g, allowing no more than n tasks to be active concurrently. If n ≤ 0, no // limit is enforced. diff --git a/taskgroup_test.go b/taskgroup_test.go index dcf4bf3..5513b1a 100644 --- a/taskgroup_test.go +++ b/taskgroup_test.go @@ -3,6 +3,7 @@ package taskgroup_test import ( "context" "errors" + "math" "math/rand/v2" "reflect" "sync" @@ -336,6 +337,75 @@ func TestCollector_Report(t *testing.T) { } } +func TestGatherer(t *testing.T) { + defer leaktest.Check(t)() + + g, run := taskgroup.New(nil).Limit(4) + checkWait := func(t *testing.T) { + t.Helper() + if err := g.Wait(); err != nil { + t.Errorf("Unexpected error from Wait: %v", err) + } + } + + t.Run("Call", func(t *testing.T) { + var sum int + r := taskgroup.Gather(run, func(v int) { + sum += v + }) + + for _, v := range rand.Perm(15) { + r.Call(func() (int, error) { + if v > 10 { + return -100, errors.New("don't add this") + } + return v, nil + }) + } + + g.Wait() + if want := (10 * 11) / 2; sum != want { + t.Errorf("Final result: got %d, want %d", sum, want) + } + }) + + t.Run("Run", func(t *testing.T) { + var sum int + r := taskgroup.Gather(run, func(v int) { + sum += v + }) + for _, v := range rand.Perm(15) { + r.Run(func() int { return v + 1 }) + } + + checkWait(t) + if want := (15 * 16) / 2; sum != want { + t.Errorf("Final result: got %d, want %d", sum, want) + } + }) + + t.Run("Report", func(t *testing.T) { + var sum uint32 + r := taskgroup.Gather(g.Go, func(v uint32) { + sum |= v + }) + + for _, i := range rand.Perm(32) { + r.Report(func(report func(v uint32)) error { + for _, v := range rand.Perm(i + 1) { + report(uint32(1 << v)) + } + return nil + }) + } + + checkWait(t) + if sum != math.MaxUint32 { + t.Errorf("Final result: got %d, want %d", sum, math.MaxUint32) + } + }) +} + type peakValue struct { μ sync.Mutex cur, max int