diff --git a/logical/error.go b/logical/error.go index b033e0eac5d6..b2017000e7f1 100644 --- a/logical/error.go +++ b/logical/error.go @@ -20,6 +20,10 @@ var ( // ErrMultiAuthzPending is returned if the the request needs more // authorizations ErrMultiAuthzPending = errors.New("request needs further approval") + + // ErrUpstreamRateLimited is returned when Vault recieves a rate limited + // response from an upstream + ErrUpstreamRateLimited = errors.New("upstream rate limited") ) type HTTPCodedError interface { diff --git a/logical/response_util.go b/logical/response_util.go index c89e547c7726..fbbd021d8cd4 100644 --- a/logical/response_util.go +++ b/logical/response_util.go @@ -105,6 +105,8 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { statusCode = http.StatusNotFound case errwrap.Contains(err, ErrInvalidRequest.Error()): statusCode = http.StatusBadRequest + case errwrap.Contains(err, ErrUpstreamRateLimited.Error()): + statusCode = http.StatusBadGateway } } diff --git a/logical/response_util_test.go b/logical/response_util_test.go new file mode 100644 index 000000000000..823c4afbdc21 --- /dev/null +++ b/logical/response_util_test.go @@ -0,0 +1,83 @@ +package logical + +import ( + "strings" + "testing" +) + +func TestResponseUtil_RespondErrorCommon_basic(t *testing.T) { + testCases := []struct { + title string + req *Request + resp *Response + respErr error + expectedStatus int + expectedErr error + }{ + { + title: "Throttled, no error", + respErr: ErrUpstreamRateLimited, + resp: &Response{}, + expectedStatus: 502, + }, + { + title: "Throttled, with error", + respErr: ErrUpstreamRateLimited, + resp: &Response{ + Data: map[string]interface{}{ + "error": "rate limited", + }, + }, + expectedStatus: 502, + }, + { + title: "Read not found", + req: &Request{ + Operation: ReadOperation, + }, + respErr: nil, + expectedStatus: 404, + }, + { + title: "List with response and no keys", + req: &Request{ + Operation: ListOperation, + }, + resp: &Response{}, + respErr: nil, + expectedStatus: 404, + }, + { + title: "List with response and keys", + req: &Request{ + Operation: ListOperation, + }, + resp: &Response{ + Data: map[string]interface{}{ + "keys": []string{"some", "things", "here"}, + }, + }, + respErr: nil, + expectedStatus: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.title, func(t *testing.T) { + var status int + var err, respErr error + if tc.respErr != nil { + respErr = tc.respErr + } + status, err = RespondErrorCommon(tc.req, tc.resp, respErr) + if status != tc.expectedStatus { + t.Fatalf("Expected (%d) status code, got (%d)", tc.expectedStatus, status) + } + if tc.expectedErr != nil { + if !strings.Contains(tc.expectedErr.Error(), err.Error()) { + t.Fatalf("Expected error to contain:\n%s\n\ngot:\n%s\n", tc.expectedErr, err) + } + } + }) + } +}