Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow proxy middleware to pass rewritten query part (fix #1798) #1802

Merged
merged 2 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package middleware

import (
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -49,30 +48,39 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
return rulesRegex
}

func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) {
for k, v := range rewriteRegex {
rawPath := req.URL.RawPath
if rawPath != "" {
// RawPath is only set when there has been escaping done. In that case Path must be deduced from rewritten RawPath
// because encoded Path could match rules that RawPath did not
if replacer := captureTokens(k, rawPath); replacer != nil {
rawPath = replacer.Replace(v)

req.URL.RawPath = rawPath
req.URL.Path, _ = url.PathUnescape(rawPath)

return // rewrite only once
}
func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error {
if len(rewriteRegex) == 0 {
return nil
}

continue
// Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path.
// We only want to use path part for rewriting and therefore trim prefix if it exists
rawURI := req.RequestURI
if rawURI != "" && rawURI[0] != '/' {
prefix := ""
if req.URL.Scheme != "" {
prefix = req.URL.Scheme + "://"
}
if req.URL.Host != "" {
prefix += req.URL.Host // host or host:port
}
if prefix != "" {
rawURI = strings.TrimPrefix(rawURI, prefix)
lammel marked this conversation as resolved.
Show resolved Hide resolved
}
}

if replacer := captureTokens(k, req.URL.Path); replacer != nil {
req.URL.Path = replacer.Replace(v)
for k, v := range rewriteRegex {
if replacer := captureTokens(k, rawURI); replacer != nil {
url, err := req.URL.Parse(replacer.Replace(v))
if err != nil {
return err
}
req.URL = url

return // rewrite only once
return nil // rewrite only once
}
}
return nil
}

// DefaultSkipper returns false which processes the middleware.
Expand Down
29 changes: 26 additions & 3 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (
"testing"
)

func TestRewritePath(t *testing.T) {
func TestRewriteURL(t *testing.T) {
var testCases = []struct {
whenURL string
expectPath string
expectRawPath string
expectQuery string
expectErr string
}{
{
whenURL: "http://localhost:8080/old",
Expand All @@ -28,16 +30,18 @@ func TestRewritePath(t *testing.T) {
whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1",
expectPath: "/user/+_+/order/___++++",
expectRawPath: "",
expectQuery: "test=1",
},
{
whenURL: "http://localhost:8080/users/%20a/orders/%20aa",
expectPath: "/user/ a/order/ aa",
expectRawPath: "",
},
{
whenURL: "http://localhost:8080/%47%6f%2f",
whenURL: "http://localhost:8080/%47%6f%2f?test=1",
expectPath: "/Go/",
expectRawPath: "/%47%6f%2f",
expectQuery: "test=1",
},
{
whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
Expand All @@ -49,21 +53,40 @@ func TestRewritePath(t *testing.T) {
expectPath: "/user/jill/order/T/cO4lW/t/Vp/",
expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
},
{
whenURL: "http://localhost:8080/static",
expectPath: "/static/path",
expectRawPath: "",
expectQuery: "role=AUTHOR&limit=1000",
},
{
whenURL: "/static",
expectPath: "/static/path",
expectRawPath: "",
expectQuery: "role=AUTHOR&limit=1000",
},
}

rules := map[*regexp.Regexp]string{
regexp.MustCompile("^/old$"): "/new",
regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2",
regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000",
}

for _, tc := range testCases {
t.Run(tc.whenURL, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)

rewritePath(rules, req)
err := rewriteURL(rules, req)

if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/.
assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path.
assert.Equal(t, tc.expectQuery, req.URL.RawQuery)
})
}
}
5 changes: 3 additions & 2 deletions middleware/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
tgt := config.Balancer.Next(c)
c.Set(config.ContextKey, tgt)

// Set rewrite path and raw path
rewritePath(config.RegexRewrite, req)
if err := rewriteURL(config.RegexRewrite, req); err != nil {
return err
}

// Fix header
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
Expand Down
27 changes: 19 additions & 8 deletions middleware/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,16 @@ func TestProxyRewrite(t *testing.T) {

func TestProxyRewriteRegex(t *testing.T) {
// Setup
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
receivedRequestURI := make(chan string, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server
// we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic
// if original request had `%2F` we should not magically decode it to `/` as it would change what was requested
receivedRequestURI <- r.RequestURI
}))
defer upstream.Close()
url, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
tmpUrL, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: tmpUrL}})

// Rewrite
e := echo.New()
Expand Down Expand Up @@ -279,14 +283,21 @@ func TestProxyRewriteRegex(t *testing.T) {
{"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"},
{"/x/ignore/test", http.StatusOK, "/v4/test"},
{"/y/foo/bar", http.StatusOK, "/v5/bar/foo"},
// NB: fragment is not added by golang httputil.NewSingleHostReverseProxy implementation
// $2 = `bar?q=1#frag`, $1 = `foo`. replaced uri = `/v5/bar?q=1#frag/foo` but httputil.NewSingleHostReverseProxy does not send `#frag/foo` (currently)
{"/y/foo/bar?q=1#frag", http.StatusOK, "/v5/bar?q=1"},
}

for _, tc := range testCases {
t.Run(tc.requestPath, func(t *testing.T) {
req.URL, _ = url.Parse(tc.requestPath)
rec = httptest.NewRecorder()
targetURL, _ := url.Parse(tc.requestPath)
req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil)
rec := httptest.NewRecorder()

e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectPath, req.URL.EscapedPath())

actualRequestURI := <-receivedRequestURI
assert.Equal(t, tc.expectPath, actualRequestURI)
assert.Equal(t, tc.statusCode, rec.Code)
})
}
Expand Down
6 changes: 3 additions & 3 deletions middleware/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
return next(c)
}

