Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FMS-1249 #33

Merged
merged 2 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ jobs:
restore-keys: |
${{ runner.os }}-go
- name: Run Test Suite
run: go test -p=1 ./...
run: go test -p=1 -race ./...
52 changes: 46 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ func DialWithContext(ctx context.Context, target string, opts ...DialOption) (*C
if err := addrConn.connect(); err != nil {
return nil, fmt.Errorf("error connecting: %w", err)
}
cc.mu.Lock()
cc.conn = addrConn
cc.mu.Unlock()

if cc.dopts.block {
for {
Expand Down Expand Up @@ -207,9 +209,27 @@ func (cc *ClientConn) listenForRead() {
// handleRead listens to the transport read channel and passes the message to the
// readFn handler.
func (cc *ClientConn) handleRead(done <-chan struct{}) {
var tr transport.ClientTransport
var conn *addrConn

cc.mu.RLock()
conn = cc.conn

// if connection has been closed, then conn can be nil
if conn == nil {
cc.mu.RUnlock()

return
}

conn.mu.RLock()
tr = cc.conn.transport
conn.mu.RUnlock()
cc.mu.RUnlock()

for {
select {
case in := <-cc.conn.transport.Read():
case in := <-tr.Read():
// Unmarshal the message
msg := &message.Message{}
if err := UnmarshalProtoMessage(in, msg); err != nil {
Expand Down Expand Up @@ -253,7 +273,14 @@ func (cc *ClientConn) handleMessageRequest(r *message.Request) {
return
}

if err := cc.conn.transport.Write(replyMsg); err != nil {
var tr transport.ClientTransport
cc.mu.RLock()
cc.conn.mu.RLock()
tr = cc.conn.transport
cc.conn.mu.RUnlock()
cc.mu.RUnlock()

if err := tr.Write(replyMsg); err != nil {
cc.dopts.logger.Errorf("error writing to transport: %s", err)
}
}
Expand Down Expand Up @@ -296,9 +323,9 @@ func (cc *ClientConn) register(sd *ServiceDesc, ss interface{}) {

// Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() {
conn := cc.conn

cc.mu.Lock()
conn := cc.conn
cc.conn = nil
cc.mu.Unlock()

Expand All @@ -309,7 +336,13 @@ func (cc *ClientConn) Close() {
// received.
func (cc *ClientConn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}) error {
// Ensure the connection state is ready
if cc.conn.state != connectivity.Ready {
cc.mu.RLock()
cc.conn.mu.RLock()
state := cc.conn.state
cc.conn.mu.RUnlock()
cc.mu.RUnlock()

if state != connectivity.Ready {
return errors.New("connection is not ready")
}

Expand All @@ -336,7 +369,14 @@ func (cc *ClientConn) Invoke(ctx context.Context, method string, args interface{
cc.mu.Unlock()
}()

if err := cc.conn.transport.Write(reqB); err != nil {
var tr transport.ClientTransport
cc.mu.RLock()
cc.conn.mu.RLock()
tr = cc.conn.transport
cc.conn.mu.RUnlock()
cc.mu.RUnlock()
Comment on lines +373 to +377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use getters to simplify this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. that is the clean up we should do. the first priority was to make it work since it's blocking many merges in the chainlink repo


if err := tr.Write(reqB); err != nil {
return err
}

Expand Down Expand Up @@ -398,7 +438,7 @@ type addrConn struct {
// after transport is closed, ac has been torn down).
transport transport.ClientTransport // The current transport.

mu sync.Mutex
mu sync.RWMutex

// Use updateConnectivityState for updating addrConn's connectivity state.
state connectivity.State
Expand Down
28 changes: 23 additions & 5 deletions credentials/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"math/big"
"sync"
)

type StaticSizedPublicKey [ed25519.PublicKeySize]byte
Expand Down Expand Up @@ -82,7 +83,20 @@ func newMinimalX509Cert(priv ed25519.PrivateKey) (tls.Certificate, error) {
}

// PublicKeys wraps a slice of keys so we can update the keys dynamically.
type PublicKeys []ed25519.PublicKey
type PublicKeys struct {
mu sync.RWMutex
keys []ed25519.PublicKey
}

func NewPublicKeys(keys ...ed25519.PublicKey) *PublicKeys {
return &PublicKeys{
keys: keys,
}
}

func (r *PublicKeys) Keys() []ed25519.PublicKey {
return r.keys
}

// Verifies that the certificate's public key matches with one of the keys in
// our list of registered keys.
Expand All @@ -100,7 +114,7 @@ func (r *PublicKeys) VerifyPeerCertificate() func(rawCerts [][]byte, verifiedCha
return err
}

ok := isValidPublicKey(*r, pk)
ok := r.isValidPublicKey(pk)
if !ok {
return fmt.Errorf("unknown public key on cert %x", pk)
}
Expand All @@ -112,12 +126,16 @@ func (r *PublicKeys) VerifyPeerCertificate() func(rawCerts [][]byte, verifiedCha
// Replace replaces the existing keys with new keys. Use this to dynamically
// update the allowable keys at runtime.
func (r *PublicKeys) Replace(pubs []ed25519.PublicKey) {
*r = PublicKeys(pubs)
r.mu.Lock()
defer r.mu.Unlock()
r.keys = pubs
}

// isValidPublicKey checks the public key against a list of valid keys.
func isValidPublicKey(valid []ed25519.PublicKey, pub ed25519.PublicKey) bool {
for _, vpub := range valid {
func (r *PublicKeys) isValidPublicKey(pub ed25519.PublicKey) bool {
r.mu.RLock()
defer r.mu.RUnlock()
for _, vpub := range r.keys {
if pub.Equal(vpub) {
return true
}
Expand Down
4 changes: 2 additions & 2 deletions credentials/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func Test_NewClientTLSConfig(t *testing.T) {
spub, spriv, err := ed25519.GenerateKey(nil)
require.NoError(t, err)

tlsCfg, err := NewClientTLSConfig(cpriv, &PublicKeys{spub})
tlsCfg, err := NewClientTLSConfig(cpriv, NewPublicKeys(spub))
require.NoError(t, err)
require.Len(t, tlsCfg.Certificates, 1)

Expand Down Expand Up @@ -53,7 +53,7 @@ func Test_NewServerTLSConfig(t *testing.T) {
cpub, cpriv, err := ed25519.GenerateKey(nil)
require.NoError(t, err)

tlsCfg, err := NewServerTLSConfig(spriv, &PublicKeys{cpub})
tlsCfg, err := NewServerTLSConfig(spriv, NewPublicKeys(cpub))
require.NoError(t, err)
require.Len(t, tlsCfg.Certificates, 1)

Expand Down
6 changes: 3 additions & 3 deletions dialoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
// level security credentials (e.g., TLS/SSL).
func WithTransportCreds(privKey ed25519.PrivateKey, serverPubKey ed25519.PublicKey) DialOption {
return newFuncDialOption(func(o *dialOptions) {
pubs := credentials.PublicKeys{serverPubKey}
pubs := credentials.NewPublicKeys(serverPubKey)

// Generate the TLS config for the client
config, err := credentials.NewClientTLSConfig(privKey, &pubs)
config, err := credentials.NewClientTLSConfig(privKey, pubs)
if err != nil {
log.Println(err)

return
}

o.copts.TransportCredentials = credentials.NewTLS(config, &pubs)
o.copts.TransportCredentials = credentials.NewTLS(config, pubs)
})
}

Expand Down
13 changes: 6 additions & 7 deletions intgtest/uni_client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,12 @@ func Test_ClientServer_ConcurrentCalls(t *testing.T) {
respCh := make(chan *pb.EchoResponse)
defer close(respCh)

processEchos(t, c,
[]*echoReq{
{message: &pb.EchoRequest{Body: "call1", DelayMs: 500}},
{message: &pb.EchoRequest{Body: "call2"}, timeout: 200 * time.Millisecond},
},
respCh,
)
reqs := []echoReq{
{message: &pb.EchoRequest{Body: "call1", DelayMs: 500}},
{message: &pb.EchoRequest{Body: "call2"}, timeout: 200 * time.Millisecond},
}

processEchos(t, c, reqs, respCh)

actual := waitForResponses(t, respCh, 2)

Expand Down
13 changes: 6 additions & 7 deletions intgtest/uni_server_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,12 @@ func Test_ServerClient_ConcurrentCalls(t *testing.T) {
defer close(respCh)

pk := keypairs.Client1.StaticallySizedPublicKey(t)
processEchos(t, c,
[]*echoReq{
{message: &pb.EchoRequest{Body: "call1", DelayMs: 500}, pubKey: &pk},
{message: &pb.EchoRequest{Body: "call2"}, timeout: 200 * time.Millisecond, pubKey: &pk},
},
respCh,
)
reqs := []echoReq{
{message: &pb.EchoRequest{Body: "call1", DelayMs: 500}, pubKey: &pk},
{message: &pb.EchoRequest{Body: "call2"}, timeout: 200 * time.Millisecond, pubKey: &pk},
}

processEchos(t, c, reqs, respCh)

actual := waitForResponses(t, respCh, 2)

Expand Down
6 changes: 3 additions & 3 deletions intgtest/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ type echoReq struct {

func processEchos(t *testing.T,
c pb.EchoClient,
reqs []*echoReq,
reqs []echoReq,
ch chan<- *pb.EchoResponse,
) {
t.Helper()

wg := sync.WaitGroup{}
for _, req := range reqs {
wg.Add(1)
go func() {
go func(req echoReq) {
wg.Done()

ctx := context.Background()
Expand All @@ -143,7 +143,7 @@ func processEchos(t *testing.T,
require.NoError(t, err)

ch <- resp
}()
}(req)

wg.Wait()
}
Expand Down
29 changes: 22 additions & 7 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ func (s *Server) wshandler(w http.ResponseWriter, r *http.Request) {
onClose := func() {
// There is no connection manager when we are shutting down, so
// we can ignore removing the connection.
s.mu.RLock()
defer s.mu.RUnlock()
if s.connMgr != nil {
s.connMgr.mu.Lock()
s.connMgr.removeConnection(pubKey)
Expand All @@ -157,7 +159,9 @@ func (s *Server) wshandler(w http.ResponseWriter, r *http.Request) {
}

// Register the transport against the public key
s.mu.RLock()
s.connMgr.registerConnection(pubKey, tr)
s.mu.RUnlock()

s.serveWG.Add(1)

Expand All @@ -175,7 +179,9 @@ func (s *Server) wshandler(w http.ResponseWriter, r *http.Request) {
// sendMsg writes the message to the connection which matches the public key.
func (s *Server) sendMsg(pub [32]byte, msg []byte) error {
// Find the transport matching the public key
s.mu.RLock()
tr, err := s.connMgr.getTransport(pub)
s.mu.RUnlock()
if err != nil {
return err
}
Expand All @@ -186,7 +192,9 @@ func (s *Server) sendMsg(pub [32]byte, msg []byte) error {
// handleRead listens to the transport read channel and passes the message to the
// readFn handler.
func (s *Server) handleRead(pubKey credentials.StaticSizedPublicKey, done <-chan struct{}) {
s.mu.RLock()
tr, err := s.connMgr.getTransport(pubKey)
s.mu.RUnlock()
if err != nil {
return
}
Expand Down Expand Up @@ -340,7 +348,7 @@ func (s *Server) UpdatePublicKeys(pubKeys []ed25519.PublicKey) {
s.mu.Lock()
defer s.mu.Unlock()

s.opts.creds.PublicKeys.Replace(pubKeys)
s.opts.creds.PublicKeys.Replace(pubKeys) //credentials.NewPublicKeys(pubKeys...)
s.removeConnectionsToDeletedKeys(pubKeys)
}

Expand All @@ -349,6 +357,8 @@ func (s *Server) UpdatePublicKeys(pubKeys []ed25519.PublicKey) {
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func (s *Server) GetConnectionNotifyChan() <-chan struct{} {
s.mu.RLock()
defer s.mu.Unlock()
return s.connMgr.getNotifyChan()
}

Expand All @@ -358,6 +368,8 @@ func (s *Server) GetConnectionNotifyChan() <-chan struct{} {
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func (s *Server) GetConnectedPeerPublicKeys() []credentials.StaticSizedPublicKey {
s.mu.RLock()
defer s.mu.Unlock()
return s.connMgr.getConnectionPublicKeys()
}

Expand Down Expand Up @@ -413,8 +425,9 @@ func (s *Server) ensureSingleClientConnection(cert *x509.Certificate) ([ed25519.
if err != nil {
return pubKey, errors.New("could not extracting public key from certificate")
}

s.mu.RLock()
_, err = s.connMgr.getTransport(pubKey)
s.mu.RUnlock()
if err == nil {
return pubKey, errors.New("only one connection allowed per client")
}
Expand All @@ -441,7 +454,7 @@ func (s *Server) removeMethodCall(id string) {

// connectionsManager manages the active clients connections.
type connectionsManager struct {
mu sync.Mutex
mu sync.RWMutex
// Holds a list of the open connections mapped to a buffered channel of
// outbound messages.
conns map[credentials.StaticSizedPublicKey]transport.ServerTransport
Expand All @@ -457,8 +470,8 @@ func newConnectionsManager() *connectionsManager {

// getTransport fetches the transport which matches the public key.
func (cm *connectionsManager) getTransport(key credentials.StaticSizedPublicKey) (transport.ServerTransport, error) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.mu.RLock()
defer cm.mu.RUnlock()

tr, ok := cm.conns[key]
if !ok {
Expand Down Expand Up @@ -496,8 +509,8 @@ func (cm *connectionsManager) removeConnection(key credentials.StaticSizedPublic

// getConnectionPublicKeys gets the public keys of the active connections.
func (cm *connectionsManager) getConnectionPublicKeys() []credentials.StaticSizedPublicKey {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.mu.RLock()
defer cm.mu.RUnlock()

keys := []credentials.StaticSizedPublicKey{}
for k := range cm.conns {
Expand All @@ -521,6 +534,8 @@ func (cm *connectionsManager) getNotifyChan() <-chan struct{} {

// close closes all registered connections.
func (cm *connectionsManager) close() {
cm.mu.RLock()
defer cm.mu.RUnlock()
for _, conn := range cm.conns {
conn.Close()
}
Expand Down
Loading