diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java index 49aea69fe..562fed8f1 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java @@ -168,8 +168,15 @@ protected void generateTestBodySetup(GoWriter writer) { */ protected void generateTestServerHandler(GoWriter writer) { writer.write("actualReq = r.Clone(r.Context())"); - // Go does not set RawPath on http server if nothing is excaped - writer.write("if len(actualReq.URL.RawPath) == 0 { actualReq.URL.RawPath = actualReq.URL.Path }"); + // Go does not set RawPath on http server if nothing is escaped + writer.openBlock("if len(actualReq.URL.RawPath) == 0 {", "}", () -> { + writer.write("actualReq.URL.RawPath = actualReq.URL.Path"); + }); + // Go automatically removes Content-Length header setting it to the member. + writer.addUseImports(SmithyGoDependency.STRCONV); + writer.openBlock("if v := actualReq.ContentLength; v != 0 {", "}", () -> { + writer.write("actualReq.Header.Set(\"Content-Length\", strconv.FormatInt(v, 10))"); + }); writer.addUseImports(SmithyGoDependency.BYTES); writer.write("var buf bytes.Buffer"); diff --git a/transport/http/middleware_content_length.go b/transport/http/middleware_content_length.go new file mode 100644 index 000000000..dde2353f2 --- /dev/null +++ b/transport/http/middleware_content_length.go @@ -0,0 +1,51 @@ +package http + +import ( + "context" + "fmt" + "strconv" + + "github.com/awslabs/smithy-go/middleware" +) + +// ContentLengthMiddleware provides a middleware to set the content-length +// header for the length of a serialize request body. +type ContentLengthMiddleware struct { +} + +// AddContentLengthMiddleware adds ContentLengthMiddleware to the middleware +// stack's Build step. +func AddContentLengthMiddleware(stack *middleware.Stack) { + stack.Build.Add(&ContentLengthMiddleware{}, middleware.After) +} + +// ID the identifier for the ContentLengthMiddleware +func (m *ContentLengthMiddleware) ID() string { return "ContentLengthMiddleware" } + +// HandleBuild adds the length of the serialized request to the HTTP header +// if the length can be determined. +func (m *ContentLengthMiddleware) HandleBuild( + ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, +) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + req, ok := in.Request.(*Request) + if !ok { + return out, metadata, fmt.Errorf("unknown request type %T", req) + } + + // Don't set content length if header is already set. + if vs := req.Header.Values("Content-Length"); len(vs) != 0 { + return next.HandleBuild(ctx, in) + } + + if n, ok, err := req.StreamLength(); err != nil { + return out, metadata, fmt.Errorf( + "failed getting length of request stream, %w", err) + } else if ok && n > 0 { + // Only set content-length header when it is a positive value. + req.Header.Set("Content-Length", strconv.FormatInt(n, 10)) + } + + return next.HandleBuild(ctx, in) +} diff --git a/transport/http/middleware_content_length_test.go b/transport/http/middleware_content_length_test.go new file mode 100644 index 000000000..30de5e5a4 --- /dev/null +++ b/transport/http/middleware_content_length_test.go @@ -0,0 +1,139 @@ +package http + +import ( + "bytes" + "context" + "fmt" + "io" + "strings" + "testing" + + "github.com/awslabs/smithy-go/middleware" +) + +func TestContentLengthMiddleware(t *testing.T) { + cases := map[string]struct { + Stream io.Reader + ExpectLen string + ExpectErr string + }{ + // Cases + "bytes.Reader": { + Stream: bytes.NewReader(make([]byte, 10)), + ExpectLen: "10", + }, + "bytes.Buffer": { + Stream: bytes.NewBuffer(make([]byte, 10)), + ExpectLen: "10", + }, + "strings.Reader": { + Stream: strings.NewReader("hello"), + ExpectLen: "5", + }, + "empty stream": { + Stream: strings.NewReader(""), + }, + "nil stream": {}, + "un-seekable and no length": { + Stream: &basicReader{buf: make([]byte, 10)}, + }, + "with error": { + Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")}, + ExpectErr: "seek failed", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var err error + req := NewStackRequest().(*Request) + req, err = req.SetStream(c.Stream) + if err != nil { + t.Fatalf("expect to set stream, %v", err) + } + + var m ContentLengthMiddleware + _, _, err = m.HandleBuild(context.Background(), + middleware.BuildInput{Request: req}, + nopBuildHandler{}, + ) + if len(c.ExpectErr) != 0 { + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %q, got %v", e, a) + } + } else if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + t.Logf("request Content-Length:%v", req.Header.Get("Content-Length")) + + if e, a := c.ExpectLen, req.Header.Get("Content-Length"); e != a { + t.Errorf("expect %v content-length, got %v", e, a) + } + if a := req.Header.Values("Content-Length"); len(c.ExpectLen) == 0 && len(a) != 0 { + t.Errorf("expect no content-length header, got %v", a) + } + }) + } +} + +func TestContentLengthMiddleware_HeaderSet(t *testing.T) { + req := NewStackRequest().(*Request) + req.Header.Set("Content-Length", "1234") + + var err error + req, err = req.SetStream(strings.NewReader("hello")) + if err != nil { + t.Fatalf("expect to set stream, %v", err) + } + + var m ContentLengthMiddleware + _, _, err = m.HandleBuild(context.Background(), + middleware.BuildInput{Request: req}, + nopBuildHandler{}, + ) + if err != nil { + t.Fatalf("expect middleware to run, %v", err) + } + + if e, a := "1234", req.Header.Get("Content-Length"); e != a { + t.Errorf("expect Content-Length not to change, got %v", a) + } +} + +type nopBuildHandler struct{} + +func (nopBuildHandler) HandleBuild(ctx context.Context, in middleware.BuildInput) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + return out, metadata, nil +} + +type basicReader struct { + buf []byte +} + +func (r *basicReader) Read(p []byte) (int, error) { + n := copy(p, r.buf) + r.buf = r.buf[n:] + return n, nil +} + +type errorSecondSeekableReader struct { + err error + count int +} + +func (r *errorSecondSeekableReader) Read(p []byte) (int, error) { + return 0, io.EOF +} +func (r *errorSecondSeekableReader) Seek(offset int64, whence int) (int64, error) { + r.count++ + if r.count == 2 { + return 0, r.err + } + return 0, nil +} diff --git a/transport/http/request.go b/transport/http/request.go index cb8a3c3e1..11a177e52 100644 --- a/transport/http/request.go +++ b/transport/http/request.go @@ -38,6 +38,42 @@ func (r *Request) Clone() *Request { return &rc } +// StreamLength returns the number of bytes of the serialized stream attached +// to the request and ok set. If the length cannot be determined, an error will +// be returned. +func (r *Request) StreamLength() (size int64, ok bool, err error) { + if r.stream == nil { + return 0, true, nil + } + + if l, ok := r.stream.(interface{ Len() int }); ok { + return int64(l.Len()), true, nil + } + + if !r.isStreamSeekable { + return 0, false, nil + } + + s := r.stream.(io.Seeker) + endOffset, err := s.Seek(0, io.SeekEnd) + if err != nil { + return 0, false, err + } + + // The reason to seek to streamStartPos instead of 0 is to ensure that the + // SDK only sends the stream from the starting position the user's + // application provided it to the SDK at. For example application opens a + // file, and wants to skip the first N bytes uploading the rest. The + // application would move the file's offset N bytes, then hand it off to + // the SDK to send the remaining. The SDK should respect that initial offset. + _, err = s.Seek(r.streamStartPos, io.SeekStart) + if err != nil { + return 0, false, err + } + + return endOffset - r.streamStartPos, true, nil +} + // RewindStream will rewind the io.Reader to the relative start position if it // is an io.Seeker. func (r *Request) RewindStream() error {