req := c.Request()
// Set rewrite path and raw path
rewritePath(config.RegexRules, req)
if err := rewriteURL(config.RegexRules, c.Request()); err != nil {
return err
}
return next(c)
}
}
Expand Down
9 changes: 7 additions & 2 deletions middleware/rewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ func TestEchoRewriteWithRegexRules(t *testing.T) {
func TestEchoRewriteReplacementEscaping(t *testing.T) {
e := echo.New()

// NOTE: these are incorrect regexps as they do not factor in that URI we are replacing could contain ? (query) and # (fragment) parts
// so in reality they append query and fragment part as `$1` matches everything after that prefix
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
"^/a/*": "/$1?query=param",
Expand All @@ -228,6 +230,7 @@ func TestEchoRewriteReplacementEscaping(t *testing.T) {
RegexRules: map[*regexp.Regexp]string{
regexp.MustCompile("^/x/(.*)"): "/$1?query=param",
regexp.MustCompile("^/y/(.*)"): "/$1;part#one",
regexp.MustCompile("^/z/(.*)"): "/$1?test=1#escaped%20test",
},
}))

Expand All @@ -236,21 +239,23 @@ func TestEchoRewriteReplacementEscaping(t *testing.T) {

testCases := []struct {
requestPath string
expectPath string
expect string
}{
{"/unmatched", "/unmatched"},
{"/a/test", "/test?query=param"},
{"/b/foo/bar", "/foo/bar;part#one"},
{"/x/test", "/test?query=param"},
{"/y/foo/bar", "/foo/bar;part#one"},
{"/z/foo/b%20ar", "/foo/b%20ar?test=1#escaped%20test"},
{"/z/foo/b%20ar?nope=1#yes", "/foo/b%20ar?nope=1#yes?test=1%23escaped%20test"}, // example of appending
}

for _, tc := range testCases {
t.Run(tc.requestPath, func(t *testing.T) {
req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectPath, req.URL.Path)
assert.Equal(t, tc.expect, req.URL.String())
})
}
}