diff --git a/transport/http/middleware_content_length.go b/transport/http/middleware_content_length.go new file mode 100644 index 000000000..20586aad6 --- /dev/null +++ b/transport/http/middleware_content_length.go @@ -0,0 +1,50 @@ +package http + +import ( + "context" + "fmt" + "strconv" + + "github.com/awslabs/smithy-go/middleware" +) + +// ContentLengthMiddleware provides a middleware to set the content-length +// header for the length of a serialize request body. +type ContentLengthMiddleware struct { +} + +// AddContentLengthMiddleware adds ContentLengthMiddleware to the middleware +// stack's Build step. +func AddContentLengthMiddleware(stack *middleware.Stack) { + stack.Build.Add(&ContentLengthMiddleware{}, middleware.After) +} + +// ID the identifier for the ContentLengthMiddleware +func (m *ContentLengthMiddleware) ID() string { return "ContentLengthMiddleware" } + +// HandleBuild adds the length of the serialized request to the HTTP header +// if the length can be determined. +func (m *ContentLengthMiddleware) HandleBuild( + ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, +) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + req, ok := in.Request.(*Request) + if !ok { + return out, metadata, fmt.Errorf("unknown request type %T", req) + } + + // Don't set content length if header is already set. + if vs := req.Header.Values("Content-Length"); len(vs) != 0 { + return next.HandleBuild(ctx, in) + } + + if n, ok, err := req.StreamLength(); err != nil { + return out, metadata, fmt.Errorf( + "failed getting length of request stream, %w", err) + } else if ok { + req.Header.Set("Content-Length", strconv.FormatInt(n, 10)) + } + + return next.HandleBuild(ctx, in) +} diff --git a/transport/http/middleware_content_length_test.go b/transport/http/middleware_content_length_test.go new file mode 100644 index 000000000..418b6ee22 --- /dev/null +++ b/transport/http/middleware_content_length_test.go @@ -0,0 +1,132 @@ +package http + +import ( + "bytes" + "context" + "fmt" + "io" + "strings" + "testing" + + "github.com/awslabs/smithy-go/middleware" +) + +func TestContentLengthMiddleware(t *testing.T) { + cases := map[string]struct { + Stream io.Reader + ExpectLen string + ExpectErr string + }{ + // Cases + "bytes.Reader": { + Stream: bytes.NewReader(make([]byte, 10)), + ExpectLen: "10", + }, + "bytes.Buffer": { + Stream: bytes.NewBuffer(make([]byte, 10)), + ExpectLen: "10", + }, + "strings.Reader": { + Stream: strings.NewReader("hello"), + ExpectLen: "5", + }, + "un-seekable and no length": { + Stream: &basicReader{buf: make([]byte, 10)}, + }, + "with error": { + Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")}, + ExpectErr: "seek failed", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var err error + req := NewStackRequest().(*Request) + req, err = req.SetStream(c.Stream) + if err != nil { + t.Fatalf("expect to set stream, %v", err) + } + + var m ContentLengthMiddleware + _, _, err = m.HandleBuild(context.Background(), + middleware.BuildInput{Request: req}, + nopBuildHandler{}, + ) + if len(c.ExpectErr) != 0 { + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %q, got %v", e, a) + } + } else if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + t.Logf("request Content-Length:%v", req.Header.Get("Content-Length")) + + if e, a := c.ExpectLen, req.Header.Get("Content-Length"); e != a { + t.Errorf("expect %v content-length, got %v", e, a) + } + }) + } +} + +func TestContentLengthMiddleware_HeaderSet(t *testing.T) { + req := NewStackRequest().(*Request) + req.Header.Set("Content-Length", "1234") + + var err error + req, err = req.SetStream(strings.NewReader("hello")) + if err != nil { + t.Fatalf("expect to set stream, %v", err) + } + + var m ContentLengthMiddleware + _, _, err = m.HandleBuild(context.Background(), + middleware.BuildInput{Request: req}, + nopBuildHandler{}, + ) + if err != nil { + t.Fatalf("expect middleware to run, %v", err) + } + + if e, a := "1234", req.Header.Get("Content-Length"); e != a { + t.Errorf("expect Content-Length not to change, got %v", a) + } +} + +type nopBuildHandler struct{} + +func (nopBuildHandler) HandleBuild(ctx context.Context, in middleware.BuildInput) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + return out, metadata, nil +} + +type basicReader struct { + buf []byte +} + +func (r *basicReader) Read(p []byte) (int, error) { + n := copy(p, r.buf) + r.buf = r.buf[n:] + return n, nil +} + +type errorSecondSeekableReader struct { + err error + count int +} + +func (r *errorSecondSeekableReader) Read(p []byte) (int, error) { + return 0, io.EOF +} +func (r *errorSecondSeekableReader) Seek(offset int64, whence int) (int64, error) { + r.count++ + if r.count == 2 { + return 0, r.err + } + return 0, nil +} diff --git a/transport/http/request.go b/transport/http/request.go index cb8a3c3e1..755e94bcb 100644 --- a/transport/http/request.go +++ b/transport/http/request.go @@ -38,6 +38,32 @@ func (r *Request) Clone() *Request { return &rc } +// StreamLength returns the number of bytes of the serialized stream attached +// to the request and ok set. If the length cannot be determined, an error will +// be returned. +func (r *Request) StreamLength() (size int64, ok bool, err error) { + if l, ok := r.stream.(interface{ Len() int }); ok { + return int64(l.Len()), true, nil + } + + if !r.isStreamSeekable { + return 0, false, nil + } + + s := r.stream.(io.Seeker) + endOffset, err := s.Seek(0, io.SeekEnd) + if err != nil { + return 0, false, err + } + + _, err = s.Seek(r.streamStartPos, io.SeekStart) + if err != nil { + return 0, false, err + } + + return endOffset - r.streamStartPos, true, nil +} + // RewindStream will rewind the io.Reader to the relative start position if it // is an io.Seeker. func (r *Request) RewindStream() error {