Skip to content

Commit

Permalink
feat: mailfilter.transaction now only uses a temporary file to captur…
Browse files Browse the repository at this point in the history
…e the mail body when it is bigger than 200KB
  • Loading branch information
d--j committed Apr 12, 2023
1 parent 0a56da8 commit c008a2d
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 7 deletions.
108 changes: 108 additions & 0 deletions internal/body/body.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Package body implements a write-once read-multiple [io.ReadSeekCloser] that is backed by a temporary file when too much data gets written into it.
package body

import (
"bytes"
"io"
"os"
)

// New creates a new Body that switches from memory-backed storage to file-backed storage
// when more than maxMem bytes were written to it.
//
// If maxMem is less than 1 a temporary file gets always used.
func New(maxMem int) *Body {
return &Body{maxMem: maxMem}
}

// Body is an [io.ReadSeekCloser] and [io.Writer] that starts buffering all data written to it in memory
// but when more than a configured amount of bytes is written to it Body will switch to writing to a temporary file.
//
// After a call to Read or Seek no more data can be written to Body.
// Body is an [io.Seeker] so you can read it multiple times or get the size of the Body.
type Body struct {
maxMem int
buf bytes.Buffer
mem *bytes.Reader
file *os.File
reading bool
}

// Write implements the io.Writer interface.
// Write will create a temporary file on-the-fly when you write more than the configured amount of bytes.
func (b *Body) Write(p []byte) (n int, err error) {
if b.reading {
panic("cannot write after read")
}
if b.file != nil {
return b.file.Write(p)
}
n, _ = b.buf.Write(p)
if b.buf.Len() > b.maxMem {
b.file, err = os.CreateTemp("", "body-*")
if err != nil {
return
}
_, err = io.Copy(b.file, &b.buf)
b.buf.Reset()
}
return
}

func (b *Body) switchToReading() error {
if !b.reading {
b.reading = true
if b.file != nil {
if _, err := b.file.Seek(0, io.SeekStart); err != nil {
return err
}
} else {
b.mem = bytes.NewReader(b.buf.Bytes())
}
}
return nil
}

// Read implements the io.Reader interface.
// After calling Read you cannot call Write anymore.
func (b *Body) Read(p []byte) (n int, err error) {
if err := b.switchToReading(); err != nil {
return 0, err
}
if b.file != nil {

return b.file.Read(p)
}
return b.mem.Read(p)
}

// Close implements the io.Closer interface.
// If a temporary file got created it will be deleted.
func (b *Body) Close() error {
if b.file != nil {
err1 := b.file.Close()
err2 := os.Remove(b.file.Name())
if err1 != nil {
return err1
}
if os.IsNotExist(err2) {
err2 = nil
}
return err2
}
b.mem = nil
b.buf.Reset()
return nil
}

// Seek implements the io.Seeker interface.
// After calling Seek you cannot call Write anymore.
func (b *Body) Seek(offset int64, whence int) (int64, error) {
if err := b.switchToReading(); err != nil {
return 0, err
}
if b.file != nil {
return b.file.Seek(offset, whence)
}
return b.mem.Seek(offset, whence)
}
158 changes: 158 additions & 0 deletions internal/body/body_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package body

import (
"bytes"
"io"
"os"
"testing"
)

func getBody(maxMem int, data []byte) *Body {
b := New(maxMem)
_, _ = b.Write(data)
return b
}

