diff --git a/http.go b/http.go index 5d8dc93477..595eafa9c2 100644 --- a/http.go +++ b/http.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "math" "mime/multipart" "net" "os" @@ -2210,7 +2209,7 @@ func readBodyIdentity(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, err return dst[:offset], ErrBodyTooLarge } if len(dst) == offset { - n := round2(2 * offset) + n := roundUpForSliceCap(2 * offset) if maxBodySize > 0 && n > maxBodySize { n = maxBodySize + 1 } @@ -2229,7 +2228,7 @@ func appendBodyFixedSize(r *bufio.Reader, dst []byte, n int) ([]byte, error) { offset := len(dst) dstLen := offset + n if cap(dst) < dstLen { - b := make([]byte, round2(dstLen)) + b := make([]byte, roundUpForSliceCap(dstLen)) copy(b, dst) dst = b } @@ -2339,26 +2338,6 @@ func readCrLf(r *bufio.Reader) error { return nil } -func round2(n int) int { - if n <= 0 { - return 0 - } - - x := uint32(n - 1) - x |= x >> 1 - x |= x >> 2 - x |= x >> 4 - x |= x >> 8 - x |= x >> 16 - - // Make sure we don't return 0 due to overflow, even on 32 bit systems - if x >= uint32(math.MaxInt32) { - return math.MaxInt32 - } - - return int(x + 1) -} - // SetTimeout sets timeout for the request. // // req.SetTimeout(t); c.Do(&req, &resp) is equivalent to diff --git a/http_test.go b/http_test.go index d4717f640c..370e9aa9e5 100644 --- a/http_test.go +++ b/http_test.go @@ -16,6 +16,7 @@ import ( "strings" "testing" "time" + "unsafe" "github.com/valyala/bytebufferpool" ) @@ -1967,25 +1968,31 @@ func testSetResponseBodyStreamChunked(t *testing.T, body string, trailer map[str } } -func TestRound2(t *testing.T) { +func TestRound2ForSliceCap(t *testing.T) { t.Parallel() - testRound2(t, 0, 0) - testRound2(t, 1, 1) - testRound2(t, 2, 2) - testRound2(t, 3, 4) - testRound2(t, 4, 4) - testRound2(t, 5, 8) - testRound2(t, 7, 8) - testRound2(t, 8, 8) - testRound2(t, 9, 16) - testRound2(t, 0x10001, 0x20000) - testRound2(t, math.MaxInt32-1, math.MaxInt32) + testRound2ForSliceCap(t, 0, 0) + testRound2ForSliceCap(t, 1, 1) + testRound2ForSliceCap(t, 2, 2) + testRound2ForSliceCap(t, 3, 4) + testRound2ForSliceCap(t, 4, 4) + testRound2ForSliceCap(t, 5, 8) + testRound2ForSliceCap(t, 7, 8) + testRound2ForSliceCap(t, 8, 8) + testRound2ForSliceCap(t, 9, 16) + testRound2ForSliceCap(t, 0x10001, 0x20000) + + if unsafe.Sizeof(int(0)) == 4 { + testRound2ForSliceCap(t, math.MaxInt32-1, math.MaxInt32) + } else { + testRound2ForSliceCap(t, math.MaxInt32, math.MaxInt32) + testRound2ForSliceCap(t, math.MaxInt64-1, math.MaxInt64-1) + } } -func testRound2(t *testing.T, n, expectedRound2 int) { - if round2(n) != expectedRound2 { - t.Fatalf("Unexpected round2(%d)=%d. Expected %d", n, round2(n), expectedRound2) +func testRound2ForSliceCap(t *testing.T, n, expectedRound2 int) { + if roundUpForSliceCap(n) != expectedRound2 { + t.Fatalf("Unexpected round2(%d)=%d. Expected %d", n, roundUpForSliceCap(n), expectedRound2) } } diff --git a/round2_32.go b/round2_32.go new file mode 100644 index 0000000000..541b85e216 --- /dev/null +++ b/round2_32.go @@ -0,0 +1,29 @@ +//go:build !amd64 && !arm64 && !ppc64 && !ppc64le && !s390x +// +build !amd64,!arm64,!ppc64,!ppc64le,!s390x + +package fasthttp + +func roundUpForSliceCap(n int) int { + if n <= 0 { + return 0 + } + + // Above 100MB, we don't round up as the overhead is too large. + if n > 100*1024*1024 { + return n + } + + x := uint32(n - 1) + x |= x >> 1 + x |= x >> 2 + x |= x >> 4 + x |= x >> 8 + x |= x >> 16 + + // Make sure we don't return 0 due to overflow, even on 32 bit systems + if x >= uint32(math.MaxInt32) { + return math.MaxInt32 + } + + return int(x + 1) +} diff --git a/round2_64.go b/round2_64.go new file mode 100644 index 0000000000..8a8e2a23e1 --- /dev/null +++ b/round2_64.go @@ -0,0 +1,24 @@ +//go:build amd64 || arm64 || ppc64 || ppc64le || s390x +// +build amd64 arm64 ppc64 ppc64le s390x + +package fasthttp + +func roundUpForSliceCap(n int) int { + if n <= 0 { + return 0 + } + + // Above 100MB, we don't round up as the overhead is too large. + if n > 100*1024*1024 { + return n + } + + x := uint64(n - 1) + x |= x >> 1 + x |= x >> 2 + x |= x >> 4 + x |= x >> 8 + x |= x >> 16 + + return int(x + 1) +}