Skip to content

Commit

Permalink
Add middleware to set HTTP request content length
Browse files Browse the repository at this point in the history
Adds middleware that will set the HTTP request's Content-Length header
if the length of the serialized stream can be determined. If the length
cannot be determined or the Header value is already set the middleware
will do nothing.

Fails the request if the underlying request stream reader is seekable
but fails to seek.
  • Loading branch information
jasdel committed Jul 13, 2020
1 parent b42826b commit 03463dd
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 0 deletions.
50 changes: 50 additions & 0 deletions transport/http/middleware_content_length.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package http

import (
"context"
"fmt"
"strconv"

"github.com/awslabs/smithy-go/middleware"
)

// ContentLengthMiddleware provides a middleware to set the content-length
// header for the length of a serialize request body.
type ContentLengthMiddleware struct {
}

// AddContentLengthMiddleware adds ContentLengthMiddleware to the middleware
// stack's Build step.
func AddContentLengthMiddleware(stack *middleware.Stack) {
stack.Build.Add(&ContentLengthMiddleware{}, middleware.After)
}

// ID the identifier for the ContentLengthMiddleware
func (m *ContentLengthMiddleware) ID() string { return "ContentLengthMiddleware" }

// HandleBuild adds the length of the serialized request to the HTTP header
// if the length can be determined.
func (m *ContentLengthMiddleware) HandleBuild(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}

// Don't set content length if header is already set.
if vs := req.Header.Values("Content-Length"); len(vs) != 0 {
return next.HandleBuild(ctx, in)
}

if n, ok, err := req.StreamLength(); err != nil {
return out, metadata, fmt.Errorf(
"failed getting length of request stream, %w", err)
} else if ok {
req.Header.Set("Content-Length", strconv.FormatInt(n, 10))
}

return next.HandleBuild(ctx, in)
}
132 changes: 132 additions & 0 deletions transport/http/middleware_content_length_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package http

import (
"bytes"
"context"
"fmt"
"io"
"strings"
"testing"

"github.com/awslabs/smithy-go/middleware"
)

func TestContentLengthMiddleware(t *testing.T) {
cases := map[string]struct {
Stream io.Reader
ExpectLen string
ExpectErr string
}{
// Cases
"bytes.Reader": {
Stream: bytes.NewReader(make([]byte, 10)),
ExpectLen: "10",
},
"bytes.Buffer": {
Stream: bytes.NewBuffer(make([]byte, 10)),
ExpectLen: "10",
},
"strings.Reader": {
Stream: strings.NewReader("hello"),
ExpectLen: "5",
},
"un-seekable and no length": {
Stream: &basicReader{buf: make([]byte, 10)},
},
"with error": {
Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")},
ExpectErr: "seek failed",
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
var err error
req := NewStackRequest().(*Request)
req, err = req.SetStream(c.Stream)
if err != nil {
t.Fatalf("expect to set stream, %v", err)
}

var m ContentLengthMiddleware
_, _, err = m.HandleBuild(context.Background(),
middleware.BuildInput{Request: req},
nopBuildHandler{},
)
if len(c.ExpectErr) != 0 {
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %q, got %v", e, a)
}
} else if err != nil {
t.Fatalf("expect no error, got %v", err)
}

t.Logf("request Content-Length:%v", req.Header.Get("Content-Length"))

if e, a := c.ExpectLen, req.Header.Get("Content-Length"); e != a {
t.Errorf("expect %v content-length, got %v", e, a)
}
})
}
}

func TestContentLengthMiddleware_HeaderSet(t *testing.T) {
req := NewStackRequest().(*Request)
req.Header.Set("Content-Length", "1234")

var err error
req, err = req.SetStream(strings.NewReader("hello"))
if err != nil {
t.Fatalf("expect to set stream, %v", err)
}

var m ContentLengthMiddleware
_, _, err = m.HandleBuild(context.Background(),
middleware.BuildInput{Request: req},
nopBuildHandler{},
)
if err != nil {
t.Fatalf("expect middleware to run, %v", err)
}

if e, a := "1234", req.Header.Get("Content-Length"); e != a {
t.Errorf("expect Content-Length not to change, got %v", a)
}
}

type nopBuildHandler struct{}

func (nopBuildHandler) HandleBuild(ctx context.Context, in middleware.BuildInput) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
return out, metadata, nil
}

type basicReader struct {
buf []byte
}

func (r *basicReader) Read(p []byte) (int, error) {
n := copy(p, r.buf)
r.buf = r.buf[n:]
return n, nil
}

type errorSecondSeekableReader struct {
err error
count int
}

func (r *errorSecondSeekableReader) Read(p []byte) (int, error) {
return 0, io.EOF
}
func (r *errorSecondSeekableReader) Seek(offset int64, whence int) (int64, error) {
r.count++
if r.count == 2 {
return 0, r.err
}
return 0, nil
}
26 changes: 26 additions & 0 deletions transport/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@ func (r *Request) Clone() *Request {
return &rc
}

// StreamLength returns the number of bytes of the serialized stream attached
// to the request and ok set. If the length cannot be determined, an error will
// be returned.
func (r *Request) StreamLength() (size int64, ok bool, err error) {
if l, ok := r.stream.(interface{ Len() int }); ok {
return int64(l.Len()), true, nil
}

if !r.isStreamSeekable {
return 0, false, nil
}

s := r.stream.(io.Seeker)
endOffset, err := s.Seek(0, io.SeekEnd)
if err != nil {
return 0, false, err
}

_, err = s.Seek(r.streamStartPos, io.SeekStart)
if err != nil {
return 0, false, err
}

return endOffset - r.streamStartPos, true, nil
}

// RewindStream will rewind the io.Reader to the relative start position if it
// is an io.Seeker.
func (r *Request) RewindStream() error {
Expand Down

0 comments on commit 03463dd

Please sign in to comment.