Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix subscription id mismtach #231

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 54 additions & 27 deletions rpc/ws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"net/http"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/buger/jsonparser"
Expand All @@ -40,9 +41,10 @@ type Client struct {
connCtx context.Context
connCtxCancel context.CancelFunc
lock sync.RWMutex
subscriptionByRequestID map[uint64]*Subscription
subscriptionByWSSubID map[uint64]*Subscription
subscriptionByRequestID map[string]*Subscription
subscriptionByWSSubID map[string]*Subscription
reconnectOnErr bool
idCounter atomic.Uint32
}

const (
Expand All @@ -66,10 +68,9 @@ func Connect(ctx context.Context, rpcEndpoint string) (c *Client, err error) {
func ConnectWithOptions(ctx context.Context, rpcEndpoint string, opt *Options) (c *Client, err error) {
c = &Client{
rpcURL: rpcEndpoint,
subscriptionByRequestID: map[uint64]*Subscription{},
subscriptionByWSSubID: map[uint64]*Subscription{},
subscriptionByRequestID: map[string]*Subscription{},
subscriptionByWSSubID: map[string]*Subscription{},
}

dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: DefaultHandshakeTimeout,
Expand Down Expand Up @@ -147,6 +148,27 @@ func (c *Client) receiveMessages() {
}
}

// GetString returns the value retrieved by `Get`, cast to a string if possible.
// If key data type do not match, it will return an error.
func getString(data []byte, keys ...string) (val string, err error) {
v, t, _, e := jsonparser.Get(data, keys...)
if e != nil {
return "", e
}
if t != jsonparser.String {
return "", fmt.Errorf("Value is not a string: %s", string(v))
}
return string(v), nil
}

func getStringWithOk(data []byte, path ...string) (string, bool) {
val, err := getString(data, path...)
if err == nil {
return val, true
}
return "", false
}

// GetUint64 returns the value retrieved by `Get`, cast to a uint64 if possible.
// If key data type do not match, it will return an error.
func getUint64(data []byte, keys ...string) (val uint64, err error) {
Expand All @@ -172,59 +194,59 @@ func (c *Client) handleMessage(message []byte) {
// when receiving message with id. the result will be a subscription number.
// that number will be associated to all future message destine to this request

requestID, ok := getUint64WithOk(message, "id")
requestID, ok := getStringWithOk(message, "id")
if ok {
subID, _ := getUint64WithOk(message, "result")
subID, _ := getStringWithOk(message, "result")
c.handleNewSubscriptionMessage(requestID, subID)
return
}

subID, _ := getUint64WithOk(message, "params", "subscription")
subID, _ := getStringWithOk(message, "params", "subscription")
c.handleSubscriptionMessage(subID, message)
}

func (c *Client) handleNewSubscriptionMessage(requestID, subID uint64) {
func (c *Client) handleNewSubscriptionMessage(requestID, subID string) {
c.lock.Lock()
defer c.lock.Unlock()

if traceEnabled {
zlog.Debug("received new subscription message",
zap.Uint64("message_id", requestID),
zap.Uint64("subscription_id", subID),
zap.String("message_id", requestID),
zap.String("subscription_id", subID),
)
}

callBack, found := c.subscriptionByRequestID[requestID]
if !found {
zlog.Error("cannot find websocket message handler for a new stream.... this should not happen",
zap.Uint64("request_id", requestID),
zap.Uint64("subscription_id", subID),
zap.String("request_id", requestID),
zap.String("subscription_id", subID),
)
return
}
callBack.subID = subID
c.subscriptionByWSSubID[subID] = callBack

zlog.Debug("registered ws subscription",
zap.Uint64("subscription_id", subID),
zap.Uint64("request_id", requestID),
zap.String("subscription_id", subID),
zap.String("request_id", requestID),
zap.Int("subscription_count", len(c.subscriptionByWSSubID)),
)
return
}

func (c *Client) handleSubscriptionMessage(subID uint64, message []byte) {
func (c *Client) handleSubscriptionMessage(subID string, message []byte) {
if traceEnabled {
zlog.Debug("received subscription message",
zap.Uint64("subscription_id", subID),
zap.String("subscription_id", subID),
)
}

c.lock.RLock()
sub, found := c.subscriptionByWSSubID[subID]
c.lock.RUnlock()
if !found {
zlog.Warn("unable to find subscription for ws message", zap.Uint64("subscription_id", subID))
zlog.Warn("unable to find subscription for ws message", zap.String("subscription_id", subID))
return
}

Expand All @@ -240,7 +262,7 @@ func (c *Client) handleSubscriptionMessage(subID uint64, message []byte) {
// we will no read any other message
if len(sub.stream) >= cap(sub.stream) {
zlog.Warn("closing ws client subscription... not consuming fast en ought",
zap.Uint64("request_id", sub.req.ID),
zap.String("request_id", sub.req.ID),
)
c.closeSubscription(sub.req.ID, fmt.Errorf("reached channel max capacity %d", len(sub.stream)))
return
Expand All @@ -260,11 +282,11 @@ func (c *Client) closeAllSubscription(err error) {
sub.err <- err
}

c.subscriptionByRequestID = map[uint64]*Subscription{}
c.subscriptionByWSSubID = map[uint64]*Subscription{}
c.subscriptionByRequestID = map[string]*Subscription{}
c.subscriptionByWSSubID = map[string]*Subscription{}
}

func (c *Client) closeSubscription(reqID uint64, err error) {
func (c *Client) closeSubscription(reqID string, err error) {
c.lock.Lock()
defer c.lock.Unlock()

Expand All @@ -286,17 +308,17 @@ func (c *Client) closeSubscription(reqID uint64, err error) {
delete(c.subscriptionByWSSubID, sub.subID)
}

func (c *Client) unsubscribe(subID uint64, method string) error {
req := newRequest([]interface{}{subID}, method, nil)
func (c *Client) unsubscribe(subID string, method string) error {
req := newRequest([]interface{}{subID}, method, nil, c.nextID())
data, err := req.encode()
if err != nil {
return fmt.Errorf("unable to encode unsubscription message for subID %d and method %s", subID, method)
return fmt.Errorf("unable to encode unsubscription message for subID %s and method %s", subID, method)
}

c.conn.SetWriteDeadline(time.Now().Add(writeWait))
err = c.conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
return fmt.Errorf("unable to send unsubscription message for subID %d and method %s", subID, method)
return fmt.Errorf("unable to send unsubscription message for subID %s and method %s", subID, method)
}
return nil
}
Expand All @@ -311,7 +333,7 @@ func (c *Client) subscribe(
c.lock.Lock()
defer c.lock.Unlock()

req := newRequest(params, subscriptionMethod, conf)
req := newRequest(params, subscriptionMethod, conf, c.nextID())
data, err := req.encode()
if err != nil {
return nil, fmt.Errorf("subscribe: unable to encode subsciption request: %w", err)
Expand Down Expand Up @@ -339,6 +361,11 @@ func (c *Client) subscribe(
return sub, nil
}

func (c *Client) nextID() string {
id := c.idCounter.Add(1)
return strconv.FormatUint(uint64(id), 10)
}

func decodeResponseFromReader(r io.Reader, reply interface{}) (err error) {
var c *response
if err := json.NewDecoder(r).Decode(&c); err != nil {
Expand Down
3 changes: 1 addition & 2 deletions rpc/ws/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package ws

type Subscription struct {
req *request
subID uint64
subID string
stream chan result
err chan error
closeFunc func(err error)
Expand All @@ -38,7 +38,6 @@ func newSubscription(
) *Subscription {
return &Subscription{
req: req,
subID: 0,
stream: make(chan result, 200_000),
err: make(chan error, 100_000),
closeFunc: closeFunc,
Expand Down
7 changes: 3 additions & 4 deletions rpc/ws/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package ws
import (
stdjson "encoding/json"
"fmt"
"math/rand"
"net/http"
"time"
)
Expand All @@ -29,18 +28,18 @@ type request struct {
Version string `json:"jsonrpc"`
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
ID uint64 `json:"id"`
ID string `json:"id"`
}

func newRequest(params []interface{}, method string, configuration map[string]interface{}) *request {
func newRequest(params []interface{}, method string, configuration map[string]interface{}, requestID string) *request {
if params != nil && configuration != nil {
params = append(params, configuration)
}
return &request{
Version: "2.0",
Method: method,
Params: params,
ID: uint64(rand.Int63()),
ID: requestID,
}
}

Expand Down