-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add middleware to set HTTP request content length
Adds middleware that will set the HTTP request's Content-Length header if the length of the serialized stream can be determined. If the length cannot be determined or the Header value is already set the middleware will do nothing. Fails the request if the underlying request stream reader is seekable but fails to seek.
- Loading branch information
Showing
3 changed files
with
208 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package http | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"strconv" | ||
|
||
"github.com/awslabs/smithy-go/middleware" | ||
) | ||
|
||
// ContentLengthMiddleware provides a middleware to set the content-length | ||
// header for the length of a serialize request body. | ||
type ContentLengthMiddleware struct { | ||
} | ||
|
||
// AddContentLengthMiddleware adds ContentLengthMiddleware to the middleware | ||
// stack's Build step. | ||
func AddContentLengthMiddleware(stack *middleware.Stack) { | ||
stack.Build.Add(&ContentLengthMiddleware{}, middleware.After) | ||
} | ||
|
||
// ID the identifier for the ContentLengthMiddleware | ||
func (m *ContentLengthMiddleware) ID() string { return "ContentLengthMiddleware" } | ||
|
||
// HandleBuild adds the length of the serialized request to the HTTP header | ||
// if the length can be determined. | ||
func (m *ContentLengthMiddleware) HandleBuild( | ||
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, | ||
) ( | ||
out middleware.BuildOutput, metadata middleware.Metadata, err error, | ||
) { | ||
req, ok := in.Request.(*Request) | ||
if !ok { | ||
return out, metadata, fmt.Errorf("unknown request type %T", req) | ||
} | ||
|
||
// Don't set content length if header is already set. | ||
if vs := req.Header.Values("Content-Length"); len(vs) != 0 { | ||
return next.HandleBuild(ctx, in) | ||
} | ||
|
||
if n, ok, err := req.StreamLength(); err != nil { | ||
return out, metadata, fmt.Errorf( | ||
"failed getting length of request stream, %w", err) | ||
} else if ok { | ||
req.Header.Set("Content-Length", strconv.FormatInt(n, 10)) | ||
} | ||
|
||
return next.HandleBuild(ctx, in) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
package http | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"fmt" | ||
"io" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/awslabs/smithy-go/middleware" | ||
) | ||
|
||
func TestContentLengthMiddleware(t *testing.T) { | ||
cases := map[string]struct { | ||
Stream io.Reader | ||
ExpectLen string | ||
ExpectErr string | ||
}{ | ||
// Cases | ||
"bytes.Reader": { | ||
Stream: bytes.NewReader(make([]byte, 10)), | ||
ExpectLen: "10", | ||
}, | ||
"bytes.Buffer": { | ||
Stream: bytes.NewBuffer(make([]byte, 10)), | ||
ExpectLen: "10", | ||
}, | ||
"strings.Reader": { | ||
Stream: strings.NewReader("hello"), | ||
ExpectLen: "5", | ||
}, | ||
"un-seekable and no length": { | ||
Stream: &basicReader{buf: make([]byte, 10)}, | ||
}, | ||
"with error": { | ||
Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")}, | ||
ExpectErr: "seek failed", | ||
}, | ||
} | ||
|
||
for name, c := range cases { | ||
t.Run(name, func(t *testing.T) { | ||
var err error | ||
req := NewStackRequest().(*Request) | ||
req, err = req.SetStream(c.Stream) | ||
if err != nil { | ||
t.Fatalf("expect to set stream, %v", err) | ||
} | ||
|
||
var m ContentLengthMiddleware | ||
_, _, err = m.HandleBuild(context.Background(), | ||
middleware.BuildInput{Request: req}, | ||
nopBuildHandler{}, | ||
) | ||
if len(c.ExpectErr) != 0 { | ||
if err == nil { | ||
t.Fatalf("expect error, got none") | ||
} | ||
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { | ||
t.Fatalf("expect error to contain %q, got %v", e, a) | ||
} | ||
} else if err != nil { | ||
t.Fatalf("expect no error, got %v", err) | ||
} | ||
|
||
t.Logf("request Content-Length:%v", req.Header.Get("Content-Length")) | ||
|
||
if e, a := c.ExpectLen, req.Header.Get("Content-Length"); e != a { | ||
t.Errorf("expect %v content-length, got %v", e, a) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestContentLengthMiddleware_HeaderSet(t *testing.T) { | ||
req := NewStackRequest().(*Request) | ||
req.Header.Set("Content-Length", "1234") | ||
|
||
var err error | ||
req, err = req.SetStream(strings.NewReader("hello")) | ||
if err != nil { | ||
t.Fatalf("expect to set stream, %v", err) | ||
} | ||
|
||
var m ContentLengthMiddleware | ||
_, _, err = m.HandleBuild(context.Background(), | ||
middleware.BuildInput{Request: req}, | ||
nopBuildHandler{}, | ||
) | ||
if err != nil { | ||
t.Fatalf("expect middleware to run, %v", err) | ||
} | ||
|
||
if e, a := "1234", req.Header.Get("Content-Length"); e != a { | ||
t.Errorf("expect Content-Length not to change, got %v", a) | ||
} | ||
} | ||
|
||
type nopBuildHandler struct{} | ||
|
||
func (nopBuildHandler) HandleBuild(ctx context.Context, in middleware.BuildInput) ( | ||
out middleware.BuildOutput, metadata middleware.Metadata, err error, | ||
) { | ||
return out, metadata, nil | ||
} | ||
|
||
type basicReader struct { | ||
buf []byte | ||
} | ||
|
||
func (r *basicReader) Read(p []byte) (int, error) { | ||
n := copy(p, r.buf) | ||
r.buf = r.buf[n:] | ||
return n, nil | ||
} | ||
|
||
type errorSecondSeekableReader struct { | ||
err error | ||
count int | ||
} | ||
|
||
func (r *errorSecondSeekableReader) Read(p []byte) (int, error) { | ||
return 0, io.EOF | ||
} | ||
func (r *errorSecondSeekableReader) Seek(offset int64, whence int) (int64, error) { | ||
r.count++ | ||
if r.count == 2 { | ||
return 0, r.err | ||
} | ||
return 0, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters