From 048f48ab75c7e8b7014ea248dec462a7d6c095d0 Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Fri, 9 Dec 2016 10:23:40 -0800 Subject: [PATCH] Allow timeout in connections, retries to be configurable --- smtp.go | 47 ++++++++++++++++++++++++++++++++++++++--------- smtp_test.go | 40 +++++++++++++++++++++++++++++++--------- 2 files changed, 69 insertions(+), 18 deletions(-) diff --git a/smtp.go b/smtp.go index 2aa49c8..aea7efd 100644 --- a/smtp.go +++ b/smtp.go @@ -33,17 +33,24 @@ type Dialer struct { // LocalName is the hostname sent to the SMTP server with the HELO command. // By default, "localhost" is sent. LocalName string + // Timeout to use for read/write operations. Defaults to 10 seconds, can + // be set to 0 to disable timeouts. + Timeout time.Duration + // Whether we should retry mailing if the connection returned an error. + RetryFailure bool } // NewDialer returns a new SMTP Dialer. The given parameters are used to connect // to the SMTP server. func NewDialer(host string, port int, username, password string) *Dialer { return &Dialer{ - Host: host, - Port: port, - Username: username, - Password: password, - SSL: port == 465, + Host: host, + Port: port, + Username: username, + Password: password, + SSL: port == 465, + Timeout: 10 * time.Second, + RetryFailure: true, } } @@ -58,7 +65,7 @@ func NewPlainDialer(host string, port int, username, password string) *Dialer { // Dial dials and authenticates to an SMTP server. The returned SendCloser // should be closed when done using it. func (d *Dialer) Dial() (SendCloser, error) { - conn, err := netDialTimeout("tcp", addr(d.Host, d.Port), 10*time.Second) + conn, err := netDialTimeout("tcp", addr(d.Host, d.Port), d.Timeout) if err != nil { return nil, err } @@ -72,6 +79,10 @@ func (d *Dialer) Dial() (SendCloser, error) { return nil, err } + if d.Timeout > 0 { + conn.SetDeadline(time.Now().Add(d.Timeout)) + } + if d.LocalName != "" { if err := c.Hello(d.LocalName); err != nil { return nil, err @@ -111,7 +122,7 @@ func (d *Dialer) Dial() (SendCloser, error) { } } - return &smtpSender{c, d}, nil + return &smtpSender{c, conn, d}, nil } func (d *Dialer) tlsConfig() *tls.Config { @@ -139,12 +150,29 @@ func (d *Dialer) DialAndSend(m ...*Message) error { type smtpSender struct { smtpClient - d *Dialer + conn net.Conn + d *Dialer +} + +func (c *smtpSender) retryError(err error) bool { + if !c.d.RetryFailure { + return false + } + + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + return true + } + + return err == io.EOF } func (c *smtpSender) Send(from string, to []string, msg io.WriterTo) error { + if c.d.Timeout > 0 { + c.conn.SetDeadline(time.Now().Add(c.d.Timeout)) + } + if err := c.Mail(from); err != nil { - if err == io.EOF { + if c.retryError(err) { // This is probably due to a timeout, so reconnect and try again. sc, derr := c.d.Dial() if derr == nil { @@ -154,6 +182,7 @@ func (c *smtpSender) Send(from string, to []string, msg io.WriterTo) error { } } } + return err } diff --git a/smtp_test.go b/smtp_test.go index b6f9155..a611164 100644 --- a/smtp_test.go +++ b/smtp_test.go @@ -18,7 +18,7 @@ const ( var ( testConn = &net.TCPConn{} - testTLSConn = &tls.Conn{} + testTLSConn = tls.Client(testConn, &tls.Config{InsecureSkipVerify: true}) testConfig = &tls.Config{InsecureSkipVerify: true} testAuth = smtp.PlainAuth("", testUser, testPwd, testHost) ) @@ -118,8 +118,9 @@ func TestDialerNoAuth(t *testing.T) { func TestDialerTimeout(t *testing.T) { d := &Dialer{ - Host: testHost, - Port: testPort, + Host: testHost, + Port: testPort, + RetryFailure: true, } testSendMailTimeout(t, d, []string{ "Extension STARTTLS", @@ -138,6 +139,25 @@ func TestDialerTimeout(t *testing.T) { }) } +func TestDialerTimeoutNoRetry(t *testing.T) { + d := &Dialer{ + Host: testHost, + Port: testPort, + RetryFailure: false, + } + + err := doTestSendMail(t, d, []string{ + "Extension STARTTLS", + "StartTLS", + "Mail " + testFrom, + "Quit", + }, true) + + if err.Error() != "gomail: could not send email 1: EOF" { + t.Error("expected to have got EOF, but got:", err) + } +} + type mockClient struct { t *testing.T i int @@ -232,14 +252,18 @@ func (w *mockWriter) Close() error { } func testSendMail(t *testing.T, d *Dialer, want []string) { - doTestSendMail(t, d, want, false) + if err := doTestSendMail(t, d, want, false); err != nil { + t.Error(err) + } } func testSendMailTimeout(t *testing.T, d *Dialer, want []string) { - doTestSendMail(t, d, want, true) + if err := doTestSendMail(t, d, want, true); err != nil { + t.Error(err) + } } -func doTestSendMail(t *testing.T, d *Dialer, want []string, timeout bool) { +func doTestSendMail(t *testing.T, d *Dialer, want []string, timeout bool) error { testClient := &mockClient{ t: t, want: want, @@ -274,9 +298,7 @@ func doTestSendMail(t *testing.T, d *Dialer, want []string, timeout bool) { return testClient, nil } - if err := d.DialAndSend(getTestMessage()); err != nil { - t.Error(err) - } + return d.DialAndSend(getTestMessage()) } func assertConfig(t *testing.T, got, want *tls.Config) {