Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow max request size to be user-specified #4824

Merged
merged 4 commits into from
Jul 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions command/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ type ServerCommand struct {

type ServerListener struct {
net.Listener
config map[string]interface{}
config map[string]interface{}
maxRequestSize int64
}

func (c *ServerCommand) Synopsis() string {
Expand Down Expand Up @@ -689,11 +690,6 @@ CLUSTER_SYNTHESIS_COMPLETE:
return 1
}

lns = append(lns, ServerListener{
Listener: ln,
config: lnConfig.Config,
})

if reloadFunc != nil {
relSlice := (*c.reloadFuncs)["listener|"+lnConfig.Type]
relSlice = append(relSlice, reloadFunc)
Expand Down Expand Up @@ -728,6 +724,26 @@ CLUSTER_SYNTHESIS_COMPLETE:
props["cluster address"] = addr
}

var maxRequestSize int64 = vaulthttp.DefaultMaxRequestSize
if valRaw, ok := lnConfig.Config["max_request_size"]; ok {
val, err := parseutil.ParseInt(valRaw)
if err != nil {
c.UI.Error(fmt.Sprintf("Could not parse max_request_size value %v", valRaw))
return 1
}

if val >= 0 {
maxRequestSize = val
}
}
props["max_request_size"] = fmt.Sprintf("%d", maxRequestSize)

lns = append(lns, ServerListener{
Listener: ln,
config: lnConfig.Config,
maxRequestSize: maxRequestSize,
})

// Store the listener props for output later
key := fmt.Sprintf("listener %d", i+1)
propsList := make([]string, 0, len(props))
Expand Down Expand Up @@ -792,7 +808,9 @@ CLUSTER_SYNTHESIS_COMPLETE:
// This needs to happen before we first unseal, so before we trigger dev
// mode if it's set
core.SetClusterListenerAddrs(clusterAddrs)
core.SetClusterHandler(vaulthttp.Handler(core))
core.SetClusterHandler(vaulthttp.Handler(&vault.HandlerProperties{
Core: core,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't we ever need a size limit on the cluster connection?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the time we get to the HTTP handler we've already handshaked the cluster connection so we already know if the connection is authorized -- if so we shouldn't limit it.

}))

err = core.UnsealWithStoredKeys(context.Background())
if err != nil {
Expand Down Expand Up @@ -925,7 +943,10 @@ CLUSTER_SYNTHESIS_COMPLETE:

// Initialize the HTTP servers
for _, ln := range lns {
handler := vaulthttp.Handler(core)
handler := vaulthttp.Handler(&vault.HandlerProperties{
Core: core,
MaxRequestSize: ln.maxRequestSize,
})

// We perform validation on the config earlier, we can just cast here
if _, ok := ln.config["x_forwarded_for_authorized_addrs"]; ok {
Expand Down Expand Up @@ -1195,7 +1216,9 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
c.UI.Output("")

for _, core := range testCluster.Cores {
core.Server.Handler = vaulthttp.Handler(core.Core)
core.Server.Handler = vaulthttp.Handler(&vault.HandlerProperties{
Core: core.Core,
})
core.SetClusterHandler(core.Server.Handler)
}

Expand Down
1 change: 1 addition & 0 deletions command/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
"x_forwarded_for_reject_not_authorized",
"x_forwarded_for_reject_not_present",
"infrastructure",
"max_request_size",
"node_id",
"proxy_protocol_behavior",
"proxy_protocol_authorized_addrs",
Expand Down
29 changes: 22 additions & 7 deletions helper/forwarding/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -56,11 +59,30 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request
}

func GenerateForwardedRequest(req *http.Request) (*Request, error) {
var reader io.Reader = req.Body
ctx := req.Context()
maxRequestSize := ctx.Value("max_request_size")
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return nil, errors.New("could not parse max_request_size from request context")
}
if max > 0 {
reader = io.LimitReader(req.Body, max)
}
}

body, err := ioutil.ReadAll(reader)
if err != nil {
return nil, err
}

fq := Request{
Method: req.Method,
HeaderEntries: make(map[string]*HeaderEntry, len(req.Header)),
Host: req.Host,
RemoteAddr: req.RemoteAddr,
Body: body,
}

reqURL := req.URL
Expand All @@ -80,13 +102,6 @@ func GenerateForwardedRequest(req *http.Request) (*Request, error) {
}
}

buf := bytes.NewBuffer(nil)
_, err := buf.ReadFrom(req.Body)
if err != nil {
return nil, err
}
fq.Body = buf.Bytes()

if req.TLS != nil && req.TLS.PeerCertificates != nil && len(req.TLS.PeerCertificates) > 0 {
fq.PeerCertificates = make([][]byte, len(req.TLS.PeerCertificates))
for i, cert := range req.TLS.PeerCertificates {
Expand Down
12 changes: 6 additions & 6 deletions http/forwarded_for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestHandler_XForwardedFor(t *testing.T) {
// First: test reject not present
t.Run("reject_not_present", func(t *testing.T) {
t.Parallel()
testHandler := func(c *vault.Core) http.Handler {
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
Expand Down Expand Up @@ -69,7 +69,7 @@ func TestHandler_XForwardedFor(t *testing.T) {
// Next: test allow unauth
t.Run("allow_unauth", func(t *testing.T) {
t.Parallel()
testHandler := func(c *vault.Core) http.Handler {
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
Expand Down Expand Up @@ -106,7 +106,7 @@ func TestHandler_XForwardedFor(t *testing.T) {
// Next: test fail unauth
t.Run("fail_unauth", func(t *testing.T) {
t.Parallel()
testHandler := func(c *vault.Core) http.Handler {
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
Expand Down Expand Up @@ -140,7 +140,7 @@ func TestHandler_XForwardedFor(t *testing.T) {
// Next: test bad hops (too many)
t.Run("too_many_hops", func(t *testing.T) {
t.Parallel()
testHandler := func(c *vault.Core) http.Handler {
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
Expand Down Expand Up @@ -174,7 +174,7 @@ func TestHandler_XForwardedFor(t *testing.T) {
// Next: test picking correct value
t.Run("correct_hop_skipping", func(t *testing.T) {
t.Parallel()
testHandler := func(c *vault.Core) http.Handler {
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
Expand Down Expand Up @@ -211,7 +211,7 @@ func TestHandler_XForwardedFor(t *testing.T) {
// Next: multi-header approach
t.Run("correct_hop_skipping_multi_header", func(t *testing.T) {
t.Parallel()
testHandler := func(c *vault.Core) http.Handler {
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
Expand Down
44 changes: 34 additions & 10 deletions http/handler.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package http

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -52,10 +54,11 @@ const (
// soft-mandatory Sentinel policies.
PolicyOverrideHeaderName = "X-Vault-Policy-Override"

// MaxRequestSize is the maximum accepted request size. This is to prevent
// a denial of service attack where no Content-Length is provided and the server
// is fed ever more data until it exhausts memory.
MaxRequestSize = 32 * 1024 * 1024
// DefaultMaxRequestSize is the default maximum accepted request size. This
// is to prevent a denial of service attack where no Content-Length is
// provided and the server is fed ever more data until it exhausts memory.
// Can be overridden per listener.
DefaultMaxRequestSize = 32 * 1024 * 1024
)

var (
Expand All @@ -67,7 +70,9 @@ var (

// Handler returns an http.Handler for the API. This can be used on
// its own to mount the Vault API within another web server.
func Handler(core *vault.Core) http.Handler {
func Handler(props *vault.HandlerProperties) http.Handler {
core := props.Core

// Create the muxer to handle the actual endpoints
mux := http.NewServeMux()
mux.Handle("/v1/sys/init", handleSysInit(core))
Expand Down Expand Up @@ -108,7 +113,7 @@ func Handler(core *vault.Core) http.Handler {

// Wrap the help wrapped handler with another layer with a generic
// handler
genericWrappedHandler := wrapGenericHandler(corsWrappedHandler)
genericWrappedHandler := wrapGenericHandler(corsWrappedHandler, props.MaxRequestSize)

// Wrap the handler with PrintablePathCheckHandler to check for non-printable
// characters in the request path.
Expand All @@ -120,12 +125,20 @@ func Handler(core *vault.Core) http.Handler {
// wrapGenericHandler wraps the handler with an extra layer of handler where
// tasks that should be commonly handled for all the requests and/or responses
// are performed.
func wrapGenericHandler(h http.Handler) http.Handler {
func wrapGenericHandler(h http.Handler, maxRequestSize int64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set the Cache-Control header for all the responses returned
// by Vault
w.Header().Set("Cache-Control", "no-store")
h.ServeHTTP(w, r)

// Add a context and put the request limit for this handler in it
if maxRequestSize > 0 {
ctx := context.WithValue(r.Context(), "max_request_size", maxRequestSize)
h.ServeHTTP(w, r.WithContext(ctx))
} else {
h.ServeHTTP(w, r)
}

return
})
}
Expand Down Expand Up @@ -326,8 +339,19 @@ func (fs *UIAssetWrapper) Open(name string) (http.File, error) {
func parseRequest(r *http.Request, w http.ResponseWriter, out interface{}) error {
// Limit the maximum number of bytes to MaxRequestSize to protect
// against an indefinite amount of data being read.
limit := http.MaxBytesReader(w, r.Body, MaxRequestSize)
err := jsonutil.DecodeJSONFromReader(limit, out)
reader := r.Body
ctx := r.Context()
maxRequestSize := ctx.Value("max_request_size")
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return errors.New("could not parse max_request_size from request context")
}
if max > 0 {
reader = http.MaxBytesReader(w, r.Body, max)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why do we use io.LimitReader above and http.MaxBytesReader here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different function signatures, based on what's available.

}
}
err := jsonutil.DecodeJSONFromReader(reader, out)
if err != nil && err != io.EOF {
return errwrap.Wrapf("failed to parse JSON input: {{err}}", err)
}
Expand Down
2 changes: 1 addition & 1 deletion http/logical_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func TestLogical_RequestSizeLimit(t *testing.T) {

// Write a very large object, should fail
resp := testHttpPut(t, token, addr+"/v1/secret/foo", map[string]interface{}{
"data": make([]byte, MaxRequestSize),
"data": make([]byte, DefaultMaxRequestSize),
})
testResponseStatus(t, resp, 413)
}
Expand Down
5 changes: 4 additions & 1 deletion http/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *v
// for tests.
mux := http.NewServeMux()
mux.Handle("/_test/auth", http.HandlerFunc(testHandleAuth))
mux.Handle("/", Handler(core))
mux.Handle("/", Handler(&vault.HandlerProperties{
Core: core,
MaxRequestSize: DefaultMaxRequestSize,
}))

server := &http.Server{
Addr: ln.Addr().String(),
Expand Down
7 changes: 7 additions & 0 deletions vault/request_handling.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ const (
replTimeout = 10 * time.Second
)

// HanlderProperties is used to seed configuration into a vaulthttp.Handler.
// It's in this package to avoid a circular dependency
type HandlerProperties struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nitpick] Can we call this HandlerOptions instead, unless there was a reason to call this HandlerProperties?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I called it Properties since it's not just user-set options (e.g. Core).

Core *Core
MaxRequestSize int64
}

// fetchEntityAndDerivedPolicies returns the entity object for the given entity
// ID. If the entity is merged into a different entity object, the entity into
// which the given entity ID is merged into will be returned. This function
Expand Down
6 changes: 4 additions & 2 deletions vault/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ type TestClusterCore struct {
type TestClusterOptions struct {
KeepStandbysSealed bool
SkipInit bool
HandlerFunc func(*Core) http.Handler
HandlerFunc func(*HandlerProperties) http.Handler
BaseListenAddress string
NumCores int
SealFunc func() Seal
Expand Down Expand Up @@ -1249,7 +1249,9 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
}
cores = append(cores, c)
if opts != nil && opts.HandlerFunc != nil {
handlers[i] = opts.HandlerFunc(c)
handlers[i] = opts.HandlerFunc(&HandlerProperties{
Core: c,
})
servers[i].Handler = handlers[i]
}
}
Expand Down
4 changes: 4 additions & 0 deletions website/source/docs/configuration/listener/tcp.html.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ listener "tcp" {
they need to hop through a TCP load balancer or some other scheme in order to
talk.

- `max_request_size` `(int: 33554432)` – Specifies a hard maximum allowed
request size, in bytes. Defaults to 32 MB. Specifying a number less than or
equal to `0` turns off limiting altogether.

- `proxy_protocol_behavior` `(string: "") – When specified, turns on the PROXY
protocol for the listener.
Accepted Values:
Expand Down