From c9c7f462eb34d18f5d5804743ec21fb7472ba8e8 Mon Sep 17 00:00:00 2001 From: utrehubenka Date: Sun, 20 Jan 2019 22:16:22 +0300 Subject: [PATCH] Added WithDisablePathLengthFallback option (to fix issue #447) --- runtime/mux.go | 30 ++++++++++++++++++----------- runtime/mux_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/runtime/mux.go b/runtime/mux.go index 3064c69ecbd..ec81e55b5ef 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -20,13 +20,14 @@ type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[str // It matches http requests to patterns and invokes the corresponding handler. type ServeMux struct { // handlers maps HTTP method to a list of handlers. - handlers map[string][]handler - forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error - marshalers marshalerRegistry - incomingHeaderMatcher HeaderMatcherFunc - outgoingHeaderMatcher HeaderMatcherFunc - metadataAnnotators []func(context.Context, *http.Request) metadata.MD - protoErrorHandler ProtoErrorHandlerFunc + handlers map[string][]handler + forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error + marshalers marshalerRegistry + incomingHeaderMatcher HeaderMatcherFunc + outgoingHeaderMatcher HeaderMatcherFunc + metadataAnnotators []func(context.Context, *http.Request) metadata.MD + protoErrorHandler ProtoErrorHandlerFunc + disablePathLengthFallback bool } // ServeMuxOption is an option that can be given to a ServeMux on construction. @@ -102,6 +103,13 @@ func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption { } } +// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback. +func WithDisablePathLengthFallback() ServeMuxOption { + return func(serveMux *ServeMux) { + serveMux.disablePathLengthFallback = true + } +} + // NewServeMux returns a new ServeMux whose internal mapping is empty. func NewServeMux(opts ...ServeMuxOption) *ServeMux { serveMux := &ServeMux{ @@ -177,7 +185,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { components[l-1], verb = c[:idx], c[idx+1:] } - if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) { + if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) { r.Method = strings.ToUpper(override) if err := r.ParseForm(); err != nil { if s.protoErrorHandler != nil { @@ -211,7 +219,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { continue } // X-HTTP-Method-Override is optional. Always allow fallback to POST. - if isPathLengthFallback(r) { + if s.isPathLengthFallback(r) { if err := r.ParseForm(); err != nil { if s.protoErrorHandler != nil { _, outboundMarshaler := MarshalerForRequest(s, r) @@ -250,8 +258,8 @@ func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.Resp return s.forwardResponseOptions } -func isPathLengthFallback(r *http.Request) bool { - return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" +func (s *ServeMux) isPathLengthFallback(r *http.Request) bool { + return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" } type handler struct { diff --git a/runtime/mux_test.go b/runtime/mux_test.go index f3ecbf63388..08ed6cc6806 100644 --- a/runtime/mux_test.go +++ b/runtime/mux_test.go @@ -27,6 +27,8 @@ func TestMuxServeHTTP(t *testing.T) { respStatus int respContent string + + disablePathLengthFallback bool }{ { patterns: nil, @@ -122,6 +124,45 @@ func TestMuxServeHTTP(t *testing.T) { respStatus: http.StatusOK, respContent: "GET /foo", }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + respStatus: http.StatusMethodNotAllowed, + respContent: "Method Not Allowed\n", + disablePathLengthFallback: true, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "POST", + ops: []int{int(utilities.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + respStatus: http.StatusOK, + respContent: "POST /foo", + disablePathLengthFallback: true, + }, { patterns: []stubPattern{ { @@ -199,7 +240,11 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "GET /foo/{id=*}:verb", }, } { - mux := runtime.NewServeMux() + var opts []runtime.ServeMuxOption + if spec.disablePathLengthFallback { + opts = append(opts, runtime.WithDisablePathLengthFallback()) + } + mux := runtime.NewServeMux(opts...) for _, p := range spec.patterns { func(p stubPattern) { pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb)