diff --git a/README.md b/README.md index 4920a3c..627705c 100644 --- a/README.md +++ b/README.md @@ -321,6 +321,15 @@ func main() { // use MessagesAndPurge() method server.MessagesAndPurge() + // In case with flaky test environment you can wait for the specified number + // of messages to arrive or until timeout is reached use WaitForMessages() method + server.WaitForMessages(42, 1 * time.Millisecond) + + // In case with flaky test environment you can wait for the specified number + // of messages to arrive or until timeout is reached and purge it on server + // after use WaitForMessagesAndPurge() method + server.WaitForMessagesAndPurge(42, 1 * time.Millisecond) + // To stop the server use Stop() method. Please note, smtpmock uses graceful shutdown. // It means that smtpmock will end all sessions after client responses or by session // timeouts immediately. diff --git a/message.go b/message.go index 4b3cea5..a200bf3 100644 --- a/message.go +++ b/message.go @@ -178,3 +178,11 @@ func (messages *messages) purge() []Message { return copiedMessages } + +// Clears the messages slice +func (messages *messages) clear() { + messages.Lock() + defer messages.Unlock() + + messages.items = nil +} diff --git a/message_test.go b/message_test.go index 8b5da6f..df51ea5 100644 --- a/message_test.go +++ b/message_test.go @@ -267,3 +267,13 @@ func TestMessagesPurge(t *testing.T) { assert.Len(t, messages.copy(), 0) }) } + +func TestMessagesClear(t *testing.T) { + t.Run("clears messages from items slice", func(t *testing.T) { + message, messages := new(Message), new(messages) + messages.append(message) + messages.clear() + + assert.Len(t, messages.copy(), 0) + }) +} diff --git a/server.go b/server.go index a8284c8..804ea04 100644 --- a/server.go +++ b/server.go @@ -124,6 +124,12 @@ func (server *Server) Messages() []Message { return server.messages.copy() } +// WaitForMessages waits for the specified number of messages to arrive or until timeout is reached. +// Returns the messages and an error if timeout occurs before receiving expected number of messages. +func (server *Server) WaitForMessages(count int, timeout time.Duration) ([]Message, error) { + return server.fetchMessages(count, timeout, false) +} + // Public interface to get access to server messages // and at the same time removes them. // Returns slice with copy of messages @@ -131,6 +137,13 @@ func (server *Server) MessagesAndPurge() []Message { return server.messages.purge() } +// WaitForMessagesAndPurge waits for the specified number of messages to arrive or until timeout is reached. +// Returns the messages and an error if timeout occurs before receiving expected number of messages. +// At the same time removes the messages from the server. +func (server *Server) WaitForMessagesAndPurge(count int, timeout time.Duration) ([]Message, error) { + return server.fetchMessages(count, timeout, true) +} + // Thread-safe getter of server port. // Returns server.portNumber func (server *Server) PortNumber() int { @@ -139,6 +152,29 @@ func (server *Server) PortNumber() int { return server.portNumber } +// fetchMessages fetches messages with timeout from the server with or without purging. +// Returns messages and an error if timeout occurs before receiving expected number of messages. +func (server *Server) fetchMessages(count int, timeout time.Duration, withPurge bool) ([]Message, error) { + deadline := time.Now().Add(timeout) + for { + messages := server.Messages() + messageCount := len(messages) + + if messageCount >= count { + if withPurge { + server.messages.clear() + } + return messages, nil + } + + if time.Now().After(deadline) { + return messages, fmt.Errorf("timeout waiting for %d messages, got %d", count, messageCount) + } + + time.Sleep(1 * time.Millisecond) + } +} + // Thread-safe getter to check if server has been started. // Returns server.started func (server *Server) isStarted() bool { diff --git a/server_test.go b/server_test.go index fb74307..a8a0f9c 100644 --- a/server_test.go +++ b/server_test.go @@ -6,6 +6,7 @@ import ( "net" "strings" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -163,10 +164,32 @@ func TestServerMessages(t *testing.T) { assert.NotSame(t, server.messages.items, server.Messages()) server.messages.RUnlock() }) +} - t.Run("no messages after purge", func(t *testing.T) { - server := newServer(configuration) - message := new(Message) +func TestServerWaitForMessages(t *testing.T) { + timeout := 1 * time.Millisecond + + t.Run("when expected number of messages is received without timeout", func(t *testing.T) { + server, message := newServer(createConfiguration()), new(Message) + server.messages.append(message) + messages, err := server.WaitForMessages(len(server.messages.copy()), timeout) + + assert.Equal(t, []Message{*message}, messages) + assert.NoError(t, err) + }) + + t.Run("when timeout occurs before receiving expected number of messages", func(t *testing.T) { + server := newServer(createConfiguration()) + messages, err := server.WaitForMessages(1, timeout) + + assert.EqualError(t, err, fmt.Sprintf("timeout waiting for %d messages, got %d", 1, 0)) + assert.Empty(t, messages) + }) +} + +func TestServerMessagesAndPurge(t *testing.T) { + t.Run("returns empty messages after purge", func(t *testing.T) { + server, message := newServer(createConfiguration()), new(Message) server.messages.append(message) assert.NotEmpty(t, server.Messages()) @@ -175,6 +198,28 @@ func TestServerMessages(t *testing.T) { }) } +func TestServerWaitForMessagesAndPurge(t *testing.T) { + timeout := 1 * time.Millisecond + + t.Run("when expected number of messages is received without timeout", func(t *testing.T) { + server, message := newServer(createConfiguration()), new(Message) + server.messages.append(message) + messages, err := server.WaitForMessagesAndPurge(len(server.messages.copy()), timeout) + + assert.Equal(t, []Message{*message}, messages) + assert.NoError(t, err) + assert.Empty(t, server.Messages()) + }) + + t.Run("when timeout occurs before receiving expected number of messages", func(t *testing.T) { + server := newServer(createConfiguration()) + messages, err := server.WaitForMessagesAndPurge(1, timeout) + + assert.EqualError(t, err, fmt.Sprintf("timeout waiting for %d messages, got %d", 1, 0)) + assert.Empty(t, messages) + }) +} + func TestServerPortNumber(t *testing.T) { t.Run("returns server port number", func(t *testing.T) { portNumber := 2525 @@ -184,6 +229,38 @@ func TestServerPortNumber(t *testing.T) { }) } +func TestServerFetchMessages(t *testing.T) { + timeout := 1 * time.Millisecond + + t.Run("when expected number of messages is received without timeout", func(t *testing.T) { + server, message := newServer(createConfiguration()), new(Message) + server.messages.append(message) + messages, err := server.fetchMessages(len(server.messages.copy()), timeout, false) + + assert.Equal(t, []Message{*message}, messages) + assert.NoError(t, err) + assert.NotEmpty(t, server.Messages()) + }) + + t.Run("when expected number of messages is received with purging", func(t *testing.T) { + server, message := newServer(createConfiguration()), new(Message) + server.messages.append(message) + messages, err := server.fetchMessages(len(server.messages.copy()), timeout, true) + + assert.Equal(t, []Message{*message}, messages) + assert.NoError(t, err) + assert.Empty(t, server.Messages()) + }) + + t.Run("when timeout occurs before receiving expected number of messages", func(t *testing.T) { + server := newServer(createConfiguration()) + messages, err := server.fetchMessages(1, timeout, false) + + assert.EqualError(t, err, fmt.Sprintf("timeout waiting for %d messages, got %d", 1, 0)) + assert.Empty(t, messages) + }) +} + func TestServerIsStarted(t *testing.T) { t.Run("returns current server started-flag status", func(t *testing.T) { server := &Server{started: true}