diff --git a/errors.go b/errors.go index fc3eb6547..ab9613be9 100644 --- a/errors.go +++ b/errors.go @@ -149,3 +149,34 @@ func NewError(err any) *Error { return &Error{Code: ErrorUnsupported, Message: fmt.Sprintf("Unsupported type to linodego.NewError: %s", reflect.TypeOf(e))} } } + +// IsNotFound indicates if err indicates a 404 Not Found error from the Linode API. +func IsNotFound(err error) bool { + return ErrHasStatus(err, http.StatusNotFound) +} + +// ErrHasStatus checks if err is an error from the Linode API, and whether it contains the given HTTP status code. +// More than one status code may be given. +// If len(code) == 0, err is nil or is not a [Error], ErrHasStatus will return false. +func ErrHasStatus(err error, code ...int) bool { + if err == nil { + return false + } + + // Short-circuit if the caller did not provide any status codes. + if len(code) == 0 { + return false + } + + var e *Error + if !errors.As(err, &e) { + return false + } + ec := e.StatusCode() + for _, c := range code { + if ec == c { + return true + } + } + return false +} diff --git a/errors_test.go b/errors_test.go index ec5519256..c10cfdb9e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -294,3 +294,78 @@ func TestErrorIs(t *testing.T) { }) } } + +func TestIsNotFound(t *testing.T) { + tests := []struct { + code int + match bool + }{ + {code: http.StatusNotFound, match: true}, + {code: http.StatusInternalServerError}, + {code: http.StatusFound}, + {code: http.StatusOK}, + } + + for _, tt := range tests { + name := http.StatusText(tt.code) + t.Run(name, func(t *testing.T) { + err := &Error{Code: tt.code} + if matches := IsNotFound(err); !matches && tt.match { + t.Errorf("should have matched %d", tt.code) + } else if matches && !tt.match { + t.Errorf("shoudl not have matched %d", tt.code) + } + }) + } +} + +func TestErrHasStatusCode(t *testing.T) { + tests := []struct { + name string + err error + codes []int + match bool + }{ + { + name: "NotFound", + err: &Error{Code: http.StatusNotFound}, + codes: []int{http.StatusNotFound}, + match: true, + }, + { + name: "NoCodes", + err: &Error{Code: http.StatusInternalServerError}, + }, + { + name: "MultipleCodes", + err: &Error{Code: http.StatusTeapot}, + codes: []int{http.StatusBadRequest, http.StatusTeapot, http.StatusUnavailableForLegalReasons}, + match: true, + }, + { + name: "NotALinodeError", + err: io.EOF, + codes: []int{http.StatusTeapot}, + }, + { + name: "NoMatch", + err: &Error{Code: http.StatusTooEarly}, + codes: []int{http.StatusLocked, http.StatusTooManyRequests}, + }, + { + name: "NilError", + codes: []int{http.StatusGone}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ErrHasStatus(tt.err, tt.codes...) + if !got && tt.match { + t.Errorf("should have matched") + } else if got && !tt.match { + t.Errorf("should not have matched") + } + }) + } +}