Skip to content

Commit

Permalink
Rewrite Lock/Delete key range ops to prevent deadlocks (#575)
Browse files Browse the repository at this point in the history
* Added TryLocks

* TryLocks in DropKeyRange

* More thorough check for deadlocks

* Deleted redundant deletedKrs field of MemQDB

* Marked unused context.Context args

* Fixed feature tests

* Feature tests fix attempt

* Some refactoring

* Fix coordinator feature test
  • Loading branch information
EinKrebs authored Apr 3, 2024
1 parent a1372e3 commit 76ea934
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 99 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ pooler_run:
####################### TESTS #######################

unittest:
go test -race ./cmd/... ./pkg/... ./router/... ./qdb/... ./coordinator/... ./yacc/console...
go test ./cmd/... ./pkg/... ./router/... ./coordinator/... ./yacc/console...
go test -race -count 20 -timeout 30s ./qdb/...

regress_local: proxy_2sh_run
./script/regress_local.sh
Expand Down
151 changes: 56 additions & 95 deletions qdb/memqdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ import (
type MemQDB struct {
ShardingSchemaKeeper
// TODO create more mutex per map if needed
mu sync.RWMutex
muDeletedKrs sync.RWMutex
mu sync.RWMutex

deletedKrs map[string]bool
Locks map[string]*sync.RWMutex `json:"locks"`
Freq map[string]bool `json:"freq"`
Krs map[string]*KeyRange `json:"krs"`
Expand Down Expand Up @@ -47,7 +45,6 @@ func NewMemQDB(backupPath string) (*MemQDB, error) {
RelationDistribution: map[string]string{},
Routers: map[string]*Router{},
Transactions: map[string]*DataTransferTransaction{},
deletedKrs: map[string]bool{},

backupPath: backupPath,
}, nil
Expand Down Expand Up @@ -127,8 +124,16 @@ func (q *MemQDB) DumpState() error {
// ==============================================================================

// TODO : unit tests
func (q *MemQDB) CreateKeyRange(ctx context.Context, keyRange *KeyRange) error {
func (q *MemQDB) CreateKeyRange(_ context.Context, keyRange *KeyRange) error {
spqrlog.Zero.Debug().Interface("key-range", keyRange).Msg("memqdb: add key range")

q.mu.RLock()
if _, ok := q.Krs[keyRange.KeyRangeID]; ok {
q.mu.RUnlock()
return spqrerror.Newf(spqrerror.SPQR_KEYRANGE_ERROR, "key range \"%s\" already exists", keyRange.KeyRangeID)
}
q.mu.RUnlock()

q.mu.Lock()
defer q.mu.Unlock()

Expand All @@ -144,7 +149,7 @@ func (q *MemQDB) CreateKeyRange(ctx context.Context, keyRange *KeyRange) error {
}

// TODO : unit tests
func (q *MemQDB) GetKeyRange(ctx context.Context, id string) (*KeyRange, error) {
func (q *MemQDB) GetKeyRange(_ context.Context, id string) (*KeyRange, error) {
spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: get key range")
q.mu.RLock()
defer q.mu.RUnlock()
Expand All @@ -158,7 +163,7 @@ func (q *MemQDB) GetKeyRange(ctx context.Context, id string) (*KeyRange, error)
}

// TODO : unit tests
func (q *MemQDB) UpdateKeyRange(ctx context.Context, keyRange *KeyRange) error {
func (q *MemQDB) UpdateKeyRange(_ context.Context, keyRange *KeyRange) error {
spqrlog.Zero.Debug().Interface("key-range", keyRange).Msg("memqdb: update key range")
q.mu.Lock()
defer q.mu.Unlock()
Expand All @@ -167,86 +172,46 @@ func (q *MemQDB) UpdateKeyRange(ctx context.Context, keyRange *KeyRange) error {
}

// TODO : unit tests
func (q *MemQDB) DropKeyRange(ctx context.Context, id string) error {
func (q *MemQDB) DropKeyRange(_ context.Context, id string) error {
spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: drop key range")

// Do not allow new locks on key range we want to delete
q.muDeletedKrs.Lock()
if q.deletedKrs[id] {
q.muDeletedKrs.Unlock()
return spqrerror.Newf(spqrerror.SPQR_KEYRANGE_ERROR, "key range '%s' already deleted", id)
}
q.deletedKrs[id] = true
q.muDeletedKrs.Unlock()

defer func() {
q.muDeletedKrs.Lock()
defer q.muDeletedKrs.Unlock()
delete(q.deletedKrs, id)
}()

q.mu.RLock()

// Wait until key range will be unlocked
if lock, ok := q.Locks[id]; ok {
lock.Lock()
defer lock.Unlock()
}
q.mu.RUnlock()

q.mu.Lock()
defer q.mu.Unlock()

lock, ok := q.Locks[id]
if !ok {
return nil
}
if !lock.TryLock() {
return spqrerror.Newf(spqrerror.SPQR_KEYRANGE_ERROR, "key range \"%s\" is locked", id)
}
defer lock.Unlock()

return ExecuteCommands(q.DumpState, NewDeleteCommand(q.Krs, id),
NewDeleteCommand(q.Freq, id), NewDeleteCommand(q.Locks, id))
}

// TODO : unit tests
func (q *MemQDB) DropKeyRangeAll(ctx context.Context) error {
func (q *MemQDB) DropKeyRangeAll(_ context.Context) error {
spqrlog.Zero.Debug().Msg("memqdb: drop all key ranges")
q.mu.RLock()

// Do not allow new locks on key range we want to delete
q.muDeletedKrs.Lock()
ids := make([]string, 0)
for id := range q.Locks {
if q.deletedKrs[id] {
q.muDeletedKrs.Unlock()
q.mu.RUnlock()
return spqrerror.Newf(spqrerror.SPQR_KEYRANGE_ERROR, "key range '%s' already deleted", id)
}
ids = append(ids, id)
q.deletedKrs[id] = true
}
q.muDeletedKrs.Unlock()
defer func() {
q.muDeletedKrs.Lock()
defer q.muDeletedKrs.Unlock()
spqrlog.Zero.Debug().Msg("delete previous marks")

for _, id := range ids {
delete(q.deletedKrs, id)
}
}()
q.mu.Lock()
defer q.mu.Unlock()

// Wait until key range will be unlocked
var locks []*sync.RWMutex
for _, l := range q.Locks {
l.Lock()
locks = append(locks, l)
}
defer func() {
for _, l := range locks {
l.Unlock()
}
}()
for krId, l := range q.Locks {
if !l.TryLock() {
return spqrerror.Newf(spqrerror.SPQR_KEYRANGE_ERROR, "key range \"%s\" is locked", krId)
}
locks = append(locks, l)
}
spqrlog.Zero.Debug().Msg("memqdb: acquired all locks")

q.mu.RUnlock()

q.mu.Lock()
defer q.mu.Unlock()

return ExecuteCommands(q.DumpState, NewDropCommand(q.Krs), NewDropCommand(q.Locks))
}

Expand Down Expand Up @@ -294,22 +259,18 @@ func (q *MemQDB) ListAllKeyRanges(_ context.Context) ([]*KeyRange, error) {

// TODO : unit tests
func (q *MemQDB) TryLockKeyRange(lock *sync.RWMutex, id string, read bool) error {
q.muDeletedKrs.RLock()

if _, ok := q.deletedKrs[id]; ok {
q.muDeletedKrs.RUnlock()
return spqrerror.Newf(spqrerror.SPQR_KEYRANGE_ERROR, "key range '%s' deleted", id)
}
q.muDeletedKrs.RUnlock()

res := false
if read {
lock.RLock()
res = lock.TryRLock()
} else {
lock.Lock()
res = lock.TryLock()
}
if !res {
return spqrerror.Newf(spqrerror.SPQR_KEYRANGE_ERROR, "key range \"%s\" is locked", id)
}

if _, ok := q.Krs[id]; !ok {
return spqrerror.Newf(spqrerror.SPQR_KEYRANGE_ERROR, "key range '%s' deleted after lock acuired", id)
return spqrerror.Newf(spqrerror.SPQR_KEYRANGE_ERROR, "key range '%s' deleted after lock acquired", id)
}
return nil
}
Expand Down Expand Up @@ -414,14 +375,14 @@ func (q *MemQDB) ShareKeyRange(id string) error {
// ==============================================================================

// TODO : unit tests
func (q *MemQDB) RecordTransferTx(ctx context.Context, key string, info *DataTransferTransaction) error {
func (q *MemQDB) RecordTransferTx(_ context.Context, key string, info *DataTransferTransaction) error {
q.mu.Lock()
defer q.mu.Unlock()
return ExecuteCommands(q.DumpState, NewUpdateCommand(q.Transactions, key, info))
}

// TODO : unit tests
func (q *MemQDB) GetTransferTx(ctx context.Context, key string) (*DataTransferTransaction, error) {
func (q *MemQDB) GetTransferTx(_ context.Context, key string) (*DataTransferTransaction, error) {
q.mu.RLock()
defer q.mu.RUnlock()

Expand All @@ -433,7 +394,7 @@ func (q *MemQDB) GetTransferTx(ctx context.Context, key string) (*DataTransferTr
}

// TODO : unit tests
func (q *MemQDB) RemoveTransferTx(ctx context.Context, key string) error {
func (q *MemQDB) RemoveTransferTx(_ context.Context, key string) error {
q.mu.Lock()
defer q.mu.Unlock()
return ExecuteCommands(q.DumpState, NewDeleteCommand(q.Transactions, key))
Expand All @@ -443,12 +404,12 @@ func (q *MemQDB) RemoveTransferTx(ctx context.Context, key string) error {
// COORDINATOR LOCK
// ==============================================================================

func (q *MemQDB) TryCoordinatorLock(ctx context.Context) error {
func (q *MemQDB) TryCoordinatorLock(_ context.Context) error {
return nil
}

// TODO : unit tests
func (q *MemQDB) UpdateCoordinator(ctx context.Context, address string) error {
func (q *MemQDB) UpdateCoordinator(_ context.Context, address string) error {
spqrlog.Zero.Debug().Str("address", address).Msg("memqdb: update coordinator address")

q.mu.Lock()
Expand All @@ -468,7 +429,7 @@ func (q *MemQDB) GetCoordinator(ctx context.Context) (string, error) {
// ==============================================================================

// TODO : unit tests
func (q *MemQDB) AddRouter(ctx context.Context, r *Router) error {
func (q *MemQDB) AddRouter(_ context.Context, r *Router) error {
spqrlog.Zero.Debug().Interface("router", r).Msg("memqdb: add router")
q.mu.Lock()
defer q.mu.Unlock()
Expand All @@ -477,7 +438,7 @@ func (q *MemQDB) AddRouter(ctx context.Context, r *Router) error {
}

// TODO : unit tests
func (q *MemQDB) DeleteRouter(ctx context.Context, id string) error {
func (q *MemQDB) DeleteRouter(_ context.Context, id string) error {
spqrlog.Zero.Debug().Str("router", id).Msg("memqdb: delete router")
q.mu.Lock()
defer q.mu.Unlock()
Expand All @@ -486,7 +447,7 @@ func (q *MemQDB) DeleteRouter(ctx context.Context, id string) error {
}

// TODO : unit tests
func (q *MemQDB) OpenRouter(ctx context.Context, id string) error {
func (q *MemQDB) OpenRouter(_ context.Context, id string) error {
spqrlog.Zero.Debug().
Str("router", id).
Msg("memqdb: open router")
Expand All @@ -499,7 +460,7 @@ func (q *MemQDB) OpenRouter(ctx context.Context, id string) error {
}

// TODO : unit tests
func (q *MemQDB) CloseRouter(ctx context.Context, id string) error {
func (q *MemQDB) CloseRouter(_ context.Context, id string) error {
spqrlog.Zero.Debug().
Str("router", id).
Msg("memqdb: open router")
Expand All @@ -512,7 +473,7 @@ func (q *MemQDB) CloseRouter(ctx context.Context, id string) error {
}

// TODO : unit tests
func (q *MemQDB) ListRouters(ctx context.Context) ([]*Router, error) {
func (q *MemQDB) ListRouters(_ context.Context) ([]*Router, error) {
spqrlog.Zero.Debug().Msg("memqdb: list routers")
q.mu.RLock()
defer q.mu.RUnlock()
Expand All @@ -535,7 +496,7 @@ func (q *MemQDB) ListRouters(ctx context.Context) ([]*Router, error) {
// ==============================================================================

// TODO : unit tests
func (q *MemQDB) AddShard(ctx context.Context, shard *Shard) error {
func (q *MemQDB) AddShard(_ context.Context, shard *Shard) error {
spqrlog.Zero.Debug().Interface("shard", shard).Msg("memqdb: add shard")
q.mu.Lock()
defer q.mu.Unlock()
Expand All @@ -544,7 +505,7 @@ func (q *MemQDB) AddShard(ctx context.Context, shard *Shard) error {
}

// TODO : unit tests
func (q *MemQDB) ListShards(ctx context.Context) ([]*Shard, error) {
func (q *MemQDB) ListShards(_ context.Context) ([]*Shard, error) {
spqrlog.Zero.Debug().Msg("memqdb: list shards")
q.mu.RLock()
defer q.mu.RUnlock()
Expand All @@ -563,7 +524,7 @@ func (q *MemQDB) ListShards(ctx context.Context) ([]*Shard, error) {
}

// TODO : unit tests
func (q *MemQDB) GetShard(ctx context.Context, id string) (*Shard, error) {
func (q *MemQDB) GetShard(_ context.Context, id string) (*Shard, error) {
spqrlog.Zero.Debug().Str("shard", id).Msg("memqdb: get shard")
q.mu.RLock()
defer q.mu.RUnlock()
Expand All @@ -590,7 +551,7 @@ func (q *MemQDB) DropShard(_ context.Context, id string) error {
// ==============================================================================

// TODO : unit tests
func (q *MemQDB) CreateDistribution(ctx context.Context, distribution *Distribution) error {
func (q *MemQDB) CreateDistribution(_ context.Context, distribution *Distribution) error {
spqrlog.Zero.Debug().Interface("distribution", distribution).Msg("memqdb: add distribution")
q.mu.Lock()
defer q.mu.Unlock()
Expand All @@ -604,7 +565,7 @@ func (q *MemQDB) CreateDistribution(ctx context.Context, distribution *Distribut
}

// TODO : unit tests
func (q *MemQDB) ListDistributions(ctx context.Context) ([]*Distribution, error) {
func (q *MemQDB) ListDistributions(_ context.Context) ([]*Distribution, error) {
spqrlog.Zero.Debug().Msg("memqdb: list distributions")
q.mu.RLock()
defer q.mu.RUnlock()
Expand All @@ -621,7 +582,7 @@ func (q *MemQDB) ListDistributions(ctx context.Context) ([]*Distribution, error)
}

// TODO : unit tests
func (q *MemQDB) DropDistribution(ctx context.Context, id string) error {
func (q *MemQDB) DropDistribution(_ context.Context, id string) error {
spqrlog.Zero.Debug().Str("distribution", id).Msg("memqdb: delete distribution")
q.mu.Lock()
defer q.mu.Unlock()
Expand Down Expand Up @@ -689,7 +650,7 @@ func (q *MemQDB) AlterDistributionDetach(_ context.Context, id string, relName s
}

// TODO : unit tests
func (q *MemQDB) GetDistribution(ctx context.Context, id string) (*Distribution, error) {
func (q *MemQDB) GetDistribution(_ context.Context, id string) (*Distribution, error) {
spqrlog.Zero.Debug().Msg("memqdb: get distribution")
q.mu.RLock()
defer q.mu.RUnlock()
Expand Down Expand Up @@ -742,7 +703,7 @@ func (q *MemQDB) WriteTaskGroup(_ context.Context, group *TaskGroup) error {
return nil
}

func (q *MemQDB) RemoveTaskGroup(ctx context.Context) error {
func (q *MemQDB) RemoveTaskGroup(_ context.Context) error {
spqrlog.Zero.Debug().Msg("memqdb: remove task group")
q.mu.Lock()
defer q.mu.Unlock()
Expand Down
2 changes: 1 addition & 1 deletion qdb/memqdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestMemqdbRacing(t *testing.T) {
func() { _ = memqdb.UpdateKeyRange(ctx, mockKeyRange) },
func() { _ = memqdb.DeleteRouter(ctx, mockRouter.ID) },
}
for i := 0; i < 10; i++ {
for i := 0; i < 1000; i++ {
for _, m := range methods {
wg.Add(1)
go func(m func()) {
Expand Down
2 changes: 1 addition & 1 deletion test/feature/features/coordinator.feature
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ Feature: Coordinator test
"""
Then SQL error on host "router" should match regexp
"""
context deadline exceeded
key range .* is locked
"""

Given I run SQL on host "coordinator"
Expand Down
2 changes: 1 addition & 1 deletion test/feature/features/proxy_console.feature
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Feature: Proxy console
"""
Then SQL error on host "router2" should match regexp
"""
context deadline exceeded
key range .* is locked
"""

When I run SQL on host "router-admin"
Expand Down

0 comments on commit 76ea934

Please sign in to comment.