Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client/core: put the core together #120

Merged
merged 2 commits into from
Jan 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 81 additions & 46 deletions client/comms/wsconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
Expand All @@ -29,17 +28,22 @@ const (
writeWait = time.Second * 3
)

// When the DEX sends a request to the client, a responseHandler is created
// to wait for the response.
type responseHandler struct {
expiration time.Time
f func(*msgjson.Message)
}

// WsCfg is the configuration struct for initializing a WsConn.
type WsCfg struct {
// The websocket host.
Host string
// The websocket api path.
Path string
// URL is the websocket endpoint URL.
URL string
// The maximum time in seconds to wait for a ping from the server.
PingWait time.Duration
// The rpc certificate file path.
RpcCert string
// ReconnectSync runs the needed reconnection synchronisation after
// ReconnectSync runs the needed reconnection synchronization after
// a disconnect.
ReconnectSync func()
// The dex client context.
Expand All @@ -57,12 +61,12 @@ type WsConn struct {
readCh chan *msgjson.Message
sendCh chan *msgjson.Message
reconnectCh chan struct{}
req map[uint64]*msgjson.Message
reqMtx sync.RWMutex
connected bool
connectedMtx sync.RWMutex
once sync.Once
wg sync.WaitGroup
respHandlers map[uint64]*responseHandler
}

// filesExists reports whether the named file or directory exists.
Expand Down Expand Up @@ -105,12 +109,12 @@ func NewWsConn(cfg *WsCfg) (*WsConn, error) {
}

conn := &WsConn{
cfg: cfg,
tlsCfg: tlsConfig,
readCh: make(chan *msgjson.Message, readBuffSize),
sendCh: make(chan *msgjson.Message),
reconnectCh: make(chan struct{}),
req: make(map[uint64]*msgjson.Message),
cfg: cfg,
tlsCfg: tlsConfig,
readCh: make(chan *msgjson.Message, readBuffSize),
sendCh: make(chan *msgjson.Message),
reconnectCh: make(chan struct{}),
respHandlers: make(map[uint64]*responseHandler),
}

conn.wg.Add(1)
Expand All @@ -134,28 +138,6 @@ func (conn *WsConn) setConnected(connected bool) {
conn.connectedMtx.Unlock()
}

// logRoute logs a request keyed by its id.
func (conn *WsConn) logRequest(id uint64, req *msgjson.Message) {
conn.reqMtx.Lock()
conn.req[id] = req
conn.reqMtx.Unlock()
}

// FetchRequest fetches the request associated with the id. The returned
// request is removed from the cache.
func (conn *WsConn) FetchRequest(id uint64) (*msgjson.Message, error) {
conn.reqMtx.Lock()
defer conn.reqMtx.Unlock()
req := conn.req[id]
if req == nil {
return nil, fmt.Errorf("no request found for id %d", id)
}

delete(conn.req, id)

return req, nil
}

// NextID returns the next request id.
func (conn *WsConn) NextID() uint64 {
return atomic.AddUint64(&conn.rID, 1)
Expand All @@ -178,19 +160,13 @@ func (conn *WsConn) close() {

// connect attempts to establish a websocket connection.
func (conn *WsConn) connect() error {
url := url.URL{
Scheme: "wss",
Host: conn.cfg.Host,
Path: conn.cfg.Path,
}

dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 10 * time.Second,
TLSClientConfig: conn.tlsCfg,
}

ws, _, err := dialer.Dial(url.String(), nil)
ws, _, err := dialer.Dial(conn.cfg.URL, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -266,6 +242,17 @@ func (conn *WsConn) read() {
return
}

// If the message is a response, find the handler.
if msg.Type == msgjson.Response {
handler := conn.respHandler(msg.ID)
if handler == nil {
b, _ := json.Marshal(msg)
log.Errorf("no handler found for response", string(b))
}
handler.f(msg)
continue
}

conn.readCh <- msg
}
}
Expand Down Expand Up @@ -336,17 +323,65 @@ func (conn *WsConn) Send(msg *msgjson.Message) error {
log.Errorf("write error: %v", err)
return err
}
return nil
}

// Request sends the message with Send, but keeps a record of the callback
// function to run when a response is received.
func (conn *WsConn) Request(msg *msgjson.Message, f func(*msgjson.Message)) error {
// Log the message sent if it is a request.
if msg.Type == msgjson.Request {
conn.logRequest(msg.ID, msg)
conn.logReq(msg.ID, f)
}
return conn.Send(msg)
}

return nil
// logReq stores the response handler in the respHandlers map. Requests to the
// client are associated with a response handler.
func (conn *WsConn) logReq(id uint64, respHandler func(*msgjson.Message)) {
conn.reqMtx.Lock()
defer conn.reqMtx.Unlock()
conn.respHandlers[id] = &responseHandler{
expiration: time.Now().Add(time.Minute * 5),
f: respHandler,
}
// clean up the response map.
if len(conn.respHandlers) > 1 {
go conn.cleanUpExpired()
}
}

// cleanUpExpired cleans up the response handler map.
func (conn *WsConn) cleanUpExpired() {
conn.reqMtx.Lock()
defer conn.reqMtx.Unlock()
var expired []uint64
for id, cb := range conn.respHandlers {
if time.Until(cb.expiration) < 0 {
expired = append(expired, id)
}
}
for _, id := range expired {
delete(conn.respHandlers, id)
}
}

// respHandler extracts the response handler for the provided request ID if it
// exists, else nil. If the handler exists, it will be deleted from the map.
func (conn *WsConn) respHandler(id uint64) *responseHandler {
conn.reqMtx.Lock()
defer conn.reqMtx.Unlock()
cb, ok := conn.respHandlers[id]
if ok {
delete(conn.respHandlers, id)
}
return cb
}

// FetchReadSource returns the connection's read source only once.
func (conn *WsConn) FetchReadSource() <-chan *msgjson.Message {
// MessageSource returns the connection's read source only once. The returned
// chan will receive requests and notifications from the server, but not
// responses, which have handlers associated with their request.
func (conn *WsConn) MessageSource() <-chan *msgjson.Message {
var ch <-chan *msgjson.Message

conn.once.Do(func() {
Expand Down
35 changes: 18 additions & 17 deletions client/comms/wsconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ func TestWsConn(t *testing.T) {
}()

cfg := &WsCfg{
Host: host,
Path: "ws",
URL: "wss://" + host + "/ws",
PingWait: pingWait,
RpcCert: certFile.Name(),
Ctx: ctx,
Expand Down Expand Up @@ -225,13 +224,13 @@ func TestWsConn(t *testing.T) {
readPumpCh <- sent

// Fetch the read source.
readSource := wsc.FetchReadSource()
readSource := wsc.MessageSource()
if readSource == nil {
t.Fatal("expected a non-nil read source")
}

// Ensure th read source can be fetched once.
rSource := wsc.FetchReadSource()
// Ensure the read source can be fetched once.
rSource := wsc.MessageSource()
if rSource != nil {
t.Fatal("expected a nil read source")
}
Expand Down Expand Up @@ -276,7 +275,10 @@ func TestWsConn(t *testing.T) {
// Send a message from the client.
mId := wsc.NextID()
sent = makeRequest(mId, msgjson.InitRoute, init)
err = wsc.Send(sent)
handlerRun := false
err = wsc.Request(sent, func(*msgjson.Message) {
handlerRun = true
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -307,19 +309,19 @@ func TestWsConn(t *testing.T) {
t.Fatalf("expected next id to be %d, got %d", 2, next)
}

// Ensure the request sent got logged.
req, err := wsc.FetchRequest(mId)
if err != nil {
t.Fatalf("unexpected error: %v", err)
// Ensure the request got logged.
hndlr := wsc.respHandler(mId)
if hndlr == nil {
t.Fatalf("no handler found")
}

if req.Route != sent.Route {
t.Fatalf("expected %s route, got %s", sent.Route, req.Route)
hndlr.f(nil)
if !handlerRun {
t.Fatalf("wrong handler retrieved")
}

// Lookup an unlogged request id.
_, err = wsc.FetchRequest(next)
if err == nil {
hndlr = wsc.respHandler(next)
if hndlr != nil {
t.Fatal("expected an error for unlogged id")
}

Expand Down Expand Up @@ -375,8 +377,7 @@ func TestFailingConnection(t *testing.T) {

host := "127.0.0.1:6060"
cfg := &WsCfg{
Host: host,
Path: "ws",
URL: "wss://" + host + "/ws",
PingWait: pingWait,
RpcCert: certFile.Name(),
Ctx: ctx,
Expand Down
Loading