Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#652: feature: HRANDFIELD command #696

Merged
merged 9 commits into from
Sep 26, 2024
105 changes: 105 additions & 0 deletions integration_tests/commands/async/hrandfield_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package async

import (
"github.com/google/go-cmp/cmp/cmpopts"
"gotest.tools/v3/assert"
"testing"
)

func TestHRANDFIELD(t *testing.T) {
conn := getLocalConnection()
defer conn.Close()
defer FireCommand(conn, "DEL key_hrandfield key_hrandfield02 key_hrandfield03")

testCases := []struct {
name string
commands []string
expected []interface{}
}{
{
name: "Basic HRANDFIELD operations",
commands: []string{"HSET key_hrandfield field value", "HSET key_hrandfield field2 value2", "HRANDFIELD key_hrandfield"},
expected: []interface{}{ONE, ONE, []string{"field", "field2"}},
},
{
name: "HRANDFIELD with count",
commands: []string{"HSET key_hrandfield field3 value3", "HRANDFIELD key_hrandfield 2"},
expected: []interface{}{ONE, []string{"field", "field2", "field3"}},
},
{
name: "HRANDFIELD with WITHVALUES",
commands: []string{"HRANDFIELD key_hrandfield 2 WITHVALUES"},
expected: []interface{}{[]string{"field", "value", "field2", "value2", "field3", "value3"}},
},
{
name: "HRANDFIELD on non-existent key",
commands: []string{"HRANDFIELD key_hrandfield_nonexistent"},
expected: []interface{}{"(nil)"},
},
{
name: "HRANDFIELD with wrong number of arguments",
commands: []string{"HRANDFIELD"},
expected: []interface{}{"ERR wrong number of arguments for 'hrandfield' command"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for i, cmd := range tc.commands {
result := FireCommand(conn, cmd)
expected := tc.expected[i]

switch expected := expected.(type) {
case []string:
assertRandomFieldResult(t, result, expected)
case int:
assert.Equal(t, result, expected, "Unexpected result for command: %s", cmd)
case string:
assert.Equal(t, result, expected, "Unexpected result for command: %s", cmd)
default:
if str, ok := result.(string); ok {
assert.Equal(t, str, expected, "Unexpected result for command: %s", cmd)
} else {
assert.DeepEqual(t, result, expected, cmpopts.EquateEmpty())
}
}
}
})
}
}

// assertRandomFieldResult asserts that the result contains all expected values or a single valid result
func assertRandomFieldResult(t *testing.T, result interface{}, expected []string) {
t.Helper()

var resultsList []string
switch r := result.(type) {
case []interface{}:
resultsList = make([]string, len(r))
for i, v := range r {
resultsList[i] = v.(string)
}
case string:
resultsList = []string{r}
default:
t.Fatalf("Expected result to be []interface{} or string, got %T", result)
}

// generate a map of expected values for easy lookup
expectedMap := make(map[string]struct{})
for _, exp := range expected {
expectedMap[exp] = struct{}{}
}

// count the number of results that are in the expected set
count := 0
for _, res := range resultsList {
if _, exists := expectedMap[res]; exists {
count++
}
}

// assert that all results are in the expected set or that there is a single valid result
assert.Assert(t, count == len(resultsList) || count == 1,
"Expected all results to be in the expected set or a single valid result. Got %d out of %d",
count, len(resultsList))
}
8 changes: 8 additions & 0 deletions internal/eval/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,13 @@ var (
IsMigrated: true,
NewEval: evalSETEX,
}
hrandfieldCmdMeta = DiceCmdMeta{
Name: "HRANDFIELD",
Info: `Returns one or more random fields from a hash.`,
Eval: evalHRANDFIELD,
Arity: -2,
KeySpecs: KeySpecs{BeginIndex: 1},
}
)

