Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add middleware to set HTTP request content length #108

Merged
merged 3 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,15 @@ protected void generateTestBodySetup(GoWriter writer) {
*/
protected void generateTestServerHandler(GoWriter writer) {
writer.write("actualReq = r.Clone(r.Context())");
// Go does not set RawPath on http server if nothing is excaped
writer.write("if len(actualReq.URL.RawPath) == 0 { actualReq.URL.RawPath = actualReq.URL.Path }");
// Go does not set RawPath on http server if nothing is escaped
writer.openBlock("if len(actualReq.URL.RawPath) == 0 {", "}", () -> {
writer.write("actualReq.URL.RawPath = actualReq.URL.Path");
});
// Go automatically removes Content-Length header setting it to the member.
writer.addUseImports(SmithyGoDependency.STRCONV);
writer.openBlock("if v := actualReq.ContentLength; v != 0 {", "}", () -> {
writer.write("actualReq.Header.Set(\"Content-Length\", strconv.FormatInt(v, 10))");
});

writer.addUseImports(SmithyGoDependency.BYTES);
writer.write("var buf bytes.Buffer");
Expand Down
50 changes: 50 additions & 0 deletions transport/http/middleware_content_length.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package http

import (
"context"
"fmt"
"strconv"

"github.com/awslabs/smithy-go/middleware"
)

// ContentLengthMiddleware provides a middleware to set the content-length
// header for the length of a serialize request body.
type ContentLengthMiddleware struct {
}

// AddContentLengthMiddleware adds ContentLengthMiddleware to the middleware
// stack's Build step.
func AddContentLengthMiddleware(stack *middleware.Stack) {
stack.Build.Add(&ContentLengthMiddleware{}, middleware.After)
}

// ID the identifier for the ContentLengthMiddleware
func (m *ContentLengthMiddleware) ID() string { return "ContentLengthMiddleware" }

// HandleBuild adds the length of the serialized request to the HTTP header
// if the length can be determined.
func (m *ContentLengthMiddleware) HandleBuild(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}

// Don't set content length if header is already set.
if vs := req.Header.Values("Content-Length"); len(vs) != 0 {
jasdel marked this conversation as resolved.
Show resolved Hide resolved
return next.HandleBuild(ctx, in)
}

if n, ok, err := req.StreamLength(); err != nil {
return out, metadata, fmt.Errorf(
"failed getting length of request stream, %w", err)
} else if ok {
req.Header.Set("Content-Length", strconv.FormatInt(n, 10))
}

return next.HandleBuild(ctx, in)
}
132 changes: 132 additions & 0 deletions transport/http/middleware_content_length_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package http

import (
"bytes"
"context"
"fmt"
"io"
"strings"
"testing"

"github.com/awslabs/smithy-go/middleware"
)

func TestContentLengthMiddleware(t *testing.T) {
cases := map[string]struct {
Stream io.Reader
ExpectLen string
ExpectErr string
}{
// Cases
"bytes.Reader": {
Stream: bytes.NewReader(make([]byte, 10)),
ExpectLen: "10",
},
"bytes.Buffer": {
Stream: bytes.NewBuffer(make([]byte, 10)),
ExpectLen: "10",
},
"strings.Reader": {
Stream: strings.NewReader("hello"),
ExpectLen: "5",
},
"un-seekable and no length": {
Stream: &basicReader{buf: make([]byte, 10)},
},
"with error": {
Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")},
ExpectErr: "seek failed",
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
var err error
req := NewStackRequest().(*Request)
req, err = req.SetStream(c.Stream)
if err != nil {
t.Fatalf("expect to set stream, %v", err)
}

var m ContentLengthMiddleware
_, _, err = m.HandleBuild(context.Background(),
middleware.BuildInput{Request: req},
nopBuildHandler{},
)
if len(c.ExpectErr) != 0 {
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %q, got %v", e, a)
}
} else if err != nil {
t.Fatalf("expect no error, got %v", err)
}

t.Logf("request Content-Length:%v", req.Header.Get("Content-Length"))

if e, a := c.ExpectLen, req.Header.Get("Content-Length"); e != a {
t.Errorf("expect %v content-length, got %v", e, a)
}
})
}
}

func TestContentLengthMiddleware_HeaderSet(t *testing.T) {
req := NewStackRequest().(*Request)
req.Header.Set("Content-Length", "1234")

var err error
req, err = req.SetStream(strings.NewReader("hello"))
if err != nil {
t.Fatalf("expect to set stream, %v", err)
}

var m ContentLengthMiddleware
_, _, err = m.HandleBuild(context.Background(),
middleware.BuildInput{Request: req},
nopBuildHandler{},
)
if err != nil {
t.Fatalf("expect middleware to run, %v", err)
}

if e, a := "1234", req.Header.Get("Content-Length"); e != a {
t.Errorf("expect Content-Length not to change, got %v", a)
}
}

type nopBuildHandler struct{}

func (nopBuildHandler) HandleBuild(ctx context.Context, in middleware.BuildInput) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
return out, metadata, nil
}

type basicReader struct {
buf []byte
}

func (r *basicReader) Read(p []byte) (int, error) {
n := copy(p, r.buf)
r.buf = r.buf[n:]
return n, nil
}

type errorSecondSeekableReader struct {
err error
count int
}

func (r *errorSecondSeekableReader) Read(p []byte) (int, error) {
return 0, io.EOF
}
func (r *errorSecondSeekableReader) Seek(offset int64, whence int) (int64, error) {
r.count++
if r.count == 2 {
return 0, r.err
}
return 0, nil
}
26 changes: 26 additions & 0 deletions transport/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@ func (r *Request) Clone() *Request {
return &rc
}

// StreamLength returns the number of bytes of the serialized stream attached
// to the request and ok set. If the length cannot be determined, an error will
// be returned.
func (r *Request) StreamLength() (size int64, ok bool, err error) {
if l, ok := r.stream.(interface{ Len() int }); ok {
return int64(l.Len()), true, nil
}

if !r.isStreamSeekable {
return 0, false, nil
}

s := r.stream.(io.Seeker)
endOffset, err := s.Seek(0, io.SeekEnd)
if err != nil {
return 0, false, err
}

_, err = s.Seek(r.streamStartPos, io.SeekStart)
if err != nil {
return 0, false, err
}
jasdel marked this conversation as resolved.
Show resolved Hide resolved

return endOffset - r.streamStartPos, true, nil
}

// RewindStream will rewind the io.Reader to the relative start position if it
// is an io.Seeker.
func (r *Request) RewindStream() error {
Expand Down