From b840d7d16f96ab916acec3231405a02f2ba632c9 Mon Sep 17 00:00:00 2001 From: Nathan Smith Date: Thu, 20 Jan 2022 11:11:41 -0800 Subject: [PATCH] Set max request size to 4MiB Signed-off-by: Nathan Smith --- cmd/app/serve.go | 46 +++++++++++++++----------- pkg/api/max_bytes.go | 27 ++++++++++++++++ pkg/api/max_bytes_test.go | 68 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 pkg/api/max_bytes.go create mode 100644 pkg/api/max_bytes_test.go diff --git a/cmd/app/serve.go b/cmd/app/serve.go index 4f39f8406..712e24cbe 100644 --- a/cmd/app/serve.go +++ b/cmd/app/serve.go @@ -163,12 +163,24 @@ func runServeCmd(cmd *cobra.Command, args []string) { log.Logger.Fatal(err) } - decorateHandler := func(h http.Handler) http.Handler { - // Wrap the inner func with instrumentation to get latencies - // that get partitioned by 'code' and 'method'. - return promhttp.InstrumentHandlerDuration( - api.MetricLatency, - http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + prom := http.Server{ + Addr: ":2112", + Handler: promhttp.Handler(), + } + go func() { + _ = prom.ListenAndServe() + }() + + host, port := viper.GetString("host"), viper.GetString("port") + log.Logger.Infof("%s:%s", host, port) + + var handler http.Handler + { + handler = api.NewHandler() + + // Inject dependencies + withDependencies := func(inner http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() // For each request, infuse context with our snapshot of the FulcioConfig. @@ -179,23 +191,21 @@ func runServeCmd(cmd *cobra.Command, args []string) { ctx = api.WithCA(ctx, baseca) ctx = api.WithCTLogURL(ctx, viper.GetString("ct-log-url")) - h.ServeHTTP(rw, r.WithContext(ctx)) - })) - } + inner.ServeHTTP(rw, r.WithContext(ctx)) + }) + } + handler = withDependencies(handler) - prom := http.Server{ - Addr: ":2112", - Handler: promhttp.Handler(), + // Instrument Prometheus metrics + handler = promhttp.InstrumentHandlerDuration(api.MetricLatency, handler) + + // Limit request size + handler = api.WithMaxBytes(handler, 1<<22) // 4MiB } - go func() { - _ = prom.ListenAndServe() - }() - host, port := viper.GetString("host"), viper.GetString("port") - log.Logger.Infof("%s:%s", host, port) api := http.Server{ Addr: host + ":" + port, - Handler: decorateHandler(api.NewHandler()), + Handler: handler, } if err := api.ListenAndServe(); err != nil && err != http.ErrServerClosed { diff --git a/pkg/api/max_bytes.go b/pkg/api/max_bytes.go new file mode 100644 index 000000000..7c267ece3 --- /dev/null +++ b/pkg/api/max_bytes.go @@ -0,0 +1,27 @@ +// Copyright 2021 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package api + +import "net/http" + +// WithMaxBytes sets the max request size on a handler to n bytes. +func WithMaxBytes(next http.Handler, n int64) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + limitedReader := http.MaxBytesReader(w, r.Body, n) + r.Body = limitedReader + next.ServeHTTP(w, r) + }) +} diff --git a/pkg/api/max_bytes_test.go b/pkg/api/max_bytes_test.go new file mode 100644 index 000000000..529f9aa43 --- /dev/null +++ b/pkg/api/max_bytes_test.go @@ -0,0 +1,68 @@ +// Copyright 2021 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package api + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestWithMaxBytes(t *testing.T) { + var maxBodySize int64 = 10 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + return + }) + + ts := httptest.NewServer(WithMaxBytes(handler, maxBodySize)) + + tests := map[string]struct { + Body string + ExpectedStatus int + }{ + "Less than max": { + Body: strings.Repeat("a", int(maxBodySize-1)), + ExpectedStatus: http.StatusOK, + }, + "At max": { + Body: strings.Repeat("b", int(maxBodySize)), + ExpectedStatus: http.StatusOK, + }, + "Over max": { + Body: strings.Repeat("c", int(maxBodySize+1)), + ExpectedStatus: http.StatusBadRequest, + }, + } + + for testcase, data := range tests { + t.Run(testcase, func(t *testing.T) { + resp, err := http.Post(ts.URL, "text/plain", strings.NewReader(data.Body)) + if err != nil { + t.Fatal("Failed to send request to test server", err) + } + if resp.StatusCode != data.ExpectedStatus { + t.Error("Expected status code", data.ExpectedStatus, "but got", resp.StatusCode) + } + }) + } +}