func init() {
Expand Down Expand Up @@ -1001,6 +1008,7 @@ func init() {
DiceCmds["INCRBY"] = incrbyCmdMeta
DiceCmds["GETRANGE"] = getRangeCmdMeta
DiceCmds["SETEX"] = setexCmdMeta
DiceCmds["HRANDFIELD"] = hrandfieldCmdMeta
DiceCmds["HDEL"] = hdelCmdMeta
}

Expand Down
47 changes: 24 additions & 23 deletions internal/eval/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,28 @@ const (
XOR string = "XOR"
NOT string = "NOT"

Ex string = "EX"
Px string = "PX"
Pxat string = "PXAT"
Exat string = "EXAT"
XX string = "XX"
NX string = "NX"
Xx string = "xx"
Nx string = "nx"
GT string = "GT"
LT string = "LT"
KEEPTTL string = "KEEPTTL"
Keepttl string = "keepttl"
Sync string = "SYNC"
Async string = "ASYNC"
Help string = "HELP"
Memory string = "MEMORY"
Count string = "COUNT"
GetKeys string = "GETKEYS"
List string = "LIST"
Info string = "INFO"
Null string = "null"
null string = "null"
NULL string = "null"
Ex string = "EX"
Px string = "PX"
Pxat string = "PXAT"
Exat string = "EXAT"
XX string = "XX"
NX string = "NX"
Xx string = "xx"
Nx string = "nx"
GT string = "GT"
LT string = "LT"
KEEPTTL string = "KEEPTTL"
Keepttl string = "keepttl"
Sync string = "SYNC"
Async string = "ASYNC"
Help string = "HELP"
Memory string = "MEMORY"
Count string = "COUNT"
GetKeys string = "GETKEYS"
List string = "LIST"
Info string = "INFO"
Null string = "null"
null string = "null"
NULL string = "null"
WITHVALUES string = "WITHVALUES"
)
93 changes: 93 additions & 0 deletions internal/eval/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package eval

import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"log/slog"
"math"
"math/big"
"math/bits"
"regexp"
"sort"
Expand Down Expand Up @@ -4064,3 +4066,94 @@ func evalGETRANGE(args []string, store *dstore.Store) []byte {

return clientio.Encode(str[start:end+1], false)
}

// evalHRANDFIELD returns random fields from a hash stored at key.
// If only the key is provided, one random field is returned.
// If count is provided, it returns that many unique random fields. A negative count allows repeated selections.
// The "WITHVALUES" option returns both fields and values.
// Returns nil if the key doesn't exist or the hash is empty.
// Errors: arity error, type error for non-hash, syntax error for "WITHVALUES", or count format error.
func evalHRANDFIELD(args []string, store *dstore.Store) []byte {
if len(args) < 1 || len(args) > 3 {
return diceerrors.NewErrArity("HRANDFIELD")
}

key := args[0]
obj := store.Get(key)
if obj == nil {
return clientio.RespNIL
}

if err := object.AssertTypeAndEncoding(obj.TypeEncoding, object.ObjTypeHashMap, object.ObjEncodingHashMap); err != nil {
return diceerrors.NewErrWithMessage(diceerrors.WrongTypeErr)
}

hashMap := obj.Value.(HashMap)
if len(hashMap) == 0 {
return clientio.Encode([]string{}, false)
}

count := 1
withValues := false

if len(args) > 1 {
var err error
// The second argument is the count.
count, err = strconv.Atoi(args[1])
if err != nil {
return diceerrors.NewErrWithFormattedMessage(diceerrors.IntOrOutOfRangeErr)
}

// The third argument is the "WITHVALUES" option.
if len(args) == 3 {
if strings.ToUpper(args[2]) != WITHVALUES {
return diceerrors.NewErrWithFormattedMessage(diceerrors.SyntaxErr)
}
withValues = true
}
}

return selectRandomFields(hashMap, count, withValues)
}

// selectRandomFields returns random fields from a hashmap.
func selectRandomFields(hashMap HashMap, count int, withValues bool) []byte {
keys := make([]string, 0, len(hashMap))
for k := range hashMap {
keys = append(keys, k)
}

var results []string
resultSet := make(map[string]struct{})

abs := func(x int) int {
if x < 0 {
return -x
}
return x
}

for i := 0; i < abs(count); i++ {
if count > 0 && len(resultSet) == len(keys) {
break
}

randomIndex, _ := rand.Int(rand.Reader, big.NewInt(int64(len(keys))))
randomField := keys[randomIndex.Int64()]

if count > 0 {
if _, exists := resultSet[randomField]; exists {
i--
continue
}
resultSet[randomField] = struct{}{}
}

results = append(results, randomField)
if withValues {
results = append(results, hashMap[randomField])
}
}

return clientio.Encode(results, false)
}
Loading