Skip to content

Commit

Permalink
Change to RequestSizeLimiter & better tests (#4)
Browse files Browse the repository at this point in the history
* change name to proper meaning; simplified tests

* go 1.7+
  • Loading branch information
gabstv authored and appleboy committed Feb 16, 2018
1 parent 6d03d1b commit 161dc59
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 107 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ language: go
sudo: false

go:
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func handler(ctx *gin.Context) {

func main() {
rtr := gin.Default()
rtr.Use(limits.RateLimiter(10))
rtr.Use(limits.RequestSizeLimiter(10))
rtr.POST("/", handler)
rtr.Run(":8080")
}
Expand Down
2 changes: 1 addition & 1 deletion example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func handler(ctx *gin.Context) {

func main() {
rtr := gin.Default()
rtr.Use(limits.RateLimiter(10))
rtr.Use(limits.RequestSizeLimiter(10))
rtr.POST("/", handler)
rtr.Run(":8080")
}
6 changes: 3 additions & 3 deletions size.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ func (mbr *maxBytesReader) Close() error {
return mbr.rdr.Close()
}

// RateLimiter returns a middleware that limits the size of request
// RequestSizeLimiter returns a middleware that limits the size of request
// When a request is over the limit, the following will happen:
// * Error will be added to the context
// * Connection: close header will be set
// * Error 413 will be send to client (http.StatusRequestEntityTooLarge)
// * Error 413 will be sent to the client (http.StatusRequestEntityTooLarge)
// * Current context will be aborted
func RateLimiter(limit int64) gin.HandlerFunc {
func RequestSizeLimiter(limit int64) gin.HandlerFunc {
return func(ctx *gin.Context) {
ctx.Request.Body = &maxBytesReader{
ctx: ctx,
Expand Down
117 changes: 39 additions & 78 deletions size_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,97 +2,58 @@ package limits

import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"os"
"os/exec"
"net/http/httptest"
"testing"
"text/template"
"time"
)

var (
params = struct {
Size int
Port int
}{10, 9388}

codeFile = "/tmp/ratelimit_test_server.go"
serverURL string
"github.com/gin-gonic/gin"
)

func init() {
tmpl := template.Must(template.ParseFiles("test_server.tmpl"))
fp, err := os.Create(codeFile)
if err != nil {
panic(fmt.Errorf("can't open %s - %s", codeFile, err))
}
err = tmpl.Execute(fp, params)
if err != nil {
panic(fmt.Errorf("can't create %s - %s", codeFile, err))
}
serverURL = fmt.Sprintf("http://localhost:%d", params.Port)
}

func waitForServer() error {
timeout := 30 * time.Second
ch := make(chan bool)
go func() {
for {
_, err := http.Post(serverURL, "text/plain", nil)
if err == nil {
ch <- true
}
time.Sleep(10 * time.Millisecond)
func TestRequestSizeLimiterOK(t *testing.T) {
router := gin.New()
router.Use(RequestSizeLimiter(10))
router.POST("/test_ok", func(c *gin.Context) {
ioutil.ReadAll(c.Request.Body)
if len(c.Errors) > 0 {
return
}
}()

select {
case <-ch:
return nil
case <-time.After(timeout):
return fmt.Errorf("server did not reply after %v", timeout)
}
c.Request.Body.Close()
c.String(http.StatusOK, "OK")
})
resp := performRequest(http.MethodPost, "/test_ok", "big=abc", router)

}

func runServer() (*exec.Cmd, error) {
cmd := exec.Command("go", "run", codeFile)
cmd.Start()
if err := waitForServer(); err != nil {
return nil, err
if resp.Code != http.StatusOK {
t.Fatalf("error posting - http status %v", resp.Code)
}
return cmd, nil
}

func doPost(val string) (*http.Response, error) {
cmd, err := runServer()
if err != nil {
return nil, err
}
defer cmd.Process.Kill()

var buf bytes.Buffer
fmt.Fprintf(&buf, "big=%s", val)
return http.Post(serverURL, "application/x-www-form-urlencoded", &buf)
}
func TestRequestSizeLimiterOver(t *testing.T) {
router := gin.New()
router.Use(RequestSizeLimiter(10))
router.POST("/test_large", func(c *gin.Context) {
ioutil.ReadAll(c.Request.Body)
if len(c.Errors) > 0 {
return
}
c.Request.Body.Close()
c.String(http.StatusOK, "OK")
})
resp := performRequest(http.MethodPost, "/test_large", "big=abcdefghijklmnop", router)

func TestRateLimiterOK(t *testing.T) {
resp, err := doPost("abc")
if err != nil {
t.Fatalf("error posting - %s", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("bad status - %d", resp.StatusCode)
if resp.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("error posting - http status %v", resp.Code)
}
}

func TestRateLimiterOver(t *testing.T) {
resp, err := doPost("abcdefghijklmnop")
if err != nil {
t.Fatalf("error posting - %s", err)
}
if resp.StatusCode != http.StatusRequestEntityTooLarge {
t.Fatalf("bad status - %d", resp.StatusCode)
func performRequest(method, target, body string, router *gin.Engine) *httptest.ResponseRecorder {
var buf *bytes.Buffer
if body != "" {
buf = new(bytes.Buffer)
buf.WriteString(body)
}
r := httptest.NewRequest(method, target, buf)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
return w
}
23 changes: 0 additions & 23 deletions test_server.tmpl

This file was deleted.

0 comments on commit 161dc59

Please sign in to comment.