func TestBody_Close(t *testing.T) {
fileAlreadyRemoved := getBody(2, []byte("test"))
_ = os.Remove(fileAlreadyRemoved.file.Name())
tests := []struct {
name string
body *Body
wantErr bool
}{
{"noop", getBody(10, nil), false},
{"mem", getBody(10, []byte("test")), false},
{"file", getBody(2, []byte("test")), false},
{"file-already-removed", fileAlreadyRemoved, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.body.Close(); (err != nil) != tt.wantErr {
t.Errorf("Close() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func TestBody(t *testing.T) {
t.Run("mem", func(t *testing.T) {
b := getBody(10, []byte("test"))
defer b.Close()
_, err := b.Write([]byte("test"))
if err != nil {
t.Fatal("b.Write got error", err)
}
if b.file != nil {
t.Fatal("b.file needs to be nil")
}
var buf [10]byte
n, err := b.Read(buf[:])
if err != nil {
t.Fatal("b.Read got error", err)
}
if !bytes.Equal([]byte("testtest"), buf[:n]) {
t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest"))
}
pos, err := b.Seek(0, io.SeekStart)
if err != nil {
t.Fatal("b.Seek got error", err)
}
if pos != 0 {
t.Fatal("b.Seek got pos", pos)
}
n, err = b.Read(buf[:])
if err != nil {
t.Fatal("b.Read got error", err)
}
if !bytes.Equal([]byte("testtest"), buf[:n]) {
t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest"))
}
})
t.Run("file", func(t *testing.T) {
b := getBody(2, []byte("test"))
defer func() {
if b != nil {
b.Close()
}
}()
if b.file == nil {
t.Fatal("b.file is nil")
}
_, err := b.Write([]byte("test"))
if err != nil {
t.Fatal("b.Write got error", err)
}
var buf [10]byte
n, err := b.Read(buf[:])
if err != nil {
t.Fatal("b.Read got error", err)
}
if !bytes.Equal([]byte("testtest"), buf[:n]) {
t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest"))
}
pos, err := b.Seek(0, io.SeekStart)
if err != nil {
t.Fatal("b.Seek got error", err)
}
if pos != 0 {
t.Fatal("b.Seek got pos", pos)
}
n, err = b.Read(buf[:])
if err != nil {
t.Fatal("b.Read got error", err)
}
if !bytes.Equal([]byte("testtest"), buf[:n]) {
t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest"))
}
name := b.file.Name()
err = b.Close()
b = nil
if err != nil {
t.Fatal("b.Close got error", err)
}
_, err = os.Stat(name)
if err == nil || !os.IsNotExist(err) {
t.Fatalf("got %v expected to not find file", err)
}
})
t.Run("panic on Write after Read", func(t *testing.T) {
defer func() { _ = recover() }()
b := getBody(10, []byte("test"))
var buf [10]byte
_, _ = b.Read(buf[:])
_, _ = b.Write([]byte("test"))
t.Errorf("did not panic")
})
t.Run("panic on Write after Seek", func(t *testing.T) {
defer func() { _ = recover() }()
b := getBody(10, []byte("test"))
_, _ = b.Seek(0, io.SeekEnd)
_, _ = b.Write([]byte("test"))
t.Errorf("did not panic")
})
t.Run("temp file fail", func(t *testing.T) {
tmpdir := os.Getenv("TMPDIR")
tmp := os.Getenv("TMP")
_ = os.Setenv("TMPDIR", "/this does not exist")
_ = os.Setenv("TMP", "/this does not exist")
defer func() {
_ = os.Setenv("TMPDIR", tmpdir)
_ = os.Setenv("TMP", tmp)
}()
b := getBody(6, []byte("test"))
_, err := b.Write([]byte("test"))
if err == nil {
_ = b.Close()
t.Fatal("b.Write got nil error")
}
})
t.Run("file close fail", func(t *testing.T) {
b := getBody(2, []byte("test"))
_ = b.file.Close()
err := b.Close()
if err == nil {
t.Fatal("b.Close got nil error")
}
})
}
10 changes: 3 additions & 7 deletions mailfilter/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"bytes"
"context"
"io"
"os"
"regexp"

"github.com/d--j/go-milter"
"github.com/d--j/go-milter/internal/body"
"github.com/d--j/go-milter/internal/header"
"github.com/d--j/go-milter/internal/rcptto"
"github.com/d--j/go-milter/mailfilter/addr"
Expand Down Expand Up @@ -58,7 +58,7 @@ type transaction struct {
headers *header.Header
origHeaders *header.Header
enforceHeaderOrder bool
body *os.File
body *body.Body
replacementBody io.Reader
queueId string
hasDecision bool
Expand Down Expand Up @@ -92,7 +92,6 @@ func (t *transaction) cleanup() {
t.closeReplacementBody()
if t.body != nil {
_ = t.body.Close()
_ = os.Remove(t.body.Name())
t.body = nil
}
}
Expand Down Expand Up @@ -249,10 +248,7 @@ func (t *transaction) addHeader(key string, raw []byte) {

func (t *transaction) addBodyChunk(chunk []byte) (err error) {
if t.body == nil {
t.body, err = os.CreateTemp("", "body-*")
if err != nil {
return
}
t.body = body.New(200 * 1024)
}
_, err = t.body.Write(chunk)
return
Expand Down

0 comments on commit c008a2d

Please sign in to comment.