From 5127a2c95c391003f712396c67b0f8b397fa7ef5 Mon Sep 17 00:00:00 2001 From: Anurag Bandyopadhyay Date: Sat, 1 Feb 2025 23:22:30 +0530 Subject: [PATCH] feat: Add Lua Locking for Redis < 7 Compatibility (#731) * feat: Add Lua Locking for Redis < 7 Compatibility * feat: address comment * feat: add LuaLock flag to clientOption * feat: add tests for client with lua lock --------- Co-authored-by: Anuragkillswitch <70265851+Anuragkillswitch@users.noreply.github.com> --- rueidisaside/aside.go | 38 ++++-- rueidisaside/aside_test.go | 238 +++++++++++++++++++++++++++++++++++++ 2 files changed, 264 insertions(+), 12 deletions(-) diff --git a/rueidisaside/aside.go b/rueidisaside/aside.go index e3efd04d..b276f27b 100644 --- a/rueidisaside/aside.go +++ b/rueidisaside/aside.go @@ -19,6 +19,7 @@ type ClientOption struct { ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error) ClientOption rueidis.ClientOption ClientTTL time.Duration // TTL for the client marker, refreshed every 1/2 TTL. Defaults to 10s. The marker allows other client to know if this client is still alive. + UseLuaLock bool } type CacheAsideClient interface { @@ -33,8 +34,9 @@ func NewClient(option ClientOption) (cc CacheAsideClient, err error) { option.ClientTTL = 10 * time.Second } ca := &Client{ - waits: make(map[string]chan struct{}), - ttl: option.ClientTTL, + waits: make(map[string]chan struct{}), + ttl: option.ClientTTL, + useLuaLock: option.UseLuaLock, } option.ClientOption.OnInvalidations = ca.onInvalidation if option.ClientBuilder != nil { @@ -50,13 +52,14 @@ func NewClient(option ClientOption) (cc CacheAsideClient, err error) { } type Client struct { - client rueidis.Client - ctx context.Context - waits map[string]chan struct{} - cancel context.CancelFunc - id string - ttl time.Duration - mu sync.Mutex + client rueidis.Client + ctx context.Context + waits map[string]chan struct{} + cancel context.CancelFunc + id string + ttl time.Duration + mu sync.Mutex + useLuaLock bool } func (c *Client) onInvalidation(messages []rueidis.RedisMessage) { @@ -144,14 +147,21 @@ func randStr() string { func (c *Client) Get(ctx context.Context, ttl time.Duration, key string, fn func(ctx context.Context, key string) (val string, err error)) (string, error) { ctx, cancel := context.WithTimeout(ctx, ttl) defer cancel() + retry: wait := c.register(key) resp := c.client.DoCache(ctx, c.client.B().Get().Key(key).Cache(), ttl) val, err := resp.ToString() + if rueidis.IsRedisNil(err) && fn != nil { // cache miss, prepare to populate the value by fn() var id string if id, err = c.keepalive(); err == nil { // acquire client id - val, err = c.client.Do(ctx, c.client.B().Set().Key(key).Value(id).Nx().Get().Px(ttl).Build()).ToString() + if c.useLuaLock { + val, err = acquireLock.Exec(ctx, c.client, []string{key}, []string{id, strconv.FormatInt(ttl.Milliseconds(), 10)}).ToString() + } else { + val, err = c.client.Do(ctx, c.client.B().Set().Key(key).Value(id).Nx().Get().Px(ttl).Build()).ToString() + } + if rueidis.IsRedisNil(err) { // successfully set client id on the key as a lock if val, err = fn(ctx, key); err == nil { err = setkey.Exec(ctx, c.client, []string{key}, []string{id, val, strconv.FormatInt(ttl.Milliseconds(), 10)}).Error() @@ -162,9 +172,11 @@ retry: } } } + if err != nil { return val, err } + if strings.HasPrefix(val, PlaceholderPrefix) { ph := c.register(val) err = c.client.DoCache(ctx, c.client.B().Get().Key(val).Cache(), c.ttl).Error() @@ -184,6 +196,7 @@ retry: goto retry } } + return val, err } @@ -210,6 +223,7 @@ func (c *Client) Close() { const PlaceholderPrefix = "rueidisid:" var ( - delkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`) - setkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`) + delkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`) + setkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`) + acquireLock = rueidis.NewLuaScript(`if redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2]) then return nil else return redis.call("GET", KEYS[1]) end`) ) diff --git a/rueidisaside/aside_test.go b/rueidisaside/aside_test.go index 23784e4c..317b5f31 100644 --- a/rueidisaside/aside_test.go +++ b/rueidisaside/aside_test.go @@ -25,6 +25,18 @@ func makeClient(t *testing.T, addr []string) CacheAsideClient { return client } +func makeClientWithLuaLock(t *testing.T, addr []string) CacheAsideClient { + client, err := NewClient(ClientOption{ + UseLuaLock: true, + ClientOption: rueidis.ClientOption{InitAddress: addr, PipelineMultiplex: -1}, + ClientTTL: time.Second, + }) + if err != nil { + t.Fatal(err) + } + return client +} + func TestClientErr(t *testing.T) { if _, err := NewClient(ClientOption{}); err == nil { t.Error(err) @@ -72,6 +84,29 @@ func TestCacheFilled(t *testing.T) { } } +func TestCacheFilledLL(t *testing.T) { + client := makeClientWithLuaLock(t, addr) + defer client.Close() + key := strconv.Itoa(rand.Int()) + for i := 0; i < 2; i++ { + val, err := client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) { + return "1", nil + }) + if err != nil || val != "1" { + t.Fatal(err) + } + val, err = client.Get(context.Background(), time.Millisecond*500, key, nil) + if err != nil || val != "1" { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 600) + val, err = client.Get(context.Background(), time.Millisecond*500, key, nil) // should miss + if !rueidis.IsRedisNil(err) { + t.Fatal(err) + } + } +} + func TestCacheDel(t *testing.T) { client := makeClient(t, addr) defer client.Close() @@ -98,6 +133,32 @@ func TestCacheDel(t *testing.T) { } } +func TestCacheDelLL(t *testing.T) { + client := makeClientWithLuaLock(t, addr) + defer client.Close() + key := strconv.Itoa(rand.Int()) + for i := 0; i < 2; i++ { + val, err := client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) { + return "1", nil + }) + if err != nil || val != "1" { + t.Fatal(err) + } + val, err = client.Get(context.Background(), time.Millisecond*500, key, nil) + if err != nil || val != "1" { + t.Fatal(err) + } + if err = client.Del(context.Background(), key); err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 50) + val, err = client.Get(context.Background(), time.Millisecond*500, key, nil) // should miss + if !rueidis.IsRedisNil(err) { + t.Fatal(err) + } + } +} + func TestClientRefresh(t *testing.T) { client := makeClient(t, addr).(*Client) defer client.Close() @@ -118,6 +179,26 @@ func TestClientRefresh(t *testing.T) { }) } +func TestClientRefreshLL(t *testing.T) { + client := makeClientWithLuaLock(t, addr).(*Client) + defer client.Close() + key := strconv.Itoa(rand.Int()) + _, _ = client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) { + id, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString() + if err != nil { + t.Error(err) + } + for i := 0; i < 2; i++ { + err = client.client.Do(context.Background(), client.client.B().Get().Key(id).Build()).Error() + if err != nil { + t.Error(err) + } + time.Sleep(client.ttl) + } + return "1", nil + }) +} + func TestCloseCleanup(t *testing.T) { client := makeClient(t, addr).(*Client) key := strconv.Itoa(rand.Int()) @@ -143,6 +224,31 @@ func TestCloseCleanup(t *testing.T) { } } +func TestCloseCleanupLL(t *testing.T) { + client := makeClientWithLuaLock(t, addr).(*Client) + key := strconv.Itoa(rand.Int()) + ch := make(chan string, 1) + _, _ = client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) { + id, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString() + if err != nil { + t.Error(err) + } + err = client.client.Do(context.Background(), client.client.B().Get().Key(id).Build()).Error() + if err != nil { + t.Error(err) + } + ch <- id + return "1", nil + }) + client.Close() + client = makeClient(t, addr).(*Client) + defer client.Close() + err := client.client.Do(context.Background(), client.client.B().Get().Key(<-ch).Build()).Error() + if !rueidis.IsRedisNil(err) { + t.Error(err) + } +} + func TestWriteCancel(t *testing.T) { client := makeClient(t, addr).(*Client) defer client.Close() @@ -170,6 +276,33 @@ func TestWriteCancel(t *testing.T) { } } +func TestWriteCancelLL(t *testing.T) { + client := makeClientWithLuaLock(t, addr).(*Client) + defer client.Close() + key := strconv.Itoa(rand.Int()) + ch := make(chan string, 1) + ctx, cancel := context.WithCancel(context.Background()) + val, err := client.Get(ctx, time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) { + id, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString() + if err != nil { + t.Error(err) + } + cancel() + ch <- id + return "1", nil + }) + if val != "1" { + t.Fatal(err) + } + if err != context.Canceled { + t.Fatal(err) + } + err = client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).Error() + if !rueidis.IsRedisNil(err) { + t.Error(err) + } +} + func TestTimeout(t *testing.T) { client := makeClient(t, addr).(*Client) defer client.Close() @@ -188,6 +321,24 @@ func TestTimeout(t *testing.T) { } } +func TestTimeoutLL(t *testing.T) { + client := makeClientWithLuaLock(t, addr).(*Client) + defer client.Close() + key := strconv.Itoa(rand.Int()) + _, err := client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) { + _, err = client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) { + return "1", nil + }) + if err != context.DeadlineExceeded { + t.Error(err) + } + return "", err + }) + if err != context.DeadlineExceeded { + t.Fatal(err) + } +} + func TestDisconnect(t *testing.T) { client := makeClient(t, addr).(*Client) defer client.Close() @@ -238,6 +389,56 @@ func TestDisconnect(t *testing.T) { time.Sleep(client.ttl) // wait old refresh goroutine exit } +func TestDisconnectLL(t *testing.T) { + client := makeClientWithLuaLock(t, addr).(*Client) + defer client.Close() + key := strconv.Itoa(rand.Int()) + ch := make(chan string, 2) + val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) { + id1, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString() + if err != nil { + t.Error(err) + } + go func() { + val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) { + id2, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString() + if err != nil { + t.Error(err) + } + ch <- id2 + return "2", nil + }) + if val != "2" { + t.Error(err) + } + }() + client.onInvalidation(nil) // simulate disconnection + id2 := <-ch + if id1 == id2 { + t.Error("id not changed") + } + ch <- id1 + ch <- id2 + return "1", nil + }) + if val != "1" { + t.Fatal(err) + } + val, err = client.Get(context.Background(), time.Millisecond*500, key, nil) + if val != "2" { + t.Error(err) + } + err = client.client.Do(context.Background(), client.client.B().Get().Key(<-ch).Build()).Error() // id1 + if !rueidis.IsRedisNil(err) { + t.Error(err) + } + err = client.client.Do(context.Background(), client.client.B().Get().Key(<-ch).Build()).Error() // id2 + if err != nil { + t.Error(err) + } + time.Sleep(client.ttl) // wait old refresh goroutine exit +} + func TestMultipleClient(t *testing.T) { clients := make([]CacheAsideClient, 10) for i := 0; i < len(clients); i++ { @@ -274,3 +475,40 @@ func TestMultipleClient(t *testing.T) { } } } + +func TestMultipleClientLL(t *testing.T) { + clients := make([]CacheAsideClient, 10) + for i := 0; i < len(clients); i++ { + clients[i] = makeClientWithLuaLock(t, addr) + } + defer func() { + for _, client := range clients { + client.Close() + } + }() + cnt := 1000 + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(len(clients)) + key := strconv.Itoa(rand.Int()) + sum := int64(0) + for i, c := range clients { + go func(i int, c CacheAsideClient) { + defer wg.Done() + for j := 0; j < cnt; j++ { + v, err := c.Get(context.Background(), time.Second, key, func(ctx context.Context, key string) (val string, err error) { + atomic.AddInt64(&sum, 1) + return "1", nil + }) + if err != nil || v != "1" { + t.Error(err) + } + } + }(i, c) + } + wg.Wait() + if atomic.LoadInt64(&sum) != 1 { + t.Fatalf("unexpected sum") + } + } +}