From 5064f7f4da485004f380dc277ad2dec1adc3ab73 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Sun, 24 Jun 2018 12:33:39 -0400 Subject: [PATCH 1/3] Allow max request size to be user-specified This turned out to be way more impactful than I'd expected because I felt like the right granularity was per-listener, since an org may want to treat external clients differently from internal clients. It's pretty straightforward though. This also introduces actually using request contexts for values, which so far we have not done (using our own logical.Request struct instead), but this allows non-logical methods to still get this benefit. --- command/server.go | 41 +++++++++++++++---- command/server/config.go | 1 + helper/forwarding/util.go | 29 +++++++++---- http/handler.go | 35 +++++++++++++--- http/testing.go | 4 +- vault/request_handling.go | 7 ++++ vault/testing.go | 6 ++- .../docs/configuration/listener/tcp.html.md | 4 ++ 8 files changed, 102 insertions(+), 25 deletions(-) diff --git a/command/server.go b/command/server.go index 2a240dfcf1ba..eb08b1eeb760 100644 --- a/command/server.go +++ b/command/server.go @@ -97,7 +97,8 @@ type ServerCommand struct { type ServerListener struct { net.Listener - config map[string]interface{} + config map[string]interface{} + maxRequestSize int64 } func (c *ServerCommand) Synopsis() string { @@ -689,11 +690,6 @@ CLUSTER_SYNTHESIS_COMPLETE: return 1 } - lns = append(lns, ServerListener{ - Listener: ln, - config: lnConfig.Config, - }) - if reloadFunc != nil { relSlice := (*c.reloadFuncs)["listener|"+lnConfig.Type] relSlice = append(relSlice, reloadFunc) @@ -728,6 +724,26 @@ CLUSTER_SYNTHESIS_COMPLETE: props["cluster address"] = addr } + var maxRequestSize int64 = 32 * 1024 * 1024 + if valRaw, ok := lnConfig.Config["max_request_size"]; ok { + val, err := parseutil.ParseInt(valRaw) + if err != nil { + c.UI.Error(fmt.Sprintf("Could not parse max_request_size value %v", valRaw)) + return 1 + } + + if val >= 0 { + maxRequestSize = val + } + } + props["max_request_size"] = fmt.Sprintf("%d", maxRequestSize) + + lns = append(lns, ServerListener{ + Listener: ln, + config: lnConfig.Config, + maxRequestSize: maxRequestSize, + }) + // Store the listener props for output later key := fmt.Sprintf("listener %d", i+1) propsList := make([]string, 0, len(props)) @@ -792,7 +808,9 @@ CLUSTER_SYNTHESIS_COMPLETE: // This needs to happen before we first unseal, so before we trigger dev // mode if it's set core.SetClusterListenerAddrs(clusterAddrs) - core.SetClusterHandler(vaulthttp.Handler(core)) + core.SetClusterHandler(vaulthttp.Handler(&vault.HandlerProperties{ + Core: core, + })) err = core.UnsealWithStoredKeys(context.Background()) if err != nil { @@ -925,7 +943,10 @@ CLUSTER_SYNTHESIS_COMPLETE: // Initialize the HTTP servers for _, ln := range lns { - handler := vaulthttp.Handler(core) + handler := vaulthttp.Handler(&vault.HandlerProperties{ + Core: core, + MaxRequestSize: ln.maxRequestSize, + }) // We perform validation on the config earlier, we can just cast here if _, ok := ln.config["x_forwarded_for_authorized_addrs"]; ok { @@ -1195,7 +1216,9 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m c.UI.Output("") for _, core := range testCluster.Cores { - core.Server.Handler = vaulthttp.Handler(core.Core) + core.Server.Handler = vaulthttp.Handler(&vault.HandlerProperties{ + Core: core.Core, + }) core.SetClusterHandler(core.Server.Handler) } diff --git a/command/server/config.go b/command/server/config.go index 33c98db4a6d3..7a5212aa999e 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -804,6 +804,7 @@ func parseListeners(result *Config, list *ast.ObjectList) error { "x_forwarded_for_reject_not_authorized", "x_forwarded_for_reject_not_present", "infrastructure", + "max_request_size", "node_id", "proxy_protocol_behavior", "proxy_protocol_authorized_addrs", diff --git a/helper/forwarding/util.go b/helper/forwarding/util.go index 92e6cb152426..67897830e2a3 100644 --- a/helper/forwarding/util.go +++ b/helper/forwarding/util.go @@ -4,6 +4,8 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "errors" + "io" "net/http" "net/url" "os" @@ -56,11 +58,31 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request } func GenerateForwardedRequest(req *http.Request) (*Request, error) { + var reader io.Reader = req.Body + ctx := req.Context() + maxRequestSize := ctx.Value("max_request_size") + if maxRequestSize != nil { + max, ok := maxRequestSize.(int64) + if !ok { + return nil, errors.New("could not parse max_request_size from request context") + } + if max > 0 { + reader = io.LimitReader(req.Body, max) + } + } + + buf := bytes.NewBuffer(nil) + _, err := buf.ReadFrom(reader) + if err != nil { + return nil, err + } + fq := Request{ Method: req.Method, HeaderEntries: make(map[string]*HeaderEntry, len(req.Header)), Host: req.Host, RemoteAddr: req.RemoteAddr, + Body: buf.Bytes(), } reqURL := req.URL @@ -80,13 +102,6 @@ func GenerateForwardedRequest(req *http.Request) (*Request, error) { } } - buf := bytes.NewBuffer(nil) - _, err := buf.ReadFrom(req.Body) - if err != nil { - return nil, err - } - fq.Body = buf.Bytes() - if req.TLS != nil && req.TLS.PeerCertificates != nil && len(req.TLS.PeerCertificates) > 0 { fq.PeerCertificates = make([][]byte, len(req.TLS.PeerCertificates)) for i, cert := range req.TLS.PeerCertificates { diff --git a/http/handler.go b/http/handler.go index a9be673cb675..c275cf300e3a 100644 --- a/http/handler.go +++ b/http/handler.go @@ -1,7 +1,9 @@ package http import ( + "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -67,7 +69,9 @@ var ( // Handler returns an http.Handler for the API. This can be used on // its own to mount the Vault API within another web server. -func Handler(core *vault.Core) http.Handler { +func Handler(props *vault.HandlerProperties) http.Handler { + core := props.Core + // Create the muxer to handle the actual endpoints mux := http.NewServeMux() mux.Handle("/v1/sys/init", handleSysInit(core)) @@ -108,7 +112,7 @@ func Handler(core *vault.Core) http.Handler { // Wrap the help wrapped handler with another layer with a generic // handler - genericWrappedHandler := wrapGenericHandler(corsWrappedHandler) + genericWrappedHandler := wrapGenericHandler(corsWrappedHandler, props.MaxRequestSize) // Wrap the handler with PrintablePathCheckHandler to check for non-printable // characters in the request path. @@ -120,12 +124,20 @@ func Handler(core *vault.Core) http.Handler { // wrapGenericHandler wraps the handler with an extra layer of handler where // tasks that should be commonly handled for all the requests and/or responses // are performed. -func wrapGenericHandler(h http.Handler) http.Handler { +func wrapGenericHandler(h http.Handler, maxRequestSize int64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Set the Cache-Control header for all the responses returned // by Vault w.Header().Set("Cache-Control", "no-store") - h.ServeHTTP(w, r) + + // Add a context and put the request limit for this handler in it + if maxRequestSize > 0 { + ctx := context.WithValue(r.Context(), "max_request_size", maxRequestSize) + h.ServeHTTP(w, r.WithContext(ctx)) + } else { + h.ServeHTTP(w, r) + } + return }) } @@ -326,8 +338,19 @@ func (fs *UIAssetWrapper) Open(name string) (http.File, error) { func parseRequest(r *http.Request, w http.ResponseWriter, out interface{}) error { // Limit the maximum number of bytes to MaxRequestSize to protect // against an indefinite amount of data being read. - limit := http.MaxBytesReader(w, r.Body, MaxRequestSize) - err := jsonutil.DecodeJSONFromReader(limit, out) + reader := r.Body + ctx := r.Context() + maxRequestSize := ctx.Value("max_request_size") + if maxRequestSize != nil { + max, ok := maxRequestSize.(int64) + if !ok { + return errors.New("could not parse max_request_size from request context") + } + if max > 0 { + reader = http.MaxBytesReader(w, r.Body, max) + } + } + err := jsonutil.DecodeJSONFromReader(reader, out) if err != nil && err != io.EOF { return errwrap.Wrapf("failed to parse JSON input: {{err}}", err) } diff --git a/http/testing.go b/http/testing.go index 2299006c98bf..36695d6f8940 100644 --- a/http/testing.go +++ b/http/testing.go @@ -30,7 +30,9 @@ func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *v // for tests. mux := http.NewServeMux() mux.Handle("/_test/auth", http.HandlerFunc(testHandleAuth)) - mux.Handle("/", Handler(core)) + mux.Handle("/", Handler(&vault.HandlerProperties{ + Core: core, + })) server := &http.Server{ Addr: ln.Addr().String(), diff --git a/vault/request_handling.go b/vault/request_handling.go index fd91e33dbc85..a6424b362673 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -26,6 +26,13 @@ const ( replTimeout = 10 * time.Second ) +// HanlderProperties is used to seed configuration into a vaulthttp.Handler. +// It's in this package to avoid a circular dependency +type HandlerProperties struct { + Core *Core + MaxRequestSize int64 +} + // fetchEntityAndDerivedPolicies returns the entity object for the given entity // ID. If the entity is merged into a different entity object, the entity into // which the given entity ID is merged into will be returned. This function diff --git a/vault/testing.go b/vault/testing.go index da8985e81fc9..99f0d65a5b7b 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -880,7 +880,7 @@ type TestClusterCore struct { type TestClusterOptions struct { KeepStandbysSealed bool SkipInit bool - HandlerFunc func(*Core) http.Handler + HandlerFunc func(*HandlerProperties) http.Handler BaseListenAddress string NumCores int SealFunc func() Seal @@ -1249,7 +1249,9 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te } cores = append(cores, c) if opts != nil && opts.HandlerFunc != nil { - handlers[i] = opts.HandlerFunc(c) + handlers[i] = opts.HandlerFunc(&HandlerProperties{ + Core: c, + }) servers[i].Handler = handlers[i] } } diff --git a/website/source/docs/configuration/listener/tcp.html.md b/website/source/docs/configuration/listener/tcp.html.md index 36ad045b3f95..2e786ffca616 100644 --- a/website/source/docs/configuration/listener/tcp.html.md +++ b/website/source/docs/configuration/listener/tcp.html.md @@ -29,6 +29,10 @@ listener "tcp" { they need to hop through a TCP load balancer or some other scheme in order to talk. +- `max_request_size` `(int: 33554432)` – Specifies a hard maximum allowed + request size, in bytes. Defaults to 32 MB. Specifying a number less than or + equal to `0` turns off limiting altogether. + - `proxy_protocol_behavior` `(string: "") – When specified, turns on the PROXY protocol for the listener. Accepted Values: From e073371ce75977a9592496df0d76732d7c81debf Mon Sep 17 00:00:00 2001 From: Jim Kalafut Date: Sun, 24 Jun 2018 10:20:30 -0700 Subject: [PATCH 2/3] Switch to ioutil.ReadAll() --- helper/forwarding/util.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/helper/forwarding/util.go b/helper/forwarding/util.go index 67897830e2a3..0a4973e9f84e 100644 --- a/helper/forwarding/util.go +++ b/helper/forwarding/util.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "errors" "io" + "io/ioutil" "net/http" "net/url" "os" @@ -71,8 +72,7 @@ func GenerateForwardedRequest(req *http.Request) (*Request, error) { } } - buf := bytes.NewBuffer(nil) - _, err := buf.ReadFrom(reader) + body, err := ioutil.ReadAll(reader) if err != nil { return nil, err } @@ -82,7 +82,7 @@ func GenerateForwardedRequest(req *http.Request) (*Request, error) { HeaderEntries: make(map[string]*HeaderEntry, len(req.Header)), Host: req.Host, RemoteAddr: req.RemoteAddr, - Body: buf.Bytes(), + Body: body, } reqURL := req.URL From a7e43b0ffa55358cc1770f8c3762819560142ad6 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Sun, 24 Jun 2018 17:56:25 -0400 Subject: [PATCH 3/3] Fix tests --- command/server.go | 2 +- http/forwarded_for_test.go | 12 ++++++------ http/handler.go | 9 +++++---- http/logical_test.go | 2 +- http/testing.go | 3 ++- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/command/server.go b/command/server.go index eb08b1eeb760..ede3611661d7 100644 --- a/command/server.go +++ b/command/server.go @@ -724,7 +724,7 @@ CLUSTER_SYNTHESIS_COMPLETE: props["cluster address"] = addr } - var maxRequestSize int64 = 32 * 1024 * 1024 + var maxRequestSize int64 = vaulthttp.DefaultMaxRequestSize if valRaw, ok := lnConfig.Config["max_request_size"]; ok { val, err := parseutil.ParseInt(valRaw) if err != nil { diff --git a/http/forwarded_for_test.go b/http/forwarded_for_test.go index 0eec439f4adc..170b54334dc2 100644 --- a/http/forwarded_for_test.go +++ b/http/forwarded_for_test.go @@ -24,7 +24,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // First: test reject not present t.Run("reject_not_present", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -69,7 +69,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: test allow unauth t.Run("allow_unauth", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -106,7 +106,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: test fail unauth t.Run("fail_unauth", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -140,7 +140,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: test bad hops (too many) t.Run("too_many_hops", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -174,7 +174,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: test picking correct value t.Run("correct_hop_skipping", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) @@ -211,7 +211,7 @@ func TestHandler_XForwardedFor(t *testing.T) { // Next: multi-header approach t.Run("correct_hop_skipping_multi_header", func(t *testing.T) { t.Parallel() - testHandler := func(c *vault.Core) http.Handler { + testHandler := func(props *vault.HandlerProperties) http.Handler { origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.RemoteAddr)) diff --git a/http/handler.go b/http/handler.go index c275cf300e3a..6cfd4a7b99e2 100644 --- a/http/handler.go +++ b/http/handler.go @@ -54,10 +54,11 @@ const ( // soft-mandatory Sentinel policies. PolicyOverrideHeaderName = "X-Vault-Policy-Override" - // MaxRequestSize is the maximum accepted request size. This is to prevent - // a denial of service attack where no Content-Length is provided and the server - // is fed ever more data until it exhausts memory. - MaxRequestSize = 32 * 1024 * 1024 + // DefaultMaxRequestSize is the default maximum accepted request size. This + // is to prevent a denial of service attack where no Content-Length is + // provided and the server is fed ever more data until it exhausts memory. + // Can be overridden per listener. + DefaultMaxRequestSize = 32 * 1024 * 1024 ) var ( diff --git a/http/logical_test.go b/http/logical_test.go index e6ec3da29374..cd868bcfad92 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -261,7 +261,7 @@ func TestLogical_RequestSizeLimit(t *testing.T) { // Write a very large object, should fail resp := testHttpPut(t, token, addr+"/v1/secret/foo", map[string]interface{}{ - "data": make([]byte, MaxRequestSize), + "data": make([]byte, DefaultMaxRequestSize), }) testResponseStatus(t, resp, 413) } diff --git a/http/testing.go b/http/testing.go index 36695d6f8940..13501f5daf19 100644 --- a/http/testing.go +++ b/http/testing.go @@ -31,7 +31,8 @@ func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *v mux := http.NewServeMux() mux.Handle("/_test/auth", http.HandlerFunc(testHandleAuth)) mux.Handle("/", Handler(&vault.HandlerProperties{ - Core: core, + Core: core, + MaxRequestSize: DefaultMaxRequestSize, })) server := &http.Server{