Skip to content

Commit

Permalink
add basic tests for the Gatherer
Browse files Browse the repository at this point in the history
  • Loading branch information
creachadair committed Oct 6, 2024
1 parent 24f9525 commit 53639c3
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions taskgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package taskgroup_test
import (
"context"
"errors"
"fmt"
"math"
"math/rand/v2"
"reflect"
"sync"
Expand Down Expand Up @@ -336,6 +338,75 @@ func TestCollector_Report(t *testing.T) {
}
}

func TestGatherer(t *testing.T) {
defer leaktest.Check(t)()

g, run := taskgroup.New(nil).Limit(4)
checkWait := func(t *testing.T) {
t.Helper()
if err := g.Wait(); err != nil {
t.Errorf("Unexpected error from Wait: %v", err)
}
}

t.Run("Call", func(t *testing.T) {
var sum int
r := taskgroup.Gather(run, func(v int) {
sum += v
})

for _, v := range rand.Perm(15) {
r.Call(func() (int, error) {
if v > 10 {
return -100, errors.New("don't add this")
}
return v, nil
})
}

g.Wait()
if want := (10 * 11) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
})

t.Run("Run", func(t *testing.T) {
var sum int
r := taskgroup.Gather(run, func(v int) {
sum += v
})
for _, v := range rand.Perm(15) {
r.Run(func() int { return v + 1 })
}

checkWait(t)
if want := (15 * 16) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
})

t.Run("Report", func(t *testing.T) {
var sum uint32
r := taskgroup.Gather(g.Go, func(v uint32) {
sum |= v
})

for _, i := range rand.Perm(32) {
r.Report(func(report func(v uint32)) error {
for _, v := range rand.Perm(i + 1) {
report(uint32(1 << v))
}
return nil
})
}

checkWait(t)
if sum != math.MaxUint32 {
t.Errorf("Final result: got %d, want %d", sum, math.MaxUint32)
}
})
}

type peakValue struct {
μ sync.Mutex
cur, max int
Expand All @@ -355,3 +426,48 @@ func (p *peakValue) dec() {
p.cur--
p.μ.Unlock()
}

func TestTree(t *testing.T) {
defer leaktest.Check(t)()

vs := rand.Perm(1000)

g, run := taskgroup.New(nil).Limit(5)

type result [3]int
r := taskgroup.Gather(run, func(v result) {
t.Logf("+ %d at %d: %d", v[0], v[1], v[2])
})

for i := range vs {
r.Run(func() result {
// Find the location of i in the permutation.
for j, v := range vs {
if v != i {
continue
}

// Count the number of things less than i earlier in vs than i.
// Do this in the most inefficient possible way.
g, run := taskgroup.New(nil).Limit(5)
var countLess int
r := taskgroup.Gather(run, func(int) {
countLess++
})

for k := range j {
r.Call(func() (int, error) {
if vs[k] < v {
return k, nil
}
return -1, errors.New("no")
})
}
g.Wait()
return result{i, j, countLess}
}
panic(fmt.Sprintf("%d not found", i))
})
}
g.Wait()
}

0 comments on commit 53639c3

Please sign in to comment.