Skip to content

Commit

Permalink
add some minimal Runner tests
Browse files Browse the repository at this point in the history
  • Loading branch information
creachadair committed Oct 6, 2024
1 parent 82ce343 commit 7a6a12c
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 2 deletions.
2 changes: 0 additions & 2 deletions collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ func (c *Collector[T]) Run(f func() T) Task {
}

// A Runner manages a group of [Task] functions that report values.
// At least the Go field must be populated. A Runner must not be copied, nor
// its exported fields modified, after its first use.
type Runner[T any] struct {
run func(Task) // start the task in a goroutine

Expand Down
103 changes: 103 additions & 0 deletions taskgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package taskgroup_test
import (
"context"
"errors"
"fmt"
"math/rand/v2"
"reflect"
"sync"
Expand Down Expand Up @@ -336,6 +337,64 @@ func TestCollector_Report(t *testing.T) {
}
}

func TestRunner(t *testing.T) {
defer leaktest.Check(t)()

g, run := taskgroup.New(nil).Limit(4)

var sum int
r := taskgroup.NewRunner(run, func(v int) {
sum += v
})

vs := rand.Perm(15)
for i, v := range vs {
v := v
if v > 10 {
// This value should not be accumulated.
r.Call(func() (int, error) {
return -100, errors.New("don't add this")
})
} else if i%2 == 0 {
// Report a single value, no error.
r.Call(func() (int, error) { return v, nil })
} else {
// Report a single value.
r.Run(func() int { return v })
}
}
g.Wait() // wait for tasks to finish

if want := (10 * 11) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
}

func TestRunner_Report(t *testing.T) {
defer leaktest.Check(t)()

var g taskgroup.Group

var sum int
r := taskgroup.NewRunner(g.Go, func(v int) {
sum += v
})

r.Report(func(report func(v int)) error {
for _, v := range rand.Perm(10) {
report(v)
}
return nil
})

if err := g.Wait(); err != nil {
t.Errorf("Unexpected error from group: %v", err)
}
if want := (9 * 10) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
}

type peakValue struct {
μ sync.Mutex
cur, max int
Expand All @@ -355,3 +414,47 @@ func (p *peakValue) dec() {
p.cur--
p.μ.Unlock()
}

func TestTree(t *testing.T) {
defer leaktest.Check(t)()

vs := rand.Perm(1000)

g, run := taskgroup.New(nil).Limit(5)

type result [3]int
r := taskgroup.NewRunner[result](run, nil)
r.Collect(func(v result) { t.Logf("+ %d at %d: %d", v[0], v[1], v[2]) })

for i := range vs {
r.Run(func() result {
// Find the location of i in the permutation.
for j, v := range vs {
if v != i {
continue
}

// Count the number of things less than i earlier in vs than i.
// Do this in the most inefficient possible way.
g, run := taskgroup.New(nil).Limit(5)
var countLess int
r := taskgroup.NewRunner(run, func(int) {
countLess++
})

for k := range j {
r.Call(func() (int, error) {
if vs[k] < v {
return k, nil
}
return -1, errors.New("no")
})
}
g.Wait()
return result{i, j, countLess}
}
panic(fmt.Sprintf("%d not found", i))
})
}
g.Wait()
}

0 comments on commit 7a6a12c

Please sign in to comment.