diff --git a/client.go b/client.go index 79dc931..4441f06 100644 --- a/client.go +++ b/client.go @@ -130,6 +130,20 @@ func (r *Request) SetBody(rawBody interface{}) error { } r.body = bodyReader r.ContentLength = contentLength + if bodyReader != nil { + r.GetBody = func() (io.ReadCloser, error) { + body, err := bodyReader() + if err != nil { + return nil, err + } + if rc, ok := body.(io.ReadCloser); ok { + return rc, nil + } + return ioutil.NopCloser(body), nil + } + } else { + r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil } + } return nil } @@ -257,18 +271,17 @@ func FromRequest(r *http.Request) (*Request, error) { // NewRequest creates a new wrapped request. func NewRequest(method, url string, rawBody interface{}) (*Request, error) { - bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody) + httpReq, err := http.NewRequest(method, url, nil) if err != nil { return nil, err } - httpReq, err := http.NewRequest(method, url, nil) - if err != nil { + req := &Request{Request: httpReq} + if err := req.SetBody(rawBody); err != nil { return nil, err } - httpReq.ContentLength = contentLength - return &Request{bodyReader, httpReq}, nil + return req, nil } // Logger interface allows to use other loggers than diff --git a/client_test.go b/client_test.go index b8a3d9d..02086f3 100644 --- a/client_test.go +++ b/client_test.go @@ -842,3 +842,44 @@ func TestClient_StandardClient(t *testing.T) { t.Fatalf("expected %v, got %v", client, v) } } + +func TestClient_RedirectWithBody(t *testing.T) { + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + case "/foo/redirect": + w.Header().Set("Location", "/foo/redirected") + w.WriteHeader(307) + case "/foo/redirected": + w.WriteHeader(200) + default: + t.Fatalf("bad uri: %s", r.RequestURI) + } + })) + defer ts.Close() + + client := NewClient() + + // has body + req, err := NewRequest(http.MethodPost, ts.URL+"/foo/redirect", strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() + + // no body + if err := req.SetBody(nil); err != nil { + t.Fatalf("err: %v", err) + } + + resp, err = client.Do(req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() +}