diff --git a/integration_tests/commands/async/hrandfield_test.go b/integration_tests/commands/async/hrandfield_test.go new file mode 100644 index 000000000..2dd951f90 --- /dev/null +++ b/integration_tests/commands/async/hrandfield_test.go @@ -0,0 +1,105 @@ +package async + +import ( + "github.com/google/go-cmp/cmp/cmpopts" + "gotest.tools/v3/assert" + "testing" +) + +func TestHRANDFIELD(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + defer FireCommand(conn, "DEL key_hrandfield key_hrandfield02 key_hrandfield03") + + testCases := []struct { + name string + commands []string + expected []interface{} + }{ + { + name: "Basic HRANDFIELD operations", + commands: []string{"HSET key_hrandfield field value", "HSET key_hrandfield field2 value2", "HRANDFIELD key_hrandfield"}, + expected: []interface{}{ONE, ONE, []string{"field", "field2"}}, + }, + { + name: "HRANDFIELD with count", + commands: []string{"HSET key_hrandfield field3 value3", "HRANDFIELD key_hrandfield 2"}, + expected: []interface{}{ONE, []string{"field", "field2", "field3"}}, + }, + { + name: "HRANDFIELD with WITHVALUES", + commands: []string{"HRANDFIELD key_hrandfield 2 WITHVALUES"}, + expected: []interface{}{[]string{"field", "value", "field2", "value2", "field3", "value3"}}, + }, + { + name: "HRANDFIELD on non-existent key", + commands: []string{"HRANDFIELD key_hrandfield_nonexistent"}, + expected: []interface{}{"(nil)"}, + }, + { + name: "HRANDFIELD with wrong number of arguments", + commands: []string{"HRANDFIELD"}, + expected: []interface{}{"ERR wrong number of arguments for 'hrandfield' command"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i, cmd := range tc.commands { + result := FireCommand(conn, cmd) + expected := tc.expected[i] + + switch expected := expected.(type) { + case []string: + assertRandomFieldResult(t, result, expected) + case int: + assert.Equal(t, result, expected, "Unexpected result for command: %s", cmd) + case string: + assert.Equal(t, result, expected, "Unexpected result for command: %s", cmd) + default: + if str, ok := result.(string); ok { + assert.Equal(t, str, expected, "Unexpected result for command: %s", cmd) + } else { + assert.DeepEqual(t, result, expected, cmpopts.EquateEmpty()) + } + } + } + }) + } +} + +// assertRandomFieldResult asserts that the result contains all expected values or a single valid result +func assertRandomFieldResult(t *testing.T, result interface{}, expected []string) { + t.Helper() + + var resultsList []string + switch r := result.(type) { + case []interface{}: + resultsList = make([]string, len(r)) + for i, v := range r { + resultsList[i] = v.(string) + } + case string: + resultsList = []string{r} + default: + t.Fatalf("Expected result to be []interface{} or string, got %T", result) + } + + // generate a map of expected values for easy lookup + expectedMap := make(map[string]struct{}) + for _, exp := range expected { + expectedMap[exp] = struct{}{} + } + + // count the number of results that are in the expected set + count := 0 + for _, res := range resultsList { + if _, exists := expectedMap[res]; exists { + count++ + } + } + + // assert that all results are in the expected set or that there is a single valid result + assert.Assert(t, count == len(resultsList) || count == 1, + "Expected all results to be in the expected set or a single valid result. Got %d out of %d", + count, len(resultsList)) +} diff --git a/internal/eval/commands.go b/internal/eval/commands.go index cd80e8947..d828025b9 100644 --- a/internal/eval/commands.go +++ b/internal/eval/commands.go @@ -901,6 +901,13 @@ var ( IsMigrated: true, NewEval: evalSETEX, } + hrandfieldCmdMeta = DiceCmdMeta{ + Name: "HRANDFIELD", + Info: `Returns one or more random fields from a hash.`, + Eval: evalHRANDFIELD, + Arity: -2, + KeySpecs: KeySpecs{BeginIndex: 1}, + } ) func init() { @@ -1001,6 +1008,7 @@ func init() { DiceCmds["INCRBY"] = incrbyCmdMeta DiceCmds["GETRANGE"] = getRangeCmdMeta DiceCmds["SETEX"] = setexCmdMeta + DiceCmds["HRANDFIELD"] = hrandfieldCmdMeta DiceCmds["HDEL"] = hdelCmdMeta } diff --git a/internal/eval/constants.go b/internal/eval/constants.go index 5769c8ac7..97f06c123 100644 --- a/internal/eval/constants.go +++ b/internal/eval/constants.go @@ -9,27 +9,28 @@ const ( XOR string = "XOR" NOT string = "NOT" - Ex string = "EX" - Px string = "PX" - Pxat string = "PXAT" - Exat string = "EXAT" - XX string = "XX" - NX string = "NX" - Xx string = "xx" - Nx string = "nx" - GT string = "GT" - LT string = "LT" - KEEPTTL string = "KEEPTTL" - Keepttl string = "keepttl" - Sync string = "SYNC" - Async string = "ASYNC" - Help string = "HELP" - Memory string = "MEMORY" - Count string = "COUNT" - GetKeys string = "GETKEYS" - List string = "LIST" - Info string = "INFO" - Null string = "null" - null string = "null" - NULL string = "null" + Ex string = "EX" + Px string = "PX" + Pxat string = "PXAT" + Exat string = "EXAT" + XX string = "XX" + NX string = "NX" + Xx string = "xx" + Nx string = "nx" + GT string = "GT" + LT string = "LT" + KEEPTTL string = "KEEPTTL" + Keepttl string = "keepttl" + Sync string = "SYNC" + Async string = "ASYNC" + Help string = "HELP" + Memory string = "MEMORY" + Count string = "COUNT" + GetKeys string = "GETKEYS" + List string = "LIST" + Info string = "INFO" + Null string = "null" + null string = "null" + NULL string = "null" + WITHVALUES string = "WITHVALUES" ) diff --git a/internal/eval/eval.go b/internal/eval/eval.go index a0ecdfb07..87483a522 100644 --- a/internal/eval/eval.go +++ b/internal/eval/eval.go @@ -2,10 +2,12 @@ package eval import ( "bytes" + "crypto/rand" "errors" "fmt" "log/slog" "math" + "math/big" "math/bits" "regexp" "sort" @@ -4064,3 +4066,94 @@ func evalGETRANGE(args []string, store *dstore.Store) []byte { return clientio.Encode(str[start:end+1], false) } + +// evalHRANDFIELD returns random fields from a hash stored at key. +// If only the key is provided, one random field is returned. +// If count is provided, it returns that many unique random fields. A negative count allows repeated selections. +// The "WITHVALUES" option returns both fields and values. +// Returns nil if the key doesn't exist or the hash is empty. +// Errors: arity error, type error for non-hash, syntax error for "WITHVALUES", or count format error. +func evalHRANDFIELD(args []string, store *dstore.Store) []byte { + if len(args) < 1 || len(args) > 3 { + return diceerrors.NewErrArity("HRANDFIELD") + } + + key := args[0] + obj := store.Get(key) + if obj == nil { + return clientio.RespNIL + } + + if err := object.AssertTypeAndEncoding(obj.TypeEncoding, object.ObjTypeHashMap, object.ObjEncodingHashMap); err != nil { + return diceerrors.NewErrWithMessage(diceerrors.WrongTypeErr) + } + + hashMap := obj.Value.(HashMap) + if len(hashMap) == 0 { + return clientio.Encode([]string{}, false) + } + + count := 1 + withValues := false + + if len(args) > 1 { + var err error + // The second argument is the count. + count, err = strconv.Atoi(args[1]) + if err != nil { + return diceerrors.NewErrWithFormattedMessage(diceerrors.IntOrOutOfRangeErr) + } + + // The third argument is the "WITHVALUES" option. + if len(args) == 3 { + if strings.ToUpper(args[2]) != WITHVALUES { + return diceerrors.NewErrWithFormattedMessage(diceerrors.SyntaxErr) + } + withValues = true + } + } + + return selectRandomFields(hashMap, count, withValues) +} + +// selectRandomFields returns random fields from a hashmap. +func selectRandomFields(hashMap HashMap, count int, withValues bool) []byte { + keys := make([]string, 0, len(hashMap)) + for k := range hashMap { + keys = append(keys, k) + } + + var results []string + resultSet := make(map[string]struct{}) + + abs := func(x int) int { + if x < 0 { + return -x + } + return x + } + + for i := 0; i < abs(count); i++ { + if count > 0 && len(resultSet) == len(keys) { + break + } + + randomIndex, _ := rand.Int(rand.Reader, big.NewInt(int64(len(keys)))) + randomField := keys[randomIndex.Int64()] + + if count > 0 { + if _, exists := resultSet[randomField]; exists { + i-- + continue + } + resultSet[randomField] = struct{}{} + } + + results = append(results, randomField) + if withValues { + results = append(results, hashMap[randomField]) + } + } + + return clientio.Encode(results, false) +} diff --git a/internal/eval/eval_test.go b/internal/eval/eval_test.go index 5a19d652a..55ed89759 100644 --- a/internal/eval/eval_test.go +++ b/internal/eval/eval_test.go @@ -3213,7 +3213,7 @@ func testEvalJSONOBJKEYS(t *testing.T, store *dstore.Store) { "key does not exist": { setup: func() {}, input: []string{"NONEXISTENT_KEY"}, - output: []byte("-ERR could not perform this operation on a key that doesn't exist\r\n"), + output: []byte("-ERR could not perform this operation on a key that doesn't exist\r\n"), }, "root not object": { setup: func() { @@ -3511,8 +3511,8 @@ func testEvalHSETNX(t *testing.T, store *dstore.Store) { output: []byte("-ERR wrong number of arguments for 'hsetnx' command\r\n"), }, "more than one field and value passed": { - setup: func() {}, - input: []string{"KEY", "field1", "value1", "field2", "value2"}, + setup: func() {}, + input: []string{"KEY", "field1", "value1", "field2", "value2"}, output: []byte("-ERR wrong number of arguments for 'hsetnx' command\r\n"), }, "key, field and value passed": { @@ -3969,3 +3969,114 @@ func BenchmarkEvalBITOP(b *testing.B) { }) } } + +func testEvalHRANDFIELD(t *testing.T, store *dstore.Store) { + tests := map[string]evalTestCase{ + "wrong number of args passed": { + setup: func() {}, + input: nil, + output: []byte("-ERR wrong number of arguments for 'hrandfield' command\r\n"), + }, + "key doesn't exist": { + setup: func() {}, + input: []string{"KEY"}, + output: clientio.RespNIL, + }, + "key exists with fields and no count argument": { + setup: func() { + key := "KEY_MOCK" + newMap := make(HashMap) + newMap["field1"] = "Value1" + newMap["field2"] = "Value2" + + obj := &object.Obj{ + TypeEncoding: object.ObjTypeHashMap | object.ObjEncodingHashMap, + Value: newMap, + LastAccessedAt: uint32(time.Now().Unix()), + } + + store.Put(key, obj) + }, + input: []string{"KEY_MOCK"}, + validator: func(output []byte) { + assert.Assert(t, output != nil) + resultString := string(output) + parts := strings.SplitN(resultString, "\n", 2) + if len(parts) < 2 { + t.Errorf("Unexpected output format: %s", resultString) + return + } + decodedResult := strings.TrimSpace(parts[1]) + fmt.Printf("Decoded Result: '%s'\n", decodedResult) + assert.Assert(t, decodedResult == "field1" || decodedResult == "field2") + }, + }, + "key exists with fields and count argument": { + setup: func() { + key := "KEY_MOCK" + newMap := make(HashMap) + newMap["field1"] = "value1" + newMap["field2"] = "value2" + newMap["field3"] = "value3" + + obj := &object.Obj{ + TypeEncoding: object.ObjTypeHashMap | object.ObjEncodingHashMap, + Value: newMap, + LastAccessedAt: uint32(time.Now().Unix()), + } + + store.Put(key, obj) + + }, + input: []string{"KEY_MOCK", "2"}, + validator: func(output []byte) { + assert.Assert(t, output != nil) + decodedResult := string(output) + fields := []string{"field1", "field2", "field3"} + count := 0 + + for _, field := range fields { + if strings.Contains(decodedResult, field) { + count++ + } + } + + assert.Assert(t, count == 2) + }, + }, + "key exists with count and WITHVALUES argument": { + setup: func() { + key := "KEY_MOCK" + newMap := make(HashMap) + newMap["field1"] = "value1" + newMap["field2"] = "value2" + newMap["field3"] = "value3" + + obj := &object.Obj{ + TypeEncoding: object.ObjTypeHashMap | object.ObjEncodingHashMap, + Value: newMap, + LastAccessedAt: uint32(time.Now().Unix()), + } + + store.Put(key, obj) + + }, + input: []string{"KEY_MOCK", "2", "WITHVALUES"}, + validator: func(output []byte) { + assert.Assert(t, output != nil) + decodedResult := string(output) + fieldsAndValues := []string{"field1", "value1", "field2", "value2", "field3", "value3"} + count := 0 + for _, item := range fieldsAndValues { + if strings.Contains(decodedResult, item) { + count++ + } + } + + assert.Assert(t, count == 4) + }, + }, + } + + runEvalTests(t, tests, evalHRANDFIELD, store) +}