Skip to content

Commit

Permalink
Merge pull request #28 from rancher/testing
Browse files Browse the repository at this point in the history
Merge testing
  • Loading branch information
ibuildthecloud authored Mar 18, 2021
2 parents 6838081 + 8b1b7bb commit d13c0bd
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 248 deletions.
14 changes: 7 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type ConnectAuthorizer func(proto, address string) bool

// ClientConnect connect to WS and wait 5 seconds when error
func ClientConnect(ctx context.Context, wsURL string, headers http.Header, dialer *websocket.Dialer,
auth ConnectAuthorizer, onConnect func(context.Context) error) error {
auth ConnectAuthorizer, onConnect func(context.Context, *Session) error) error {
if err := ConnectToProxy(ctx, wsURL, headers, auth, dialer, onConnect); err != nil {
logrus.WithError(err).Error("Remotedialer proxy error")
time.Sleep(time.Duration(5) * time.Second)
Expand All @@ -25,13 +25,13 @@ func ClientConnect(ctx context.Context, wsURL string, headers http.Header, diale
}

// ConnectToProxy connect to websocket server
func ConnectToProxy(rootCtx context.Context, proxyURL string, headers http.Header, auth ConnectAuthorizer, dialer *websocket.Dialer, onConnect func(context.Context) error) error {
func ConnectToProxy(rootCtx context.Context, proxyURL string, headers http.Header, auth ConnectAuthorizer, dialer *websocket.Dialer, onConnect func(context.Context, *Session) error) error {
logrus.WithField("url", proxyURL).Info("Connecting to proxy")

if dialer == nil {
dialer = &websocket.Dialer{Proxy: http.ProxyFromEnvironment, HandshakeTimeout: HandshakeTimeOut}
}
ws, resp, err := dialer.Dial(proxyURL, headers)
ws, resp, err := dialer.DialContext(rootCtx, proxyURL, headers)
if err != nil {
if resp == nil {
logrus.WithError(err).Errorf("Failed to connect to proxy. Empty dialer response")
Expand All @@ -52,17 +52,17 @@ func ConnectToProxy(rootCtx context.Context, proxyURL string, headers http.Heade
ctx, cancel := context.WithCancel(rootCtx)
defer cancel()

session := NewClientSession(auth, ws)
defer session.Close()

if onConnect != nil {
go func() {
if err := onConnect(ctx); err != nil {
if err := onConnect(ctx, session); err != nil {
result <- err
}
}()
}

session := NewClientSession(auth, ws)
defer session.Close()

go func() {
_, err = session.Serve(ctx)
result <- err
Expand Down
10 changes: 7 additions & 3 deletions client_dialer.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
package remotedialer

import (
"context"
"io"
"net"
"sync"
"time"
)

func clientDial(dialer Dialer, conn *connection, message *message) {
func clientDial(ctx context.Context, dialer Dialer, conn *connection, message *message) {
defer conn.Close()

var (
netConn net.Conn
err error
)

ctx, cancel := context.WithDeadline(ctx, time.Now().Add(time.Minute))
if dialer == nil {
netConn, err = net.DialTimeout(message.proto, message.address, time.Duration(message.deadline)*time.Millisecond)
d := net.Dialer{}
netConn, err = d.DialContext(ctx, message.proto, message.address)
} else {
netConn, err = dialer(message.proto, message.address)
netConn, err = dialer(ctx, message.proto, message.address)
}
cancel()

if err != nil {
conn.tunnelClose(err)
Expand Down
118 changes: 12 additions & 106 deletions connection.go
Original file line number Diff line number Diff line change
@@ -1,56 +1,31 @@
package remotedialer

import (
"context"
"errors"
"io"
"net"
"os"
"strconv"
"sync"
"time"

"github.com/rancher/remotedialer/metrics"
)

var (
backupTimeout = 15 * time.Second
)

func init() {
t := os.Getenv("REMOTEDIALER_BACKUP_TIMEOUT_SECONDS")
if t != "" {
i, err := strconv.Atoi(t)
if err != nil {
panic("invalid number " + t + " for REMOTEDIALER_BACKUP_TIMEOUT_SECONDS")
}
backupTimeout = time.Duration(i) * time.Second
}
}

type connection struct {
sync.Mutex

ctx context.Context
cancel func()
err error
writeDeadline time.Time
buf chan []byte
readBuf []byte
buffer *readBuffer
addr addr
session *Session
connID int64
}

func newConnection(connID int64, session *Session, proto, address string) *connection {
c := &connection{
buffer: newReadBuffer(),
addr: addr{
proto: proto,
address: address,
},
connID: connID,
session: session,
buf: make(chan []byte, 1024),
}
metrics.IncSMTotalAddConnectionsForWS(session.clientKey, proto, address)
return c
Expand All @@ -63,9 +38,6 @@ func (c *connection) tunnelClose(err error) {
}

func (c *connection) doTunnelClose(err error) {
c.Lock()
defer c.Unlock()

if c.err != nil {
return
}
Expand All @@ -75,74 +47,38 @@ func (c *connection) doTunnelClose(err error) {
c.err = io.ErrClosedPipe
}

close(c.buf)
c.buffer.Close(c.err)
}

func (c *connection) tunnelWriter() io.Writer {
return chanWriter{conn: c, C: c.buf}
func (c *connection) OnData(m *message) error {
return c.buffer.Offer(m.body)
}

func (c *connection) Close() error {
c.session.closeConnection(c.connID, io.EOF)
return nil
}

func (c *connection) copyData(b []byte) int {
n := copy(b, c.readBuf)
c.readBuf = c.readBuf[n:]
return n
}

func (c *connection) Read(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}

n := c.copyData(b)
if n > 0 {
metrics.AddSMTotalReceiveBytesOnWS(c.session.clientKey, float64(n))
return n, nil
}

next, ok := <-c.buf
if !ok {
err := io.EOF
c.Lock()
if c.err != nil {
err = c.err
}
c.Unlock()
return 0, err
}

c.readBuf = next
n = c.copyData(b)
n, err := c.buffer.Read(b)
metrics.AddSMTotalReceiveBytesOnWS(c.session.clientKey, float64(n))
return n, nil
return n, err
}

func (c *connection) Write(b []byte) (int, error) {
c.Lock()
if c.err != nil {
defer c.Unlock()
return 0, c.err
}
c.Unlock()

deadline := int64(0)
if !c.writeDeadline.IsZero() {
deadline = c.writeDeadline.Sub(time.Now()).Nanoseconds() / 1000000
return 0, io.ErrClosedPipe
}
msg := newMessage(c.connID, deadline, b)
msg := newMessage(c.connID, b)
metrics.AddSMTotalTransmitBytesOnWS(c.session.clientKey, float64(len(msg.Bytes())))
return c.session.writeMessage(msg)
return c.session.writeMessage(c.writeDeadline, msg)
}

func (c *connection) writeErr(err error) {
if err != nil {
msg := newErrorMessage(c.connID, err)
metrics.AddSMTotalTransmitErrorBytesOnWS(c.session.clientKey, float64(len(msg.Bytes())))
c.session.writeMessage(msg)
c.session.writeMessage(c.writeDeadline, msg)
}
}

Expand All @@ -162,6 +98,7 @@ func (c *connection) SetDeadline(t time.Time) error {
}

func (c *connection) SetReadDeadline(t time.Time) error {
c.buffer.deadline = t
return nil
}

Expand All @@ -182,34 +119,3 @@ func (a addr) Network() string {
func (a addr) String() string {
return a.address
}

type chanWriter struct {
conn *connection
C chan []byte
}

func (c chanWriter) Write(buf []byte) (int, error) {
c.conn.Lock()
defer c.conn.Unlock()

if c.conn.err != nil {
return 0, c.conn.err
}

newBuf := make([]byte, len(buf))
copy(newBuf, buf)
buf = newBuf

select {
// must copy the buffer
case c.C <- buf:
return len(buf), nil
default:
select {
case c.C <- buf:
return len(buf), nil
case <-time.After(backupTimeout):
return 0, errors.New("backed up reader")
}
}
}
24 changes: 10 additions & 14 deletions dialer.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
package remotedialer

import (
"context"
"net"
"time"
)

type Dialer func(network, address string) (net.Conn, error)
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)

func (s *Server) HasSession(clientKey string) bool {
_, err := s.sessions.getDialer(clientKey, 0)
_, err := s.sessions.getDialer(clientKey)
return err == nil
}

func (s *Server) Dial(clientKey string, deadline time.Duration, proto, address string) (net.Conn, error) {
d, err := s.sessions.getDialer(clientKey, deadline)
if err != nil {
return nil, err
}

return d(proto, address)
}
func (s *Server) Dialer(clientKey string) Dialer {
return func(ctx context.Context, network, address string) (net.Conn, error) {
d, err := s.sessions.getDialer(clientKey)
if err != nil {
return nil, err
}

func (s *Server) Dialer(clientKey string, deadline time.Duration) Dialer {
return func(proto, address string) (net.Conn, error) {
return s.Dial(clientKey, deadline, proto, address)
return d(ctx, network, address)
}
}
Loading

0 comments on commit d13c0bd

Please sign in to comment.