Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added WithDisablePathLengthFallback option (to fix issue #447) #855

Merged
merged 1 commit into from
Jan 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[str
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
protoErrorHandler ProtoErrorHandlerFunc
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
protoErrorHandler ProtoErrorHandlerFunc
disablePathLengthFallback bool
}

// ServeMuxOption is an option that can be given to a ServeMux on construction.
Expand Down Expand Up @@ -102,6 +103,13 @@ func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
}
}

// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
func WithDisablePathLengthFallback() ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.disablePathLengthFallback = true
}
}

// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
Expand Down Expand Up @@ -177,7 +185,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
components[l-1], verb = c[:idx], c[idx+1:]
}

if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) {
if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
r.Method = strings.ToUpper(override)
if err := r.ParseForm(); err != nil {
if s.protoErrorHandler != nil {
Expand Down Expand Up @@ -211,7 +219,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
continue
}
// X-HTTP-Method-Override is optional. Always allow fallback to POST.
if isPathLengthFallback(r) {
if s.isPathLengthFallback(r) {
if err := r.ParseForm(); err != nil {
if s.protoErrorHandler != nil {
_, outboundMarshaler := MarshalerForRequest(s, r)
Expand Down Expand Up @@ -250,8 +258,8 @@ func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.Resp
return s.forwardResponseOptions
}

func isPathLengthFallback(r *http.Request) bool {
return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
}

type handler struct {
Expand Down
47 changes: 46 additions & 1 deletion runtime/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func TestMuxServeHTTP(t *testing.T) {

respStatus int
respContent string

disablePathLengthFallback bool
}{
{
patterns: nil,
Expand Down Expand Up @@ -122,6 +124,45 @@ func TestMuxServeHTTP(t *testing.T) {
respStatus: http.StatusOK,
respContent: "GET /foo",
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "POST",
reqPath: "/foo",
headers: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
respStatus: http.StatusMethodNotAllowed,
respContent: "Method Not Allowed\n",
disablePathLengthFallback: true,
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
{
method: "POST",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "POST",
reqPath: "/foo",
headers: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
respStatus: http.StatusOK,
respContent: "POST /foo",
disablePathLengthFallback: true,
},
{
patterns: []stubPattern{
{
Expand Down Expand Up @@ -199,7 +240,11 @@ func TestMuxServeHTTP(t *testing.T) {
respContent: "GET /foo/{id=*}:verb",
},
} {
mux := runtime.NewServeMux()
var opts []runtime.ServeMuxOption
if spec.disablePathLengthFallback {
opts = append(opts, runtime.WithDisablePathLengthFallback())
}
mux := runtime.NewServeMux(opts...)
for _, p := range spec.patterns {
func(p stubPattern) {
pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb)
Expand Down