From 459ab9d41a647e3d21d1a053baa159889616b054 Mon Sep 17 00:00:00 2001 From: Nikiforov Konstantin Date: Mon, 7 Aug 2023 10:31:40 +0000 Subject: [PATCH] fix deadlock: dropkeyrange with unlockkeyrange --- pkg/coord/local/clocal.go | 9 +--- qdb/command.go | 53 ++++++++++++-------- qdb/memqdb.go | 103 ++++++++++++++++++++++++++++---------- qdb/memqdb_test.go | 2 +- 4 files changed, 112 insertions(+), 55 deletions(-) diff --git a/pkg/coord/local/clocal.go b/pkg/coord/local/clocal.go index 28c080355..e6b9db1ad 100644 --- a/pkg/coord/local/clocal.go +++ b/pkg/coord/local/clocal.go @@ -202,16 +202,9 @@ func (qr *LocalCoordinator) Unite(ctx context.Context, req *kr.UniteKeyRange) er }(qr.qdb, ctx, req.KeyRangeIDLeft) // TODO: krRight seems to be empty. - if krright, err = qr.qdb.LockKeyRange(ctx, req.KeyRangeIDRight); err != nil { + if krright, err = qr.qdb.GetKeyRange(ctx, req.KeyRangeIDRight); err != nil { return err } - defer func(qdb qdb.QDB, ctx context.Context, keyRangeID string) { - err := qdb.UnlockKeyRange(ctx, keyRangeID) - if err != nil { - spqrlog.Zero.Error().Err(err).Msg("") - return - } - }(qr.qdb, ctx, req.KeyRangeIDRight) if err = qr.qdb.DropKeyRange(ctx, krright.KeyRangeID); err != nil { return err diff --git a/qdb/command.go b/qdb/command.go index da8039433..b2aec49f7 100644 --- a/qdb/command.go +++ b/qdb/command.go @@ -3,8 +3,8 @@ package qdb import "github.com/pg-sharding/spqr/pkg/spqrlog" type Command interface { - Do() - Undo() + Do() error + Undo() error } func NewDeleteCommand[T any](m map[string]T, key string) *DeleteCommand[T] { @@ -18,17 +18,19 @@ type DeleteCommand[T any] struct { present bool } -func (c *DeleteCommand[T]) Do() { +func (c *DeleteCommand[T]) Do() error { c.value, c.present = c.m[c.key] delete(c.m, c.key) + return nil } -func (c *DeleteCommand[T]) Undo() { +func (c *DeleteCommand[T]) Undo() error { if !c.present { delete(c.m, c.key) } else { c.m[c.key] = c.value } + return nil } func NewUpdateCommand[T any](m map[string]T, key string, value T) *UpdateCommand[T] { @@ -43,17 +45,19 @@ type UpdateCommand[T any] struct { present bool } -func (c *UpdateCommand[T]) Do() { +func (c *UpdateCommand[T]) Do() error { c.prevValue, c.present = c.m[c.key] c.m[c.key] = c.value + return nil } -func (c *UpdateCommand[T]) Undo() { +func (c *UpdateCommand[T]) Undo() error { if !c.present { delete(c.m, c.key) } else { c.m[c.key] = c.prevValue } + return nil } func NewDropCommand[T any](m map[string]T) *DropCommand[T] { @@ -65,7 +69,7 @@ type DropCommand[T any] struct { copy map[string]T } -func (c *DropCommand[T]) Do() { +func (c *DropCommand[T]) Do() error { c.copy = make(map[string]T) for k, v := range c.m { c.copy[k] = v @@ -73,40 +77,49 @@ func (c *DropCommand[T]) Do() { for k := range c.m { delete(c.m, k) } + return nil } -func (c *DropCommand[T]) Undo() { +func (c *DropCommand[T]) Undo() error { for k, v := range c.copy { c.m[k] = v } + return nil } -func NewCustomCommand(do func(), undo func()) *CustomCommand { +func NewCustomCommand(do func() error, undo func() error) *CustomCommand { return &CustomCommand{do: do, undo: undo} } type CustomCommand struct { - do func() - undo func() + do func() error + undo func() error } -func (c *CustomCommand) Do() { - c.do() +func (c *CustomCommand) Do() error { + return c.do() } -func (c *CustomCommand) Undo() { - c.undo() +func (c *CustomCommand) Undo() error { + return c.undo() } func ExecuteCommands(saver func() error, commands ...Command) error { - for _, c := range commands { - c.Do() + firstError := len(commands) + var err error + for i, c := range commands { + err = c.Do() + if err != nil { + firstError = i + } + } + if err == nil { + err = saver() } - err := saver() if err != nil { spqrlog.Zero.Info().Msg("memqdb: undo commands") - for _, c := range commands { - c.Undo() + for _, c := range commands[:firstError] { + err = c.Undo() } return err } diff --git a/qdb/memqdb.go b/qdb/memqdb.go index ca565f44f..5430b3e24 100644 --- a/qdb/memqdb.go +++ b/qdb/memqdb.go @@ -13,8 +13,10 @@ import ( type MemQDB struct { // TODO create more mutex per map if needed - mu sync.RWMutex + mu sync.RWMutex + muDeletedKrs sync.RWMutex + deletedKrs map[string]bool Locks map[string]*sync.RWMutex `json:"locks"` Freq map[string]bool `json:"freq"` Krs map[string]*KeyRange `json:"krs"` @@ -40,6 +42,7 @@ func NewMemQDB(backupPath string) (*MemQDB, error) { Dataspaces: map[string]*Dataspace{}, Routers: map[string]*Router{}, Transactions: map[string]*DataTransferTransaction{}, + deletedKrs: map[string]bool{}, backupPath: backupPath, }, nil @@ -224,6 +227,23 @@ func (q *MemQDB) UpdateKeyRange(ctx context.Context, keyRange *KeyRange) error { func (q *MemQDB) DropKeyRange(ctx context.Context, id string) error { spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: drop key range") + q.muDeletedKrs.Lock() + if q.deletedKrs[id] { + q.muDeletedKrs.Unlock() + return fmt.Errorf("key range '%s' already deleted", id) + } + q.deletedKrs[id] = true + q.muDeletedKrs.Unlock() + + q.mu.RLock() + + var lock *sync.RWMutex + if _, ok := q.Krs[id]; ok { + q.Locks[id].Lock() + lock = q.Locks[id] + defer lock.Unlock() + } + q.mu.RUnlock() q.mu.Lock() defer q.mu.Unlock() @@ -233,27 +253,35 @@ func (q *MemQDB) DropKeyRange(ctx context.Context, id string) error { func (q *MemQDB) DropKeyRangeAll(ctx context.Context) error { spqrlog.Zero.Debug().Msg("memqdb: drop all key ranges") - q.mu.Lock() - defer q.mu.Unlock() + q.muDeletedKrs.Lock() + for id := range q.Locks { + if q.deletedKrs[id] { + q.muDeletedKrs.Unlock() + return fmt.Errorf("key range '%s' already deleted", id) + } + q.deletedKrs[id] = true + } + q.muDeletedKrs.Unlock() + + q.mu.RLock() var locks []*sync.RWMutex + for _, l := range q.Locks { + l.Lock() + locks = append(locks, l) + } + defer func() { + for _, l := range locks { + l.Unlock() + } + }() + spqrlog.Zero.Debug().Msg("memqdb: acquired all locks") - return ExecuteCommands(q.DumpState, - NewCustomCommand(func() { - for _, l := range q.Locks { - l.Lock() - locks = append(locks, l) - } - spqrlog.Zero.Debug().Msg("memqdb: acquired all locks") - }, func() {}), - NewDropCommand(q.Krs), NewDropCommand(q.Locks), - NewCustomCommand(func() { - for _, l := range locks { - l.Unlock() - } - }, - func() {}), - ) + q.mu.RUnlock() + q.mu.Lock() + defer q.mu.Unlock() + + return ExecuteCommands(q.DumpState, NewDropCommand(q.Krs), NewDropCommand(q.Locks)) } func (q *MemQDB) ListKeyRanges(_ context.Context) ([]*KeyRange, error) { @@ -274,6 +302,22 @@ func (q *MemQDB) ListKeyRanges(_ context.Context) ([]*KeyRange, error) { return ret, nil } +func (q *MemQDB) TryLockKeyRange(lock *sync.RWMutex, id string, read bool) error { + if _, ok := q.deletedKrs[id]; ok { + return fmt.Errorf("key range '%s' deleted", id) + } + if read { + lock.RLock() + } else { + lock.Lock() + } + + if _, ok := q.Krs[id]; !ok { + return fmt.Errorf("key range '%s' deleted after lock acuired", id) + } + return nil +} + func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) { spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: lock key range") q.mu.RLock() @@ -286,14 +330,16 @@ func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) { } err := ExecuteCommands(q.DumpState, NewUpdateCommand(q.Freq, id, true), - NewCustomCommand(func() { + NewCustomCommand(func() error { if lock, ok := q.Locks[id]; ok { - lock.Lock() + return q.TryLockKeyRange(lock, id, false) } - }, func() { + return nil + }, func() error { if lock, ok := q.Locks[id]; ok { lock.Unlock() } + return nil })) if err != nil { return nil, err @@ -313,14 +359,16 @@ func (q *MemQDB) UnlockKeyRange(_ context.Context, id string) error { } return ExecuteCommands(q.DumpState, NewUpdateCommand(q.Freq, id, false), - NewCustomCommand(func() { + NewCustomCommand(func() error { if lock, ok := q.Locks[id]; ok { lock.Unlock() } - }, func() { + return nil + }, func() error { if lock, ok := q.Locks[id]; ok { - lock.Lock() + return q.TryLockKeyRange(lock, id, false) } + return nil })) } @@ -352,7 +400,10 @@ func (q *MemQDB) ShareKeyRange(id string) error { return fmt.Errorf("no such key") } - lock.RLock() + err := q.TryLockKeyRange(lock, id, true) + if err != nil { + return err + } defer lock.RUnlock() return nil diff --git a/qdb/memqdb_test.go b/qdb/memqdb_test.go index 5b59001e0..5e4451e65 100644 --- a/qdb/memqdb_test.go +++ b/qdb/memqdb_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) -const MemQDBPath = "memqdb.json" +const MemQDBPath = "" var mockDataspace *qdb.Dataspace = &qdb.Dataspace{"123"} var mockShard *qdb.Shard = &qdb.Shard{