diff --git a/core/workerpool/group.go b/core/workerpool/group.go index 3d9de6f7b..fdf8d273a 100644 --- a/core/workerpool/group.go +++ b/core/workerpool/group.go @@ -58,17 +58,16 @@ func (g *Group) CreatePool(name string, optsWorkerCount ...int) (pool *Unbounded return pool.Start() } +func (g *Group) Root() *Group { + return lo.Cond(g.root != nil, g.root, g) +} + func (g *Group) Wait() { g.PendingChildrenCounter.WaitIsZero() } func (g *Group) WaitAll() { - if g.root != nil { - g.root.Wait() - return - } - - g.Wait() + g.Root().Wait() } func (g *Group) Pool(name string) (pool *UnboundedWorkerPool, exists bool) { @@ -101,7 +100,7 @@ func (g *Group) Pools() (pools map[string]*UnboundedWorkerPool) { } func (g *Group) CreateGroup(name string) (group *Group) { - group = newGroupWithRoot(name, lo.Cond(g.root != nil, g.root, g)) + group = newGroupWithRoot(name, g.Root()) group.PendingChildrenCounter.Subscribe(func(oldValue, newValue int) { if oldValue == 0 { g.PendingChildrenCounter.Increase() diff --git a/core/workerpool/group_test.go b/core/workerpool/group_test.go index 43f155318..35225d52d 100644 --- a/core/workerpool/group_test.go +++ b/core/workerpool/group_test.go @@ -4,12 +4,16 @@ import ( "fmt" "testing" "time" + + "github.com/stretchr/testify/require" ) func Test(t *testing.T) { group := NewGroup(t.Name()) _ = group.CreatePool("poolA") + require.Equal(t, group, group.Root()) + subgroup1 := group.CreateGroup("sub1") pool1 := subgroup1.CreatePool("pool1") pool2 := subgroup1.CreatePool("pool2") @@ -18,6 +22,8 @@ func Test(t *testing.T) { subSubGroup := subgroup2.CreateGroup("loop") _ = subSubGroup.CreatePool("pool3") + require.Equal(t, group, subSubGroup.Root()) + pool1.Submit(func() { time.Sleep(1 * time.Second)