From 03463dde2e9cc4cad5ff615bc0fc77e35831a2e4 Mon Sep 17 00:00:00 2001 From: Jason Del Ponte Date: Wed, 8 Jul 2020 17:32:46 -0700 Subject: [PATCH 1/3] Add middleware to set HTTP request content length Adds middleware that will set the HTTP request's Content-Length header if the length of the serialized stream can be determined. If the length cannot be determined or the Header value is already set the middleware will do nothing. Fails the request if the underlying request stream reader is seekable but fails to seek. --- transport/http/middleware_content_length.go | 50 +++++++ .../http/middleware_content_length_test.go | 132 ++++++++++++++++++ transport/http/request.go | 26 ++++ 3 files changed, 208 insertions(+) create mode 100644 transport/http/middleware_content_length.go create mode 100644 transport/http/middleware_content_length_test.go diff --git a/transport/http/middleware_content_length.go b/transport/http/middleware_content_length.go new file mode 100644 index 000000000..20586aad6 --- /dev/null +++ b/transport/http/middleware_content_length.go @@ -0,0 +1,50 @@ +package http + +import ( + "context" + "fmt" + "strconv" + + "github.com/awslabs/smithy-go/middleware" +) + +// ContentLengthMiddleware provides a middleware to set the content-length +// header for the length of a serialize request body. +type ContentLengthMiddleware struct { +} + +// AddContentLengthMiddleware adds ContentLengthMiddleware to the middleware +// stack's Build step. +func AddContentLengthMiddleware(stack *middleware.Stack) { + stack.Build.Add(&ContentLengthMiddleware{}, middleware.After) +} + +// ID the identifier for the ContentLengthMiddleware +func (m *ContentLengthMiddleware) ID() string { return "ContentLengthMiddleware" } + +// HandleBuild adds the length of the serialized request to the HTTP header +// if the length can be determined. +func (m *ContentLengthMiddleware) HandleBuild( + ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, +) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + req, ok := in.Request.(*Request) + if !ok { + return out, metadata, fmt.Errorf("unknown request type %T", req) + } + + // Don't set content length if header is already set. + if vs := req.Header.Values("Content-Length"); len(vs) != 0 { + return next.HandleBuild(ctx, in) + } + + if n, ok, err := req.StreamLength(); err != nil { + return out, metadata, fmt.Errorf( + "failed getting length of request stream, %w", err) + } else if ok { + req.Header.Set("Content-Length", strconv.FormatInt(n, 10)) + } + + return next.HandleBuild(ctx, in) +} diff --git a/transport/http/middleware_content_length_test.go b/transport/http/middleware_content_length_test.go new file mode 100644 index 000000000..418b6ee22 --- /dev/null +++ b/transport/http/middleware_content_length_test.go @@ -0,0 +1,132 @@ +package http + +import ( + "bytes" + "context" + "fmt" + "io" + "strings" + "testing" + + "github.com/awslabs/smithy-go/middleware" +) + +func TestContentLengthMiddleware(t *testing.T) { + cases := map[string]struct { + Stream io.Reader + ExpectLen string + ExpectErr string + }{ + // Cases + "bytes.Reader": { + Stream: bytes.NewReader(make([]byte, 10)), + ExpectLen: "10", + }, + "bytes.Buffer": { + Stream: bytes.NewBuffer(make([]byte, 10)), + ExpectLen: "10", + }, + "strings.Reader": { + Stream: strings.NewReader("hello"), + ExpectLen: "5", + }, + "un-seekable and no length": { + Stream: &basicReader{buf: make([]byte, 10)}, + }, + "with error": { + Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")}, + ExpectErr: "seek failed", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var err error + req := NewStackRequest().(*Request) + req, err = req.SetStream(c.Stream) + if err != nil { + t.Fatalf("expect to set stream, %v", err) + } + + var m ContentLengthMiddleware + _, _, err = m.HandleBuild(context.Background(), + middleware.BuildInput{Request: req}, + nopBuildHandler{}, + ) + if len(c.ExpectErr) != 0 { + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %q, got %v", e, a) + } + } else if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + t.Logf("request Content-Length:%v", req.Header.Get("Content-Length")) + + if e, a := c.ExpectLen, req.Header.Get("Content-Length"); e != a { + t.Errorf("expect %v content-length, got %v", e, a) + } + }) + } +} + +func TestContentLengthMiddleware_HeaderSet(t *testing.T) { + req := NewStackRequest().(*Request) + req.Header.Set("Content-Length", "1234") + + var err error + req, err = req.SetStream(strings.NewReader("hello")) + if err != nil { + t.Fatalf("expect to set stream, %v", err) + } + + var m ContentLengthMiddleware + _, _, err = m.HandleBuild(context.Background(), + middleware.BuildInput{Request: req}, + nopBuildHandler{}, + ) + if err != nil { + t.Fatalf("expect middleware to run, %v", err) + } + + if e, a := "1234", req.Header.Get("Content-Length"); e != a { + t.Errorf("expect Content-Length not to change, got %v", a) + } +} + +type nopBuildHandler struct{} + +func (nopBuildHandler) HandleBuild(ctx context.Context, in middleware.BuildInput) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + return out, metadata, nil +} + +type basicReader struct { + buf []byte +} + +func (r *basicReader) Read(p []byte) (int, error) { + n := copy(p, r.buf) + r.buf = r.buf[n:] + return n, nil +} + +type errorSecondSeekableReader struct { + err error + count int +} + +func (r *errorSecondSeekableReader) Read(p []byte) (int, error) { + return 0, io.EOF +} +func (r *errorSecondSeekableReader) Seek(offset int64, whence int) (int64, error) { + r.count++ + if r.count == 2 { + return 0, r.err + } + return 0, nil +} diff --git a/transport/http/request.go b/transport/http/request.go index cb8a3c3e1..755e94bcb 100644 --- a/transport/http/request.go +++ b/transport/http/request.go @@ -38,6 +38,32 @@ func (r *Request) Clone() *Request { return &rc } +// StreamLength returns the number of bytes of the serialized stream attached +// to the request and ok set. If the length cannot be determined, an error will +// be returned. +func (r *Request) StreamLength() (size int64, ok bool, err error) { + if l, ok := r.stream.(interface{ Len() int }); ok { + return int64(l.Len()), true, nil + } + + if !r.isStreamSeekable { + return 0, false, nil + } + + s := r.stream.(io.Seeker) + endOffset, err := s.Seek(0, io.SeekEnd) + if err != nil { + return 0, false, err + } + + _, err = s.Seek(r.streamStartPos, io.SeekStart) + if err != nil { + return 0, false, err + } + + return endOffset - r.streamStartPos, true, nil +} + // RewindStream will rewind the io.Reader to the relative start position if it // is an io.Seeker. func (r *Request) RewindStream() error { From cfcb303aec1b123af4aa2d31c566da15a8fbe451 Mon Sep 17 00:00:00 2001 From: Jason Del Ponte Date: Wed, 8 Jul 2020 18:03:20 -0700 Subject: [PATCH 2/3] ensure content-length is available in request test --- .../HttpProtocolUnitTestRequestGenerator.java | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java index 49aea69fe..562fed8f1 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java @@ -168,8 +168,15 @@ protected void generateTestBodySetup(GoWriter writer) { */ protected void generateTestServerHandler(GoWriter writer) { writer.write("actualReq = r.Clone(r.Context())"); - // Go does not set RawPath on http server if nothing is excaped - writer.write("if len(actualReq.URL.RawPath) == 0 { actualReq.URL.RawPath = actualReq.URL.Path }"); + // Go does not set RawPath on http server if nothing is escaped + writer.openBlock("if len(actualReq.URL.RawPath) == 0 {", "}", () -> { + writer.write("actualReq.URL.RawPath = actualReq.URL.Path"); + }); + // Go automatically removes Content-Length header setting it to the member. + writer.addUseImports(SmithyGoDependency.STRCONV); + writer.openBlock("if v := actualReq.ContentLength; v != 0 {", "}", () -> { + writer.write("actualReq.Header.Set(\"Content-Length\", strconv.FormatInt(v, 10))"); + }); writer.addUseImports(SmithyGoDependency.BYTES); writer.write("var buf bytes.Buffer"); From 9012df2af9db17072dada6f4eefb7326e85bd41a Mon Sep 17 00:00:00 2001 From: Jason Del Ponte Date: Tue, 14 Jul 2020 16:18:22 -0700 Subject: [PATCH 3/3] fixup PR feedback, don't set content length if empty/unset --- transport/http/middleware_content_length.go | 3 ++- transport/http/middleware_content_length_test.go | 7 +++++++ transport/http/request.go | 10 ++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/transport/http/middleware_content_length.go b/transport/http/middleware_content_length.go index 20586aad6..dde2353f2 100644 --- a/transport/http/middleware_content_length.go +++ b/transport/http/middleware_content_length.go @@ -42,7 +42,8 @@ func (m *ContentLengthMiddleware) HandleBuild( if n, ok, err := req.StreamLength(); err != nil { return out, metadata, fmt.Errorf( "failed getting length of request stream, %w", err) - } else if ok { + } else if ok && n > 0 { + // Only set content-length header when it is a positive value. req.Header.Set("Content-Length", strconv.FormatInt(n, 10)) } diff --git a/transport/http/middleware_content_length_test.go b/transport/http/middleware_content_length_test.go index 418b6ee22..30de5e5a4 100644 --- a/transport/http/middleware_content_length_test.go +++ b/transport/http/middleware_content_length_test.go @@ -30,6 +30,10 @@ func TestContentLengthMiddleware(t *testing.T) { Stream: strings.NewReader("hello"), ExpectLen: "5", }, + "empty stream": { + Stream: strings.NewReader(""), + }, + "nil stream": {}, "un-seekable and no length": { Stream: &basicReader{buf: make([]byte, 10)}, }, @@ -69,6 +73,9 @@ func TestContentLengthMiddleware(t *testing.T) { if e, a := c.ExpectLen, req.Header.Get("Content-Length"); e != a { t.Errorf("expect %v content-length, got %v", e, a) } + if a := req.Header.Values("Content-Length"); len(c.ExpectLen) == 0 && len(a) != 0 { + t.Errorf("expect no content-length header, got %v", a) + } }) } } diff --git a/transport/http/request.go b/transport/http/request.go index 755e94bcb..11a177e52 100644 --- a/transport/http/request.go +++ b/transport/http/request.go @@ -42,6 +42,10 @@ func (r *Request) Clone() *Request { // to the request and ok set. If the length cannot be determined, an error will // be returned. func (r *Request) StreamLength() (size int64, ok bool, err error) { + if r.stream == nil { + return 0, true, nil + } + if l, ok := r.stream.(interface{ Len() int }); ok { return int64(l.Len()), true, nil } @@ -56,6 +60,12 @@ func (r *Request) StreamLength() (size int64, ok bool, err error) { return 0, false, err } + // The reason to seek to streamStartPos instead of 0 is to ensure that the + // SDK only sends the stream from the starting position the user's + // application provided it to the SDK at. For example application opens a + // file, and wants to skip the first N bytes uploading the rest. The + // application would move the file's offset N bytes, then hand it off to + // the SDK to send the remaining. The SDK should respect that initial offset. _, err = s.Seek(r.streamStartPos, io.SeekStart) if err != nil { return 0, false, err