diff --git a/Makefile b/Makefile index 717d043df..78c719245 100644 --- a/Makefile +++ b/Makefile @@ -68,7 +68,7 @@ pooler_run: ####################### TESTS ####################### unittest: - go test ./cmd/... ./pkg/... ./router/... ./qdb/... ./coordinator/... + go test -race ./cmd/... ./pkg/... ./router/... ./qdb/... ./coordinator/... regress_local: proxy_2sh_run ./script/regress_local.sh diff --git a/pkg/coord/local/clocal.go b/pkg/coord/local/clocal.go index 28c080355..e6b9db1ad 100644 --- a/pkg/coord/local/clocal.go +++ b/pkg/coord/local/clocal.go @@ -202,16 +202,9 @@ func (qr *LocalCoordinator) Unite(ctx context.Context, req *kr.UniteKeyRange) er }(qr.qdb, ctx, req.KeyRangeIDLeft) // TODO: krRight seems to be empty. - if krright, err = qr.qdb.LockKeyRange(ctx, req.KeyRangeIDRight); err != nil { + if krright, err = qr.qdb.GetKeyRange(ctx, req.KeyRangeIDRight); err != nil { return err } - defer func(qdb qdb.QDB, ctx context.Context, keyRangeID string) { - err := qdb.UnlockKeyRange(ctx, keyRangeID) - if err != nil { - spqrlog.Zero.Error().Err(err).Msg("") - return - } - }(qr.qdb, ctx, req.KeyRangeIDRight) if err = qr.qdb.DropKeyRange(ctx, krright.KeyRangeID); err != nil { return err diff --git a/qdb/command.go b/qdb/command.go index 3a073306c..df855174e 100644 --- a/qdb/command.go +++ b/qdb/command.go @@ -1,8 +1,14 @@ package qdb +import ( + "fmt" + + "github.com/pg-sharding/spqr/pkg/spqrlog" +) + type Command interface { - Do() - Undo() + Do() error + Undo() error } func NewDeleteCommand[T any](m map[string]T, key string) *DeleteCommand[T] { @@ -16,17 +22,19 @@ type DeleteCommand[T any] struct { present bool } -func (c *DeleteCommand[T]) Do() { +func (c *DeleteCommand[T]) Do() error { c.value, c.present = c.m[c.key] delete(c.m, c.key) + return nil } -func (c *DeleteCommand[T]) Undo() { +func (c *DeleteCommand[T]) Undo() error { if !c.present { delete(c.m, c.key) } else { c.m[c.key] = c.value } + return nil } func NewUpdateCommand[T any](m map[string]T, key string, value T) *UpdateCommand[T] { @@ -41,17 +49,19 @@ type UpdateCommand[T any] struct { present bool } -func (c *UpdateCommand[T]) Do() { +func (c *UpdateCommand[T]) Do() error { c.prevValue, c.present = c.m[c.key] c.m[c.key] = c.value + return nil } -func (c *UpdateCommand[T]) Undo() { +func (c *UpdateCommand[T]) Undo() error { if !c.present { delete(c.m, c.key) } else { c.m[c.key] = c.prevValue } + return nil } func NewDropCommand[T any](m map[string]T) *DropCommand[T] { @@ -63,7 +73,7 @@ type DropCommand[T any] struct { copy map[string]T } -func (c *DropCommand[T]) Do() { +func (c *DropCommand[T]) Do() error { c.copy = make(map[string]T) for k, v := range c.m { c.copy[k] = v @@ -71,39 +81,63 @@ func (c *DropCommand[T]) Do() { for k := range c.m { delete(c.m, k) } + return nil } -func (c *DropCommand[T]) Undo() { +func (c *DropCommand[T]) Undo() error { for k, v := range c.copy { c.m[k] = v } + return nil } -func NewCustomCommand(do func(), undo func()) *CustomCommand { +func NewCustomCommand(do func() error, undo func() error) *CustomCommand { return &CustomCommand{do: do, undo: undo} } type CustomCommand struct { - do func() - undo func() + do func() error + undo func() error } -func (c *CustomCommand) Do() { - c.do() +func (c *CustomCommand) Do() error { + return c.do() } -func (c *CustomCommand) Undo() { - c.undo() +func (c *CustomCommand) Undo() error { + return c.undo() } -func ExecuteCommands(saver func() error, commands ...Command) error { +func doCommands(commands ...Command) (int, error) { + for i, c := range commands { + err := c.Do() + if err != nil { + return i, err + } + } + return len(commands), nil +} + +func undoCommands(commands ...Command) error { + spqrlog.Zero.Info().Msg("memqdb: undo commands") for _, c := range commands { - c.Do() + err := c.Undo() + if err != nil { + return err + } + } + return nil +} + +func ExecuteCommands(saver func() error, commands ...Command) error { + completed, err := doCommands(commands...) + if err == nil { + err = saver() } - err := saver() if err != nil { - for _, c := range commands { - c.Undo() + undoErr := undoCommands(commands[:completed]...) + if undoErr != nil { + return fmt.Errorf("failed to undo command %s while: %s", undoErr.Error(), err.Error()) } return err } diff --git a/qdb/memqdb.go b/qdb/memqdb.go index 6720b6803..e38c982c8 100644 --- a/qdb/memqdb.go +++ b/qdb/memqdb.go @@ -12,8 +12,11 @@ import ( ) type MemQDB struct { - mu sync.RWMutex + // TODO create more mutex per map if needed + mu sync.RWMutex + muDeletedKrs sync.RWMutex + deletedKrs map[string]bool Locks map[string]*sync.RWMutex `json:"locks"` Freq map[string]bool `json:"freq"` Krs map[string]*KeyRange `json:"krs"` @@ -39,6 +42,7 @@ func NewMemQDB(backupPath string) (*MemQDB, error) { Dataspaces: map[string]*Dataspace{}, Routers: map[string]*Router{}, Transactions: map[string]*DataTransferTransaction{}, + deletedKrs: map[string]bool{}, backupPath: backupPath, }, nil @@ -154,6 +158,8 @@ func (q *MemQDB) DropShardingRuleAll(ctx context.Context) ([]*ShardingRule, erro func (q *MemQDB) GetShardingRule(ctx context.Context, id string) (*ShardingRule, error) { spqrlog.Zero.Debug().Str("rule", id).Msg("memqdb: get sharding rule") + q.mu.RLock() + defer q.mu.RUnlock() rule, ok := q.Shrules[id] if ok { return rule, nil @@ -200,8 +206,8 @@ func (q *MemQDB) AddKeyRange(ctx context.Context, keyRange *KeyRange) error { func (q *MemQDB) GetKeyRange(ctx context.Context, id string) (*KeyRange, error) { spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: get key range") - q.mu.Lock() - defer q.mu.Unlock() + q.mu.RLock() + defer q.mu.RUnlock() krs, ok := q.Krs[id] if !ok { @@ -221,6 +227,31 @@ func (q *MemQDB) UpdateKeyRange(ctx context.Context, keyRange *KeyRange) error { func (q *MemQDB) DropKeyRange(ctx context.Context, id string) error { spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: drop key range") + + // Do not allow new locks on key range we want to delete + q.muDeletedKrs.Lock() + if q.deletedKrs[id] { + q.muDeletedKrs.Unlock() + return fmt.Errorf("key range '%s' already deleted", id) + } + q.deletedKrs[id] = true + q.muDeletedKrs.Unlock() + + defer func() { + q.muDeletedKrs.Lock() + defer q.muDeletedKrs.Unlock() + delete(q.deletedKrs, id) + }() + + q.mu.RLock() + + // Wait until key range will be unlocked + if lock, ok := q.Locks[id]; ok { + lock.Lock() + defer lock.Unlock() + } + q.mu.RUnlock() + q.mu.Lock() defer q.mu.Unlock() @@ -230,27 +261,50 @@ func (q *MemQDB) DropKeyRange(ctx context.Context, id string) error { func (q *MemQDB) DropKeyRangeAll(ctx context.Context) error { spqrlog.Zero.Debug().Msg("memqdb: drop all key ranges") - q.mu.Lock() - defer q.mu.Unlock() + q.mu.RLock() + + // Do not allow new locks on key range we want to delete + q.muDeletedKrs.Lock() + ids := make([]string, 0) + for id := range q.Locks { + if q.deletedKrs[id] { + q.muDeletedKrs.Unlock() + q.mu.RUnlock() + return fmt.Errorf("key range '%s' already deleted", id) + } + ids = append(ids, id) + q.deletedKrs[id] = true + } + q.muDeletedKrs.Unlock() + defer func() { + q.muDeletedKrs.Lock() + defer q.muDeletedKrs.Unlock() + spqrlog.Zero.Debug().Msg("delete previous marks") + + for _, id := range ids { + delete(q.deletedKrs, id) + } + }() + // Wait until key range will be unlocked var locks []*sync.RWMutex + for _, l := range q.Locks { + l.Lock() + locks = append(locks, l) + } + defer func() { + for _, l := range locks { + l.Unlock() + } + }() + spqrlog.Zero.Debug().Msg("memqdb: acquired all locks") - return ExecuteCommands(q.DumpState, - NewCustomCommand(func() { - for _, l := range q.Locks { - l.Lock() - locks = append(locks, l) - } - spqrlog.Zero.Debug().Msg("memqdb: acquired all locks") - }, func() {}), - NewDropCommand(q.Krs), NewDropCommand(q.Locks), - NewCustomCommand(func() { - for _, l := range locks { - l.Unlock() - } - }, - func() {}), - ) + q.mu.RUnlock() + + q.mu.Lock() + defer q.mu.Unlock() + + return ExecuteCommands(q.DumpState, NewDropCommand(q.Krs), NewDropCommand(q.Locks)) } func (q *MemQDB) ListKeyRanges(_ context.Context) ([]*KeyRange, error) { @@ -271,10 +325,32 @@ func (q *MemQDB) ListKeyRanges(_ context.Context) ([]*KeyRange, error) { return ret, nil } +func (q *MemQDB) TryLockKeyRange(lock *sync.RWMutex, id string, read bool) error { + q.muDeletedKrs.RLock() + + if _, ok := q.deletedKrs[id]; ok { + q.muDeletedKrs.RUnlock() + return fmt.Errorf("key range '%s' deleted", id) + } + q.muDeletedKrs.RUnlock() + + if read { + lock.RLock() + } else { + lock.Lock() + } + + if _, ok := q.Krs[id]; !ok { + return fmt.Errorf("key range '%s' deleted after lock acuired", id) + } + return nil +} + func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) { spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: lock key range") - q.mu.Lock() - defer q.mu.Unlock() + q.mu.RLock() + defer q.mu.RUnlock() + defer spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: exit: lock key range") krs, ok := q.Krs[id] if !ok { @@ -282,7 +358,17 @@ func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) { } err := ExecuteCommands(q.DumpState, NewUpdateCommand(q.Freq, id, true), - NewCustomCommand(q.Locks[id].Lock, q.Locks[id].Unlock)) + NewCustomCommand(func() error { + if lock, ok := q.Locks[id]; ok { + return q.TryLockKeyRange(lock, id, false) + } + return nil + }, func() error { + if lock, ok := q.Locks[id]; ok { + lock.Unlock() + } + return nil + })) if err != nil { return nil, err } @@ -291,22 +377,33 @@ func (q *MemQDB) LockKeyRange(_ context.Context, id string) (*KeyRange, error) { } func (q *MemQDB) UnlockKeyRange(_ context.Context, id string) error { - spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: lock key range") - q.mu.Lock() - defer q.mu.Unlock() + spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: unlock key range") + q.mu.RLock() + defer q.mu.RUnlock() + defer spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: exit: unlock key range") if !q.Freq[id] { return fmt.Errorf("key range %v not locked", id) } return ExecuteCommands(q.DumpState, NewUpdateCommand(q.Freq, id, false), - NewCustomCommand(q.Locks[id].Unlock, q.Locks[id].Lock)) + NewCustomCommand(func() error { + if lock, ok := q.Locks[id]; ok { + lock.Unlock() + } + return nil + }, func() error { + if lock, ok := q.Locks[id]; ok { + return q.TryLockKeyRange(lock, id, false) + } + return nil + })) } func (q *MemQDB) CheckLockedKeyRange(ctx context.Context, id string) (*KeyRange, error) { spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: check locked key range") - q.mu.Lock() - defer q.mu.Unlock() + q.mu.RLock() + defer q.mu.RUnlock() krs, ok := q.Krs[id] if !ok { @@ -323,8 +420,19 @@ func (q *MemQDB) CheckLockedKeyRange(ctx context.Context, id string) (*KeyRange, func (q *MemQDB) ShareKeyRange(id string) error { spqrlog.Zero.Debug().Str("key-range", id).Msg("memqdb: sharing key with key") - q.Locks[id].RLock() - defer q.Locks[id].RUnlock() + q.mu.RLock() + defer q.mu.RUnlock() + + lock, ok := q.Locks[id] + if !ok { + return fmt.Errorf("no such key") + } + + err := q.TryLockKeyRange(lock, id, true) + if err != nil { + return err + } + defer lock.RUnlock() return nil } @@ -336,13 +444,12 @@ func (q *MemQDB) ShareKeyRange(id string) error { func (q *MemQDB) RecordTransferTx(ctx context.Context, key string, info *DataTransferTransaction) error { q.mu.Lock() defer q.mu.Unlock() - return ExecuteCommands(q.DumpState, NewUpdateCommand(q.Transactions, key, info)) } func (q *MemQDB) GetTransferTx(ctx context.Context, key string) (*DataTransferTransaction, error) { - q.mu.Lock() - defer q.mu.Unlock() + q.mu.RLock() + defer q.mu.RUnlock() ans, ok := q.Transactions[key] if !ok { @@ -354,7 +461,6 @@ func (q *MemQDB) GetTransferTx(ctx context.Context, key string) (*DataTransferTr func (q *MemQDB) RemoveTransferTx(ctx context.Context, key string) error { q.mu.Lock() defer q.mu.Unlock() - return ExecuteCommands(q.DumpState, NewDeleteCommand(q.Transactions, key)) } @@ -392,8 +498,8 @@ func (q *MemQDB) OpenRouter(ctx context.Context, id string) error { func (q *MemQDB) ListRouters(ctx context.Context) ([]*Router, error) { spqrlog.Zero.Debug().Msg("memqdb: list routers") - q.mu.Lock() - defer q.mu.Unlock() + q.mu.RLock() + defer q.mu.RUnlock() var ret []*Router for _, v := range q.Routers { @@ -426,8 +532,8 @@ func (q *MemQDB) AddShard(ctx context.Context, shard *Shard) error { func (q *MemQDB) ListShards(ctx context.Context) ([]*Shard, error) { spqrlog.Zero.Debug().Msg("memqdb: list shards") - q.mu.Lock() - defer q.mu.Unlock() + q.mu.RLock() + defer q.mu.RUnlock() var ret []*Shard for _, v := range q.Shards { @@ -444,8 +550,8 @@ func (q *MemQDB) ListShards(ctx context.Context) ([]*Shard, error) { func (q *MemQDB) GetShard(ctx context.Context, id string) (*Shard, error) { spqrlog.Zero.Debug().Str("shard", id).Msg("memqdb: get shard") - q.mu.Lock() - defer q.mu.Unlock() + q.mu.RLock() + defer q.mu.RUnlock() if _, ok := q.Shards[id]; ok { return &Shard{ID: id}, nil diff --git a/qdb/memqdb_test.go b/qdb/memqdb_test.go new file mode 100644 index 000000000..5e4451e65 --- /dev/null +++ b/qdb/memqdb_test.go @@ -0,0 +1,102 @@ +package qdb_test + +import ( + "context" + "sync" + "testing" + + "github.com/pg-sharding/spqr/qdb" + "github.com/stretchr/testify/assert" +) + +const MemQDBPath = "" + +var mockDataspace *qdb.Dataspace = &qdb.Dataspace{"123"} +var mockShard *qdb.Shard = &qdb.Shard{ + ID: "shard_id", + Hosts: []string{"host1", "host2"}, +} +var mockKeyRange *qdb.KeyRange = &qdb.KeyRange{ + LowerBound: []byte{1, 2}, + UpperBound: []byte{3, 4}, + ShardID: mockShard.ID, + KeyRangeID: "key_range_id", +} +var mockRouter *qdb.Router = &qdb.Router{ + Address: "address", + ID: "router_id", + State: qdb.CLOSED, +} +var mockShardingRule *qdb.ShardingRule = &qdb.ShardingRule{ + ID: "sharding_rule_id", + TableName: "fake_table", + Entries: []qdb.ShardingRuleEntry{ + { + Column: "i", + }, + }, +} +var mockDataTransferTransaction *qdb.DataTransferTransaction = &qdb.DataTransferTransaction{ + ToShardId: mockShard.ID, + FromShardId: mockShard.ID, + FromTxName: "fake_tx_1", + ToTxName: "fake_tx_2", + FromStatus: "fake_st_1", + ToStatus: "fake_st_2", +} + +// must run with -race +func TestMemqdbRacing(t *testing.T) { + assert := assert.New(t) + + memqdb, err := qdb.RestoreQDB(MemQDBPath) + assert.NoError(err) + + var wg sync.WaitGroup + ctx := context.TODO() + + methods := []func(){ + func() { _ = memqdb.AddDataspace(ctx, mockDataspace) }, + func() { _ = memqdb.AddKeyRange(ctx, mockKeyRange) }, + func() { _ = memqdb.AddRouter(ctx, mockRouter) }, + func() { _ = memqdb.AddShard(ctx, mockShard) }, + func() { _ = memqdb.AddShardingRule(ctx, mockShardingRule) }, + func() { + _ = memqdb.RecordTransferTx(ctx, mockDataTransferTransaction.FromShardId, mockDataTransferTransaction) + }, + func() { _, _ = memqdb.ListDataspaces(ctx) }, + func() { _, _ = memqdb.ListKeyRanges(ctx) }, + func() { _, _ = memqdb.ListRouters(ctx) }, + func() { _, _ = memqdb.ListShardingRules(ctx) }, + func() { _, _ = memqdb.ListShards(ctx) }, + func() { _, _ = memqdb.GetKeyRange(ctx, mockKeyRange.KeyRangeID) }, + func() { _, _ = memqdb.GetShard(ctx, mockShard.ID) }, + func() { _, _ = memqdb.GetShardingRule(ctx, mockShardingRule.ID) }, + func() { _, _ = memqdb.GetTransferTx(ctx, mockDataTransferTransaction.FromShardId) }, + func() { _ = memqdb.ShareKeyRange(mockKeyRange.KeyRangeID) }, + func() { _ = memqdb.DropKeyRange(ctx, mockKeyRange.KeyRangeID) }, + func() { _ = memqdb.DropKeyRangeAll(ctx) }, + func() { _ = memqdb.DropShardingRule(ctx, mockShardingRule.ID) }, + func() { _, _ = memqdb.DropShardingRuleAll(ctx) }, + func() { _ = memqdb.RemoveTransferTx(ctx, mockDataTransferTransaction.FromShardId) }, + func() { _ = memqdb.LockRouter(ctx, mockRouter.ID) }, + func() { + _, _ = memqdb.LockKeyRange(ctx, mockKeyRange.KeyRangeID) + _ = memqdb.UnlockKeyRange(ctx, mockKeyRange.KeyRangeID) + }, + func() { _ = memqdb.UpdateKeyRange(ctx, mockKeyRange) }, + func() { _ = memqdb.DeleteRouter(ctx, mockRouter.ID) }, + } + for i := 0; i < 10; i++ { + for _, m := range methods { + wg.Add(1) + go func(m func()) { + m() + wg.Done() + }(m) + } + wg.Wait() + + } + wg.Wait() +} diff --git a/test/feature/spqr_test.go b/test/feature/spqr_test.go index d62debefb..7f1688e7d 100644 --- a/test/feature/spqr_test.go +++ b/test/feature/spqr_test.go @@ -288,6 +288,7 @@ func (tctx *testContext) queryPostgresql(host string, query string, args interfa tctx.commandRetcode = 0 if err != nil { tctx.commandRetcode = 1 + tctx.commandOutput = err.Error() tctx.sqlUserQueryError.Store(host, err.Error()) } tctx.sqlQueryResult = result