From 57744ebbe8f95ee65855f4611b93dd0f2c7b18ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Efe=20=C3=87etin?= Date: Wed, 25 Dec 2024 14:53:14 +0300 Subject: [PATCH] :bug: bug: fix EnableSplittingOnParsers is not functional (#3231) * :bug: bug: fix EnableSplittingOnParsers is not functional * remove wrong testcase * add support for external xml decoders * improve test coverage * fix linter * update * add reset methods * improve test coverage * merge Form and MultipartForm methods * fix linter * split reset and putting steps * fix linter --- app.go | 10 +++ bind.go | 109 +++++++++++++++++++---- bind_test.go | 38 +++++--- binder/binder.go | 79 ++++++++++++++--- binder/binder_test.go | 28 ++++++ binder/cbor.go | 17 ++-- binder/cbor_test.go | 92 ++++++++++++++++++++ binder/cookie.go | 19 ++-- binder/cookie_test.go | 90 +++++++++++++++++++ binder/form.go | 32 +++++-- binder/form_test.go | 174 +++++++++++++++++++++++++++++++++++++ binder/header.go | 17 ++-- binder/header_test.go | 88 +++++++++++++++++++ binder/json.go | 17 ++-- binder/json_test.go | 69 +++++++++++++++ binder/mapping.go | 31 ++++--- binder/mapping_test.go | 80 +++++++++++++++++ binder/query.go | 19 ++-- binder/query_test.go | 87 +++++++++++++++++++ binder/resp_header.go | 17 ++-- binder/resp_header_test.go | 79 +++++++++++++++++ binder/uri.go | 11 ++- binder/uri_test.go | 77 ++++++++++++++++ binder/xml.go | 20 +++-- binder/xml_test.go | 135 ++++++++++++++++++++++++++++ ctx_test.go | 4 +- docs/api/bind.md | 46 ++-------- docs/api/fiber.md | 1 + docs/whats_new.md | 1 + redirect.go | 4 +- 30 files changed, 1339 insertions(+), 152 deletions(-) create mode 100644 binder/binder_test.go create mode 100644 binder/cbor_test.go create mode 100644 binder/cookie_test.go create mode 100644 binder/form_test.go create mode 100644 binder/header_test.go create mode 100644 binder/json_test.go create mode 100644 binder/query_test.go create mode 100644 binder/resp_header_test.go create mode 100644 binder/uri_test.go create mode 100644 binder/xml_test.go diff --git a/app.go b/app.go index 7f9193a1a1..5e5475b5f1 100644 --- a/app.go +++ b/app.go @@ -341,6 +341,13 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa // Default: xml.Marshal XMLEncoder utils.XMLMarshal `json:"-"` + // XMLDecoder set by an external client of Fiber it will use the provided implementation of a + // XMLUnmarshal + // + // Allowing for flexibility in using another XML library for decoding + // Default: xml.Unmarshal + XMLDecoder utils.XMLUnmarshal `json:"-"` + // If you find yourself behind some sort of proxy, like a load balancer, // then certain header information may be sent to you using special X-Forwarded-* headers or the Forwarded header. // For example, the Host HTTP header is usually used to return the requested host. @@ -560,6 +567,9 @@ func New(config ...Config) *App { if app.config.XMLEncoder == nil { app.config.XMLEncoder = xml.Marshal } + if app.config.XMLDecoder == nil { + app.config.XMLDecoder = xml.Unmarshal + } if len(app.config.RequestMethods) == 0 { app.config.RequestMethods = DefaultMethods } diff --git a/bind.go b/bind.go index 5af83743a0..13d9d3675e 100644 --- a/bind.go +++ b/bind.go @@ -77,7 +77,16 @@ func (b *Bind) Custom(name string, dest any) error { // Header binds the request header strings into the struct, map[string]string and map[string][]string. func (b *Bind) Header(out any) error { - if err := b.returnErr(binder.HeaderBinder.Bind(b.ctx.Request(), out)); err != nil { + bind := binder.GetFromThePool[*binder.HeaderBinding](&binder.HeaderBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.HeaderBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Request(), out)); err != nil { return err } @@ -86,7 +95,16 @@ func (b *Bind) Header(out any) error { // RespHeader binds the response header strings into the struct, map[string]string and map[string][]string. func (b *Bind) RespHeader(out any) error { - if err := b.returnErr(binder.RespHeaderBinder.Bind(b.ctx.Response(), out)); err != nil { + bind := binder.GetFromThePool[*binder.RespHeaderBinding](&binder.RespHeaderBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.RespHeaderBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Response(), out)); err != nil { return err } @@ -96,7 +114,16 @@ func (b *Bind) RespHeader(out any) error { // Cookie binds the request cookie strings into the struct, map[string]string and map[string][]string. // NOTE: If your cookie is like key=val1,val2; they'll be binded as an slice if your map is map[string][]string. Else, it'll use last element of cookie. func (b *Bind) Cookie(out any) error { - if err := b.returnErr(binder.CookieBinder.Bind(b.ctx.RequestCtx(), out)); err != nil { + bind := binder.GetFromThePool[*binder.CookieBinding](&binder.CookieBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.CookieBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(&b.ctx.RequestCtx().Request, out)); err != nil { return err } @@ -105,7 +132,16 @@ func (b *Bind) Cookie(out any) error { // Query binds the query string into the struct, map[string]string and map[string][]string. func (b *Bind) Query(out any) error { - if err := b.returnErr(binder.QueryBinder.Bind(b.ctx.RequestCtx(), out)); err != nil { + bind := binder.GetFromThePool[*binder.QueryBinding](&binder.QueryBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.QueryBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(&b.ctx.RequestCtx().Request, out)); err != nil { return err } @@ -114,7 +150,16 @@ func (b *Bind) Query(out any) error { // JSON binds the body string into the struct. func (b *Bind) JSON(out any) error { - if err := b.returnErr(binder.JSONBinder.Bind(b.ctx.Body(), b.ctx.App().Config().JSONDecoder, out)); err != nil { + bind := binder.GetFromThePool[*binder.JSONBinding](&binder.JSONBinderPool) + bind.JSONDecoder = b.ctx.App().Config().JSONDecoder + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.JSONBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Body(), out)); err != nil { return err } @@ -123,7 +168,16 @@ func (b *Bind) JSON(out any) error { // CBOR binds the body string into the struct. func (b *Bind) CBOR(out any) error { - if err := b.returnErr(binder.CBORBinder.Bind(b.ctx.Body(), b.ctx.App().Config().CBORDecoder, out)); err != nil { + bind := binder.GetFromThePool[*binder.CBORBinding](&binder.CBORBinderPool) + bind.CBORDecoder = b.ctx.App().Config().CBORDecoder + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.CBORBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Body(), out)); err != nil { return err } return b.validateStruct(out) @@ -131,7 +185,16 @@ func (b *Bind) CBOR(out any) error { // XML binds the body string into the struct. func (b *Bind) XML(out any) error { - if err := b.returnErr(binder.XMLBinder.Bind(b.ctx.Body(), out)); err != nil { + bind := binder.GetFromThePool[*binder.XMLBinding](&binder.XMLBinderPool) + bind.XMLDecoder = b.ctx.App().config.XMLDecoder + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.XMLBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(b.ctx.Body(), out)); err != nil { return err } @@ -139,8 +202,20 @@ func (b *Bind) XML(out any) error { } // Form binds the form into the struct, map[string]string and map[string][]string. +// If Content-Type is "application/x-www-form-urlencoded" or "multipart/form-data", it will bind the form values. +// +// Binding multipart files is not supported yet. func (b *Bind) Form(out any) error { - if err := b.returnErr(binder.FormBinder.Bind(b.ctx.RequestCtx(), out)); err != nil { + bind := binder.GetFromThePool[*binder.FormBinding](&binder.FormBinderPool) + bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers + + // Reset & put binder + defer func() { + bind.Reset() + binder.PutToThePool(&binder.FormBinderPool, bind) + }() + + if err := b.returnErr(bind.Bind(&b.ctx.RequestCtx().Request, out)); err != nil { return err } @@ -149,16 +224,14 @@ func (b *Bind) Form(out any) error { // URI binds the route parameters into the struct, map[string]string and map[string][]string. func (b *Bind) URI(out any) error { - if err := b.returnErr(binder.URIBinder.Bind(b.ctx.Route().Params, b.ctx.Params, out)); err != nil { - return err - } + bind := binder.GetFromThePool[*binder.URIBinding](&binder.URIBinderPool) - return b.validateStruct(out) -} + // Reset & put binder + defer func() { + binder.PutToThePool(&binder.URIBinderPool, bind) + }() -// MultipartForm binds the multipart form into the struct, map[string]string and map[string][]string. -func (b *Bind) MultipartForm(out any) error { - if err := b.returnErr(binder.FormBinder.BindMultipart(b.ctx.RequestCtx(), out)); err != nil { + if err := b.returnErr(bind.Bind(b.ctx.Route().Params, b.ctx.Params, out)); err != nil { return err } @@ -193,10 +266,8 @@ func (b *Bind) Body(out any) error { return b.XML(out) case MIMEApplicationCBOR: return b.CBOR(out) - case MIMEApplicationForm: + case MIMEApplicationForm, MIMEMultipartForm: return b.Form(out) - case MIMEMultipartForm: - return b.MultipartForm(out) } // No suitable content type found diff --git a/bind_test.go b/bind_test.go index 55d2dd75e9..52c9004c61 100644 --- a/bind_test.go +++ b/bind_test.go @@ -32,7 +32,9 @@ func Test_returnErr(t *testing.T) { // go test -run Test_Bind_Query -v func Test_Bind_Query(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Query struct { @@ -111,7 +113,9 @@ func Test_Bind_Query(t *testing.T) { func Test_Bind_Query_Map(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) @@ -318,13 +322,13 @@ func Test_Bind_Header(t *testing.T) { c.Request().Header.Add("Hobby", "golang,fiber") q := new(Header) require.NoError(t, c.Bind().Header(q)) - require.Len(t, q.Hobby, 2) + require.Len(t, q.Hobby, 1) c.Request().Header.Del("hobby") c.Request().Header.Add("Hobby", "golang,fiber,go") q = new(Header) require.NoError(t, c.Bind().Header(q)) - require.Len(t, q.Hobby, 3) + require.Len(t, q.Hobby, 1) empty := new(Header) c.Request().Header.Del("hobby") @@ -357,7 +361,7 @@ func Test_Bind_Header(t *testing.T) { require.Equal(t, "go,fiber", h2.Hobby) require.True(t, h2.Bool) require.Equal(t, "Jane Doe", h2.Name) // check value get overwritten - require.Equal(t, []string{"milo", "coke", "pepsi"}, h2.FavouriteDrinks) + require.Equal(t, []string{"milo,coke,pepsi"}, h2.FavouriteDrinks) var nilSlice []string require.Equal(t, nilSlice, h2.Empty) require.Equal(t, []string{""}, h2.Alloc) @@ -386,13 +390,13 @@ func Test_Bind_Header_Map(t *testing.T) { c.Request().Header.Add("Hobby", "golang,fiber") q := make(map[string][]string, 0) require.NoError(t, c.Bind().Header(&q)) - require.Len(t, q["Hobby"], 2) + require.Len(t, q["Hobby"], 1) c.Request().Header.Del("hobby") c.Request().Header.Add("Hobby", "golang,fiber,go") q = make(map[string][]string, 0) require.NoError(t, c.Bind().Header(&q)) - require.Len(t, q["Hobby"], 3) + require.Len(t, q["Hobby"], 1) empty := make(map[string][]string, 0) c.Request().Header.Del("hobby") @@ -543,7 +547,9 @@ func Test_Bind_Header_Schema(t *testing.T) { // go test -run Test_Bind_Resp_Header -v func Test_Bind_RespHeader(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Header struct { @@ -627,13 +633,13 @@ func Test_Bind_RespHeader_Map(t *testing.T) { c.Response().Header.Add("Hobby", "golang,fiber") q := make(map[string][]string, 0) require.NoError(t, c.Bind().RespHeader(&q)) - require.Len(t, q["Hobby"], 2) + require.Len(t, q["Hobby"], 1) c.Response().Header.Del("hobby") c.Response().Header.Add("Hobby", "golang,fiber,go") q = make(map[string][]string, 0) require.NoError(t, c.Bind().RespHeader(&q)) - require.Len(t, q["Hobby"], 3) + require.Len(t, q["Hobby"], 1) empty := make(map[string][]string, 0) c.Response().Header.Del("hobby") @@ -751,7 +757,9 @@ func Benchmark_Bind_Query_WithParseParam(b *testing.B) { func Benchmark_Bind_Query_Comma(b *testing.B) { var err error - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Query struct { @@ -1341,7 +1349,9 @@ func Benchmark_Bind_URI_Map(b *testing.B) { func Test_Bind_Cookie(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Cookie struct { @@ -1414,7 +1424,9 @@ func Test_Bind_Cookie(t *testing.T) { func Test_Bind_Cookie_Map(t *testing.T) { t.Parallel() - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) diff --git a/binder/binder.go b/binder/binder.go index bb3fc2b394..06c7c926a5 100644 --- a/binder/binder.go +++ b/binder/binder.go @@ -2,6 +2,7 @@ package binder import ( "errors" + "sync" ) // Binder errors @@ -10,15 +11,69 @@ var ( ErrMapNotConvertable = errors.New("binder: map is not convertable to map[string]string or map[string][]string") ) -// Init default binders for Fiber -var ( - HeaderBinder = &headerBinding{} - RespHeaderBinder = &respHeaderBinding{} - CookieBinder = &cookieBinding{} - QueryBinder = &queryBinding{} - FormBinder = &formBinding{} - URIBinder = &uriBinding{} - XMLBinder = &xmlBinding{} - JSONBinder = &jsonBinding{} - CBORBinder = &cborBinding{} -) +var HeaderBinderPool = sync.Pool{ + New: func() any { + return &HeaderBinding{} + }, +} + +var RespHeaderBinderPool = sync.Pool{ + New: func() any { + return &RespHeaderBinding{} + }, +} + +var CookieBinderPool = sync.Pool{ + New: func() any { + return &CookieBinding{} + }, +} + +var QueryBinderPool = sync.Pool{ + New: func() any { + return &QueryBinding{} + }, +} + +var FormBinderPool = sync.Pool{ + New: func() any { + return &FormBinding{} + }, +} + +var URIBinderPool = sync.Pool{ + New: func() any { + return &URIBinding{} + }, +} + +var XMLBinderPool = sync.Pool{ + New: func() any { + return &XMLBinding{} + }, +} + +var JSONBinderPool = sync.Pool{ + New: func() any { + return &JSONBinding{} + }, +} + +var CBORBinderPool = sync.Pool{ + New: func() any { + return &CBORBinding{} + }, +} + +func GetFromThePool[T any](pool *sync.Pool) T { + binder, ok := pool.Get().(T) + if !ok { + panic(errors.New("failed to type-assert to T")) + } + + return binder +} + +func PutToThePool[T any](pool *sync.Pool, binder T) { + pool.Put(binder) +} diff --git a/binder/binder_test.go b/binder/binder_test.go new file mode 100644 index 0000000000..d078ed02c6 --- /dev/null +++ b/binder/binder_test.go @@ -0,0 +1,28 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_GetAndPutToThePool(t *testing.T) { + t.Parallel() + + // Panics in case we get from another pool + require.Panics(t, func() { + _ = GetFromThePool[*HeaderBinding](&CookieBinderPool) + }) + + // We get from the pool + binder := GetFromThePool[*HeaderBinding](&HeaderBinderPool) + PutToThePool(&HeaderBinderPool, binder) + + _ = GetFromThePool[*RespHeaderBinding](&RespHeaderBinderPool) + _ = GetFromThePool[*QueryBinding](&QueryBinderPool) + _ = GetFromThePool[*FormBinding](&FormBinderPool) + _ = GetFromThePool[*URIBinding](&URIBinderPool) + _ = GetFromThePool[*XMLBinding](&XMLBinderPool) + _ = GetFromThePool[*JSONBinding](&JSONBinderPool) + _ = GetFromThePool[*CBORBinding](&CBORBinderPool) +} diff --git a/binder/cbor.go b/binder/cbor.go index 6f47893531..8b1d0d4291 100644 --- a/binder/cbor.go +++ b/binder/cbor.go @@ -4,15 +4,22 @@ import ( "github.com/gofiber/utils/v2" ) -// cborBinding is the CBOR binder for CBOR request body. -type cborBinding struct{} +// CBORBinding is the CBOR binder for CBOR request body. +type CBORBinding struct { + CBORDecoder utils.CBORUnmarshal +} // Name returns the binding name. -func (*cborBinding) Name() string { +func (*CBORBinding) Name() string { return "cbor" } // Bind parses the request body as CBOR and returns the result. -func (*cborBinding) Bind(body []byte, cborDecoder utils.CBORUnmarshal, out any) error { - return cborDecoder(body, out) +func (b *CBORBinding) Bind(body []byte, out any) error { + return b.CBORDecoder(body, out) +} + +// Reset resets the CBORBinding binder. +func (b *CBORBinding) Reset() { + b.CBORDecoder = nil } diff --git a/binder/cbor_test.go b/binder/cbor_test.go new file mode 100644 index 0000000000..16c24cbbca --- /dev/null +++ b/binder/cbor_test.go @@ -0,0 +1,92 @@ +package binder + +import ( + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/require" +) + +func Test_CBORBinder_Bind(t *testing.T) { + t.Parallel() + + b := &CBORBinding{ + CBORDecoder: cbor.Unmarshal, + } + require.Equal(t, "cbor", b.Name()) + + type Post struct { + Title string `cbor:"title"` + } + + type User struct { + Name string `cbor:"name"` + Posts []Post `cbor:"posts"` + Names []string `cbor:"names"` + Age int `cbor:"age"` + } + var user User + + wantedUser := User{ + Name: "john", + Names: []string{ + "john", + "doe", + }, + Age: 42, + Posts: []Post{ + {Title: "post1"}, + {Title: "post2"}, + {Title: "post3"}, + }, + } + + body, err := cbor.Marshal(wantedUser) + require.NoError(t, err) + + err = b.Bind(body, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.Nil(t, b.CBORDecoder) +} + +func Benchmark_CBORBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &CBORBinding{ + CBORDecoder: cbor.Unmarshal, + } + + type User struct { + Name string `cbor:"name"` + Age int `cbor:"age"` + } + + var user User + wantedUser := User{ + Name: "john", + Age: 42, + } + + body, err := cbor.Marshal(wantedUser) + require.NoError(b, err) + + for i := 0; i < b.N; i++ { + err = binder.Bind(body, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) +} diff --git a/binder/cookie.go b/binder/cookie.go index 62271c8e38..230794f45a 100644 --- a/binder/cookie.go +++ b/binder/cookie.go @@ -8,20 +8,22 @@ import ( "github.com/valyala/fasthttp" ) -// cookieBinding is the cookie binder for cookie request body. -type cookieBinding struct{} +// CookieBinding is the cookie binder for cookie request body. +type CookieBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*cookieBinding) Name() string { +func (*CookieBinding) Name() string { return "cookie" } // Bind parses the request cookie and returns the result. -func (b *cookieBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { +func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) var err error - reqCtx.Request.Header.VisitAllCookie(func(key, val []byte) { + req.Header.VisitAllCookie(func(key, val []byte) { if err != nil { return } @@ -29,7 +31,7 @@ func (b *cookieBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -45,3 +47,8 @@ func (b *cookieBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { return parse(b.Name(), out, data) } + +// Reset resets the CookieBinding binder. +func (b *CookieBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/cookie_test.go b/binder/cookie_test.go new file mode 100644 index 0000000000..bca316c9fe --- /dev/null +++ b/binder/cookie_test.go @@ -0,0 +1,90 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_CookieBinder_Bind(t *testing.T) { + t.Parallel() + + b := &CookieBinding{ + EnableSplitting: true, + } + require.Equal(t, "cookie", b.Name()) + + type Post struct { + Title string `form:"title"` + } + + type User struct { + Name string `form:"name"` + Names []string `form:"names"` + Posts []Post `form:"posts"` + Age int `form:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + + req.Header.SetCookie("name", "john") + req.Header.SetCookie("names", "john,doe") + req.Header.SetCookie("age", "42") + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_CookieBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &CookieBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `query:"name"` + Posts []string `query:"posts"` + Age int `query:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + b.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + req.Header.SetCookie("name", "john") + req.Header.SetCookie("age", "42") + req.Header.SetCookie("posts", "post1,post2,post3") + + b.ResetTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) + require.Contains(b, user.Posts, "post1") + require.Contains(b, user.Posts, "post2") + require.Contains(b, user.Posts, "post3") +} diff --git a/binder/form.go b/binder/form.go index e0f1acd302..7ab0b1b258 100644 --- a/binder/form.go +++ b/binder/form.go @@ -8,20 +8,29 @@ import ( "github.com/valyala/fasthttp" ) -// formBinding is the form binder for form request body. -type formBinding struct{} +const MIMEMultipartForm string = "multipart/form-data" + +// FormBinding is the form binder for form request body. +type FormBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*formBinding) Name() string { +func (*FormBinding) Name() string { return "form" } // Bind parses the request body and returns the result. -func (b *formBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { +func (b *FormBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) var err error - reqCtx.PostArgs().VisitAll(func(key, val []byte) { + // Handle multipart form + if FilterFlags(utils.UnsafeString(req.Header.ContentType())) == MIMEMultipartForm { + return b.bindMultipart(req, out) + } + + req.PostArgs().VisitAll(func(key, val []byte) { if err != nil { return } @@ -33,7 +42,7 @@ func (b *formBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k, err = parseParamSquareBrackets(k) } - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -50,12 +59,17 @@ func (b *formBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { return parse(b.Name(), out, data) } -// BindMultipart parses the request body and returns the result. -func (b *formBinding) BindMultipart(reqCtx *fasthttp.RequestCtx, out any) error { - data, err := reqCtx.MultipartForm() +// bindMultipart parses the request body and returns the result. +func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error { + data, err := req.MultipartForm() if err != nil { return err } return parse(b.Name(), out, data.Value) } + +// Reset resets the FormBinding binder. +func (b *FormBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/form_test.go b/binder/form_test.go new file mode 100644 index 0000000000..c3c52c73fd --- /dev/null +++ b/binder/form_test.go @@ -0,0 +1,174 @@ +package binder + +import ( + "bytes" + "mime/multipart" + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_FormBinder_Bind(t *testing.T) { + t.Parallel() + + b := &FormBinding{ + EnableSplitting: true, + } + require.Equal(t, "form", b.Name()) + + type Post struct { + Title string `form:"title"` + } + + type User struct { + Name string `form:"name"` + Names []string `form:"names"` + Posts []Post `form:"posts"` + Age int `form:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + req.SetBodyString("name=john&names=john,doe&age=42&posts[0][title]=post1&posts[1][title]=post2&posts[2][title]=post3") + req.Header.SetContentType("application/x-www-form-urlencoded") + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_FormBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &QueryBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `query:"name"` + Posts []string `query:"posts"` + Age int `query:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + req.URI().SetQueryString("name=john&age=42&posts=post1,post2,post3") + req.Header.SetContentType("application/x-www-form-urlencoded") + + b.ResetTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) +} + +func Test_FormBinder_BindMultipart(t *testing.T) { + t.Parallel() + + b := &FormBinding{ + EnableSplitting: true, + } + require.Equal(t, "form", b.Name()) + + type User struct { + Name string `form:"name"` + Names []string `form:"names"` + Age int `form:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + + buf := &bytes.Buffer{} + mw := multipart.NewWriter(buf) + + require.NoError(t, mw.WriteField("name", "john")) + require.NoError(t, mw.WriteField("names", "john")) + require.NoError(t, mw.WriteField("names", "doe")) + require.NoError(t, mw.WriteField("age", "42")) + require.NoError(t, mw.Close()) + + req.Header.SetContentType(mw.FormDataContentType()) + req.SetBody(buf.Bytes()) + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") +} + +func Benchmark_FormBinder_BindMultipart(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &FormBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `form:"name"` + Posts []string `form:"posts"` + Age int `form:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + b.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + buf := &bytes.Buffer{} + mw := multipart.NewWriter(buf) + + require.NoError(b, mw.WriteField("name", "john")) + require.NoError(b, mw.WriteField("age", "42")) + require.NoError(b, mw.WriteField("posts", "post1")) + require.NoError(b, mw.WriteField("posts", "post2")) + require.NoError(b, mw.WriteField("posts", "post3")) + require.NoError(b, mw.Close()) + + req.Header.SetContentType(mw.FormDataContentType()) + req.SetBody(buf.Bytes()) + + b.ResetTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) +} diff --git a/binder/header.go b/binder/header.go index 258a0b2229..b04ce9add3 100644 --- a/binder/header.go +++ b/binder/header.go @@ -8,22 +8,24 @@ import ( "github.com/valyala/fasthttp" ) -// headerBinding is the header binder for header request body. -type headerBinding struct{} +// v is the header binder for header request body. +type HeaderBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*headerBinding) Name() string { +func (*HeaderBinding) Name() string { return "header" } // Bind parses the request header and returns the result. -func (b *headerBinding) Bind(req *fasthttp.Request, out any) error { +func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) req.Header.VisitAll(func(key, val []byte) { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -35,3 +37,8 @@ func (b *headerBinding) Bind(req *fasthttp.Request, out any) error { return parse(b.Name(), out, data) } + +// Reset resets the HeaderBinding binder. +func (b *HeaderBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/header_test.go b/binder/header_test.go new file mode 100644 index 0000000000..bdef8680ac --- /dev/null +++ b/binder/header_test.go @@ -0,0 +1,88 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_HeaderBinder_Bind(t *testing.T) { + t.Parallel() + + b := &HeaderBinding{ + EnableSplitting: true, + } + require.Equal(t, "header", b.Name()) + + type User struct { + Name string `header:"Name"` + Names []string `header:"Names"` + Posts []string `header:"Posts"` + Age int `header:"Age"` + } + var user User + + req := fasthttp.AcquireRequest() + req.Header.Set("name", "john") + req.Header.Set("names", "john,doe") + req.Header.Set("age", "42") + req.Header.Set("posts", "post1,post2,post3") + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0]) + require.Equal(t, "post2", user.Posts[1]) + require.Equal(t, "post3", user.Posts[2]) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_HeaderBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &HeaderBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `header:"Name"` + Posts []string `header:"Posts"` + Age int `header:"Age"` + } + var user User + + req := fasthttp.AcquireRequest() + b.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + req.Header.Set("name", "john") + req.Header.Set("age", "42") + req.Header.Set("posts", "post1,post2,post3") + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) + require.Contains(b, user.Posts, "post1") + require.Contains(b, user.Posts, "post2") + require.Contains(b, user.Posts, "post3") +} diff --git a/binder/json.go b/binder/json.go index 7889aee8a2..a6a904b550 100644 --- a/binder/json.go +++ b/binder/json.go @@ -4,15 +4,22 @@ import ( "github.com/gofiber/utils/v2" ) -// jsonBinding is the JSON binder for JSON request body. -type jsonBinding struct{} +// JSONBinding is the JSON binder for JSON request body. +type JSONBinding struct { + JSONDecoder utils.JSONUnmarshal +} // Name returns the binding name. -func (*jsonBinding) Name() string { +func (*JSONBinding) Name() string { return "json" } // Bind parses the request body as JSON and returns the result. -func (*jsonBinding) Bind(body []byte, jsonDecoder utils.JSONUnmarshal, out any) error { - return jsonDecoder(body, out) +func (b *JSONBinding) Bind(body []byte, out any) error { + return b.JSONDecoder(body, out) +} + +// Reset resets the JSONBinding binder. +func (b *JSONBinding) Reset() { + b.JSONDecoder = nil } diff --git a/binder/json_test.go b/binder/json_test.go new file mode 100644 index 0000000000..00718fdf26 --- /dev/null +++ b/binder/json_test.go @@ -0,0 +1,69 @@ +package binder + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_JSON_Binding_Bind(t *testing.T) { + t.Parallel() + + b := &JSONBinding{ + JSONDecoder: json.Unmarshal, + } + require.Equal(t, "json", b.Name()) + + type Post struct { + Title string `json:"title"` + } + + type User struct { + Name string `json:"name"` + Posts []Post `json:"posts"` + Age int `json:"age"` + } + var user User + + err := b.Bind([]byte(`{"name":"john","age":42,"posts":[{"title":"post1"},{"title":"post2"},{"title":"post3"}]}`), &user) + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) + + b.Reset() + require.Nil(t, b.JSONDecoder) +} + +func Benchmark_JSON_Binding_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &JSONBinding{ + JSONDecoder: json.Unmarshal, + } + + type User struct { + Name string `json:"name"` + Posts []string `json:"posts"` + Age int `json:"age"` + } + + var user User + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind([]byte(`{"name":"john","age":42,"posts":["post1","post2","post3"]}`), &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) + require.Equal(b, "post1", user.Posts[0]) + require.Equal(b, "post2", user.Posts[1]) + require.Equal(b, "post3", user.Posts[2]) +} diff --git a/binder/mapping.go b/binder/mapping.go index 055345fe26..d8b692f7e4 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -32,7 +32,7 @@ var ( // decoderPoolMap helps to improve binders decoderPoolMap = map[string]*sync.Pool{} // tags is used to classify parser's pool - tags = []string{HeaderBinder.Name(), RespHeaderBinder.Name(), CookieBinder.Name(), QueryBinder.Name(), FormBinder.Name(), URIBinder.Name()} + tags = []string{"header", "respHeader", "cookie", "query", "form", "uri"} ) // SetParserDecoder allow globally change the option of form decoder, update decoderPool @@ -107,8 +107,9 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error { func parseToMap(ptr any, data map[string][]string) error { elem := reflect.TypeOf(ptr).Elem() - // map[string][]string - if elem.Kind() == reflect.Slice { + //nolint:exhaustive // it's not necessary to check all types + switch elem.Kind() { + case reflect.Slice: newMap, ok := ptr.(map[string][]string) if !ok { return ErrMapNotConvertable @@ -117,18 +118,20 @@ func parseToMap(ptr any, data map[string][]string) error { for k, v := range data { newMap[k] = v } + case reflect.String, reflect.Interface: + newMap, ok := ptr.(map[string]string) + if !ok { + return ErrMapNotConvertable + } - return nil - } - - // map[string]string - newMap, ok := ptr.(map[string]string) - if !ok { - return ErrMapNotConvertable - } + for k, v := range data { + if len(v) == 0 { + newMap[k] = "" + continue + } - for k, v := range data { - newMap[k] = v[len(v)-1] + newMap[k] = v[len(v)-1] + } } return nil @@ -223,7 +226,7 @@ func equalFieldType(out any, kind reflect.Kind, key string) bool { continue } // Get tag from field if exist - inputFieldName := typeField.Tag.Get(QueryBinder.Name()) + inputFieldName := typeField.Tag.Get("query") // Name of query binder if inputFieldName == "" { inputFieldName = typeField.Name } else { diff --git a/binder/mapping_test.go b/binder/mapping_test.go index e6fc8146f7..75cdc78305 100644 --- a/binder/mapping_test.go +++ b/binder/mapping_test.go @@ -29,6 +29,21 @@ func Test_EqualFieldType(t *testing.T) { require.True(t, equalFieldType(&user, reflect.String, "Address")) require.True(t, equalFieldType(&user, reflect.Int, "AGE")) require.True(t, equalFieldType(&user, reflect.Int, "age")) + + var user2 struct { + User struct { + Name string + Address string `query:"address"` + Age int `query:"AGE"` + } `query:"user"` + } + + require.True(t, equalFieldType(&user2, reflect.String, "user.name")) + require.True(t, equalFieldType(&user2, reflect.String, "user.Name")) + require.True(t, equalFieldType(&user2, reflect.String, "user.address")) + require.True(t, equalFieldType(&user2, reflect.String, "user.Address")) + require.True(t, equalFieldType(&user2, reflect.Int, "user.AGE")) + require.True(t, equalFieldType(&user2, reflect.Int, "user.age")) } func Test_ParseParamSquareBrackets(t *testing.T) { @@ -97,3 +112,68 @@ func Test_ParseParamSquareBrackets(t *testing.T) { }) } } + +func Test_parseToMap(t *testing.T) { + inputMap := map[string][]string{ + "key1": {"value1", "value2"}, + "key2": {"value3"}, + "key3": {"value4"}, + } + + // Test map[string]string + m := make(map[string]string) + err := parseToMap(m, inputMap) + require.NoError(t, err) + + require.Equal(t, "value2", m["key1"]) + require.Equal(t, "value3", m["key2"]) + require.Equal(t, "value4", m["key3"]) + + // Test map[string][]string + m2 := make(map[string][]string) + err = parseToMap(m2, inputMap) + require.NoError(t, err) + + require.Len(t, m2["key1"], 2) + require.Contains(t, m2["key1"], "value1") + require.Contains(t, m2["key1"], "value2") + require.Len(t, m2["key2"], 1) + require.Len(t, m2["key3"], 1) + + // Test map[string]any + m3 := make(map[string]any) + err = parseToMap(m3, inputMap) + require.ErrorIs(t, err, ErrMapNotConvertable) +} + +func Test_FilterFlags(t *testing.T) { + tests := []struct { + input string + expected string + }{ + { + input: "text/javascript; charset=utf-8", + expected: "text/javascript", + }, + { + input: "text/javascript", + expected: "text/javascript", + }, + + { + input: "text/javascript; charset=utf-8; foo=bar", + expected: "text/javascript", + }, + { + input: "text/javascript charset=utf-8", + expected: "text/javascript", + }, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := FilterFlags(tt.input) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/binder/query.go b/binder/query.go index 8f029d30c4..9ee500ba63 100644 --- a/binder/query.go +++ b/binder/query.go @@ -8,20 +8,22 @@ import ( "github.com/valyala/fasthttp" ) -// queryBinding is the query binder for query request body. -type queryBinding struct{} +// QueryBinding is the query binder for query request body. +type QueryBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*queryBinding) Name() string { +func (*QueryBinding) Name() string { return "query" } // Bind parses the request query and returns the result. -func (b *queryBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { +func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error { data := make(map[string][]string) var err error - reqCtx.QueryArgs().VisitAll(func(key, val []byte) { + reqCtx.URI().QueryArgs().VisitAll(func(key, val []byte) { if err != nil { return } @@ -33,7 +35,7 @@ func (b *queryBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k, err = parseParamSquareBrackets(k) } - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -49,3 +51,8 @@ func (b *queryBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { return parse(b.Name(), out, data) } + +// Reset resets the QueryBinding binder. +func (b *QueryBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/query_test.go b/binder/query_test.go new file mode 100644 index 0000000000..0d457e5795 --- /dev/null +++ b/binder/query_test.go @@ -0,0 +1,87 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_QueryBinder_Bind(t *testing.T) { + t.Parallel() + + b := &QueryBinding{ + EnableSplitting: true, + } + require.Equal(t, "query", b.Name()) + + type Post struct { + Title string `query:"title"` + } + + type User struct { + Name string `query:"name"` + Names []string `query:"names"` + Posts []Post `query:"posts"` + Age int `query:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + req.URI().SetQueryString("name=john&names=john,doe&age=42&posts[0][title]=post1&posts[1][title]=post2&posts[2][title]=post3") + + t.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + err := b.Bind(req, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) + require.Contains(t, user.Names, "john") + require.Contains(t, user.Names, "doe") + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_QueryBinder_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &QueryBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `query:"name"` + Posts []string `query:"posts"` + Age int `query:"age"` + } + var user User + + req := fasthttp.AcquireRequest() + b.Cleanup(func() { + fasthttp.ReleaseRequest(req) + }) + + req.URI().SetQueryString("name=john&age=42&posts=post1,post2,post3") + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(req, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Len(b, user.Posts, 3) + require.Contains(b, user.Posts, "post1") + require.Contains(b, user.Posts, "post2") + require.Contains(b, user.Posts, "post3") +} diff --git a/binder/resp_header.go b/binder/resp_header.go index ef14255315..fc84d01402 100644 --- a/binder/resp_header.go +++ b/binder/resp_header.go @@ -8,22 +8,24 @@ import ( "github.com/valyala/fasthttp" ) -// respHeaderBinding is the respHeader binder for response header. -type respHeaderBinding struct{} +// RespHeaderBinding is the respHeader binder for response header. +type RespHeaderBinding struct { + EnableSplitting bool +} // Name returns the binding name. -func (*respHeaderBinding) Name() string { +func (*RespHeaderBinding) Name() string { return "respHeader" } // Bind parses the response header and returns the result. -func (b *respHeaderBinding) Bind(resp *fasthttp.Response, out any) error { +func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error { data := make(map[string][]string) resp.Header.VisitAll(func(key, val []byte) { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { + if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { values := strings.Split(v, ",") for i := 0; i < len(values); i++ { data[k] = append(data[k], values[i]) @@ -35,3 +37,8 @@ func (b *respHeaderBinding) Bind(resp *fasthttp.Response, out any) error { return parse(b.Name(), out, data) } + +// Reset resets the RespHeaderBinding binder. +func (b *RespHeaderBinding) Reset() { + b.EnableSplitting = false +} diff --git a/binder/resp_header_test.go b/binder/resp_header_test.go new file mode 100644 index 0000000000..ff3b51f604 --- /dev/null +++ b/binder/resp_header_test.go @@ -0,0 +1,79 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_RespHeaderBinder_Bind(t *testing.T) { + t.Parallel() + + b := &RespHeaderBinding{ + EnableSplitting: true, + } + require.Equal(t, "respHeader", b.Name()) + + type User struct { + Name string `respHeader:"name"` + Posts []string `respHeader:"posts"` + Age int `respHeader:"age"` + } + var user User + + resp := fasthttp.AcquireResponse() + resp.Header.Set("name", "john") + resp.Header.Set("age", "42") + resp.Header.Set("posts", "post1,post2,post3") + + t.Cleanup(func() { + fasthttp.ReleaseResponse(resp) + }) + + err := b.Bind(resp, &user) + + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Equal(t, []string{"post1", "post2", "post3"}, user.Posts) + + b.Reset() + require.False(t, b.EnableSplitting) +} + +func Benchmark_RespHeaderBinder_Bind(b *testing.B) { + b.ReportAllocs() + + binder := &RespHeaderBinding{ + EnableSplitting: true, + } + + type User struct { + Name string `respHeader:"name"` + Posts []string `respHeader:"posts"` + Age int `respHeader:"age"` + } + var user User + + resp := fasthttp.AcquireResponse() + resp.Header.Set("name", "john") + resp.Header.Set("age", "42") + resp.Header.Set("posts", "post1,post2,post3") + + b.Cleanup(func() { + fasthttp.ReleaseResponse(resp) + }) + + b.ResetTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(resp, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Equal(b, []string{"post1", "post2", "post3"}, user.Posts) +} diff --git a/binder/uri.go b/binder/uri.go index b58d9d49c4..9b358c64b8 100644 --- a/binder/uri.go +++ b/binder/uri.go @@ -1,15 +1,15 @@ package binder // uriBinding is the URI binder for URI parameters. -type uriBinding struct{} +type URIBinding struct{} // Name returns the binding name. -func (*uriBinding) Name() string { +func (*URIBinding) Name() string { return "uri" } // Bind parses the URI parameters and returns the result. -func (b *uriBinding) Bind(params []string, paramsFunc func(key string, defaultValue ...string) string, out any) error { +func (b *URIBinding) Bind(params []string, paramsFunc func(key string, defaultValue ...string) string, out any) error { data := make(map[string][]string, len(params)) for _, param := range params { data[param] = append(data[param], paramsFunc(param)) @@ -17,3 +17,8 @@ func (b *uriBinding) Bind(params []string, paramsFunc func(key string, defaultVa return parse(b.Name(), out, data) } + +// Reset resets URIBinding binder. +func (*URIBinding) Reset() { + // Nothing to reset +} diff --git a/binder/uri_test.go b/binder/uri_test.go new file mode 100644 index 0000000000..8babdef962 --- /dev/null +++ b/binder/uri_test.go @@ -0,0 +1,77 @@ +package binder + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_URIBinding_Bind(t *testing.T) { + t.Parallel() + + b := &URIBinding{} + require.Equal(t, "uri", b.Name()) + + type User struct { + Name string `uri:"name"` + Posts []string `uri:"posts"` + Age int `uri:"age"` + } + var user User + + paramsKey := []string{"name", "age", "posts"} + paramsVals := []string{"john", "42", "post1,post2,post3"} + paramsFunc := func(key string, _ ...string) string { + for i, k := range paramsKey { + if k == key { + return paramsVals[i] + } + } + + return "" + } + + err := b.Bind(paramsKey, paramsFunc, &user) + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Equal(t, []string{"post1,post2,post3"}, user.Posts) + + b.Reset() +} + +func Benchmark_URIBinding_Bind(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + binder := &URIBinding{} + + type User struct { + Name string `uri:"name"` + Posts []string `uri:"posts"` + Age int `uri:"age"` + } + var user User + + paramsKey := []string{"name", "age", "posts"} + paramsVals := []string{"john", "42", "post1,post2,post3"} + paramsFunc := func(key string, _ ...string) string { + for i, k := range paramsKey { + if k == key { + return paramsVals[i] + } + } + + return "" + } + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(paramsKey, paramsFunc, &user) + } + + require.NoError(b, err) + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + require.Equal(b, []string{"post1,post2,post3"}, user.Posts) +} diff --git a/binder/xml.go b/binder/xml.go index 58da2b9b07..0c345a4236 100644 --- a/binder/xml.go +++ b/binder/xml.go @@ -1,23 +1,31 @@ package binder import ( - "encoding/xml" "fmt" + + "github.com/gofiber/utils/v2" ) -// xmlBinding is the XML binder for XML request body. -type xmlBinding struct{} +// XMLBinding is the XML binder for XML request body. +type XMLBinding struct { + XMLDecoder utils.XMLUnmarshal +} // Name returns the binding name. -func (*xmlBinding) Name() string { +func (*XMLBinding) Name() string { return "xml" } // Bind parses the request body as XML and returns the result. -func (*xmlBinding) Bind(body []byte, out any) error { - if err := xml.Unmarshal(body, out); err != nil { +func (b *XMLBinding) Bind(body []byte, out any) error { + if err := b.XMLDecoder(body, out); err != nil { return fmt.Errorf("failed to unmarshal xml: %w", err) } return nil } + +// Reset resets the XMLBinding binder. +func (b *XMLBinding) Reset() { + b.XMLDecoder = nil +} diff --git a/binder/xml_test.go b/binder/xml_test.go new file mode 100644 index 0000000000..879ccf0b78 --- /dev/null +++ b/binder/xml_test.go @@ -0,0 +1,135 @@ +package binder + +import ( + "encoding/xml" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_XMLBinding_Bind(t *testing.T) { + t.Parallel() + + b := &XMLBinding{ + XMLDecoder: xml.Unmarshal, + } + require.Equal(t, "xml", b.Name()) + + type Posts struct { + XMLName xml.Name `xml:"post"` + Title string `xml:"title"` + } + + type User struct { + Name string `xml:"name"` + Ignore string `xml:"-"` + Posts []Posts `xml:"posts>post"` + Age int `xml:"age"` + } + + user := new(User) + err := b.Bind([]byte(` + + john + 42 + ignore + + + post1 + + + post2 + + + + `), user) + require.NoError(t, err) + require.Equal(t, "john", user.Name) + require.Equal(t, 42, user.Age) + require.Empty(t, user.Ignore) + + require.Len(t, user.Posts, 2) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + + b.Reset() + require.Nil(t, b.XMLDecoder) +} + +func Test_XMLBinding_Bind_error(t *testing.T) { + t.Parallel() + b := &XMLBinding{ + XMLDecoder: xml.Unmarshal, + } + + type User struct { + Name string `xml:"name"` + Age int `xml:"age"` + } + + user := new(User) + err := b.Bind([]byte(` + + john + 42 + unknown + post"` + Age int `xml:"age"` + } + + user := new(User) + data := []byte(` + + john + 42 + ignore + + + post1 + + + post2 + + + + `) + + b.StartTimer() + + var err error + for i := 0; i < b.N; i++ { + err = binder.Bind(data, user) + } + require.NoError(b, err) + + user = new(User) + err = binder.Bind(data, user) + require.NoError(b, err) + + require.Equal(b, "john", user.Name) + require.Equal(b, 42, user.Age) + + require.Len(b, user.Posts, 2) + require.Equal(b, "post1", user.Posts[0].Title) + require.Equal(b, "post2", user.Posts[1].Title) +} diff --git a/ctx_test.go b/ctx_test.go index d025c24413..88b617eb5b 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -1433,7 +1433,9 @@ func Benchmark_Ctx_Fresh_LastModified(b *testing.B) { func Test_Ctx_Binders(t *testing.T) { t.Parallel() // setup - app := New() + app := New(Config{ + EnableSplittingOnParsers: true, + }) type TestEmbeddedStruct struct { Names []string `query:"names"` diff --git a/docs/api/bind.md b/docs/api/bind.md index e91fed4273..d2b336310d 100644 --- a/docs/api/bind.md +++ b/docs/api/bind.md @@ -18,7 +18,6 @@ Make copies or use the [**`Immutable`**](./ctx.md) setting instead. [Read more.. - [Body](#body) - [Form](#form) - [JSON](#json) - - [MultipartForm](#multipartform) - [XML](#xml) - [CBOR](#cbor) - [Cookie](#cookie) @@ -83,7 +82,7 @@ curl -X POST -F name=john -F pass=doe http://localhost:3000 ### Form -Binds the request form body to a struct. +Binds the request or multipart form body data to a struct. It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a form body with a field called `Pass`, you would use a struct field with `form:"pass"`. @@ -111,12 +110,16 @@ app.Post("/", func(c fiber.Ctx) error { }) ``` -Run tests with the following `curl` command: +Run tests with the following `curl` commands for both `application/x-www-form-urlencoded` and `multipart/form-data`: ```bash curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=john&pass=doe" localhost:3000 ``` +```bash +curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" localhost:3000 +``` + ### JSON Binds the request JSON body to a struct. @@ -153,43 +156,6 @@ Run tests with the following `curl` command: curl -X POST -H "Content-Type: application/json" --data "{\"name\":\"john\",\"pass\":\"doe\"}" localhost:3000 ``` -### MultipartForm - -Binds the request multipart form body to a struct. - -It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a multipart form body with a field called `Pass`, you would use a struct field with `form:"pass"`. - -```go title="Signature" -func (b *Bind) MultipartForm(out any) error -``` - -```go title="Example" -// Field names should start with an uppercase letter -type Person struct { - Name string `form:"name"` - Pass string `form:"pass"` -} - -app.Post("/", func(c fiber.Ctx) error { - p := new(Person) - - if err := c.Bind().MultipartForm(p); err != nil { - return err - } - - log.Println(p.Name) // john - log.Println(p.Pass) // doe - - // ... -}) -``` - -Run tests with the following `curl` command: - -```bash -curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" localhost:3000 -``` - ### XML Binds the request XML body to a struct. diff --git a/docs/api/fiber.md b/docs/api/fiber.md index 17cf3896b9..70320984da 100644 --- a/docs/api/fiber.md +++ b/docs/api/fiber.md @@ -83,6 +83,7 @@ app := fiber.New(fiber.Config{ | WriteBufferSize | `int` | Per-connection buffer size for responses' writing. | `4096` | | WriteTimeout | `time.Duration` | The maximum duration before timing out writes of the response. The default timeout is unlimited. | `nil` | | XMLEncoder | `utils.XMLMarshal` | Allowing for flexibility in using another XML library for encoding. | `xml.Marshal` | +| XMLDecoder | `utils.XMLUnmarshal` | Allowing for flexibility in using another XML library for decoding. | `xml.Unmarshal` | ## Server listening diff --git a/docs/whats_new.md b/docs/whats_new.md index eadc1afa4a..321df424d6 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -49,6 +49,7 @@ We have made several changes to the Fiber app, including: - `EnablePrintRoutes` - `ListenerNetwork` (previously `Network`) - **Trusted Proxy Configuration**: The `EnabledTrustedProxyCheck` has been moved to `app.Config.TrustProxy`, and `TrustedProxies` has been moved to `TrustProxyConfig.Proxies`. +- **XMLDecoder Config Property**: The `XMLDecoder` property has been added to allow usage of 3rd-party XML libraries in XML binder. ### New Methods diff --git a/redirect.go b/redirect.go index bc79314922..483272c7b5 100644 --- a/redirect.go +++ b/redirect.go @@ -146,10 +146,8 @@ func (r *Redirect) WithInput() *Redirect { oldInput := make(map[string]string) switch ctype { - case MIMEApplicationForm: + case MIMEApplicationForm, MIMEMultipartForm: _ = r.c.Bind().Form(oldInput) //nolint:errcheck // not needed - case MIMEMultipartForm: - _ = r.c.Bind().MultipartForm(oldInput) //nolint:errcheck // not needed default: _ = r.c.Bind().Query(oldInput) //nolint:errcheck // not needed }