Skip to content

Commit

Permalink
Merge pull request #499 from aws/fixIssue497
Browse files Browse the repository at this point in the history
Fix status waiter with status code > 4xx
  • Loading branch information
jasdel committed Jan 13, 2016
2 parents c924893 + c7e0589 commit 64ecfaa
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 42 deletions.
6 changes: 5 additions & 1 deletion private/protocol/query/unmarshal_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ func UnmarshalError(r *request.Request) {
if err != nil && err != io.EOF {
r.Error = awserr.New("SerializationError", "failed to decode query XML error response", err)
} else {
reqID := resp.RequestID
if reqID == "" {
reqID = r.RequestID
}
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
reqID,
)
}
}
4 changes: 4 additions & 0 deletions private/protocol/rest/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ func Unmarshal(r *request.Request) {
// UnmarshalMeta unmarshals the REST metadata of a response in a REST service
func UnmarshalMeta(r *request.Request) {
r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
if r.RequestID == "" {
// Alternative version of request id in the header
r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
}
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalLocationElements(r, v)
Expand Down
8 changes: 3 additions & 5 deletions private/waiter/waiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,15 @@ func (w *Waiter) Wait() error {

err := req.Send()
for _, a := range w.Acceptors {
if err != nil && a.Matcher != "error" {
// Only matcher error is valid if there is a request error
continue
}

result := false
var vals []interface{}
switch a.Matcher {
case "pathAll", "path":
// Require all matches to be equal for result to match
vals, _ = awsutil.ValuesAtPath(req.Data, a.Argument)
if len(vals) == 0 {
break
}
result = true
for _, val := range vals {
if !awsutil.DeepEqual(val, a.Expected) {
Expand Down
24 changes: 17 additions & 7 deletions private/waiter/waiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ func TestWaiterError(t *testing.T) {
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.UnmarshalError.Clear()
svc.Handlers.ValidateResponse.Clear()

reqNum := 0
Expand All @@ -291,14 +292,14 @@ func TestWaiterError(t *testing.T) {
numBuiltReq++
})
svc.Handlers.Send.PushBack(func(r *request.Request) {
code := 200
if reqNum == 1 {
r.Error = awserr.New("MockException", "mock exception message", nil)
r.HTTPResponse = &http.Response{
StatusCode: 400,
Status: http.StatusText(400),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
reqNum++
code = 400
}
r.HTTPResponse = &http.Response{
StatusCode: code,
Status: http.StatusText(code),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
Expand All @@ -309,6 +310,14 @@ func TestWaiterError(t *testing.T) {
r.Data = resps[reqNum]
reqNum++
})
svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) {
if reqNum == 1 {
r.Error = awserr.New("MockException", "mock exception message", nil)
// If there was an error unmarshal error will be called instead of unmarshal
// need to increment count here also
reqNum++
}
})

waiterCfg := waiter.Config{
Operation: "Mock",
Expand Down Expand Up @@ -358,6 +367,7 @@ func TestWaiterStatus(t *testing.T) {
code := 200
if reqNum == 3 {
code = 404
r.Error = awserr.New("NotFound", "Not Found", nil)
}
r.HTTPResponse = &http.Response{
StatusCode: code,
Expand Down
16 changes: 11 additions & 5 deletions service/s3/unmarshal_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,23 @@ func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()

if r.HTTPResponse.StatusCode == http.StatusMovedPermanently {
r.Error = awserr.New("BucketRegionError",
fmt.Sprintf("incorrect region, the bucket is not in '%s' region", aws.StringValue(r.Config.Region)), nil)
r.Error = awserr.NewRequestFailure(
awserr.New("BucketRegionError",
fmt.Sprintf("incorrect region, the bucket is not in '%s' region",
aws.StringValue(r.Config.Region)),
nil),
r.HTTPResponse.StatusCode,
r.RequestID,
)
return
}

if r.HTTPResponse.ContentLength == int64(0) {
if r.HTTPResponse.ContentLength <= 1 {
// No body, use status code to generate an awserr.Error
r.Error = awserr.NewRequestFailure(
awserr.New(strings.Replace(r.HTTPResponse.Status, " ", "", -1), r.HTTPResponse.Status, nil),
r.HTTPResponse.StatusCode,
"",
r.RequestID,
)
return
}
Expand All @@ -45,7 +51,7 @@ func unmarshalError(r *request.Request) {
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
"",
r.RequestID,
)
}
}
103 changes: 79 additions & 24 deletions service/s3/unmarshal_error_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package s3_test

import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -15,39 +16,93 @@ import (
"github.com/aws/aws-sdk-go/service/s3"
)

var s3StatusCodeErrorTests = []struct {
scode int
status string
body string
code string
message string
}{
{301, "Moved Permanently", "", "BucketRegionError", "incorrect region, the bucket is not in 'mock-region' region"},
{403, "Forbidden", "", "Forbidden", "Forbidden"},
{400, "Bad Request", "", "BadRequest", "Bad Request"},
{404, "Not Found", "", "NotFound", "Not Found"},
{500, "Internal Error", "", "InternalError", "Internal Error"},
type testErrorCase struct {
RespFn func() *http.Response
ReqID string
Code, Msg string
}

func TestStatusCodeError(t *testing.T) {
for _, test := range s3StatusCodeErrorTests {
var testUnmarshalCases = []testErrorCase{
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 301,
Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}},
Body: ioutil.NopCloser(nil),
ContentLength: -1,
}
},
ReqID: "abc123",
Code: "BucketRegionError", Msg: "incorrect region, the bucket is not in 'mock-region' region",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 403,
Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}},
Body: ioutil.NopCloser(nil),
ContentLength: 0,
}
},
ReqID: "abc123",
Code: "Forbidden", Msg: "Forbidden",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 400,
Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}},
Body: ioutil.NopCloser(nil),
ContentLength: 0,
}
},
ReqID: "abc123",
Code: "BadRequest", Msg: "Bad Request",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 404,
Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}},
Body: ioutil.NopCloser(nil),
ContentLength: 0,
}
},
ReqID: "abc123",
Code: "NotFound", Msg: "Not Found",
},
{
RespFn: func() *http.Response {
body := `<Error><Code>SomeException</Code><Message>Exception message</Message></Error>`
return &http.Response{
StatusCode: 500,
Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}},
Body: ioutil.NopCloser(strings.NewReader(body)),
ContentLength: int64(len(body)),
}
},
ReqID: "abc123",
Code: "SomeException", Msg: "Exception message",
},
}

func TestUnmarshalError(t *testing.T) {
for _, c := range testUnmarshalCases {
s := s3.New(unit.Session)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *request.Request) {
body := ioutil.NopCloser(bytes.NewReader([]byte(test.body)))
r.HTTPResponse = &http.Response{
ContentLength: int64(len(test.body)),
StatusCode: test.scode,
Status: test.status,
Body: body,
}
r.HTTPResponse = c.RespFn()
r.HTTPResponse.Status = http.StatusText(r.HTTPResponse.StatusCode)
})
_, err := s.PutBucketAcl(&s3.PutBucketAclInput{
Bucket: aws.String("bucket"), ACL: aws.String("public-read"),
})

fmt.Printf("%#v\n", err)

assert.Error(t, err)
assert.Equal(t, test.code, err.(awserr.Error).Code())
assert.Equal(t, test.message, err.(awserr.Error).Message())
assert.Equal(t, c.Code, err.(awserr.Error).Code())
assert.Equal(t, c.Msg, err.(awserr.Error).Message())
assert.Equal(t, c.ReqID, err.(awserr.RequestFailure).RequestID())
}
}

0 comments on commit 64ecfaa

Please sign in to comment.