Skip to content

Commit

Permalink
Merge pull request #48 from pingcap/qiuyesuifeng/node-conns
Browse files Browse the repository at this point in the history
add keep alive connection support for sending command to raft server
  • Loading branch information
qiuyesuifeng committed Mar 30, 2016
2 parents f9b4217 + 40a047a commit cf3894e
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 42 deletions.
8 changes: 8 additions & 0 deletions server/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type raftCluster struct {
// store cache
stores map[uint64]metapb.Store
}

// for node conns
nodeConns *nodeConns
}

func (s *Server) newCluster(clusterID uint64, meta metapb.Cluster) (*raftCluster, error) {
Expand All @@ -76,8 +79,11 @@ func (s *Server) newCluster(clusterID uint64, meta metapb.Cluster) (*raftCluster
clusterRoot: s.getClusterRootPath(clusterID),
askJobCh: make(chan struct{}, askJobChannelSize),
quitCh: make(chan struct{}),
nodeConns: newNodeConns(defaultConnFunc),
}

c.nodeConns.SetIdleTimeout(idleTimeout)

// Force checking the pending job.
c.askJobCh <- struct{}{}

Expand All @@ -104,6 +110,8 @@ func (s *Server) newCluster(clusterID uint64, meta metapb.Cluster) (*raftCluster
func (c *raftCluster) Close() {
close(c.quitCh)
c.wg.Wait()

c.nodeConns.Close()
}

func (s *Server) getClusterRootPath(clusterID uint64) string {
Expand Down
17 changes: 7 additions & 10 deletions server/cluster_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package server
import (
"math"
"math/rand"
"net"
"sync/atomic"
"time"

Expand All @@ -25,9 +24,8 @@ import (
const (
checkJobInterval = 10 * time.Second

connectTimeout = 3 * time.Second
readTimeout = 3 * time.Second
writeTimeout = 3 * time.Second
readTimeout = 3 * time.Second
writeTimeout = 3 * time.Second

maxSendRetry = 10
)
Expand Down Expand Up @@ -667,26 +665,25 @@ func (c *raftCluster) callCommand(request *raft_cmdpb.RaftCmdRequest) (*raft_cmd
return nil, errors.Trace(err)
}

// Connect the node.
// TODO: use connection pool
conn, err := net.DialTimeout("tcp", node.GetAddress(), connectTimeout)
nc, err := c.nodeConns.GetConn(node.GetAddress())
if err != nil {
return nil, errors.Trace(err)
}
defer conn.Close()

msg := &raft_serverpb.Message{
MsgType: raft_serverpb.MessageType_Cmd.Enum(),
CmdReq: request,
}

msgID := atomic.AddUint64(&c.s.msgID, 1)
if err = util.WriteMessage(conn, msgID, msg); err != nil {
if err = util.WriteMessage(nc.conn, msgID, msg); err != nil {
c.nodeConns.RemoveConn(node.GetAddress())
return nil, errors.Trace(err)
}

msg.Reset()
if _, err = util.ReadMessage(conn, msg); err != nil {
if _, err = util.ReadMessage(nc.conn, msg); err != nil {
c.nodeConns.RemoveConn(node.GetAddress())
return nil, errors.Trace(err)
}

Expand Down
62 changes: 36 additions & 26 deletions server/cluster_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,32 +250,42 @@ func (n *mockRaftNode) runCmd(c *C) {
return
}

msg := &raft_serverpb.Message{}
msgID, err := util.ReadMessage(conn, msg)
c.Assert(err, IsNil)

req := msg.GetCmdReq()
c.Assert(req, NotNil)

resp := n.proposeCommand(c, req)
if resp.Header == nil {
resp.Header = &raft_cmdpb.RaftResponseHeader{}
}
resp.Header.Uuid = req.Header.Uuid

respMsg := &raft_serverpb.Message{
MsgType: raft_serverpb.MessageType_CmdResp.Enum(),
CmdResp: resp,
}

if rand.Intn(2) == 1 && resp.StatusResponse == nil {
// randomly close the connection to force
// cluster work retry
conn.Close()
} else {
err = util.WriteMessage(conn, msgID, respMsg)
c.Assert(err, IsNil)
}
go func() {
for {
msg := &raft_serverpb.Message{}
msgID, err := util.ReadMessage(conn, msg)
if err != nil {
c.Log(err)
return
}

req := msg.GetCmdReq()
c.Assert(req, NotNil)

resp := n.proposeCommand(c, req)
if resp.Header == nil {
resp.Header = &raft_cmdpb.RaftResponseHeader{}
}
resp.Header.Uuid = req.Header.Uuid

respMsg := &raft_serverpb.Message{
MsgType: raft_serverpb.MessageType_CmdResp.Enum(),
CmdResp: resp,
}

if rand.Intn(2) == 1 && resp.StatusResponse == nil {
// randomly close the connection to force
// cluster work retry
conn.Close()
return
}

err = util.WriteMessage(conn, msgID, respMsg)
if err != nil {
c.Log(err)
}
}
}()
}
}

Expand Down
4 changes: 2 additions & 2 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ func updateResponse(req *pdpb.Request, resp *pdpb.Response) {
resp.Header.ClusterId = req.Header.ClusterId
}

func (c *conn) Close() {
c.conn.Close()
func (c *conn) Close() error {
return errors.Trace(c.conn.Close())
}

func (c *conn) handleRequest(req *pdpb.Request) (*pdpb.Response, error) {
Expand Down
3 changes: 1 addition & 2 deletions server/leader.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ import (
"sync/atomic"
"time"

"golang.org/x/net/context"

"github.com/coreos/etcd/clientv3"
storagepb "github.com/coreos/etcd/storage/storagepb"
"github.com/golang/protobuf/proto"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/kvproto/pkg/pdpb"
"golang.org/x/net/context"
)

// IsLeader returns whether server is leader or not.
Expand Down
127 changes: 127 additions & 0 deletions server/node_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package server

import (
"net"
"sync"
"time"

"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/ngaut/sync2"
)

const (
connectTimeout = 3 * time.Second
idleTimeout = 30 * time.Second
)

type nodeConn struct {
conn net.Conn
touchedTime time.Time
}

func (nc *nodeConn) close() error {
return errors.Trace(nc.conn.Close())
}

func newNodeConn(addr string) (*nodeConn, error) {
conn, err := net.DialTimeout("tcp", addr, connectTimeout)
if err != nil {
return nil, errors.Trace(err)
}

return &nodeConn{
conn: conn,
touchedTime: time.Now()}, nil
}

type createConnFunc func(addr string) (*nodeConn, error)

var defaultConnFunc = newNodeConn

type nodeConns struct {
m sync.Mutex
conns map[string]*nodeConn
idleTimeout sync2.AtomicDuration
f createConnFunc
}

// newNodeConns creates a new node conns.
func newNodeConns(f createConnFunc) *nodeConns {
ncs := new(nodeConns)
ncs.f = f
ncs.conns = make(map[string]*nodeConn)
return ncs
}

// This function is not thread-safed.
func (ncs *nodeConns) createNewConn(addr string) (*nodeConn, error) {
conn, err := ncs.f(addr)
if err != nil {
return nil, errors.Trace(err)
}

ncs.conns[addr] = conn
return conn, nil
}

// SetIdleTimeout sets idleTimeout of each conn.
func (ncs *nodeConns) SetIdleTimeout(idleTimeout time.Duration) {
ncs.idleTimeout.Set(idleTimeout)
}

// GetConn gets the conn by addr.
func (ncs *nodeConns) GetConn(addr string) (*nodeConn, error) {
ncs.m.Lock()
defer ncs.m.Unlock()

conn, ok := ncs.conns[addr]
if !ok {
return ncs.createNewConn(addr)
}

timeout := ncs.idleTimeout.Get()
if timeout > 0 && conn.touchedTime.Add(timeout).Sub(time.Now()) < 0 {
err := conn.close()
if err != nil {
return nil, errors.Trace(err)
}

return ncs.createNewConn(addr)
}

conn.touchedTime = time.Now()
return conn, nil
}

// RemoveConn removes the conn by addr.
func (ncs *nodeConns) RemoveConn(addr string) {
ncs.m.Lock()
defer ncs.m.Unlock()

conn, ok := ncs.conns[addr]
if !ok {
return
}

err := conn.close()
if err != nil {
log.Warnf("close node conn failed - %v", err)
}
delete(ncs.conns, addr)
}

// Close closes the conns.
func (ncs *nodeConns) Close() {
ncs.m.Lock()
defer ncs.m.Unlock()

for _, conn := range ncs.conns {
err := conn.close()
if err != nil {
log.Warnf("close node conn failed - %v", err)
}
}

ncs.conns = map[string]*nodeConn{}
}
90 changes: 90 additions & 0 deletions server/node_conn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package server

import (
"net"
"time"

. "github.com/pingcap/check"
)

type testNodeConnSuite struct {
}

var _ = Suite(&testNodeConnSuite{})

type testConn struct {
}

func (c *testConn) Read(b []byte) (n int, err error) { return len(b), nil }
func (c *testConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *testConn) Close() error { return nil }
func (c *testConn) LocalAddr() net.Addr { return nil }
func (c *testConn) RemoteAddr() net.Addr { return nil }
func (c *testConn) SetDeadline(t time.Time) error { return nil }
func (c *testConn) SetReadDeadline(t time.Time) error { return nil }
func (c *testConn) SetWriteDeadline(t time.Time) error { return nil }

func testNodeConn(addr string) (*nodeConn, error) {
return &nodeConn{
conn: &testConn{},
touchedTime: time.Now()}, nil
}

func (s *testNodeConnSuite) TestNodeConns(c *C) {
conns := newNodeConns(testNodeConn)
c.Assert(conns.conns, HasLen, 0)

addr1 := "127.0.0.1:1"
oldConn, err := conns.GetConn(addr1)
c.Assert(err, IsNil)
c.Assert(conns.conns, HasLen, 1)
c.Assert(conns.conns, HasKey, addr1)

newConn, err := conns.GetConn(addr1)
c.Assert(err, IsNil)
c.Assert(conns.conns, HasLen, 1)
c.Assert(conns.conns, HasKey, addr1)

c.Assert(oldConn, Equals, newConn)

conns.RemoveConn(addr1)
c.Assert(conns.conns, HasLen, 0)

addr2 := "127.0.0.1:2"
conns.GetConn(addr2)
c.Assert(conns.conns, HasLen, 1)
c.Assert(conns.conns, HasKey, addr2)

conns.Close()
c.Assert(conns.conns, HasLen, 0)

// Test with idleTimeout conn.
idleTimeout := 100 * time.Millisecond
conns.SetIdleTimeout(idleTimeout)

addr3 := "127.0.0.1:3"
oldConn, err = conns.GetConn(addr3)
c.Assert(err, IsNil)
c.Assert(conns.conns, HasLen, 1)
c.Assert(conns.conns, HasKey, addr3)

time.Sleep(2 * idleTimeout)

c.Assert(conns.conns, HasLen, 1)
c.Assert(conns.conns, HasKey, addr3)

newConn, err = conns.GetConn(addr3)
c.Assert(err, IsNil)
c.Assert(conns.conns, HasLen, 1)
c.Assert(conns.conns, HasKey, addr3)

c.Assert(oldConn, Not(Equals), newConn)

addr4 := "127.0.0.1:4"
conns.GetConn(addr4)
c.Assert(conns.conns, HasLen, 2)
c.Assert(conns.conns, HasKey, addr4)

conns.Close()
c.Assert(conns.conns, HasLen, 0)
}
Loading

0 comments on commit cf3894e

Please sign in to comment.