Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change to RequestSizeLimiter & better tests #4

Merged
merged 2 commits into from
Feb 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ language: go
sudo: false

go:
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func handler(ctx *gin.Context) {

func main() {
rtr := gin.Default()
rtr.Use(limits.RateLimiter(10))
rtr.Use(limits.RequestSizeLimiter(10))
rtr.POST("/", handler)
rtr.Run(":8080")
}
Expand Down
2 changes: 1 addition & 1 deletion example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func handler(ctx *gin.Context) {

func main() {
rtr := gin.Default()
rtr.Use(limits.RateLimiter(10))
rtr.Use(limits.RequestSizeLimiter(10))
rtr.POST("/", handler)
rtr.Run(":8080")
}
6 changes: 3 additions & 3 deletions size.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ func (mbr *maxBytesReader) Close() error {
return mbr.rdr.Close()
}

// RateLimiter returns a middleware that limits the size of request
// RequestSizeLimiter returns a middleware that limits the size of request
// When a request is over the limit, the following will happen:
// * Error will be added to the context
// * Connection: close header will be set
// * Error 413 will be send to client (http.StatusRequestEntityTooLarge)
// * Error 413 will be sent to the client (http.StatusRequestEntityTooLarge)
// * Current context will be aborted
func RateLimiter(limit int64) gin.HandlerFunc {
func RequestSizeLimiter(limit int64) gin.HandlerFunc {
return func(ctx *gin.Context) {
ctx.Request.Body = &maxBytesReader{
ctx: ctx,
Expand Down
117 changes: 39 additions & 78 deletions size_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,97 +2,58 @@ package limits

import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"os"
"os/exec"
"net/http/httptest"
"testing"
"text/template"
"time"
)

var (
params = struct {
Size int
Port int
}{10, 9388}

codeFile = "/tmp/ratelimit_test_server.go"
serverURL string
"github.com/gin-gonic/gin"
)

func init() {
tmpl := template.Must(template.ParseFiles("test_server.tmpl"))
fp, err := os.Create(codeFile)
if err != nil {
panic(fmt.Errorf("can't open %s - %s", codeFile, err))
}
err = tmpl.Execute(fp, params)
if err != nil {
panic(fmt.Errorf("can't create %s - %s", codeFile, err))
}
serverURL = fmt.Sprintf("http://localhost:%d", params.Port)
}

func waitForServer() error {
timeout := 30 * time.Second
ch := make(chan bool)
go func() {
for {
_, err := http.Post(serverURL, "text/plain", nil)
if err == nil {
ch <- true
}
time.Sleep(10 * time.Millisecond)
func TestRequestSizeLimiterOK(t *testing.T) {
router := gin.New()
router.Use(RequestSizeLimiter(10))
router.POST("/test_ok", func(c *gin.Context) {
ioutil.ReadAll(c.Request.Body)
if len(c.Errors) > 0 {
return
}
}()

select {
case <-ch:
return nil
case <-time.After(timeout):
return fmt.Errorf("server did not reply after %v", timeout)
}
c.Request.Body.Close()
c.String(http.StatusOK, "OK")
})
resp := performRequest(http.MethodPost, "/test_ok", "big=abc", router)

}

func runServer() (*exec.Cmd, error) {
cmd := exec.Command("go", "run", codeFile)
cmd.Start()
if err := waitForServer(); err != nil {
return nil, err
if resp.Code != http.StatusOK {
t.Fatalf("error posting - http status %v", resp.Code)
}
return cmd, nil
}

func doPost(val string) (*http.Response, error) {
cmd, err := runServer()
if err != nil {
return nil, err
}
defer cmd.Process.Kill()

var buf bytes.Buffer
fmt.Fprintf(&buf, "big=%s", val)
return http.Post(serverURL, "application/x-www-form-urlencoded", &buf)
}
func TestRequestSizeLimiterOver(t *testing.T) {
router := gin.New()
router.Use(RequestSizeLimiter(10))
router.POST("/test_large", func(c *gin.Context) {
ioutil.ReadAll(c.Request.Body)
if len(c.Errors) > 0 {
return
}
c.Request.Body.Close()
c.String(http.StatusOK, "OK")
})
resp := performRequest(http.MethodPost, "/test_large", "big=abcdefghijklmnop", router)

func TestRateLimiterOK(t *testing.T) {
resp, err := doPost("abc")
if err != nil {
t.Fatalf("error posting - %s", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("bad status - %d", resp.StatusCode)
if resp.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("error posting - http status %v", resp.Code)
}
}

func TestRateLimiterOver(t *testing.T) {
resp, err := doPost("abcdefghijklmnop")
if err != nil {
t.Fatalf("error posting - %s", err)
}
if resp.StatusCode != http.StatusRequestEntityTooLarge {
t.Fatalf("bad status - %d", resp.StatusCode)
func performRequest(method, target, body string, router *gin.Engine) *httptest.ResponseRecorder {
var buf *bytes.Buffer
if body != "" {
buf = new(bytes.Buffer)
buf.WriteString(body)
}
r := httptest.NewRequest(method, target, buf)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
return w
}
23 changes: 0 additions & 23 deletions test_server.tmpl

This file was deleted.