Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WATCH implementation in IronHawk #1442

Merged
merged 4 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions internal/cmd/cmd_get_watch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package cmd

import (
"strconv"

dstore "github.com/dicedb/dice/internal/store"
"google.golang.org/protobuf/types/known/structpb"
)

var cGETWATCH = &DiceDBCommand{
Name: "GET.WATCH",
HelpShort: "GET.WATCH creates a query subscription over the GET command",
Eval: evalGETWATCH,
}

func init() {
commandRegistry.AddCommand(cGETWATCH)
}

func evalGETWATCH(c *Cmd, s *dstore.Store) (*CmdRes, error) {
if len(c.C.Args) != 1 {
return cmdResNil, errWrongArgumentCount("GET.WATCH")
}

r, err := evalGET(c, s)
if err != nil {
return nil, err
}

if r.R.Attrs == nil {
r.R.Attrs = &structpb.Struct{
Fields: make(map[string]*structpb.Value),
}
}

r.R.Attrs.Fields["fingerprint"] = structpb.NewStringValue(strconv.FormatUint(uint64(c.GetFingerprint()), 10))
return r, nil
}
23 changes: 23 additions & 0 deletions internal/cmd/cmd_unwatch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package cmd

import (
dstore "github.com/dicedb/dice/internal/store"
)

var cUNWATCH = &DiceDBCommand{
Name: "UNWATCH",
HelpShort: "UNWATCH removes the previously created query subscription",
Eval: evalUNWATCH,
}

func init() {
commandRegistry.AddCommand(cUNWATCH)
}

func evalUNWATCH(c *Cmd, s *dstore.Store) (*CmdRes, error) {
if len(c.C.Args) != 1 {
return cmdResNil, errWrongArgumentCount("UNWATCH")
}

return cmdResOK, nil
}
45 changes: 31 additions & 14 deletions internal/cmd/cmds.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@ type Cmd struct {
ThreadID string
}

func (c *Cmd) String() string {
return fmt.Sprintf("%s %s", c.C.Cmd, strings.Join(c.C.Args, " "))
}

func (c *Cmd) GetFingerprint() uint32 {
return farm.Fingerprint32([]byte(c.String()))
}

func (c *Cmd) Key() string {
if len(c.C.Args) > 0 {
return c.C.Args[0]
}
return ""
}

