Skip to content

Commit

Permalink
add missing methods
Browse files Browse the repository at this point in the history
  • Loading branch information
nikifkon committed Aug 3, 2023
1 parent 3b31d87 commit b9ad88e
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 58 deletions.
24 changes: 21 additions & 3 deletions qdb/memqdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,23 @@ func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) {
spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: lock key range")
q.mu.Lock()
defer q.mu.Unlock()
defer spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: exit: lock key range")

krs, ok := q.Krs[id]
if !ok {
return nil, fmt.Errorf("key range '%s' does not exist", id)
}

err := ExecuteCommands(q.DumpState, NewUpdateCommand(q.Freq, id, true),
NewCustomCommand(q.Locks[id].Lock, q.Locks[id].Unlock))
NewCustomCommand(func() {
if lock, ok := q.Locks[id]; ok {
lock.Lock()
}
}, func() {
if lock, ok := q.Locks[id]; ok {
lock.Unlock()
}
}))
if err != nil {
return nil, err
}
Expand All @@ -294,16 +303,25 @@ func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) {
}

func (q *MemQDB) UnlockKeyRange(_ context.Context, id string) error {
spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: lock key range")
spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: unlock key range")
q.mu.Lock()
defer q.mu.Unlock()
defer spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: exit: unlock key range")

if !q.Freq[id] {
return fmt.Errorf("key range %v not locked", id)
}

return ExecuteCommands(q.DumpState, NewUpdateCommand(q.Freq, id, false),
NewCustomCommand(q.Locks[id].Unlock, q.Locks[id].Lock))
NewCustomCommand(func() {
if lock, ok := q.Locks[id]; ok {
lock.Unlock()
}
}, func() {
if lock, ok := q.Locks[id]; ok {
lock.Lock()
}
}))
}

func (q *MemQDB) CheckLockedKeyRange(ctx context.Context, id string) (*KeyRange, error) {
Expand Down
88 changes: 33 additions & 55 deletions qdb/memqdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,70 +55,48 @@ func TestMemqdbRacing(t *testing.T) {
var wg sync.WaitGroup
ctx := context.TODO()

methods := []func() error{
func() error { return memqdb.AddDataspace(ctx, mockDataspace) },
func() error { return memqdb.AddKeyRange(ctx, mockKeyRange) },
func() error { return memqdb.AddRouter(ctx, mockRouter) },
func() error { return memqdb.AddShard(ctx, mockShard) },
func() error { return memqdb.AddShardingRule(ctx, mockShardingRule) },
func() error {
return memqdb.RecordTransferTx(ctx, mockDataTransferTransaction.FromShardId, mockDataTransferTransaction)
methods := []func(){
func() { _ = memqdb.AddDataspace(ctx, mockDataspace) },
func() { _ = memqdb.AddKeyRange(ctx, mockKeyRange) },
func() { _ = memqdb.AddRouter(ctx, mockRouter) },
func() { _ = memqdb.AddShard(ctx, mockShard) },
func() { _ = memqdb.AddShardingRule(ctx, mockShardingRule) },
func() {
_ = memqdb.RecordTransferTx(ctx, mockDataTransferTransaction.FromShardId, mockDataTransferTransaction)
},
func() error {
_, err_local := memqdb.ListDataspaces(ctx)
return err_local
func() { _, _ = memqdb.ListDataspaces(ctx) },
func() { _, _ = memqdb.ListKeyRanges(ctx) },
func() { _, _ = memqdb.ListRouters(ctx) },
func() { _, _ = memqdb.ListShardingRules(ctx) },
func() { _, _ = memqdb.ListShards(ctx) },
func() { _, _ = memqdb.GetKeyRange(ctx, mockKeyRange.KeyRangeID) },
func() { _, _ = memqdb.GetShard(ctx, mockShard.ID) },
func() { _, _ = memqdb.GetShardingRule(ctx, mockShardingRule.ID) },
func() { _, _ = memqdb.GetTransferTx(ctx, mockDataTransferTransaction.FromShardId) },
func() { _ = memqdb.ShareKeyRange(mockKeyRange.KeyRangeID) },
func() { _ = memqdb.DropKeyRange(ctx, mockKeyRange.KeyRangeID) },
func() { _ = memqdb.DropKeyRangeAll(ctx) },
func() { _ = memqdb.DropShardingRule(ctx, mockShardingRule.ID) },
func() { _, _ = memqdb.DropShardingRuleAll(ctx) },
func() { _ = memqdb.RemoveTransferTx(ctx, mockDataTransferTransaction.FromShardId) },
func() { _ = memqdb.LockRouter(ctx, mockRouter.ID) },
func() {
_, _ = memqdb.LockKeyRange(ctx, mockKeyRange.KeyRangeID)
_ = memqdb.UnlockKeyRange(ctx, mockKeyRange.KeyRangeID)
},
func() error {
_, err_local := memqdb.ListKeyRanges(ctx)
return err_local
},
func() error {
_, err_local := memqdb.ListRouters(ctx)
return err_local
},
func() error {
_, err_local := memqdb.ListShardingRules(ctx)
return err_local
},
func() error {
_, err_local := memqdb.ListShards(ctx)
return err_local
},
func() error {
_, err_local := memqdb.GetKeyRange(ctx, mockKeyRange.KeyRangeID)
return err_local
},
func() error {
_, err_local := memqdb.GetShard(ctx, mockShard.ID)
return err_local
},
func() error {
_, err_local := memqdb.GetShardingRule(ctx, mockShardingRule.ID)
return err_local
},
func() error {
_, err_local := memqdb.GetTransferTx(ctx, mockDataTransferTransaction.FromShardId)
return err_local
},
func() error { return memqdb.ShareKeyRange(mockKeyRange.KeyRangeID) },
func() error { return memqdb.DropKeyRange(ctx, mockKeyRange.KeyRangeID) },
func() error { return memqdb.DropKeyRangeAll(ctx) },
func() error { return memqdb.DropShardingRule(ctx, mockShardingRule.ID) },
func() error {
_, err_local := memqdb.DropShardingRuleAll(ctx)
return err_local
},
func() error { return memqdb.RemoveTransferTx(ctx, mockDataTransferTransaction.FromShardId) },
func() { _ = memqdb.UpdateKeyRange(ctx, mockKeyRange) },
func() { _ = memqdb.DeleteRouter(ctx, mockRouter.ID) },
}

for i := 0; i < 10; i++ {
for _, m := range methods {
wg.Add(1)
go func(m func() error) {
_ = m()
go func(m func()) {
m()
wg.Done()
}(m)
}
wg.Wait()

}
wg.Wait()
}

0 comments on commit b9ad88e

Please sign in to comment.