Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add force cancel option to the producer-consumer #40

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 38 additions & 32 deletions parallel/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ type Runner interface {
AddTaskWithError(TaskFunc, OnErrorFunc) (int, error)
Run()
Done()
Cancel()
Cancel(bool)
Errors() map[int]error
ActiveThreads() uint32
OpenThreads() int
OpenThreads() uint32
IsStarted() bool
SetMaxParallel(int)
GetFinishedNotification() chan bool
Expand All @@ -39,10 +39,12 @@ type runner struct {
tasks chan *task
// Tasks counter, used to give each task an identifier (task.num).
taskId uint32
// A channel that is closed when the runner is cancelled.
cancel chan struct{}
// Used to make sure the cancel channel is closed only once.
// True when Cancel was invoked
cancel atomic.Bool
// Used to make sure that cancel is called only once.
cancelOnce sync.Once
// Used to make sure that done is called only once.
doneOnce sync.Once
// The maximum number of threads running in parallel.
maxParallel int
// If true, the runner will be cancelled on the first error thrown from a task.
Expand All @@ -52,15 +54,15 @@ type runner struct {
// A WaitGroup that waits for all the threads to close.
threadsWaitGroup sync.WaitGroup
// Threads counter, used to give each thread an identifier (threadId).
threadCount uint32
threadCount atomic.Uint32
// The number of open threads.
openThreads int
openThreads atomic.Uint32
// A lock on openThreads.
openThreadsLock sync.Mutex
// The number of threads currently running tasks.
activeThreads uint32
activeThreads atomic.Uint32
// The number of tasks in the queue.
totalTasksInQueue uint32
totalTasksInQueue atomic.Uint32
// Indicate that the runner has finished.
finishedNotifier chan bool
// Indicates that the finish channel is closed.
Expand Down Expand Up @@ -91,7 +93,7 @@ func NewRunner(maxParallel int, capacity uint, failFast bool) *runner {
finishedNotifier: make(chan bool, 1),
maxParallel: consumers,
failFast: failFast,
cancel: make(chan struct{}),
cancel: atomic.Bool{},
tasks: make(chan *task, capacity),
}
r.errors = make(map[int]error)
Expand Down Expand Up @@ -122,14 +124,12 @@ func (r *runner) addTask(t TaskFunc, errorHandler OnErrorFunc) (int, error) {
nextCount := atomic.AddUint32(&r.taskId, 1)
task := &task{run: t, num: nextCount - 1, onError: errorHandler}

select {
case <-r.cancel:
if r.cancel.Load() {
return -1, errors.New("runner stopped")
default:
atomic.AddUint32(&r.totalTasksInQueue, 1)
r.tasks <- task
return int(task.num), nil
}
r.totalTasksInQueue.Add(1)
r.tasks <- task
return int(task.num), nil
}

// Run r.maxParallel go routines in order to consume all the tasks
Expand All @@ -154,7 +154,9 @@ func (r *runner) Run() {

// Done is used to notify that no more tasks will be produced.
func (r *runner) Done() {
close(r.tasks)
r.doneOnce.Do(func() {
close(r.tasks)
})
}

// GetFinishedNotification returns the finishedNotifier channel, which notifies when the runner is done.
Expand All @@ -171,10 +173,14 @@ func (r *runner) IsStarted() bool {
// Cancel stops the Runner from getting new tasks and empties the tasks queue.
// No new tasks will be executed, and tasks that already started will continue running and won't be interrupted.
// If this Runner is already cancelled, then this function will do nothing.
func (r *runner) Cancel() {
// force - If true, pending tasks in the queue will not be handled.
func (r *runner) Cancel(force bool) {
// No more adding tasks
r.cancel.Store(true)
if force {
r.Done()
}
r.cancelOnce.Do(func() {
// No more adding tasks
close(r.cancel)
// Consume all tasks left
for len(r.tasks) > 0 {
<-r.tasks
Expand All @@ -191,12 +197,12 @@ func (r *runner) Errors() map[int]error {
}

// OpenThreads returns the number of open threads (including idle threads).
func (r *runner) OpenThreads() int {
return r.openThreads
func (r *runner) OpenThreads() uint32 {
return r.openThreads.Load()
}

func (r *runner) ActiveThreads() uint32 {
return r.activeThreads
return r.activeThreads.Load()
}

func (r *runner) SetFinishedNotification(toEnable bool) {
Expand All @@ -222,28 +228,28 @@ func (r *runner) SetMaxParallel(newVal int) {

func (r *runner) addThread() {
r.threadsWaitGroup.Add(1)
nextThreadId := atomic.AddUint32(&r.threadCount, 1) - 1
nextThreadId := r.threadCount.Add(1) - 1
go func(threadId int) {
defer r.threadsWaitGroup.Done()
r.openThreadsLock.Lock()
r.openThreads++
r.openThreads.Add(1)
r.openThreadsLock.Unlock()

// Keep on taking tasks from the queue.
for t := range r.tasks {
// Increase the total of active threads.
atomic.AddUint32(&r.activeThreads, 1)
r.activeThreads.Add(1)
atomic.AddUint32(&r.started, 1)
// Run the task.
e := t.run(threadId)
// Decrease the total of active threads.
atomic.AddUint32(&r.activeThreads, ^uint32(0))
r.activeThreads.Add(^uint32(0))
// Decrease the total of in progress tasks.
atomic.AddUint32(&r.totalTasksInQueue, ^uint32(0))
r.totalTasksInQueue.Add(^uint32(0))
if r.finishedNotificationEnabled {
r.finishedNotifierLock.Lock()
// Notify that the runner has finished its job.
if r.activeThreads == 0 && r.totalTasksInQueue == 0 {
if r.activeThreads.Load() == 0 && r.totalTasksInQueue.Load() == 0 {
r.notifyFinished()
}
r.finishedNotifierLock.Unlock()
Expand All @@ -260,15 +266,15 @@ func (r *runner) addThread() {
r.errorsLock.Unlock()

if r.failFast {
r.Cancel()
r.Cancel(false)
break
}
}

r.openThreadsLock.Lock()
// If the total of open threads is larger than the maximum (maxParallel), then this thread should be closed.
if r.openThreads > r.maxParallel {
r.openThreads--
if int(r.openThreads.Load()) > r.maxParallel {
r.openThreads.Add(^uint32(0))
r.openThreadsLock.Unlock()
break
}
Expand Down
153 changes: 138 additions & 15 deletions parallel/runner_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
package parallel

import (
"errors"
"fmt"
"math"
"math/rand"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestTask(t *testing.T) {
var errTest = errors.New("some error")

func TestIsStarted(t *testing.T) {
runner := NewBounedRunner(1, false)
runner.AddTask(func(i int) error {
return nil
})
runner.Done()
runner.Run()
assert.True(t, runner.IsStarted())
}

func TestAddTask(t *testing.T) {
const count = 70
results := make(chan int, 100)

Expand All @@ -17,44 +32,152 @@ func TestTask(t *testing.T) {
var expectedErrorTotal int
for i := 0; i < count; i++ {
expectedTotal += i
if float64(i) > math.Floor(float64(count)/2) {
if float64(i) > float64(count)/2 {
expectedErrorTotal += i
}

x := i
runner.AddTask(func(i int) error {
_, err := runner.AddTask(func(int) error {
results <- x
time.Sleep(time.Millisecond * time.Duration(rand.Intn(50)))
if float64(x) > math.Floor(float64(count)/2) {
if float64(x) > float64(count)/2 {
return fmt.Errorf("Second half value %d not counted", x)
}
return nil
})
assert.NoError(t, err)
}
runner.Done()
runner.Run()

errs := runner.Errors()

close(results)
var resultsTotal int
for result := range results {
resultsTotal += result
}
if resultsTotal != expectedTotal {
t.Error("Unexpected results total:", resultsTotal)
}
assert.Equal(t, expectedTotal, resultsTotal)

var errorsTotal int
for k, v := range errs {
for k, v := range runner.Errors() {
if v != nil {
errorsTotal += k
}
}
if errorsTotal != expectedErrorTotal {
t.Error("Unexpected errs total:", errorsTotal)
assert.Equal(t, expectedErrorTotal, errorsTotal)
assert.NotZero(t, errorsTotal)
}

func TestAddTaskWithError(t *testing.T) {
// Create new runner
runner := NewRunner(1, 1, false)

// Add task with error
var receivedError = new(error)
onError := func(err error) { *receivedError = err }
taskFunc := func(int) error { return errTest }
_, err := runner.AddTaskWithError(taskFunc, onError)
assert.NoError(t, err)

// Wait for task to finish
runner.Done()
runner.Run()

// Assert error captured
assert.Equal(t, errTest, *receivedError)
assert.Equal(t, errTest, runner.Errors()[0])
}

func TestCancel(t *testing.T) {
// Create new runner
runner := NewBounedRunner(1, false)

// Cancel to prevent receiving another tasks
runner.Cancel(false)

// Add task and expect error
_, err := runner.AddTask(func(int) error { return nil })
assert.ErrorContains(t, err, "runner stopped")
}

func TestForceCancel(t *testing.T) {
// Create new runner
const capacity = 10
runner := NewRunner(1, capacity, true)
// Run tasks
for i := 0; i < capacity; i++ {
taskId := i
_, err := runner.AddTask(func(int) error {
assert.Less(t, taskId, 9)
time.Sleep(100 * time.Millisecond)
return nil
})
assert.NoError(t, err)
}
if errorsTotal == 0 {
t.Error("Unexpected 0 errs total")
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
runner.Run()
}()
go func() {
time.Sleep(200 * time.Millisecond)
runner.Cancel(true)
}()
wg.Wait()

assert.InDelta(t, 5, runner.started, 4)
}

func TestFailFast(t *testing.T) {
// Create new runner with fail-fast
runner := NewBounedRunner(1, true)

// Add task that returns an error
_, err := runner.AddTask(func(int) error {
return errTest
})
assert.NoError(t, err)

// Wait for task to finish
runner.Run()

// Add another task and expect error
_, err = runner.AddTask(func(int) error {
return nil
})
assert.ErrorContains(t, err, "runner stopped")
}

func TestNotifyFinished(t *testing.T) {
// Create new runner
runner := NewBounedRunner(1, false)
runner.SetFinishedNotification(true)

// Cancel to prevent receiving another tasks
runner.Cancel(false)
<-runner.GetFinishedNotification()
}

func TestMaxParallel(t *testing.T) {
// Create new runner with capacity of 10 and max parallelism of 3
const capacity = 10
const parallelism = 3
runner := NewRunner(parallelism, capacity, false)

// Run tasks in parallel
for i := 0; i < capacity; i++ {
_, err := runner.AddTask(func(int) error {
// Assert in range between 1 and 3
assert.InDelta(t, 2, runner.ActiveThreads(), 1)
assert.InDelta(t, 2, runner.OpenThreads(), 1)
time.Sleep(100 * time.Millisecond)
return nil
})
assert.NoError(t, err)
}

// Wait for tasks to finish
runner.Done()
runner.Run()
assert.Equal(t, uint32(capacity), runner.started)
}
Loading