From 291ef6a536f05678d1fbddc6071194fe9ce65b97 Mon Sep 17 00:00:00 2001 From: Harry Bagdi Date: Fri, 21 May 2021 13:48:24 -0700 Subject: [PATCH] fix(solver) resolve a data race with ops counters --- cmd/common.go | 8 ++++---- cmd/common_konnect.go | 8 ++++---- solver/solver.go | 36 +++++++++++++++++++++++++++++------- solver/solver_test.go | 23 +++++++++++++++++++++++ 4 files changed, 60 insertions(+), 15 deletions(-) create mode 100644 solver/solver_test.go diff --git a/cmd/common.go b/cmd/common.go index b482dca8c..f758afa65 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -166,14 +166,14 @@ func syncMain(ctx context.Context, filenames []string, dry bool, parallelism, stats, errs := solver.Solve(ctx, s, wsClient, nil, parallelism, dry) printFn := color.New(color.FgGreen, color.Bold).PrintfFunc() printFn("Summary:\n") - printFn(" Created: %v\n", stats.CreateOps) - printFn(" Updated: %v\n", stats.UpdateOps) - printFn(" Deleted: %v\n", stats.DeleteOps) + printFn(" Created: %v\n", stats.CreateOps.Count()) + printFn(" Updated: %v\n", stats.UpdateOps.Count()) + printFn(" Deleted: %v\n", stats.DeleteOps.Count()) if errs != nil { return utils.ErrArray{Errors: errs} } if diffCmdNonZeroExitCode && - stats.CreateOps+stats.UpdateOps+stats.DeleteOps != 0 { + stats.CreateOps.Count()+stats.UpdateOps.Count()+stats.DeleteOps.Count() != 0 { os.Exit(exitCodeDiffDetection) } return nil diff --git a/cmd/common_konnect.go b/cmd/common_konnect.go index 37aa493f7..2d375bf98 100644 --- a/cmd/common_konnect.go +++ b/cmd/common_konnect.go @@ -90,14 +90,14 @@ func syncKonnect(ctx context.Context, stats, errs := solver.Solve(ctx, s, kongClient, konnectClient, parallelism, dry) printFn := color.New(color.FgGreen, color.Bold).PrintfFunc() printFn("Summary:\n") - printFn(" Created: %v\n", stats.CreateOps) - printFn(" Updated: %v\n", stats.UpdateOps) - printFn(" Deleted: %v\n", stats.DeleteOps) + printFn(" Created: %v\n", stats.CreateOps.Count()) + printFn(" Updated: %v\n", stats.UpdateOps.Count()) + printFn(" Deleted: %v\n", stats.DeleteOps.Count()) if errs != nil { return utils.ErrArray{Errors: errs} } if diffCmdNonZeroExitCode && - stats.CreateOps+stats.UpdateOps+stats.DeleteOps != 0 { + stats.CreateOps.Count()+stats.UpdateOps.Count()+stats.DeleteOps.Count() != 0 { os.Exit(exitCodeDiffDetection) } diff --git a/solver/solver.go b/solver/solver.go index 6d71e8136..c308f351d 100644 --- a/solver/solver.go +++ b/solver/solver.go @@ -2,6 +2,7 @@ package solver import ( "context" + "sync" "github.com/kong/deck/crud" "github.com/kong/deck/diff" @@ -14,9 +15,26 @@ import ( // Stats holds the stats related to a Solve. type Stats struct { - CreateOps int - UpdateOps int - DeleteOps int + CreateOps *AtomicInt32Counter + UpdateOps *AtomicInt32Counter + DeleteOps *AtomicInt32Counter +} + +type AtomicInt32Counter struct { + counter int32 + lock sync.RWMutex +} + +func (a *AtomicInt32Counter) Increment(delta int32) { + a.lock.Lock() + defer a.lock.Unlock() + a.counter += delta +} + +func (a *AtomicInt32Counter) Count() int32 { + a.lock.RLock() + defer a.lock.RUnlock() + return a.counter } // Solve generates a diff and walks the graph. @@ -26,15 +44,19 @@ func Solve(ctx context.Context, syncer *diff.Syncer, r := buildRegistry(client, konnectClient) - var stats Stats + stats := Stats{ + CreateOps: &AtomicInt32Counter{}, + UpdateOps: &AtomicInt32Counter{}, + DeleteOps: &AtomicInt32Counter{}, + } recordOp := func(op crud.Op) { switch op { case crud.Create: - stats.CreateOps = stats.CreateOps + 1 + stats.CreateOps.Increment(1) case crud.Update: - stats.UpdateOps = stats.UpdateOps + 1 + stats.UpdateOps.Increment(1) case crud.Delete: - stats.DeleteOps = stats.DeleteOps + 1 + stats.DeleteOps.Increment(1) } } diff --git a/solver/solver_test.go b/solver/solver_test.go new file mode 100644 index 000000000..427b96504 --- /dev/null +++ b/solver/solver_test.go @@ -0,0 +1,23 @@ +package solver + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicInt32Counter(t *testing.T) { + var a AtomicInt32Counter + var wg sync.WaitGroup + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + a.Increment(int32(1)) + }() + } + wg.Wait() + assert.Equal(t, int32(10), a.Count()) +}