diff --git a/qdb/memqdb.go b/qdb/memqdb.go index 011db356b..fd1822ba8 100644 --- a/qdb/memqdb.go +++ b/qdb/memqdb.go @@ -278,6 +278,7 @@ func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) { spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: lock key range") q.mu.Lock() defer q.mu.Unlock() + defer spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: exit: lock key range") krs, ok := q.Krs[id] if !ok { @@ -285,7 +286,15 @@ func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) { } err := ExecuteCommands(q.DumpState, NewUpdateCommand(q.Freq, id, true), - NewCustomCommand(q.Locks[id].Lock, q.Locks[id].Unlock)) + NewCustomCommand(func() { + if lock, ok := q.Locks[id]; ok { + lock.Lock() + } + }, func() { + if lock, ok := q.Locks[id]; ok { + lock.Unlock() + } + })) if err != nil { return nil, err } @@ -294,16 +303,25 @@ func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) { } func (q *MemQDB) UnlockKeyRange(_ context.Context, id string) error { - spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: lock key range") + spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: unlock key range") q.mu.Lock() defer q.mu.Unlock() + defer spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: exit: unlock key range") if !q.Freq[id] { return fmt.Errorf("key range %v not locked", id) } return ExecuteCommands(q.DumpState, NewUpdateCommand(q.Freq, id, false), - NewCustomCommand(q.Locks[id].Unlock, q.Locks[id].Lock)) + NewCustomCommand(func() { + if lock, ok := q.Locks[id]; ok { + lock.Unlock() + } + }, func() { + if lock, ok := q.Locks[id]; ok { + lock.Lock() + } + })) } func (q *MemQDB) CheckLockedKeyRange(ctx context.Context, id string) (*KeyRange, error) { diff --git a/qdb/memqdb_test.go b/qdb/memqdb_test.go index ec47ca926..5b59001e0 100644 --- a/qdb/memqdb_test.go +++ b/qdb/memqdb_test.go @@ -55,70 +55,48 @@ func TestMemqdbRacing(t *testing.T) { var wg sync.WaitGroup ctx := context.TODO() - methods := []func() error{ - func() error { return memqdb.AddDataspace(ctx, mockDataspace) }, - func() error { return memqdb.AddKeyRange(ctx, mockKeyRange) }, - func() error { return memqdb.AddRouter(ctx, mockRouter) }, - func() error { return memqdb.AddShard(ctx, mockShard) }, - func() error { return memqdb.AddShardingRule(ctx, mockShardingRule) }, - func() error { - return memqdb.RecordTransferTx(ctx, mockDataTransferTransaction.FromShardId, mockDataTransferTransaction) + methods := []func(){ + func() { _ = memqdb.AddDataspace(ctx, mockDataspace) }, + func() { _ = memqdb.AddKeyRange(ctx, mockKeyRange) }, + func() { _ = memqdb.AddRouter(ctx, mockRouter) }, + func() { _ = memqdb.AddShard(ctx, mockShard) }, + func() { _ = memqdb.AddShardingRule(ctx, mockShardingRule) }, + func() { + _ = memqdb.RecordTransferTx(ctx, mockDataTransferTransaction.FromShardId, mockDataTransferTransaction) }, - func() error { - _, err_local := memqdb.ListDataspaces(ctx) - return err_local + func() { _, _ = memqdb.ListDataspaces(ctx) }, + func() { _, _ = memqdb.ListKeyRanges(ctx) }, + func() { _, _ = memqdb.ListRouters(ctx) }, + func() { _, _ = memqdb.ListShardingRules(ctx) }, + func() { _, _ = memqdb.ListShards(ctx) }, + func() { _, _ = memqdb.GetKeyRange(ctx, mockKeyRange.KeyRangeID) }, + func() { _, _ = memqdb.GetShard(ctx, mockShard.ID) }, + func() { _, _ = memqdb.GetShardingRule(ctx, mockShardingRule.ID) }, + func() { _, _ = memqdb.GetTransferTx(ctx, mockDataTransferTransaction.FromShardId) }, + func() { _ = memqdb.ShareKeyRange(mockKeyRange.KeyRangeID) }, + func() { _ = memqdb.DropKeyRange(ctx, mockKeyRange.KeyRangeID) }, + func() { _ = memqdb.DropKeyRangeAll(ctx) }, + func() { _ = memqdb.DropShardingRule(ctx, mockShardingRule.ID) }, + func() { _, _ = memqdb.DropShardingRuleAll(ctx) }, + func() { _ = memqdb.RemoveTransferTx(ctx, mockDataTransferTransaction.FromShardId) }, + func() { _ = memqdb.LockRouter(ctx, mockRouter.ID) }, + func() { + _, _ = memqdb.LockKeyRange(ctx, mockKeyRange.KeyRangeID) + _ = memqdb.UnlockKeyRange(ctx, mockKeyRange.KeyRangeID) }, - func() error { - _, err_local := memqdb.ListKeyRanges(ctx) - return err_local - }, - func() error { - _, err_local := memqdb.ListRouters(ctx) - return err_local - }, - func() error { - _, err_local := memqdb.ListShardingRules(ctx) - return err_local - }, - func() error { - _, err_local := memqdb.ListShards(ctx) - return err_local - }, - func() error { - _, err_local := memqdb.GetKeyRange(ctx, mockKeyRange.KeyRangeID) - return err_local - }, - func() error { - _, err_local := memqdb.GetShard(ctx, mockShard.ID) - return err_local - }, - func() error { - _, err_local := memqdb.GetShardingRule(ctx, mockShardingRule.ID) - return err_local - }, - func() error { - _, err_local := memqdb.GetTransferTx(ctx, mockDataTransferTransaction.FromShardId) - return err_local - }, - func() error { return memqdb.ShareKeyRange(mockKeyRange.KeyRangeID) }, - func() error { return memqdb.DropKeyRange(ctx, mockKeyRange.KeyRangeID) }, - func() error { return memqdb.DropKeyRangeAll(ctx) }, - func() error { return memqdb.DropShardingRule(ctx, mockShardingRule.ID) }, - func() error { - _, err_local := memqdb.DropShardingRuleAll(ctx) - return err_local - }, - func() error { return memqdb.RemoveTransferTx(ctx, mockDataTransferTransaction.FromShardId) }, + func() { _ = memqdb.UpdateKeyRange(ctx, mockKeyRange) }, + func() { _ = memqdb.DeleteRouter(ctx, mockRouter.ID) }, } - for i := 0; i < 10; i++ { for _, m := range methods { wg.Add(1) - go func(m func() error) { - _ = m() + go func(m func()) { + m() wg.Done() }(m) } + wg.Wait() + } wg.Wait() }