From 11de1a4527e62cf71de0ebe906f53bdbeb449c26 Mon Sep 17 00:00:00 2001 From: sigua-cs Date: Sat, 3 Feb 2024 13:25:18 +0200 Subject: [PATCH] fix: restore req body --- main_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ proxy.go | 16 +++++++--------- proxy_test.go | 19 +++++++++++++++++++ utils.go | 23 +++++++++++++++++++++++ 4 files changed, 100 insertions(+), 9 deletions(-) diff --git a/main_test.go b/main_test.go index c60852bf..d3e1d707 100644 --- a/main_test.go +++ b/main_test.go @@ -212,6 +212,37 @@ func TestServe(t *testing.T) { }, startTLS, }, + { + "https request body is not empty", + "testdata/https.yml", + func(t *testing.T) { + query := "SELECT SleepTimeout" + buf := bytes.NewBufferString(query) + req, err := http.NewRequest("POST", "https://127.0.0.1:8443", buf) + checkErr(t, err) + req.SetBasicAuth("default", "qwerty") + req.Close = true + + resp, err := tlsClient.Do(req) + checkErr(t, err) + if resp.StatusCode != http.StatusGatewayTimeout { + t.Fatalf("unexpected status code: %d; expected: %d", resp.StatusCode, http.StatusGatewayTimeout) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error while reading body from response; err: %q", err) + } + + b := string(bodyBytes) + if !strings.Contains(b, query) { + t.Fatalf("expected request body: %q; got: %q", query, b) + } + + resp.Body.Close() + }, + startTLS, + }, { "https cache with mix query source", "testdata/https.cache.yml", @@ -1019,6 +1050,26 @@ func fakeCHHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") fmt.Fprint(w, b) fmt.Fprint(w, "Ok.\n") + case strings.Contains(q, "SELECT SleepTimeout"): + w.WriteHeader(http.StatusGatewayTimeout) + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + fmt.Fprintf(w, "query: %s; error while reading body: %s", query, err) + return + } + + b := string(bodyBytes) + // Ensure the original request body is not empty and remains unchanged + // after it is processed by getFullQuery. + if b == "" && b != q { + fmt.Fprintf(w, "got original req body: <%s>; escaped query: <%s>", b, q) + return + } + + // execute sleep 1.5 sec + time.Sleep(1500 * time.Millisecond) + fmt.Fprint(w, b) default: if strings.Contains(string(query), killQueryPattern) { fakeCHState.kill() diff --git a/proxy.go b/proxy.go index 9312a83f..d166aadf 100644 --- a/proxy.go +++ b/proxy.go @@ -207,23 +207,21 @@ func executeWithRetry( startTime := time.Now() var since float64 - // keep the request body - body, err := io.ReadAll(req.Body) - req.Body.Close() + // Use readAndRestoreRequestBody to read the entire request body into a byte slice, + // and to restore req.Body so that it can be reused later in the code. + body, err := readAndRestoreRequestBody(req) if err != nil { - since = time.Since(startTime).Seconds() - + since := time.Since(startTime).Seconds() return since, err } numRetry := 0 for { - // update body - req.Body = io.NopCloser(bytes.NewBuffer(body)) - req.Body.Close() - rp(rw, req) + // Restore req.Body after it's consumed by 'rp' for potential reuse. + req.Body = io.NopCloser(bytes.NewBuffer(body)) + err := ctx.Err() if err != nil { since = time.Since(startTime).Seconds() diff --git a/proxy_test.go b/proxy_test.go index 43bd8e14..4301a88a 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -1012,6 +1012,25 @@ func TestReverseProxy_ServeHTTP2(t *testing.T) { t.Fatalf("expected response: %q; got: %q", expected, b) } }) + + t.Run("request body not empty", func(t *testing.T) { + proxy, err := getProxy(goodCfg) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + body := bytes.NewBufferString("SELECT sleep(1.5)") + expected := "SELECT sleep(1.5)" + req := httptest.NewRequest("POST", fakeServer.URL, body) + + resp := makeCustomRequest(proxy, req) + b := bbToString(t, resp.Body) + resp.Body.Close() + + if !strings.Contains(b, expected) { + t.Fatalf("expected response: %q; got: %q", expected, b) + } + + }) } func getNetwork(s string) *net.IPNet { diff --git a/utils.go b/utils.go index cf7cd65a..c3d97d14 100644 --- a/utils.go +++ b/utils.go @@ -98,9 +98,17 @@ func getQuerySnippetFromBody(req *http.Request) string { // 'read' request body, so it traps into to crc. // Ignore any errors, since getQuerySnippet is called only // during error reporting. + // Temporary solution: Quick and dirty way to work with the request body. + // TODO: Create an original copy of req.Body and work with the copy to avoid altering the original request. + // This current approach consumes the req.Body content with io.Copy(io.Discard, crc) to reset the internal state of crc. + // However, it is not the most efficient or safest method, as it modifies the original req.Body. io.Copy(io.Discard, crc) // nolint data := crc.String() + // Here, we attempt to restore req.Body by wrapping the string data in a ReadCloser. + // This is part of the temporary solution and should be replaced with a more robust method that does not consume the original req.Body. + req.Body = io.NopCloser(strings.NewReader(data)) + u := getDecompressor(req) if u == nil { return data @@ -295,3 +303,18 @@ func calcCredentialHash(user string, pwd string) (uint32, error) { _, err := h.Write([]byte(user + pwd)) return h.Sum32(), err } + +// Function to read the request body and return it as a byte slice. +// It also restores the req.Body to be used again. +func readAndRestoreRequestBody(req *http.Request) ([]byte, error) { + // Read the entire request body. + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Restore the req.Body with a new reader for the original content. + req.Body = io.NopCloser(bytes.NewReader(body)) + + // Return the read body. + return body, nil +}