Skip to content

Commit

Permalink
Make RetryError and HaltError able to be fetched for root cause (than…
Browse files Browse the repository at this point in the history
…os-io#7043)

* Make RetryError and HaltError able to be fetched for root cause

Signed-off-by: Alex Le <leqiyue@amazon.com>

* Added unit test

Signed-off-by: Alex Le <leqiyue@amazon.com>

* fix lint

Signed-off-by: Alex Le <leqiyue@amazon.com>

* fixed IsRetryError and IsHaltError functions

Signed-off-by: Alex Le <leqiyue@amazon.com>

---------

Signed-off-by: Alex Le <leqiyue@amazon.com>
  • Loading branch information
alexqyle authored Jan 15, 2024
1 parent bee20b9 commit 324846f
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
12 changes: 10 additions & 2 deletions pkg/compact/compact.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,10 +941,14 @@ func (e HaltError) Error() string {
return e.err.Error()
}

func (e HaltError) Unwrap() error {
return errors.Cause(e.err)
}

// IsHaltError returns true if the base error is a HaltError.
// If a multierror is passed, any halt error will return true.
func IsHaltError(err error) bool {
if multiErr, ok := errors.Cause(err).(errutil.NonNilMultiError); ok {
if multiErr, ok := errors.Cause(err).(errutil.NonNilMultiRootError); ok {
for _, err := range multiErr {
if _, ok := errors.Cause(err).(HaltError); ok {
return true
Expand Down Expand Up @@ -974,10 +978,14 @@ func (e RetryError) Error() string {
return e.err.Error()
}

func (e RetryError) Unwrap() error {
return errors.Cause(e.err)
}

// IsRetryError returns true if the base error is a RetryError.
// If a multierror is passed, all errors must be retriable.
func IsRetryError(err error) bool {
if multiErr, ok := errors.Cause(err).(errutil.NonNilMultiError); ok {
if multiErr, ok := errors.Cause(err).(errutil.NonNilMultiRootError); ok {
for _, err := range multiErr {
if _, ok := errors.Cause(err).(RetryError); !ok {
return false
Expand Down
29 changes: 29 additions & 0 deletions pkg/errutil/multierror.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"bytes"
"fmt"
"sync"

"github.com/pkg/errors"
)

// The MultiError type implements the error interface, and contains the
Expand Down Expand Up @@ -62,6 +64,33 @@ type NonNilMultiError MultiError

// Returns a concatenated string of the contained errors.
func (es NonNilMultiError) Error() string {
return multiErrorString(es)
}

func (es NonNilMultiError) Cause() error {
return es.getCause()
}

func (es NonNilMultiError) getCause() NonNilMultiRootError {
var causes []error
for _, err := range es {
if multiErr, ok := errors.Cause(err).(NonNilMultiError); ok {
causes = append(causes, multiErr.getCause()...)
} else {
causes = append(causes, errors.Cause(err))
}
}
return causes
}

type NonNilMultiRootError MultiError

// Returns a concatenated string of the contained errors.
func (es NonNilMultiRootError) Error() string {
return multiErrorString(es)
}

func multiErrorString(es []error) string {
var buf bytes.Buffer

if len(es) > 1 {
Expand Down
47 changes: 47 additions & 0 deletions pkg/errutil/multierror_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,56 @@ package errutil
import (
"fmt"
"testing"

"github.com/pkg/errors"
"github.com/stretchr/testify/require"
)

func TestMultiSyncErrorAdd(t *testing.T) {
sme := &SyncMultiError{}
sme.Add(fmt.Errorf("test"))
}

func TestNonNilMultiErrorCause_SingleCause(t *testing.T) {
rootCause := fmt.Errorf("test root cause")
me := MultiError{}
me.Add(errors.Wrap(rootCause, "wrapped error"))
causes, ok := errors.Cause(NonNilMultiError(me)).(NonNilMultiRootError)
require.True(t, ok)
require.Equal(t, 1, len(causes))
require.Equal(t, rootCause, causes[0])
}

func TestNonNilMultiErrorCause_MultipleCauses(t *testing.T) {
rootCause1 := fmt.Errorf("test root cause 1")
rootCause2 := fmt.Errorf("test root cause 2")
rootCause3 := fmt.Errorf("test root cause 3")
me := MultiError{}
me.Add(errors.Wrap(rootCause1, "wrapped error 1"))
me.Add(errors.Wrap(errors.Wrap(rootCause2, "wrapped error 2"), "wrapped error 2 again"))
me.Add(rootCause3)
causes, ok := errors.Cause(NonNilMultiError(me)).(NonNilMultiRootError)
require.True(t, ok)
require.Equal(t, 3, len(causes))
require.Contains(t, causes, rootCause1)
require.Contains(t, causes, rootCause2)
require.Contains(t, causes, rootCause3)
}

func TestNonNilMultiErrorCause_MultipleCausesWithNestedNonNilMultiError(t *testing.T) {
rootCause1 := fmt.Errorf("test root cause 1")
rootCause2 := fmt.Errorf("test root cause 2")
rootCause3 := fmt.Errorf("test root cause 3")
me1 := MultiError{}
me1.Add(errors.Wrap(rootCause1, "wrapped error 1"))
me1.Add(errors.Wrap(rootCause2, "wrapped error 2"))
me := MultiError{}
me.Add(errors.Wrap(rootCause3, "wrapped error 3"))
me.Add(NonNilMultiError(me1))
causes, ok := errors.Cause(NonNilMultiError(me)).(NonNilMultiRootError)
require.True(t, ok)
require.Equal(t, 3, len(causes))
require.Contains(t, causes, rootCause1)
require.Contains(t, causes, rootCause2)
require.Contains(t, causes, rootCause3)
}

0 comments on commit 324846f

Please sign in to comment.