diff --git a/docs/src/content/docs/commands/SET.md b/docs/src/content/docs/commands/SET.md index c65db76c9..2ea5806ad 100644 --- a/docs/src/content/docs/commands/SET.md +++ b/docs/src/content/docs/commands/SET.md @@ -8,7 +8,7 @@ The `SET` command in DiceDB is used to set the value of a key. If the key alread ## Syntax ```bash -SET key value [EX seconds | PX milliseconds | EXAT unix-time-seconds | PXAT unix-time-milliseconds | KEEPTTL] [NX | XX] +SET key value [NX | XX] [GET] [EX seconds | PX milliseconds | EXAT unix-time-seconds | PXAT unix-time-milliseconds | KEEPTTL] ``` ## Parameters @@ -24,6 +24,7 @@ SET key value [EX seconds | PX milliseconds | EXAT unix-time-seconds | PXAT unix | `NX` | Only set the key if it does not already exist. | None | No | | `XX` | Only set the key if it already exists. | None | No | | `KEEPTTL` | Retain the time-to-live associated with the key. | None | No | +| `GET` | Return the value of the key before setting it. | None | No | ## Return values @@ -32,6 +33,8 @@ SET key value [EX seconds | PX milliseconds | EXAT unix-time-seconds | PXAT unix | Command is successful | `OK` | | `NX` or `XX` conditions are not met | `nil` | | Syntax or specified constraints are invalid | error | +| If the `GET` option is provided | The value of the key before setting it or error if value cannot be returned as a string | + ## Behaviour @@ -41,6 +44,7 @@ SET key value [EX seconds | PX milliseconds | EXAT unix-time-seconds | PXAT unix - Using the `EX`, `EXAT`, `PX` or `PXAT` options together with `KEEPTTL` is not allowed and will result in an error. - When provided, `EX` sets the expiry time in seconds and `PX` sets the expiry time in milliseconds. - The `KEEPTTL` option ensures that the key's existing TTL is retained. +- The `GET` option can be used to return the value of the key before setting it. If the key does not exist, `nil` is returned. If the key exists but does not contain a value which can be returned as a string, an error is returned. The set operation is not performed in this case. ## Errors @@ -131,3 +135,27 @@ Trying to set key `foo` with both `EX` and `KEEPTTL` will result in an error 127.0.0.1:7379> SET foo bar EX 10 KEEPTTL (error) ERR syntax error ``` + +### Set with GET option + +```bash +127.0.0.1:7379> set foo bar +OK +127.0.0.1:7379> set foo bazz get +"bar" +``` +### Set with GET option when key does not exist + +```bash +127.0.0.1:7379> set foo bazz get +(nil) +127.0.0.1:7379> get foo +(nil) +``` + +### Set with Get with wrong type of value +```bash +127.0.0.1:7379> sadd foo item1 +(integer) 1 +127.0.0.1:7379> set foo bazz get +(error) WRONGTYPE Operation against a key holding the wrong kind of value \ No newline at end of file diff --git a/integration_tests/commands/http/bloom_test.go b/integration_tests/commands/http/bloom_test.go index f0cae9a12..d6c47e41b 100644 --- a/integration_tests/commands/http/bloom_test.go +++ b/integration_tests/commands/http/bloom_test.go @@ -376,5 +376,10 @@ func TestBFEdgeCasesAndErrors(t *testing.T) { Body: map[string]interface{}{"key": "foo"}, }) }) + exec.FireCommand(HTTPCommand{ + Command: "FLUSHDB", + Body: map[string]interface{}{"values": []interface{}{}}, + }, + ) } } diff --git a/integration_tests/commands/http/set_test.go b/integration_tests/commands/http/set_test.go index f4b3f3781..84c10f8d8 100644 --- a/integration_tests/commands/http/set_test.go +++ b/integration_tests/commands/http/set_test.go @@ -188,6 +188,29 @@ func TestSetWithOptions(t *testing.T) { }, expected: []interface{}{nil, nil, "OK", nil, nil, nil}, }, + { + name: "GET with Existing Value", + commands: []HTTPCommand{ + {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v"}}, + {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "vv", "get": true}}, + }, + expected: []interface{}{"OK", "v"}, + }, + { + name: "GET with Non-Existing Value", + commands: []HTTPCommand{ + {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "vv", "get": true}}, + }, + expected: []interface{}{nil}, + }, + { + name: "GET with wrong type of value", + commands: []HTTPCommand{ + {Command: "SADD", Body: map[string]interface{}{"key": "k", "value": "b"}}, + {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v", "get": true}}, + }, + expected: []interface{}{float64(1), "WRONGTYPE Operation against a key holding the wrong kind of value"}, + }, } for _, tc := range testCases { @@ -207,7 +230,7 @@ func TestWithKeepTTLFlag(t *testing.T) { exec := NewHTTPCommandExecutor() expiryTime := strconv.FormatInt(time.Now().Add(1*time.Minute).UnixMilli(), 10) - testCases := []TestCase { + testCases := []TestCase{ { name: "SET WITH KEEP TTL", commands: []HTTPCommand{ @@ -228,7 +251,7 @@ func TestWithKeepTTLFlag(t *testing.T) { }, { name: "SET WITH KEEPTTL with PX", - commands: []HTTPCommand { + commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v", "px": 2000, "keepttl": true}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, }, @@ -236,7 +259,7 @@ func TestWithKeepTTLFlag(t *testing.T) { }, { name: "SET WITH KEEPTTL with EX", - commands: []HTTPCommand { + commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v", "ex": 3, "keepttl": true}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, }, @@ -244,7 +267,7 @@ func TestWithKeepTTLFlag(t *testing.T) { }, { name: "SET WITH KEEPTTL with NX", - commands: []HTTPCommand { + commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v", "nx": true, "keepttl": true}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, }, @@ -252,7 +275,7 @@ func TestWithKeepTTLFlag(t *testing.T) { }, { name: "SET WITH KEEPTTL with XX", - commands: []HTTPCommand { + commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v", "xx": true, "keepttl": true}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, }, @@ -260,7 +283,7 @@ func TestWithKeepTTLFlag(t *testing.T) { }, { name: "SET WITH KEEPTTL with PXAT", - commands: []HTTPCommand { + commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v", "pxat": expiryTime, "keepttl": true}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, }, @@ -269,7 +292,7 @@ func TestWithKeepTTLFlag(t *testing.T) { { name: "SET WITH KEEPTTL with EXAT", - commands: []HTTPCommand { + commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v", "exat": expiryTime, "keepttl": true}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, }, diff --git a/integration_tests/commands/resp/append_test.go b/integration_tests/commands/resp/append_test.go index fc70851cf..b1b024b96 100644 --- a/integration_tests/commands/resp/append_test.go +++ b/integration_tests/commands/resp/append_test.go @@ -8,6 +8,7 @@ import ( func TestAPPEND(t *testing.T) { conn := getLocalConnection() + FireCommand(conn, "FLUSHDB") defer conn.Close() testCases := []struct { diff --git a/integration_tests/commands/resp/bloom_test.go b/integration_tests/commands/resp/bloom_test.go index bda85c284..53ba8b506 100644 --- a/integration_tests/commands/resp/bloom_test.go +++ b/integration_tests/commands/resp/bloom_test.go @@ -218,5 +218,6 @@ func TestBFEdgeCasesAndErrors(t *testing.T) { FireCommand(conn, cmd) } }) + FireCommand(conn, "FLUSHDB") } } diff --git a/integration_tests/commands/resp/getunwatch_test.go b/integration_tests/commands/resp/getunwatch_test.go index 790bc6fdb..addcfadb7 100644 --- a/integration_tests/commands/resp/getunwatch_test.go +++ b/integration_tests/commands/resp/getunwatch_test.go @@ -14,7 +14,6 @@ import ( const ( getUnwatchKey = "getunwatchkey" - fingerprint = "426696421" ) type getUnwatchTestCase struct { @@ -78,16 +77,17 @@ func TestGETUNWATCH(t *testing.T) { if !ok { t.Errorf("Type assertion to []interface{} failed for value: %v", v) } + fmt.Println(castedValue) assert.Equal(t, 3, len(castedValue)) assert.Equal(t, "GET", castedValue[0]) - assert.Equal(t, fingerprint, castedValue[1]) + assert.Equal(t, "426696421", castedValue[1]) assert.Equal(t, tc.val, castedValue[2]) } } // unsubscribe from updates for _, subscriber := range subscribers { - rp := fireCommandAndGetRESPParser(subscriber, fmt.Sprintf("GET.UNWATCH %s", fingerprint)) + rp := fireCommandAndGetRESPParser(subscriber, fmt.Sprintf("GET.UNWATCH %s", "426696421")) assert.NotNil(t, rp) v, err := rp.DecodeOne() @@ -98,7 +98,6 @@ func TestGETUNWATCH(t *testing.T) { } assert.Equal(t, castedValue, "OK") } - // Test updates are not sent after unsubscribing for _, tc := range getUnwatchTestCases[2:] { res := FireCommand(publisher, fmt.Sprintf("SET %s %s", tc.key, tc.val)) @@ -144,7 +143,7 @@ func TestGETUNWATCHWithSDK(t *testing.T) { firstMsg, err := watch.Watch(context.Background(), "GET", getUnwatchKey) assert.Nil(t, err) assert.Equal(t, firstMsg.Command, "GET") - assert.Equal(t, firstMsg.Fingerprint, fingerprint) + assert.Equal(t, "426696421", firstMsg.Fingerprint) channels[i] = watch.Channel() } @@ -155,13 +154,13 @@ func TestGETUNWATCHWithSDK(t *testing.T) { for _, channel := range channels { v := <-channel assert.Equal(t, "GET", v.Command) // command - assert.Equal(t, fingerprint, v.Fingerprint) // Fingerprint + assert.Equal(t, "426696421", v.Fingerprint) // Fingerprint assert.Equal(t, "check", v.Data.(string)) // data } // unsubscribe from updates for _, subscriber := range subscribers { - err := subscriber.watch.Unwatch(context.Background(), "GET", fingerprint) + err := subscriber.watch.Unwatch(context.Background(), "GET", "426696421") assert.Nil(t, err) } diff --git a/integration_tests/commands/resp/getwatch_test.go b/integration_tests/commands/resp/getwatch_test.go index 739a26819..da5fa9042 100644 --- a/integration_tests/commands/resp/getwatch_test.go +++ b/integration_tests/commands/resp/getwatch_test.go @@ -17,7 +17,9 @@ type WatchSubscriber struct { watch *dicedb.WatchConn } -const getWatchKey = "getwatchkey" +const ( + getWatchKey = "getwatchkey" +) type getWatchTestCase struct { key string @@ -34,7 +36,6 @@ var getWatchTestCases = []getWatchTestCase{ func TestGETWATCH(t *testing.T) { publisher := getLocalConnection() subscribers := []net.Conn{getLocalConnection(), getLocalConnection(), getLocalConnection()} - FireCommand(publisher, fmt.Sprintf("DEL %s", getWatchKey)) defer func() { @@ -103,7 +104,7 @@ func TestGETWATCHWithSDK(t *testing.T) { firstMsg, err := watch.Watch(context.Background(), "GET", getWatchKey) assert.Nil(t, err) assert.Equal(t, firstMsg.Command, "GET") - assert.Equal(t, firstMsg.Fingerprint, "2714318480") + assert.Equal(t, "2714318480", firstMsg.Fingerprint) channels[i] = watch.Channel() } @@ -113,9 +114,9 @@ func TestGETWATCHWithSDK(t *testing.T) { for _, channel := range channels { v := <-channel - assert.Equal(t, "GET", v.Command) // command + assert.Equal(t, "GET", v.Command) // command assert.Equal(t, "2714318480", v.Fingerprint) // Fingerprint - assert.Equal(t, tc.val, v.Data.(string)) // data + assert.Equal(t, tc.val, v.Data.(string)) // data } } } @@ -134,7 +135,7 @@ func TestGETWATCHWithSDK2(t *testing.T) { firstMsg, err := watch.GetWatch(context.Background(), getWatchKey) assert.Nil(t, err) assert.Equal(t, firstMsg.Command, "GET") - assert.Equal(t, firstMsg.Fingerprint, "2714318480") + assert.Equal(t, "2714318480", firstMsg.Fingerprint) channels[i] = watch.Channel() } @@ -144,9 +145,9 @@ func TestGETWATCHWithSDK2(t *testing.T) { for _, channel := range channels { v := <-channel - assert.Equal(t, "GET", v.Command) // command + assert.Equal(t, "GET", v.Command) // command assert.Equal(t, "2714318480", v.Fingerprint) // Fingerprint - assert.Equal(t, tc.val, v.Data.(string)) // data + assert.Equal(t, tc.val, v.Data.(string)) // data } } } diff --git a/integration_tests/commands/resp/set_test.go b/integration_tests/commands/resp/set_test.go index e5ddb036c..781e4d347 100644 --- a/integration_tests/commands/resp/set_test.go +++ b/integration_tests/commands/resp/set_test.go @@ -120,6 +120,21 @@ func TestSetWithOptions(t *testing.T) { commands: []string{"SET k v XX EX 1", "GET k", "SLEEP 2", "GET k", "SET k v XX EX 1", "GET k"}, expected: []interface{}{"(nil)", "(nil)", "OK", "(nil)", "(nil)", "(nil)"}, }, + { + name: "GET with Existing Value", + commands: []string{"SET k v", "SET k vv GET"}, + expected: []interface{}{"OK", "v"}, + }, + { + name: "GET with Non-Existing Value", + commands: []string{"SET k vv GET"}, + expected: []interface{}{"(nil)"}, + }, + { + name: "GET with wrong type of value", + commands: []string{"sadd k v", "SET k vv GET"}, + expected: []interface{}{int64(1), "WRONGTYPE Operation against a key holding the wrong kind of value"}, + }, } for _, tc := range testCases { @@ -134,6 +149,8 @@ func TestSetWithOptions(t *testing.T) { } }) } + + FireCommand(conn, "FLUSHDB") } func TestSetWithExat(t *testing.T) { diff --git a/integration_tests/commands/resp/setup.go b/integration_tests/commands/resp/setup.go index ac7d906a8..a71d5a51d 100644 --- a/integration_tests/commands/resp/setup.go +++ b/integration_tests/commands/resp/setup.go @@ -12,6 +12,7 @@ import ( "time" "github.com/dicedb/dice/internal/server/resp" + "github.com/dicedb/dice/internal/wal" "github.com/dicedb/dice/internal/watchmanager" "github.com/dicedb/dice/internal/worker" @@ -128,7 +129,8 @@ func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) { shardManager := shard.NewShardManager(1, queryWatchChan, cmdWatchChan, gec) workerManager := worker.NewWorkerManager(20000, shardManager) // Initialize the RESP Server - testServer := resp.NewServer(shardManager, workerManager, cmdWatchSubscriptionChan, cmdWatchChan, gec, nil) + wl, _ := wal.NewNullWAL() + testServer := resp.NewServer(shardManager, workerManager, cmdWatchSubscriptionChan, cmdWatchChan, gec, wl) ctx, cancel := context.WithCancel(context.Background()) fmt.Println("Starting the test server on port", config.DiceConfig.AsyncServer.Port) diff --git a/integration_tests/commands/websocket/bloom_test.go b/integration_tests/commands/websocket/bloom_test.go index 44c28cb08..7198505df 100644 --- a/integration_tests/commands/websocket/bloom_test.go +++ b/integration_tests/commands/websocket/bloom_test.go @@ -211,5 +211,7 @@ func TestBFEdgeCasesAndErrors(t *testing.T) { exec.FireCommand(conn, cmd) } }) + conn := exec.ConnectToServer() + exec.FireCommandAndReadResponse(conn, "FLUSHDB") } } diff --git a/integration_tests/commands/websocket/set_test.go b/integration_tests/commands/websocket/set_test.go index 94480ebde..8dac72f15 100644 --- a/integration_tests/commands/websocket/set_test.go +++ b/integration_tests/commands/websocket/set_test.go @@ -121,6 +121,21 @@ func TestSetWithOptions(t *testing.T) { commands: []string{"SET k v XX EX 1", "GET k", "SLEEP 2", "GET k", "SET k v XX EX 1", "GET k"}, expected: []interface{}{nil, nil, "OK", nil, nil, nil}, }, + { + name: "GET with Existing Value", + commands: []string{"SET k v", "SET k vv GET"}, + expected: []interface{}{"OK", "v"}, + }, + { + name: "GET with Non-Existing Value", + commands: []string{"SET k vv GET"}, + expected: []interface{}{nil}, + }, + { + name: "GET with wrong type of value", + commands: []string{"sadd k v", "SET k vv GET"}, + expected: []interface{}{float64(1), "WRONGTYPE Operation against a key holding the wrong kind of value"}, + }, } for _, tc := range testCases { diff --git a/internal/eval/eval_test.go b/internal/eval/eval_test.go index e0768573e..1a832e846 100644 --- a/internal/eval/eval_test.go +++ b/internal/eval/eval_test.go @@ -187,204 +187,129 @@ func testEvalHELLO(t *testing.T, store *dstore.Store) { } func testEvalSET(t *testing.T, store *dstore.Store) { - tests := []evalTestCase{ - { - name: "nil value", + tests := map[string]evalTestCase{ + "nil value": { input: nil, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'set' command")}, }, - { - name: "empty array", + "empty array": { input: []string{}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'set' command")}, }, - { - name: "one value", + "one value": { input: []string{"KEY"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'set' command")}, }, - { - name: "key val pair", + "key val pair": { input: []string{"KEY", "VAL"}, migratedOutput: EvalResponse{Result: clientio.OK, Error: nil}, }, - { - name: "key val pair with int val", + "key val pair with int val": { input: []string{"KEY", "123456"}, migratedOutput: EvalResponse{Result: clientio.OK, Error: nil}, }, - { - name: "key val pair and expiry key", + "key val pair and expiry key": { input: []string{"KEY", "VAL", Px}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, }, - { - name: "key val pair and EX no val", + "key val pair and EX no val": { input: []string{"KEY", "VAL", Ex}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, }, - { - name: "key val pair and valid EX", + "key val pair and valid EX": { input: []string{"KEY", "VAL", Ex, "2"}, migratedOutput: EvalResponse{Result: clientio.OK, Error: nil}, }, - { - name: "key val pair and invalid negative EX", + "key val pair and invalid negative EX": { input: []string{"KEY", "VAL", Ex, "-2"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'set' command")}, }, - { - name: "key val pair and invalid float EX", + "key val pair and invalid float EX": { input: []string{"KEY", "VAL", Ex, "2.0"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}, }, - { - name: "key val pair and invalid out of range int EX", + "key val pair and invalid out of range int EX": { input: []string{"KEY", "VAL", Ex, "9223372036854775807"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'set' command")}, }, - { - name: "key val pair and invalid greater than max duration EX", + "key val pair and invalid greater than max duration EX": { input: []string{"KEY", "VAL", Ex, "9223372036854775"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'set' command")}, }, - { - name: "key val pair and invalid EX", + "key val pair and invalid EX": { input: []string{"KEY", "VAL", Ex, "invalid_expiry_val"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}, }, - { - name: "key val pair and PX no val", + "key val pair and PX no val": { input: []string{"KEY", "VAL", Px}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, }, - { - name: "key val pair and valid PX", + "key val pair and valid PX": { input: []string{"KEY", "VAL", Px, "2000"}, migratedOutput: EvalResponse{Result: clientio.OK, Error: nil}, }, - { - name: "key val pair and invalid PX", + "key val pair and invalid PX": { input: []string{"KEY", "VAL", Px, "invalid_expiry_val"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}, }, - { - name: "key val pair and invalid negative PX", + "key val pair and invalid negative PX": { input: []string{"KEY", "VAL", Px, "-2"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'set' command")}, }, - { - name: "key val pair and invalid float PX", + "key val pair and invalid float PX": { input: []string{"KEY", "VAL", Px, "2.0"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}, }, - { - name: "key val pair and invalid out of range int PX", + + "key val pair and invalid out of range int PX": { input: []string{"KEY", "VAL", Px, "9223372036854775807"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'set' command")}, }, - { - name: "key val pair and invalid greater than max duration PX", + "key val pair and invalid greater than max duration PX": { input: []string{"KEY", "VAL", Px, "9223372036854775"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'set' command")}, }, - { - name: "key val pair and both EX and PX", + "key val pair and both EX and PX": { input: []string{"KEY", "VAL", Ex, "2", Px, "2000"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, }, - { - name: "key val pair and PXAT no val", + "key val pair and PXAT no val": { input: []string{"KEY", "VAL", Pxat}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, }, - { - name: "key val pair and invalid PXAT", + "key val pair and invalid PXAT": { input: []string{"KEY", "VAL", Pxat, "invalid_expiry_val"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}, }, - { - name: "key val pair and expired PXAT", - input: []string{"KEY", "VAL", Pxat, "2"}, - migratedOutput: EvalResponse{Result: clientio.OK, Error: nil}, - }, - { - name: "key val pair and negative PXAT", - input: []string{"KEY", "VAL", Pxat, "-123456"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'set' command")}, - }, - { - name: "key val pair and valid PXAT", - input: []string{"KEY", "VAL", Pxat, strconv.FormatInt(time.Now().Add(2*time.Minute).UnixMilli(), 10)}, - migratedOutput: EvalResponse{Result: clientio.OK, Error: nil}, - }, - { - name: "key val pair and invalid EX and PX", - input: []string{"KEY", "VAL", Ex, "2", Px, "2000"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, - }, - { - name: "key val pair and invalid EX and PXAT", - input: []string{"KEY", "VAL", Ex, "2", Pxat, "2"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, - }, - { - name: "key val pair and invalid PX and PXAT", - input: []string{"KEY", "VAL", Px, "2000", Pxat, "2"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, - }, - { - name: "key val pair and KeepTTL", - input: []string{"KEY", "VAL", KeepTTL}, - migratedOutput: EvalResponse{Result: clientio.OK, Error: nil}, - }, - { - name: "key val pair and invalid KeepTTL", - input: []string{"KEY", "VAL", KeepTTL, "2"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, - }, - { - name: "key val pair and KeepTTL, EX", - input: []string{"KEY", "VAL", Ex, "2", KeepTTL}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, - }, - { - name: "key val pair and KeepTTL, PX", - input: []string{"KEY", "VAL", Px, "2000", KeepTTL}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, + "key val with get": { + input: []string{"key", "bazz", "GET"}, + setup: func() { + key := "key" + value := "bar" + obj := store.NewObj(value, -1, object.ObjTypeString, object.ObjEncodingEmbStr) + store.Put(key, obj) + }, + migratedOutput: EvalResponse{Result: "bar", Error: nil}, }, - { - name: "key val pair and KeepTTL, PXAT", - input: []string{"KEY", "VAL", Pxat, strconv.FormatInt(time.Now().Add(2*time.Minute).UnixMilli(), 10), KeepTTL}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, + "key val with get and nil get": { + input: []string{"key", "bar", "GET"}, + migratedOutput: EvalResponse{Result: clientio.NIL, Error: nil}, }, - { - name: "key val pair and KeepTTL, invalid PXAT", - input: []string{"KEY", "VAL", Pxat, "invalid_expiry_val", KeepTTL}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}, + "key val with get and but value is json": { + input: []string{"key", "bar", "GET"}, + setup: func() { + key := "key" + value := "{\"a\":2}" + var rootData interface{} + _ = sonic.Unmarshal([]byte(value), &rootData) + obj := store.NewObj(rootData, -1, object.ObjTypeJSON, object.ObjEncodingJSON) + store.Put(key, obj) + }, + migratedOutput: EvalResponse{Result: nil, Error: diceerrors.ErrWrongTypeOperation}, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - response := evalSET(tt.input, store) - - // Handle comparison for byte slices - if b, ok := response.Result.([]byte); ok && tt.migratedOutput.Result != nil { - if expectedBytes, ok := tt.migratedOutput.Result.([]byte); ok { - assert.True(t, bytes.Equal(b, expectedBytes), "expected and actual byte slices should be equal") - } - } else { - assert.Equal(t, tt.migratedOutput.Result, response.Result) - } - - if tt.migratedOutput.Error != nil { - assert.EqualError(t, response.Error, tt.migratedOutput.Error.Error()) - } else { - assert.NoError(t, response.Error) - } - }) - } + runMigratedEvalTests(t, tests, evalSET, store) } func testEvalGETEX(t *testing.T, store *dstore.Store) { diff --git a/internal/eval/store_eval.go b/internal/eval/store_eval.go index e19a3c513..79d63182a 100644 --- a/internal/eval/store_eval.go +++ b/internal/eval/store_eval.go @@ -183,16 +183,14 @@ func evalEXPIRETIME(args []string, store *dstore.Store) *EvalResponse { // If the key already exists then the value will be overwritten and expiry will be discarded func evalSET(args []string, store *dstore.Store) *EvalResponse { if len(args) <= 1 { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrWrongArgumentCount("SET"), - } + return makeEvalError(diceerrors.ErrWrongArgumentCount("SET")) } var key, value string var exDurationMs int64 = -1 var state exDurationState = Uninitialized var keepttl bool = false + var oldVal *interface{} key, value = args[0], args[1] oType, oEnc := deduceTypeEncoding(value) @@ -202,38 +200,23 @@ func evalSET(args []string, store *dstore.Store) *EvalResponse { switch arg { case Ex, Px: if state != Uninitialized { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrSyntax, - } + return makeEvalError(diceerrors.ErrSyntax) } if keepttl { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrSyntax, - } + return makeEvalError(diceerrors.ErrSyntax) } i++ if i == len(args) { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrSyntax, - } + return makeEvalError(diceerrors.ErrSyntax) } exDuration, err := strconv.ParseInt(args[i], 10, 64) if err != nil { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrIntegerOutOfRange, - } + return makeEvalError(diceerrors.ErrIntegerOutOfRange) } if exDuration <= 0 || exDuration >= maxExDuration { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrInvalidExpireTime("SET"), - } + return makeEvalError(diceerrors.ErrInvalidExpireTime("SET")) } // converting seconds to milliseconds @@ -245,37 +228,22 @@ func evalSET(args []string, store *dstore.Store) *EvalResponse { case Pxat, Exat: if state != Uninitialized { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrSyntax, - } + return makeEvalError(diceerrors.ErrSyntax) } if keepttl { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrSyntax, - } + return makeEvalError(diceerrors.ErrSyntax) } i++ if i == len(args) { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrSyntax, - } + return makeEvalError(diceerrors.ErrSyntax) } exDuration, err := strconv.ParseInt(args[i], 10, 64) if err != nil { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrIntegerOutOfRange, - } + return makeEvalError(diceerrors.ErrIntegerOutOfRange) } if exDuration < 0 { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrInvalidExpireTime("SET"), - } + return makeEvalError(diceerrors.ErrInvalidExpireTime("SET")) } if arg == Exat { @@ -295,32 +263,26 @@ func evalSET(args []string, store *dstore.Store) *EvalResponse { // if key does not exist, return RESP encoded nil if obj == nil { - return &EvalResponse{ - Result: clientio.NIL, - Error: nil, - } + return makeEvalResult(clientio.NIL) } case NX: obj := store.Get(key) if obj != nil { - return &EvalResponse{ - Result: clientio.NIL, - Error: nil, - } + return makeEvalResult(clientio.NIL) } case KeepTTL: if state != Uninitialized { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrSyntax, - } + return makeEvalError(diceerrors.ErrSyntax) } keepttl = true - default: - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrSyntax, + case GET: + getResult := evalGET([]string{key}, store) + if getResult.Error != nil { + return makeEvalError(diceerrors.ErrWrongTypeOperation) } + oldVal = &getResult.Result + default: + return makeEvalError(diceerrors.ErrSyntax) } } @@ -332,19 +294,15 @@ func evalSET(args []string, store *dstore.Store) *EvalResponse { case object.ObjEncodingEmbStr, object.ObjEncodingRaw: storedValue = value default: - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrUnsupportedEncoding(int(oEnc)), - } + return makeEvalError(diceerrors.ErrUnsupportedEncoding(int(oEnc))) } // putting the k and value in a Hash Table store.Put(key, store.NewObj(storedValue, exDurationMs, oType, oEnc), dstore.WithKeepTTL(keepttl)) - - return &EvalResponse{ - Result: clientio.OK, - Error: nil, + if oldVal != nil { + return makeEvalResult(*oldVal) } + return makeEvalResult(clientio.OK) } // evalGET returns the value for the queried key in args