type CmdRes struct {
R *wire.Response
ThreadID string
Expand Down Expand Up @@ -49,22 +64,24 @@ var commandRegistry CmdRegistry = CmdRegistry{
}

func Execute(c *Cmd, s *dstore.Store) (*CmdRes, error) {
// TODO: Replace this iteration with a HashTable lookup.
for _, cmd := range commandRegistry.cmds {
if cmd.Name == c.C.Cmd {
start := time.Now()
resp, err := cmd.Eval(c, s)
if err != nil {
resp.R.Err = err.Error()
}

slog.Debug("command executed",
slog.Any("cmd", c.C.Cmd),
slog.String("args", strings.Join(c.C.Args, " ")),
slog.String("thread_id", c.ThreadID),
slog.Int("shard_id", s.ShardID),
slog.Any("took_ns", time.Since(start).Nanoseconds()))
return resp, err
if cmd.Name != c.C.Cmd {
continue
}

start := time.Now()
resp, err := cmd.Eval(c, s)
if err != nil {
resp.R.Err = err.Error()
}

slog.Debug("command executed",
slog.Any("cmd", c.String()),
slog.String("thread_id", c.ThreadID),
slog.Int("shard_id", s.ShardID),
slog.Any("took_ns", time.Since(start).Nanoseconds()))
return resp, err
}
return cmdResNil, errors.New("command not found")
}
Expand Down
2 changes: 1 addition & 1 deletion internal/eval/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ func init() {
DiceCmds["DECRBY"] = decrByCmdMeta
DiceCmds["DEL"] = delCmdMeta
DiceCmds["DUMP"] = dumpkeyCMmdMeta
DiceCmds["ECHO"] = echoCmdMeta
DiceCmds["ECHO"] = echoCmdMeta // moved to ironhawk
DiceCmds["EXISTS"] = existsCmdMeta
DiceCmds["EXPIRE"] = expireCmdMeta
DiceCmds["EXPIREAT"] = expireatCmdMeta
Expand Down
30 changes: 23 additions & 7 deletions internal/iothread/iothread.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package iothread
import (
"context"
"log/slog"
"strings"

"github.com/dicedb/dice/internal/auth"
"github.com/dicedb/dice/internal/clientio/iohandler"
Expand All @@ -14,7 +15,7 @@ import (

type IOThread struct {
id string
ioHandler iohandler.IOHandler
IoHandler iohandler.IOHandler
Session *auth.Session
ioThreadReadChan chan []byte // Channel to send data to the command handler
ioThreadWriteChan chan interface{} // Channel to receive data from the command handler
Expand All @@ -26,7 +27,7 @@ func NewIOThread(id string, ioHandler iohandler.IOHandler,
ioThreadErrChan chan error) *IOThread {
return &IOThread{
id: id,
ioHandler: ioHandler,
IoHandler: ioHandler,
Session: auth.NewSession(),
ioThreadReadChan: ioThreadReadChan,
ioThreadWriteChan: ioThreadWriteChan,
Expand Down Expand Up @@ -64,7 +65,7 @@ func (t *IOThread) Start(ctx context.Context) error {
t.ioThreadErrChan <- err
return err
case resp := <-t.ioThreadWriteChan:
err := t.ioHandler.Write(ctx, resp)
err := t.IoHandler.Write(ctx, resp)
if err != nil {
slog.Debug("error while sending response to the client", slog.String("id", t.id), slog.Any("error", err))
continue
Expand All @@ -74,10 +75,14 @@ func (t *IOThread) Start(ctx context.Context) error {
}
}

func (t *IOThread) StartSync(ctx context.Context, execute func(c *cmd.Cmd) (*cmd.CmdRes, error)) error {
func (t *IOThread) StartSync(
ctx context.Context, execute func(c *cmd.Cmd) (*cmd.CmdRes, error),
handleWatch func(c *cmd.Cmd, t *IOThread),
handleUnwatch func(c *cmd.Cmd, t *IOThread),
notifyWatchers func(c *cmd.Cmd, execute func(c *cmd.Cmd) (*cmd.CmdRes, error))) error {
slog.Debug("io thread started", slog.String("id", t.id))
for {
c, err := t.ioHandler.ReadSync()
c, err := t.IoHandler.ReadSync()
if err != nil {
return err
}
Expand All @@ -86,10 +91,21 @@ func (t *IOThread) StartSync(ctx context.Context, execute func(c *cmd.Cmd) (*cmd
if err != nil {
res.R.Err = err.Error()
}
err = t.ioHandler.WriteSync(ctx, res)

if strings.HasSuffix(c.C.Cmd, ".WATCH") {
handleWatch(c, t)
}

if strings.HasSuffix(c.C.Cmd, "UNWATCH") {
handleUnwatch(c, t)
}

err = t.IoHandler.WriteSync(ctx, res)
if err != nil {
return err
}

go notifyWatchers(c, execute)
}
}

Expand All @@ -99,7 +115,7 @@ func (t *IOThread) startInputReader(ctx context.Context, incomingDataChan chan [
defer close(readErrChan)

for {
data, err := t.ioHandler.Read(ctx)
data, err := t.IoHandler.Read(ctx)
if err != nil {
select {
case readErrChan <- err:
Expand Down
3 changes: 2 additions & 1 deletion internal/server/ironhawk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,10 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou

func (s *Server) startIOThread(ctx context.Context, wg *sync.WaitGroup, thread *iothread.IOThread) {
wg.Done()
err := thread.StartSync(ctx, s.shardManager.Execute)
err := thread.StartSync(ctx, s.shardManager.Execute, HandleWatch, HandleUnwatch, NotifyWatchers)
if err != nil {
if err == io.EOF {
CleanupThreadWatchSubscriptions(thread)
slog.Debug("client disconnected. io-thread stopped", slog.String("id", thread.ID()))
} else {
slog.Debug("io-thread errored out", slog.String("id", thread.ID()), slog.Any("error", err))
Expand Down
113 changes: 113 additions & 0 deletions internal/server/ironhawk/watch_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package ironhawk

import (
"context"
"log/slog"
"strconv"

"github.com/dicedb/dice/internal/cmd"
"github.com/dicedb/dice/internal/iothread"
)

var (
keyFPMap map[string]map[uint32]bool // keyFPMap is a map of Key -> [fingerprint1, fingerprint2, ...]
fpThreadMap map[uint32]map[*iothread.IOThread]bool // fpConnMap is a map of fingerprint -> [client1Chan, client2Chan, ...]
fpCmdMap map[uint32]*cmd.Cmd // fpCmdMap is a map of fingerprint -> command
)

func init() {
keyFPMap = make(map[string]map[uint32]bool)
fpThreadMap = make(map[uint32]map[*iothread.IOThread]bool)
fpCmdMap = make(map[uint32]*cmd.Cmd)
}

func HandleWatch(c *cmd.Cmd, t *iothread.IOThread) {
fp := c.GetFingerprint()
key := c.Key()
slog.Debug("creating a new subscription",
slog.String("key", key),
slog.String("cmd", c.String()),
slog.Any("fingerprint", fp))

if _, ok := keyFPMap[key]; !ok {
keyFPMap[key] = make(map[uint32]bool)
}
keyFPMap[key][fp] = true

if _, ok := fpThreadMap[fp]; !ok {
fpThreadMap[fp] = make(map[*iothread.IOThread]bool)
}
fpThreadMap[fp][t] = true
fpCmdMap[fp] = c
}

func HandleUnwatch(c *cmd.Cmd, t *iothread.IOThread) {
if len(c.C.Args) != 1 {
return
}

_fp, err := strconv.ParseUint(c.C.Args[0], 10, 32)
if err != nil {
return
}
fp := uint32(_fp)

delete(fpThreadMap[fp], t)
if len(fpThreadMap[fp]) == 0 {
delete(fpThreadMap, fp)
}

for key, fpMap := range keyFPMap {
if _, ok := fpMap[fp]; ok {
delete(keyFPMap[key], fp)
}
if len(keyFPMap[key]) == 0 {
delete(keyFPMap, key)
}
}

// TODO: Maintain ref count for gp -> cmd mapping
// delete it from delete(fpCmdMap, fp) only when ref count is 0
// check if any easier way to do this
}

func CleanupThreadWatchSubscriptions(t *iothread.IOThread) {
for fp, threadMap := range fpThreadMap {
if _, ok := threadMap[t]; ok {
delete(fpThreadMap[fp], t)
}
if len(fpThreadMap[fp]) == 0 {
delete(fpThreadMap, fp)
}
}
}

func NotifyWatchers(c *cmd.Cmd, execute func(c *cmd.Cmd) (*cmd.CmdRes, error)) {
// TODO: During first WATCH call, we are getting the response multiple times on the Client
// Check if this is happening because of the way we are notifying the watchers
key := c.Key()
for fp := range keyFPMap[key] {
_c := fpCmdMap[fp]
if _c == nil {
// TODO: We might want to remove the key from keyFPMap if we don't have a command for it.
continue
}

r, err := execute(_c)
if err != nil {
slog.Error("failed to execute command as part of watch notification",
slog.Any("cmd", _c.String()),
slog.Any("error", err))
continue
}

for thread := range fpThreadMap[fp] {
err := thread.IoHandler.WriteSync(context.Background(), r)
if err != nil {
slog.Error("failed to write response to thread", slog.Any("thread", thread.ID()), slog.Any("error", err))
}
}

slog.Debug("notifying watchers for key", slog.String("key", key), slog.Int("watchers", len(fpThreadMap[fp])))
}
}
Loading