From c008a2d440bf51ba73b4eb14ea6a2f74a723413d Mon Sep 17 00:00:00 2001 From: Daniel Jagszent Date: Wed, 12 Apr 2023 18:45:51 +0200 Subject: [PATCH] feat: mailfilter.transaction now only uses a temporary file to capture the mail body when it is bigger than 200KB --- internal/body/body.go | 108 +++++++++++++++++++++++++ internal/body/body_test.go | 158 +++++++++++++++++++++++++++++++++++++ mailfilter/transaction.go | 10 +-- 3 files changed, 269 insertions(+), 7 deletions(-) create mode 100644 internal/body/body.go create mode 100644 internal/body/body_test.go diff --git a/internal/body/body.go b/internal/body/body.go new file mode 100644 index 0000000..584e7c8 --- /dev/null +++ b/internal/body/body.go @@ -0,0 +1,108 @@ +// Package body implements a write-once read-multiple [io.ReadSeekCloser] that is backed by a temporary file when too much data gets written into it. +package body + +import ( + "bytes" + "io" + "os" +) + +// New creates a new Body that switches from memory-backed storage to file-backed storage +// when more than maxMem bytes were written to it. +// +// If maxMem is less than 1 a temporary file gets always used. +func New(maxMem int) *Body { + return &Body{maxMem: maxMem} +} + +// Body is an [io.ReadSeekCloser] and [io.Writer] that starts buffering all data written to it in memory +// but when more than a configured amount of bytes is written to it Body will switch to writing to a temporary file. +// +// After a call to Read or Seek no more data can be written to Body. +// Body is an [io.Seeker] so you can read it multiple times or get the size of the Body. +type Body struct { + maxMem int + buf bytes.Buffer + mem *bytes.Reader + file *os.File + reading bool +} + +// Write implements the io.Writer interface. +// Write will create a temporary file on-the-fly when you write more than the configured amount of bytes. +func (b *Body) Write(p []byte) (n int, err error) { + if b.reading { + panic("cannot write after read") + } + if b.file != nil { + return b.file.Write(p) + } + n, _ = b.buf.Write(p) + if b.buf.Len() > b.maxMem { + b.file, err = os.CreateTemp("", "body-*") + if err != nil { + return + } + _, err = io.Copy(b.file, &b.buf) + b.buf.Reset() + } + return +} + +func (b *Body) switchToReading() error { + if !b.reading { + b.reading = true + if b.file != nil { + if _, err := b.file.Seek(0, io.SeekStart); err != nil { + return err + } + } else { + b.mem = bytes.NewReader(b.buf.Bytes()) + } + } + return nil +} + +// Read implements the io.Reader interface. +// After calling Read you cannot call Write anymore. +func (b *Body) Read(p []byte) (n int, err error) { + if err := b.switchToReading(); err != nil { + return 0, err + } + if b.file != nil { + + return b.file.Read(p) + } + return b.mem.Read(p) +} + +// Close implements the io.Closer interface. +// If a temporary file got created it will be deleted. +func (b *Body) Close() error { + if b.file != nil { + err1 := b.file.Close() + err2 := os.Remove(b.file.Name()) + if err1 != nil { + return err1 + } + if os.IsNotExist(err2) { + err2 = nil + } + return err2 + } + b.mem = nil + b.buf.Reset() + return nil +} + +// Seek implements the io.Seeker interface. +// After calling Seek you cannot call Write anymore. +func (b *Body) Seek(offset int64, whence int) (int64, error) { + if err := b.switchToReading(); err != nil { + return 0, err + } + if b.file != nil { + return b.file.Seek(offset, whence) + } + return b.mem.Seek(offset, whence) +} diff --git a/internal/body/body_test.go b/internal/body/body_test.go new file mode 100644 index 0000000..b442252 --- /dev/null +++ b/internal/body/body_test.go @@ -0,0 +1,158 @@ +package body + +import ( + "bytes" + "io" + "os" + "testing" +) + +func getBody(maxMem int, data []byte) *Body { + b := New(maxMem) + _, _ = b.Write(data) + return b +} + +func TestBody_Close(t *testing.T) { + fileAlreadyRemoved := getBody(2, []byte("test")) + _ = os.Remove(fileAlreadyRemoved.file.Name()) + tests := []struct { + name string + body *Body + wantErr bool + }{ + {"noop", getBody(10, nil), false}, + {"mem", getBody(10, []byte("test")), false}, + {"file", getBody(2, []byte("test")), false}, + {"file-already-removed", fileAlreadyRemoved, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.body.Close(); (err != nil) != tt.wantErr { + t.Errorf("Close() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestBody(t *testing.T) { + t.Run("mem", func(t *testing.T) { + b := getBody(10, []byte("test")) + defer b.Close() + _, err := b.Write([]byte("test")) + if err != nil { + t.Fatal("b.Write got error", err) + } + if b.file != nil { + t.Fatal("b.file needs to be nil") + } + var buf [10]byte + n, err := b.Read(buf[:]) + if err != nil { + t.Fatal("b.Read got error", err) + } + if !bytes.Equal([]byte("testtest"), buf[:n]) { + t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest")) + } + pos, err := b.Seek(0, io.SeekStart) + if err != nil { + t.Fatal("b.Seek got error", err) + } + if pos != 0 { + t.Fatal("b.Seek got pos", pos) + } + n, err = b.Read(buf[:]) + if err != nil { + t.Fatal("b.Read got error", err) + } + if !bytes.Equal([]byte("testtest"), buf[:n]) { + t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest")) + } + }) + t.Run("file", func(t *testing.T) { + b := getBody(2, []byte("test")) + defer func() { + if b != nil { + b.Close() + } + }() + if b.file == nil { + t.Fatal("b.file is nil") + } + _, err := b.Write([]byte("test")) + if err != nil { + t.Fatal("b.Write got error", err) + } + var buf [10]byte + n, err := b.Read(buf[:]) + if err != nil { + t.Fatal("b.Read got error", err) + } + if !bytes.Equal([]byte("testtest"), buf[:n]) { + t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest")) + } + pos, err := b.Seek(0, io.SeekStart) + if err != nil { + t.Fatal("b.Seek got error", err) + } + if pos != 0 { + t.Fatal("b.Seek got pos", pos) + } + n, err = b.Read(buf[:]) + if err != nil { + t.Fatal("b.Read got error", err) + } + if !bytes.Equal([]byte("testtest"), buf[:n]) { + t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest")) + } + name := b.file.Name() + err = b.Close() + b = nil + if err != nil { + t.Fatal("b.Close got error", err) + } + _, err = os.Stat(name) + if err == nil || !os.IsNotExist(err) { + t.Fatalf("got %v expected to not find file", err) + } + }) + t.Run("panic on Write after Read", func(t *testing.T) { + defer func() { _ = recover() }() + b := getBody(10, []byte("test")) + var buf [10]byte + _, _ = b.Read(buf[:]) + _, _ = b.Write([]byte("test")) + t.Errorf("did not panic") + }) + t.Run("panic on Write after Seek", func(t *testing.T) { + defer func() { _ = recover() }() + b := getBody(10, []byte("test")) + _, _ = b.Seek(0, io.SeekEnd) + _, _ = b.Write([]byte("test")) + t.Errorf("did not panic") + }) + t.Run("temp file fail", func(t *testing.T) { + tmpdir := os.Getenv("TMPDIR") + tmp := os.Getenv("TMP") + _ = os.Setenv("TMPDIR", "/this does not exist") + _ = os.Setenv("TMP", "/this does not exist") + defer func() { + _ = os.Setenv("TMPDIR", tmpdir) + _ = os.Setenv("TMP", tmp) + }() + b := getBody(6, []byte("test")) + _, err := b.Write([]byte("test")) + if err == nil { + _ = b.Close() + t.Fatal("b.Write got nil error") + } + }) + t.Run("file close fail", func(t *testing.T) { + b := getBody(2, []byte("test")) + _ = b.file.Close() + err := b.Close() + if err == nil { + t.Fatal("b.Close got nil error") + } + }) +} diff --git a/mailfilter/transaction.go b/mailfilter/transaction.go index dfd284f..6df3e18 100644 --- a/mailfilter/transaction.go +++ b/mailfilter/transaction.go @@ -4,10 +4,10 @@ import ( "bytes" "context" "io" - "os" "regexp" "github.com/d--j/go-milter" + "github.com/d--j/go-milter/internal/body" "github.com/d--j/go-milter/internal/header" "github.com/d--j/go-milter/internal/rcptto" "github.com/d--j/go-milter/mailfilter/addr" @@ -58,7 +58,7 @@ type transaction struct { headers *header.Header origHeaders *header.Header enforceHeaderOrder bool - body *os.File + body *body.Body replacementBody io.Reader queueId string hasDecision bool @@ -92,7 +92,6 @@ func (t *transaction) cleanup() { t.closeReplacementBody() if t.body != nil { _ = t.body.Close() - _ = os.Remove(t.body.Name()) t.body = nil } } @@ -249,10 +248,7 @@ func (t *transaction) addHeader(key string, raw []byte) { func (t *transaction) addBodyChunk(chunk []byte) (err error) { if t.body == nil { - t.body, err = os.CreateTemp("", "body-*") - if err != nil { - return - } + t.body = body.New(200 * 1024) } _, err = t.body.Write(chunk) return