diff --git a/README.md b/README.md index 23fcc2e6a..4fb87b7e4 100644 --- a/README.md +++ b/README.md @@ -41,8 +41,10 @@ Hertz [həːts] is a high-usability, high-performance and high-extensibility Gol The Hertz-Examples repository provides code out of the box. [more](https://www.cloudwego.io/zh/docs/hertz/tutorials/example/) ### Basic Features Contains introduction and use of general middleware, context selection, data binding, data rendering, direct access, logging, error handling. [more](https://www.cloudwego.io/zh/docs/hertz/tutorials/basic-feature/) +### Observability + Contains instrumentation, logging, tracing, monitoring, OpenTelemetry integration. [more](https://www.cloudwego.io/docs/hertz/tutorials/observability/) ### Service Governance - Contains tracer monitor. [more](https://www.cloudwego.io/zh/docs/hertz/tutorials/service-governance/) + Contains service registration and discovery extensions, Sentinel integration. [more](https://www.cloudwego.io/zh/docs/hertz/tutorials/service-governance/) ### Framework Extension Contains network library extensions. [more](https://www.cloudwego.io/zh/docs/hertz/tutorials/framework-exten/) ### Reference @@ -51,10 +53,10 @@ Hertz [həːts] is a high-usability, high-performance and high-extensibility Gol Frequently Asked Questions. [more](https://www.cloudwego.io/zh/docs/hertz/faq/) ## Performance Performance testing can only provide a relative reference. In production, there are many factors that can affect actual performance. - We provide the hertz-benchmark project to track and compare the performance of Hertz and other frameworks in different situations for reference. + We provide the [hertz-benchmark](https://github.com/cloudwego/hertz-benchmark) project to track and compare the performance of Hertz and other frameworks in different situations for reference. ## Related Projects - [Netpoll](https://github.com/cloudwego/netpoll): A high-performance network library. Hertz integrated by default. -- [Hertz-Contrib](https://github.com/hertz-contrib): A partial extension library of Hertz, which users can integrate into Hertz through options according to their needs. +- [Hertz-contrib](https://github.com/hertz-contrib): A partial extension library of Hertz, which users can integrate into Hertz through options according to their needs. - [Example](https://github.com/cloudwego/hertz-examples): Use examples of Hertz. ## Extensions @@ -120,7 +122,7 @@ Thank you for your contribution to Hertz! ## Landscapes

-   +  

CloudWeGo enriches the CNCF CLOUD NATIVE Landscape.

diff --git a/README_cn.md b/README_cn.md index 6468418b4..812c65973 100644 --- a/README_cn.md +++ b/README_cn.md @@ -41,8 +41,10 @@ Hertz[həːts] 是一个 Golang 微服务 HTTP 框架,在设计之初参考了 ### 用户指南 ### 基本特性 包含通用中间件的介绍和使用,上下文选择,数据绑定,数据渲染,直连访问,日志,错误处理,[详见文档](https://www.cloudwego.io/zh/docs/hertz/tutorials/basic-feature/) +### 可观测性 + 包含日志,链路追踪,埋点,监控,OpenTelemetry 集成,[详见文档](https://www.cloudwego.io/zh/docs/hertz/tutorials/observability/) ### 治理特性 - 包含 trace monitor,[详见文档](https://www.cloudwego.io/zh/docs/hertz/tutorials/service-governance/) + 包含服务注册与发现扩展,Sentinel 集成,[详见文档](https://www.cloudwego.io/zh/docs/hertz/tutorials/service-governance/) ### 框架扩展 包含网络库扩展,[详见文档](https://www.cloudwego.io/zh/docs/hertz/tutorials/framework-exten/) ### 参考 @@ -51,10 +53,10 @@ Hertz[həːts] 是一个 Golang 微服务 HTTP 框架,在设计之初参考了 常见问题排查,[详见文档](https://www.cloudwego.io/zh/docs/hertz/faq/) ## 框架性能 性能测试只能提供相对参考,工业场景下,有诸多因素可以影响实际的性能表现 - 我们提供了 hertz-benchmark 项目用来长期追踪和比较 Hertz 与其他框架在不同情况下的性能数据以供参考 + 我们提供了 [hertz-benchmark](https://github.com/cloudwego/hertz-benchmark) 项目用来长期追踪和比较 Hertz 与其他框架在不同情况下的性能数据以供参考 ## 相关项目 - [Netpoll](https://github.com/cloudwego/netpoll): 自研高性能网络库,Hertz 默认集成 -- [Hertz-Contrib](https://github.com/hertz-contrib): Hertz 扩展仓库,提供中间件、tracer 等能力 +- [hertz-Contrib](https://github.com/hertz-contrib): Hertz 扩展仓库,提供可观测、安全、流量治理、协议、HTTP 通用能力等扩展 - [Example](https://github.com/cloudwego/hertz-examples): Hertz 使用例子 ## 相关拓展 @@ -121,7 +123,7 @@ Hertz 基于[Apache License 2.0](https://github.com/cloudwego/hertz/blob/main/LI ## Landscapes

-   +  

CloudWeGo 丰富了 CNCF 云原生生态

diff --git a/_typos.toml b/_typos.toml index 3d3103e19..bb673bac6 100644 --- a/_typos.toml +++ b/_typos.toml @@ -18,4 +18,8 @@ HeaderReferer = "HeaderReferer" expectedReferer = "expectedReferer" Referer = "Referer" O_WRONLY = "O_WRONLY" -WRONLY = "WRONLY" \ No newline at end of file +WRONLY = "WRONLY" +ome = "ome" +ifModifiedSice = "ifModifiedSice" +hd = "hd" +pn = "pn" \ No newline at end of file diff --git a/cmd/hz/app/app.go b/cmd/hz/app/app.go index 2313bb8a4..116df93cc 100644 --- a/cmd/hz/app/app.go +++ b/cmd/hz/app/app.go @@ -179,8 +179,10 @@ func Init() *cli.App { noRecurseFlag := cli.BoolFlag{Name: "no_recurse", Usage: "Generate master model only.", Destination: &globalArgs.NoRecurse} forceNewFlag := cli.BoolFlag{Name: "force", Aliases: []string{"f"}, Usage: "Force new a project, which will overwrite the generated files", Destination: &globalArgs.ForceNew} enableExtendsFlag := cli.BoolFlag{Name: "enable_extends", Usage: "Parse 'extends' for thrift IDL", Destination: &globalArgs.EnableExtends} + sortRouterFlag := cli.BoolFlag{Name: "sort_router", Usage: "Sort router register code, to avoid code difference", Destination: &globalArgs.SortRouter} jsonEnumStrFlag := cli.BoolFlag{Name: "json_enumstr", Usage: "Use string instead of num for json enums when idl is thrift.", Destination: &globalArgs.JSONEnumStr} + queryEnumIntFlag := cli.BoolFlag{Name: "query_enumint", Usage: "Use num instead of string for query enum parameter.", Destination: &globalArgs.QueryEnumAsInt} unsetOmitemptyFlag := cli.BoolFlag{Name: "unset_omitempty", Usage: "Remove 'omitempty' tag for generated struct.", Destination: &globalArgs.UnsetOmitempty} protoCamelJSONTag := cli.BoolFlag{Name: "pb_camel_json_tag", Usage: "Convert Name style for json tag to camel(Only works protobuf).", Destination: &globalArgs.ProtobufCamelJSONTag} snakeNameFlag := cli.BoolFlag{Name: "snake_tag", Usage: "Use snake_case style naming for tags. (Only works for 'form', 'query', 'json')", Destination: &globalArgs.SnakeName} @@ -226,6 +228,7 @@ func Init() *cli.App { &noRecurseFlag, &forceNewFlag, &enableExtendsFlag, + &sortRouterFlag, &jsonEnumStrFlag, &unsetOmitemptyFlag, @@ -260,6 +263,7 @@ func Init() *cli.App { &optPkgFlag, &noRecurseFlag, &enableExtendsFlag, + &sortRouterFlag, &jsonEnumStrFlag, &unsetOmitemptyFlag, @@ -316,6 +320,7 @@ func Init() *cli.App { &enableExtendsFlag, &jsonEnumStrFlag, + &queryEnumIntFlag, &unsetOmitemptyFlag, &protoCamelJSONTag, &snakeNameFlag, diff --git a/cmd/hz/config/argument.go b/cmd/hz/config/argument.go index defc505cc..4e6d75ed5 100644 --- a/cmd/hz/config/argument.go +++ b/cmd/hz/config/argument.go @@ -57,6 +57,7 @@ type Argument struct { NeedGoMod bool JSONEnumStr bool + QueryEnumAsInt bool UnsetOmitempty bool ProtobufCamelJSONTag bool ProtocOptions []string // options to pass through to protoc @@ -71,6 +72,7 @@ type Argument struct { ForceNew bool SnakeStyleMiddleware bool EnableExtends bool + SortRouter bool CustomizeLayout string CustomizeLayoutData string diff --git a/cmd/hz/generator/client.go b/cmd/hz/generator/client.go index ff32c4b45..7545a5023 100644 --- a/cmd/hz/generator/client.go +++ b/cmd/hz/generator/client.go @@ -34,7 +34,12 @@ type ClientMethod struct { FormFileCode string } +type ClientConfig struct { + QueryEnumAsInt bool +} + type ClientFile struct { + Config ClientConfig FilePath string PackageName string ServiceName string @@ -64,6 +69,7 @@ func (pkgGen *HttpPackageGenerator) genClient(pkg *HttpPackage, clientDir string ServiceName: util.ToCamelCase(s.Name), ClientMethods: s.ClientMethods, BaseDomain: baseDomain, + Config: ClientConfig{QueryEnumAsInt: pkgGen.QueryEnumAsInt}, } if !isExist { err := pkgGen.TemplateGenerator.Generate(client, hertzClientTplName, hertzClientPath, false) diff --git a/cmd/hz/generator/custom_files.go b/cmd/hz/generator/custom_files.go index 1315948f1..d3a905ba7 100644 --- a/cmd/hz/generator/custom_files.go +++ b/cmd/hz/generator/custom_files.go @@ -259,7 +259,7 @@ func getInsertImportContent(tplInfo *Template, renderInfo interface{}, fileConte } imptSlice = append(imptSlice, [2]string{"", impt[1 : len(impt)-1]}) } else { // 3. alias "import" - idx := strings.Index(impt, "\n") + idx := strings.Index(impt, "\"") if idx == -1 { return nil, fmt.Errorf("error import format for file: %s", tplInfo.Path) } diff --git a/cmd/hz/generator/handler.go b/cmd/hz/generator/handler.go index 609e1d6cf..eeab0cf7a 100644 --- a/cmd/hz/generator/handler.go +++ b/cmd/hz/generator/handler.go @@ -161,7 +161,7 @@ func (pkgGen *HttpPackageGenerator) processHandler(handler *Handler, root *Route } handler.Imports[mm.PackageName] = mm } - err := root.Update(m, handler.PackageName, singleHandlerPackage) + err := root.Update(m, handler.PackageName, singleHandlerPackage, pkgGen.SortRouter) if err != nil { return err } diff --git a/cmd/hz/generator/package.go b/cmd/hz/generator/package.go index 01a6884a1..db4d9762b 100644 --- a/cmd/hz/generator/package.go +++ b/cmd/hz/generator/package.go @@ -63,11 +63,13 @@ type HttpPackageGenerator struct { IdlClientDir string // client dir for "client" command ForceClientDir string // client dir without namespace for "client" command BaseDomain string // request domain for "client" command + QueryEnumAsInt bool // client code use number for query parameter ServiceGenDir string NeedModel bool HandlerByMethod bool // generate handler files with method dimension SnakeStyleMiddleware bool // use snake name style for middleware + SortRouter bool loadedBackend Backend curModel *model.Model diff --git a/cmd/hz/generator/package_tpl.go b/cmd/hz/generator/package_tpl.go index 7f72b9b91..0efa0ccfe 100644 --- a/cmd/hz/generator/package_tpl.go +++ b/cmd/hz/generator/package_tpl.go @@ -324,7 +324,6 @@ type Option struct { type Options struct { hostUrl string - enumAsInt bool doer client.Doer header http.Header requestBodyBind bindRequestBodyFunc @@ -376,13 +375,6 @@ func WithResponseResultDecider(decider ResponseResultDecider) Option { }} } -// WithQueryEnumAsInt is used to set enum as int for query parameters -func WithQueryEnumAsInt(enable bool) Option { - return Option{func(op *Options) { - op.enumAsInt = enable - }} -} - func withHostUrl(HostUrl string) Option { return Option{func(op *Options) { op.hostUrl = HostUrl @@ -392,7 +384,6 @@ func withHostUrl(HostUrl string) Option { // underlying client type cli struct { hostUrl string - enumAsInt bool doer client.Doer header http.Header bindRequestBody bindRequestBodyFunc @@ -428,7 +419,6 @@ func newClient(opts *Options) (*cli, error) { c := &cli{ hostUrl: opts.hostUrl, - enumAsInt: opts.enumAsInt, doer: opts.doer, header: opts.header, bindRequestBody: opts.requestBodyBind, @@ -505,12 +495,13 @@ func (c *cli) execute(req *request) (*response, error) { // r get request func (c *cli) r() *request { return &request{ - queryParam: url.Values{}, - header: http.Header{}, - pathParam: map[string]string{}, - formParam: map[string]string{}, - fileParam: map[string]string{}, - client: c, + queryParam: url.Values{}, + header: http.Header{}, + pathParam: map[string]string{}, + formParam: map[string]string{}, + fileParam: map[string]string{}, + client: c, + queryEnumAsInt: {{.Config.QueryEnumAsInt}}, } } @@ -556,6 +547,7 @@ type request struct { client *cli url string method string + queryEnumAsInt bool queryParam url.Values header http.Header pathParam map[string]string @@ -601,10 +593,14 @@ func (r *request) setQueryParam(param string, value interface{}) *request { switch v.Kind() { case reflect.Slice, reflect.Array: for index := 0; index < v.Len(); index++ { - r.queryParam.Add(param, fmt.Sprint(v.Index(index).Interface())) + if r.queryEnumAsInt && (v.Index(index).Kind() == reflect.Int32 || v.Index(index).Kind() == reflect.Int64) { + r.queryParam.Add(param, fmt.Sprintf("%d", v.Index(index).Interface())) + } else { + r.queryParam.Add(param, fmt.Sprint(v.Index(index).Interface())) + } } case reflect.Int32, reflect.Int64: - if r.client.enumAsInt { + if r.queryEnumAsInt { r.queryParam.Add(param, fmt.Sprintf("%d", v.Interface())) } else { r.queryParam.Add(param, fmt.Sprint(v)) diff --git a/cmd/hz/generator/router.go b/cmd/hz/generator/router.go index 10f431657..37366ea23 100644 --- a/cmd/hz/generator/router.go +++ b/cmd/hz/generator/router.go @@ -20,10 +20,13 @@ import ( "bytes" "fmt" "io/ioutil" + "math" "path/filepath" "regexp" "sort" + "strconv" "strings" + "unicode" "github.com/cloudwego/hertz/cmd/hz/util" ) @@ -73,7 +76,7 @@ func (routerNode *RouterNode) Sort() { sort.Sort(routerNode.Children) } -func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg string) error { +func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg string, sortRouter bool) error { if method.Path == "" { return fmt.Errorf("empty path for method '%s'", method.Name) } @@ -81,12 +84,12 @@ func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg if paths[0] == "" { paths = paths[1:] } - parent, last := routerNode.FindNearest(paths) + parent, last := routerNode.FindNearest(paths, method.HTTPMethod, sortRouter) if last == len(paths) { return fmt.Errorf("path '%s' has been registered", method.Path) } name := util.ToVarName(paths[:last]) - parent.Insert(name, method, handlerType, paths[last:], handlerPkg) + parent.Insert(name, method, handlerType, paths[last:], handlerPkg, sortRouter) parent.Sort() return nil } @@ -192,7 +195,7 @@ func (routerNode *RouterNode) DFS(i int, hook func(layer int, node *RouterNode) var handlerPkgMap map[string]string -func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerType string, paths []string, handlerPkg string) { +func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerType string, paths []string, handlerPkg string, sortRouter bool) { cur := routerNode for i, p := range paths { c := &RouterNode{ @@ -229,6 +232,9 @@ func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerTyp cur.Children = make([]*RouterNode, 0, 1) } cur.Children = append(cur.Children, c) + if sortRouter { + sort.Sort(cur.Children) + } cur = c } } @@ -240,14 +246,21 @@ func getHttpMethod(method string) string { return strings.ToUpper(method) } -func (routerNode *RouterNode) FindNearest(paths []string) (*RouterNode, int) { +func (routerNode *RouterNode) FindNearest(paths []string, method string, sortRouter bool) (*RouterNode, int) { ns := len(paths) cur := routerNode i := 0 path := paths[i] for j := 0; j < len(cur.Children); j++ { c := cur.Children[j] + tmpMethod := "" // group do not have http method + if i == ns { // only i==ns, the path is http method node + tmpMethod = method + } if ("/" + path) == c.Path { + if sortRouter && !strings.EqualFold(c.HttpMethod, tmpMethod) { + continue + } i++ if i == ns { return cur, i - 1 @@ -270,17 +283,35 @@ func (c childrenRouterInfo) Len() int { // Less reports whether the element with // index i should sort before the element with index j. func (c childrenRouterInfo) Less(i, j int) bool { - ci := c[i].Path - if len(c[i].Children) != 0 { - ci = ci[1:] + if c[i].HttpMethod == "" && c[j].HttpMethod != "" { + return false + } + if c[i].HttpMethod != "" && c[j].HttpMethod == "" { + return true } - cj := c[j].Path - if len(c[j].Children) != 0 { - cj = cj[1:] + // remove non-litter char + // eg. /a -> a + // /:a -> a + ci := removeNonLetterPrefix(c[i].Path) + cj := removeNonLetterPrefix(c[j].Path) + + // if ci == cj, use HTTP mothod for sort, preventing sorting inconsistencies + if ci == cj { + return c[i].HttpMethod < c[j].HttpMethod } + return ci < cj } +func removeNonLetterPrefix(str string) string { + for i, char := range str { + if unicode.IsLetter(char) || unicode.IsDigit(char) { + return str[i:] + } + } + return str +} + // Swap swaps the elements with indexes i and j. func (c childrenRouterInfo) Swap(i, j int) { c[i], c[j] = c[j], c[i] @@ -341,6 +372,29 @@ func (pkgGen *HttpPackageGenerator) updateRegister(pkg, rDir, pkgName string) er return nil } +func appendMw(mws []string, mw string) ([]string, string) { + for i := 0; true; i++ { + if i == math.MaxInt { + break + } + if !stringsIncludes(mws, mw) { + mws = append(mws, mw) + break + } + mw += strconv.Itoa(i) + } + return mws, mw +} + +func stringsIncludes(strs []string, str string) bool { + for _, s := range strs { + if s == str { + return true + } + } + return false +} + func (pkgGen *HttpPackageGenerator) genRouter(pkg *HttpPackage, root *RouterNode, handlerPackage, routerDir, routerPackage string) error { err := root.DyeGroupName(pkgGen.SnakeStyleMiddleware) if err != nil { @@ -367,6 +421,31 @@ func (pkgGen *HttpPackageGenerator) genRouter(pkg *HttpPackage, root *RouterNode router.HandlerPackages = handlerMap } + if pkgGen.SnakeStyleMiddleware { // unique middleware name for SnakeStyleMiddleware + mws := []string{} + hook := func(layer int, node *RouterNode) error { + if len(node.Children) == 0 { + return nil + } + groupMwName := node.GroupMiddleware + handlerMwName := node.HandlerMiddleware + if len(groupMwName) != 0 { + mws, groupMwName = appendMw(mws, groupMwName) + } + if len(handlerMwName) != 0 { + mws, handlerMwName = appendMw(mws, handlerMwName) + } + if groupMwName != node.GroupMiddleware { + node.GroupMiddleware = groupMwName + } + if handlerMwName != node.HandlerMiddleware { + node.HandlerMiddleware = handlerMwName + } + return nil + } + root.DFS(0, hook) + } + // store router info pkg.RouterInfo = &router diff --git a/cmd/hz/meta/const.go b/cmd/hz/meta/const.go index 7441eaba3..c6018814b 100644 --- a/cmd/hz/meta/const.go +++ b/cmd/hz/meta/const.go @@ -19,7 +19,7 @@ package meta import "runtime" // Version hz version -const Version = "v0.8.1" +const Version = "v0.9.0" const DefaultServiceName = "hertz_service" diff --git a/cmd/hz/protobuf/plugin.go b/cmd/hz/protobuf/plugin.go index 260334012..d6f775cef 100644 --- a/cmd/hz/protobuf/plugin.go +++ b/cmd/hz/protobuf/plugin.go @@ -619,7 +619,9 @@ func (plugin *Plugin) genHttpPackage(ast *descriptorpb.FileDescriptorProto, deps IdlClientDir: plugin.IdlClientDir, ForceClientDir: args.ForceClientDir, BaseDomain: args.BaseDomain, + QueryEnumAsInt: args.QueryEnumAsInt, SnakeStyleMiddleware: args.SnakeStyleMiddleware, + SortRouter: args.SortRouter, } if args.ModelBackend != "" { diff --git a/cmd/hz/thrift/plugin.go b/cmd/hz/thrift/plugin.go index c315c9dcb..91fab02fc 100644 --- a/cmd/hz/thrift/plugin.go +++ b/cmd/hz/thrift/plugin.go @@ -150,7 +150,9 @@ func (plugin *Plugin) Run() int { IdlClientDir: util.SubDir(modelDir, pkgInfo.Package), ForceClientDir: args.ForceClientDir, BaseDomain: args.BaseDomain, + QueryEnumAsInt: args.QueryEnumAsInt, SnakeStyleMiddleware: args.SnakeStyleMiddleware, + SortRouter: args.SortRouter, } if args.ModelBackend != "" { sg.Backend = meta.Backend(args.ModelBackend) diff --git a/cmd/hz/thrift/plugin_test.go b/cmd/hz/thrift/plugin_test.go index 5e471b67e..1cd4d8b67 100644 --- a/cmd/hz/thrift/plugin_test.go +++ b/cmd/hz/thrift/plugin_test.go @@ -79,12 +79,21 @@ func TestRun(t *testing.T) { HandlerDir: handlerDir, RouterDir: routerDir, ModelDir: modelDir, + UseDir: args.Use, ClientDir: clientDir, TemplateGenerator: generator.TemplateGenerator{ OutputDir: args.OutDir, + Excludes: args.Excludes, }, - ProjPackage: pkg, - Options: options, + ProjPackage: pkg, + Options: options, + HandlerByMethod: args.HandlerByMethod, + CmdType: args.CmdType, + ForceClientDir: args.ForceClientDir, + BaseDomain: args.BaseDomain, + QueryEnumAsInt: args.QueryEnumAsInt, + SnakeStyleMiddleware: args.SnakeStyleMiddleware, + SortRouter: args.SortRouter, } if args.ModelBackend != "" { sg.Backend = meta.Backend(args.ModelBackend) diff --git a/images/wechat_group_cn.png b/images/wechat_group_cn.png deleted file mode 100644 index a9bf48ce1..000000000 Binary files a/images/wechat_group_cn.png and /dev/null differ diff --git a/images/wechat_group_en.png b/images/wechat_group_en.png deleted file mode 100644 index d7c73bcae..000000000 Binary files a/images/wechat_group_en.png and /dev/null differ diff --git a/pkg/app/context.go b/pkg/app/context.go index 9117e3b34..369ece736 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -236,6 +236,17 @@ type RequestContext struct { binder binding.Binder validator binding.StructValidator + exiled bool +} + +// Exile marks this RequestContext as not to be recycled. +// Experimental features: Use with caution, it may have a slight impact on performance. +func (ctx *RequestContext) Exile() { + ctx.exiled = true +} + +func (ctx *RequestContext) IsExiled() bool { + return ctx.exiled } // Flush is the shortcut for ctx.Response.GetHijackWriter().Flush(). @@ -1227,6 +1238,10 @@ func (ctx *RequestContext) Cookie(key string) []byte { // 4. ctx.SetCookie("user", "", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) // add response header ---> Set-Cookie: user=; max-age=10; domain=localhost; path=/; SameSite=Lax; func (ctx *RequestContext) SetCookie(name, value string, maxAge int, path, domain string, sameSite protocol.CookieSameSite, secure, httpOnly bool) { + ctx.setCookie(name, value, maxAge, path, domain, sameSite, secure, httpOnly, false) +} + +func (ctx *RequestContext) setCookie(name, value string, maxAge int, path, domain string, sameSite protocol.CookieSameSite, secure, httpOnly, partitioned bool) { if path == "" { path = "/" } @@ -1240,9 +1255,20 @@ func (ctx *RequestContext) SetCookie(name, value string, maxAge int, path, domai cookie.SetSecure(secure) cookie.SetHTTPOnly(httpOnly) cookie.SetSameSite(sameSite) + cookie.SetPartitioned(partitioned) ctx.Response.Header.SetCookie(cookie) } +// SetPartitionedCookie adds a partitioned cookie to the Response's headers. +// Use protocol.CookieSameSiteNoneMode for cross-site cookies to work. +// +// Usage: ctx.SetPartitionedCookie("user", "name", 10, "/", "localhost", protocol.CookieSameSiteNoneMode, true, true) +// +// This adds the response header: Set-Cookie: user=name; Max-Age=10; Domain=localhost; Path=/; HttpOnly; Secure; SameSite=None; Partitioned +func (ctx *RequestContext) SetPartitionedCookie(name, value string, maxAge int, path, domain string, sameSite protocol.CookieSameSite, secure, httpOnly bool) { + ctx.setCookie(name, value, maxAge, path, domain, sameSite, secure, httpOnly, true) +} + // UserAgent returns the value of the request user_agent. func (ctx *RequestContext) UserAgent() []byte { return ctx.Request.Header.UserAgent() diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 7cbe71c7d..1317a062c 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -1615,8 +1615,14 @@ func TestSetBinder(t *testing.T) { func TestRequestContext_SetCookie(t *testing.T) { c := NewContext(0) - c.SetCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteLaxMode, true, true) - assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=Lax", c.Response.Header.Get("Set-Cookie")) + c.SetCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteNoneMode, true, true) + assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=None", c.Response.Header.Get("Set-Cookie")) +} + +func TestRequestContext_SetPartitionedCookie(t *testing.T) { + c := NewContext(0) + c.SetPartitionedCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteNoneMode, true, true) + assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=None; Partitioned", c.Response.Header.Get("Set-Cookie")) } func TestRequestContext_SetCookiePathEmpty(t *testing.T) { diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 7919d8d60..225b8bd2e 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -417,10 +417,11 @@ func TestBind_ZeroValueBind(t *testing.T) { func TestBind_DefaultValueBind(t *testing.T) { var s struct { - A int `default:"15"` - B float64 `query:"b" default:"17"` - C []int `default:"15"` - D []string `default:"qwe"` + A int `default:"15"` + B float64 `query:"b" default:"17"` + C []int `default:"[15]"` + D []string `default:"['qwe','asd']"` + F [2]string `default:"['qwe','asd','zxc']"` } req := newMockRequest(). SetRequestURI("http://foobar.com") @@ -432,7 +433,23 @@ func TestBind_DefaultValueBind(t *testing.T) { assert.DeepEqual(t, 15, s.A) assert.DeepEqual(t, float64(17), s.B) assert.DeepEqual(t, 15, s.C[0]) + assert.DeepEqual(t, 2, len(s.D)) assert.DeepEqual(t, "qwe", s.D[0]) + assert.DeepEqual(t, "asd", s.D[1]) + assert.DeepEqual(t, 2, len(s.F)) + assert.DeepEqual(t, "qwe", s.F[0]) + assert.DeepEqual(t, "asd", s.F[1]) + + var s2 struct { + F [2]string `default:"['qwe']"` + } + err = DefaultBinder().Bind(req.Req, &s2, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 2, len(s2.F)) + assert.DeepEqual(t, "qwe", s2.F[0]) + assert.DeepEqual(t, "", s2.F[1]) var d struct { D [2]string `default:"qwe"` @@ -1549,6 +1566,32 @@ func TestBind_Issue1015(t *testing.T) { assert.DeepEqual(t, "asd", result.A) } +func TestBind_JSONWithDefault(t *testing.T) { + type Req struct { + J1 string `json:"j1" default:"j1default"` + } + + req := newMockRequest(). + SetJSONContentType(). + SetBody([]byte(`{"j1":"j1"}`)) + var result Req + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "j1", result.J1) + + result = Req{} + req = newMockRequest(). + SetJSONContentType(). + SetBody([]byte(`{"j2":"j2"}`)) + err = DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "j1default", result.J1) +} + func TestBind_WithoutPreBindForTag(t *testing.T) { type BaseQuery struct { Action string `query:"Action" binding:"required"` diff --git a/pkg/app/server/binding/internal/decoder/base_type_decoder.go b/pkg/app/server/binding/internal/decoder/base_type_decoder.go index ece04f737..9c5cb1200 100644 --- a/pkg/app/server/binding/internal/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/base_type_decoder.go @@ -69,14 +69,17 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { - defaultValue = tagInfo.Default if tagInfo.Key == jsonTag { + defaultValue = tagInfo.Default found := checkRequireJSON(req, tagInfo) if found { err = nil } else { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request body does not have this parameter '%s'", d.fieldName, tagInfo.JSONName) } + if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { + defaultValue = "" + } } continue } @@ -94,7 +97,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa return err } if len(text) == 0 && len(defaultValue) != 0 { - text = defaultValue + text = toDefaultValue(d.fieldType, defaultValue) } if !exist && len(text) == 0 { return nil diff --git a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go index 19efa46ae..ab343c811 100644 --- a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go @@ -60,7 +60,12 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param. var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { - defaultValue = tagInfo.Default + if tagInfo.Key == jsonTag { + defaultValue = tagInfo.Default + if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { + defaultValue = "" + } + } continue } text, exist = tagInfo.Getter(req, params, tagInfo.Value) @@ -73,7 +78,7 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param. return nil } if len(text) == 0 && len(defaultValue) != 0 { - text = defaultValue + text = toDefaultValue(d.fieldType, defaultValue) } v, err := d.decodeFunc(req, params, text) diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index f18a68127..4c36bcea9 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -124,7 +124,7 @@ func getFieldDecoder(pInfo parentInfos, field reflect.StructField, index int, by // JSONName is like 'a.b.c' for 'required validate' fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, pInfo.JSONName, config) if len(fieldTagInfos) == 0 && !config.DisableDefaultTag { - fieldTagInfos = getDefaultFieldTags(field) + fieldTagInfos, newParentJSONName = getDefaultFieldTags(field, pInfo.JSONName) } if len(byTag) != 0 { fieldTagInfos = getFieldTagInfoByTag(field, byTag) diff --git a/pkg/app/server/binding/internal/decoder/gjson_required.go b/pkg/app/server/binding/internal/decoder/gjson_required.go index 95e3c4fc5..130d65d8f 100644 --- a/pkg/app/server/binding/internal/decoder/gjson_required.go +++ b/pkg/app/server/binding/internal/decoder/gjson_required.go @@ -47,3 +47,12 @@ func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool { } return true } + +func keyExist(req *protocol.Request, tagInfo TagInfo) bool { + ct := bytesconv.B2s(req.Header.ContentType()) + if utils.FilterContentType(ct) != consts.MIMEApplicationJSON { + return false + } + result := gjson.GetBytes(req.Body(), tagInfo.JSONName) + return result.Exists() +} diff --git a/pkg/app/server/binding/internal/decoder/map_type_decoder.go b/pkg/app/server/binding/internal/decoder/map_type_decoder.go index 31fe85a1b..59fed7716 100644 --- a/pkg/app/server/binding/internal/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/map_type_decoder.go @@ -61,14 +61,17 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { - defaultValue = tagInfo.Default if tagInfo.Key == jsonTag { + defaultValue = tagInfo.Default found := checkRequireJSON(req, tagInfo) if found { err = nil } else { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) } + if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { + defaultValue = "" + } } continue } @@ -86,7 +89,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par return err } if len(text) == 0 && len(defaultValue) != 0 { - text = defaultValue + text = toDefaultValue(d.fieldType, defaultValue) } if !exist && len(text) == 0 { return nil diff --git a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go index fc5c9814f..c2887d1c4 100644 --- a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go @@ -61,16 +61,20 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P var texts []string var defaultValue string var bindRawBody bool + var isDefault bool for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { - defaultValue = tagInfo.Default if tagInfo.Key == jsonTag { + defaultValue = tagInfo.Default found := checkRequireJSON(req, tagInfo) if found { err = nil } else { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) } + if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { // + defaultValue = "" + } } continue } @@ -91,7 +95,9 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P return err } if len(texts) == 0 && len(defaultValue) != 0 { + defaultValue = toDefaultValue(d.fieldType, defaultValue) texts = append(texts, defaultValue) + isDefault = true } if len(texts) == 0 { return nil @@ -113,7 +119,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P } if d.isArray { - if len(texts) != field.Len() { + if len(texts) != field.Len() && !isDefault { return fmt.Errorf("%q is not valid value for %s", texts, field.Type().String()) } } else { @@ -135,6 +141,13 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P elemKind = t.Kind() ptrDepth++ } + if isDefault { + err = hJson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface()) + if err != nil { + return fmt.Errorf("using '%s' to unmarshal field '%s: %s' failed, %v", texts[0], d.fieldName, d.fieldType.String(), err) + } + return nil + } for idx, text := range texts { var vv reflect.Value @@ -218,33 +231,3 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn isArray: isArray, }}, nil } - -func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params, config *DecodeConfig) (v reflect.Value, err error) { - v = reflect.New(elemType).Elem() - if customizedFunc, exist := config.TypeUnmarshalFuncs[elemType]; exist { - val, err := customizedFunc(req, params, text) - if err != nil { - return reflect.Value{}, err - } - return val, nil - } - switch elemType.Kind() { - case reflect.Struct: - err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) - case reflect.Map: - err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) - case reflect.Array, reflect.Slice: - // do nothing - default: - decoder, err := SelectTextDecoder(elemType) - if err != nil { - return reflect.Value{}, fmt.Errorf("unsupported type %s for slice/array", elemType.String()) - } - err = decoder.UnmarshalString(text, v, config.LooseZeroMode) - if err != nil { - return reflect.Value{}, fmt.Errorf("unable to decode '%s' as %s: %w", text, elemType.String(), err) - } - } - - return v, err -} diff --git a/pkg/app/server/binding/internal/decoder/sonic_required.go b/pkg/app/server/binding/internal/decoder/sonic_required.go index e408901a9..61f3d8d68 100644 --- a/pkg/app/server/binding/internal/decoder/sonic_required.go +++ b/pkg/app/server/binding/internal/decoder/sonic_required.go @@ -60,3 +60,12 @@ func stringSliceForInterface(s string) (ret []interface{}) { } return } + +func keyExist(req *protocol.Request, tagInfo TagInfo) bool { + ct := bytesconv.B2s(req.Header.ContentType()) + if utils.FilterContentType(ct) != consts.MIMEApplicationJSON { + return false + } + node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName)...) + return node.Exists() +} diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go index 3030f2ac6..75f3ae4aa 100644 --- a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -38,14 +38,17 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { - defaultValue = tagInfo.Default if tagInfo.Key == jsonTag { + defaultValue = tagInfo.Default found := checkRequireJSON(req, tagInfo) if found { err = nil } else { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) } + if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { + defaultValue = "" + } } continue } @@ -63,7 +66,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. return err } if len(text) == 0 && len(defaultValue) != 0 { - text = defaultValue + text = toDefaultValue(d.fieldType, defaultValue) } if !exist && len(text) == 0 { return nil diff --git a/pkg/app/server/binding/internal/decoder/tag.go b/pkg/app/server/binding/internal/decoder/tag.go index 6df09aaa3..8ca5ae0e6 100644 --- a/pkg/app/server/binding/internal/decoder/tag.go +++ b/pkg/app/server/binding/internal/decoder/tag.go @@ -87,7 +87,7 @@ func lookupFieldTags(field reflect.StructField, parentJSONName string, config *D tagValue = field.Name } skip := false - jsonName := "" + jsonName := parentJSONName + "." + field.Name if tag == jsonTag { jsonName = parentJSONName + "." + tagValue } @@ -120,7 +120,7 @@ func lookupFieldTags(field reflect.StructField, parentJSONName string, config *D return tagInfos, newParentJSONName, needValidate } -func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) { +func getDefaultFieldTags(field reflect.StructField, parentJSONName string) (tagInfos []TagInfo, newParentJSONName string) { defaultVal := "" if val, ok := field.Tag.Lookup(defaultTag); ok { defaultVal = val @@ -128,8 +128,10 @@ func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) { tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, fileNameTag} for _, tag := range tags { - tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name, Default: defaultVal}) + jsonName := strings.TrimPrefix(parentJSONName+"."+field.Name, ".") + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name, Default: defaultVal, JSONName: jsonName}) } + newParentJSONName = strings.TrimPrefix(parentJSONName+"."+field.Name, ".") return } diff --git a/pkg/app/server/binding/internal/decoder/util.go b/pkg/app/server/binding/internal/decoder/util.go new file mode 100644 index 000000000..be141c282 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/util.go @@ -0,0 +1,76 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package decoder + +import ( + "fmt" + "reflect" + "strings" + + "github.com/cloudwego/hertz/internal/bytesconv" + hJson "github.com/cloudwego/hertz/pkg/common/json" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +const ( + specialChar = "\x07" +) + +// toDefaultValue will preprocess the default value and transfer it to be standard format +func toDefaultValue(typ reflect.Type, defaultValue string) string { + switch typ.Kind() { + case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct: + // escape single quote and double quote, replace single quote with double quote + defaultValue = strings.Replace(defaultValue, `"`, `\"`, -1) + defaultValue = strings.Replace(defaultValue, `\'`, specialChar, -1) + defaultValue = strings.Replace(defaultValue, `'`, `"`, -1) + defaultValue = strings.Replace(defaultValue, specialChar, `'`, -1) + } + return defaultValue +} + +// stringToValue is used to dynamically create reflect.Value for 'text' +func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params, config *DecodeConfig) (v reflect.Value, err error) { + v = reflect.New(elemType).Elem() + if customizedFunc, exist := config.TypeUnmarshalFuncs[elemType]; exist { + val, err := customizedFunc(req, params, text) + if err != nil { + return reflect.Value{}, err + } + return val, nil + } + switch elemType.Kind() { + case reflect.Struct: + err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + case reflect.Map: + err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + case reflect.Array, reflect.Slice: + // do nothing + default: + decoder, err := SelectTextDecoder(elemType) + if err != nil { + return reflect.Value{}, fmt.Errorf("unsupported type %s for slice/array", elemType.String()) + } + err = decoder.UnmarshalString(text, v, config.LooseZeroMode) + if err != nil { + return reflect.Value{}, fmt.Errorf("unable to decode '%s' as %s: %w", text, elemType.String(), err) + } + } + + return v, err +} diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 82221745c..1533b5aea 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -539,38 +539,37 @@ func TestPath(t *testing.T) { assert.DeepEqual(t, (*int64)(nil), recv.Z) } -// FIXME: 复杂类型的默认值,暂时先不做,低优 func TestDefault(t *testing.T) { - //type S struct { - // SS string `json:"ss"` - //} + type S struct { + SS string `json:"ss"` + } type Recv struct { X **struct { - A []string `path:"a" json:"a"` - B int32 `path:"b" default:"32"` - C bool `json:"c" default:"true"` - D *float32 `default:"123.4"` - // E *[]string `default:"['a','b','c','d,e,f']"` - // F map[string]string `default:"{'a':'\"\\'1','\"b':'c','c':'2'}"` - // G map[string]int64 `default:"{'a':1,'b':2,'c':3}"` - // H map[string]float64 `default:"{'a':0.1,'b':1.2,'c':2.3}"` - // I map[string]float64 `default:"{'\"a\"':0.1,'b':1.2,'c':2.3}"` - Empty string `default:""` - Null string `default:""` - CommaSpace string `default:",a:c "` - Dash string `default:"-"` + A []string `path:"a" json:"a"` + B int32 `path:"b" default:"32"` + C bool `json:"c" default:"true"` + D *float32 `default:"123.4"` + E *[]string `default:"['a','b','c','d,e,f']"` + F map[string]string `default:"{'a':'\"\\'1','\"b':'c','c':'2'}"` + G map[string]int64 `default:"{'a':1,'b':2,'c':3}"` + H map[string]float64 `default:"{'a':0.1,'b':1.2,'c':2.3}"` + I map[string]float64 `default:"{'\"a\"':0.1,'b':1.2,'c':2.3}"` + Empty string `default:""` + Null string `default:""` + CommaSpace string `default:",a:c "` + Dash string `default:"-"` // InvalidInt int `default:"abc"` // InvalidMap map[string]string `default:"abc"` } - Y string `json:"y" default:"y1"` - Z int64 - W string `json:"w"` - // V []int64 `json:"u" default:"[1,2,3]"` - // U []float32 `json:"u" default:"[1.1,2,3]"` - T *string `json:"t" default:"t1"` - // S S `default:"{'ss':'test'}"` - // O *S `default:"{'ss':'test2'}"` - // Complex map[string][]map[string][]int64 `default:"{'a':[{'aa':[1,2,3], 'bb':[4,5]}],'b':[{}]}"` + Y string `json:"y" default:"y1"` + Z int64 + W string `json:"w"` + V []int64 `json:"v" default:"[1,2,3]"` + U []float32 `json:"u" default:"[1.1,2,3]"` + T *string `json:"t" default:"t1"` + S S `default:"{'ss':'test'}"` + O *S `default:"{'ss':'test2'}"` + Complex map[string][]map[string][]int64 `default:"{'a':[{'aa':[1,2,3], 'bb':[4,5]}],'b':[{}]}"` } bodyReader := strings.NewReader(`{ @@ -601,11 +600,11 @@ func TestDefault(t *testing.T) { assert.DeepEqual(t, int32(32), (**recv.X).B) assert.DeepEqual(t, true, (**recv.X).C) assert.DeepEqual(t, float32(123.4), *(**recv.X).D) - // assert.DeepEqual(t, []string{"a", "b", "c", "d,e,f"}, *(**recv.X).E) - // assert.DeepEqual(t, map[string]string{"a": "\"'1", "\"b": "c", "c": "2"}, (**recv.X).F) - // assert.DeepEqual(t, map[string]int64{"a": 1, "b": 2, "c": 3}, (**recv.X).G) - // assert.DeepEqual(t, map[string]float64{"a": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).H) - // assert.DeepEqual(t, map[string]float64{"\"a\"": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).I) + assert.DeepEqual(t, []string{"a", "b", "c", "d,e,f"}, *(**recv.X).E) + assert.DeepEqual(t, map[string]string{"a": "\"'1", "\"b": "c", "c": "2"}, (**recv.X).F) + assert.DeepEqual(t, map[string]int64{"a": 1, "b": 2, "c": 3}, (**recv.X).G) + assert.DeepEqual(t, map[string]float64{"a": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).H) + assert.DeepEqual(t, map[string]float64{"\"a\"": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).I) assert.DeepEqual(t, "", (**recv.X).Empty) assert.DeepEqual(t, "", (**recv.X).Null) assert.DeepEqual(t, ",a:c ", (**recv.X).CommaSpace) @@ -615,11 +614,11 @@ func TestDefault(t *testing.T) { assert.DeepEqual(t, "y1", recv.Y) assert.DeepEqual(t, "t1", *recv.T) assert.DeepEqual(t, int64(6), recv.Z) - // assert.DeepEqual(t, []int64{1, 2, 3}, recv.V) - // assert.DeepEqual(t, []float32{1.1, 2, 3}, recv.U) - // assert.DeepEqual(t, S{SS: "test"}, recv.S) - // assert.DeepEqual(t, &S{SS: "test2"}, recv.O) - // assert.DeepEqual(t, map[string][]map[string][]int64{"a": {{"aa": {1, 2, 3}, "bb": []int64{4, 5}}}, "b": {map[string][]int64{}}}, recv.Complex) + assert.DeepEqual(t, []int64{1, 2, 3}, recv.V) + assert.DeepEqual(t, []float32{1.1, 2, 3}, recv.U) + assert.DeepEqual(t, S{SS: "test"}, recv.S) + assert.DeepEqual(t, &S{SS: "test2"}, recv.O) + assert.DeepEqual(t, map[string][]map[string][]int64{"a": {{"aa": {1, 2, 3}, "bb": []int64{4, 5}}}, "b": {map[string][]int64{}}}, recv.Complex) } func TestAuto(t *testing.T) { @@ -1196,29 +1195,29 @@ func TestIssue26(t *testing.T) { assert.DeepEqual(t, recv, recv2) } -// FIXME: after 'json unmarshal', the default value will change it -//func TestDefault2(t *testing.T) { -// type Recv struct { -// X **struct { -// Dash string `default:"xxxx"` -// } -// } -// bodyReader := strings.NewReader(`{ -// "X": { -// "Dash": "hello Dash" -// } -// }`) -// header := make(http.Header) -// header.Set("Content-Type", consts.MIMEApplicationJSON) -// req := newRequest("", header, nil, bodyReader) -// recv := new(Recv) -// -// err := DefaultBinder().Bind(req.Req, nil, recv) -// if err != nil { -// t.Error(err) -// } -// assert.DeepEqual(t, "hello Dash", (**recv.X).Dash) -//} +// BUGFIX: after 'json unmarshal', the default value will change it +func TestDefault2(t *testing.T) { + type Recv struct { + X **struct { + Dash string `default:"xxxx"` + } + } + bodyReader := strings.NewReader(`{ + "X": { + "Dash": "hello Dash" + } + }`) + header := make(http.Header) + header.Set("Content-Type", consts.MIMEApplicationJSON) + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "hello Dash", (**recv.X).Dash) +} type ( files map[string][]file diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index 4a9cee3d0..3352d34c8 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -196,7 +196,10 @@ func formatAsDate(t time.Time) string { } // copied from router -var default400Body = []byte("400 bad request") +var ( + default400Body = []byte("400 bad request") + requiredHostBody = []byte("missing required Host header") +) func TestServer_Use(t *testing.T) { router := New() @@ -284,6 +287,13 @@ func TestNotAbsolutePath(t *testing.T) { func TestNotAbsolutePathWithRawPath(t *testing.T) { engine := New(WithHostPorts("127.0.0.1:9991"), WithUseRawPath(true)) + const ( + MiddlewareKey = "middleware_key" + MiddlewareValue = "middleware_value" + ) + engine.Use(func(c context.Context, ctx *app.RequestContext) { + ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue) + }) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { }) engine.POST("/a", func(c context.Context, ctx *app.RequestContext) { @@ -301,6 +311,8 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { engine.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) assert.DeepEqual(t, default400Body, ctx.Response.Body()) + gh := ctx.Response.Header.Get(MiddlewareKey) + assert.DeepEqual(t, MiddlewareValue, gh) s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr = mock.NewZeroCopyReader(s) @@ -312,6 +324,49 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { engine.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) assert.DeepEqual(t, default400Body, ctx.Response.Body()) + gh = ctx.Response.Header.Get(MiddlewareKey) + assert.DeepEqual(t, MiddlewareValue, gh) +} + +func TestNotValidHost(t *testing.T) { + engine := New(WithHostPorts("127.0.0.1:9992")) + const ( + MiddlewareKey = "middleware_key" + MiddlewareValue = "middleware_value" + ) + engine.Use(func(c context.Context, ctx *app.RequestContext) { + ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue) + }) + engine.POST("/", func(c context.Context, ctx *app.RequestContext) { + }) + engine.POST("/a", func(c context.Context, ctx *app.RequestContext) { + }) + + s := "POST ?a=b HTTP/1.1\r\nHost: \r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + zr := mock.NewZeroCopyReader(s) + + ctx := app.NewContext(0) + if err := req.Read(&ctx.Request, zr); err != nil { + t.Fatalf("unexpected error: %s", err) + } + engine.ServeHTTP(context.Background(), ctx) + assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) + assert.DeepEqual(t, requiredHostBody, ctx.Response.Body()) + gh := ctx.Response.Header.Get(MiddlewareKey) + assert.DeepEqual(t, MiddlewareValue, gh) + + s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + zr = mock.NewZeroCopyReader(s) + + ctx = app.NewContext(0) + if err := req.Read(&ctx.Request, zr); err != nil { + t.Fatalf("unexpected error: %s", err) + } + engine.ServeHTTP(context.Background(), ctx) + assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) + assert.DeepEqual(t, requiredHostBody, ctx.Response.Body()) + gh = ctx.Response.Header.Get(MiddlewareKey) + assert.DeepEqual(t, MiddlewareValue, gh) } func TestWithBasePath(t *testing.T) { diff --git a/pkg/protocol/header.go b/pkg/protocol/header.go index 278b137b5..88885483c 100644 --- a/pkg/protocol/header.go +++ b/pkg/protocol/header.go @@ -1766,6 +1766,8 @@ func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool { // Transfer-Encoding is managed automatically. return true } else if utils.CaseInsensitiveCompare(bytestr.StrTrailer, key) { + // copy value to avoid panic + value = append(h.bufKV.value[:0], value...) h.Trailer().SetTrailers(value) return true } diff --git a/pkg/protocol/header_test.go b/pkg/protocol/header_test.go index 35cc0011f..ba2bd6998 100644 --- a/pkg/protocol/header_test.go +++ b/pkg/protocol/header_test.go @@ -69,18 +69,18 @@ func TestResponseHeaderSetHeaderLength(t *testing.T) { func TestSetNoHTTP11(t *testing.T) { rh := ResponseHeader{} - rh.SetNoHTTP11(true) + rh.SetProtocol(consts.HTTP10) assert.DeepEqual(t, consts.HTTP10, rh.protocol) - rh.SetNoHTTP11(false) + rh.SetProtocol(consts.HTTP11) assert.DeepEqual(t, consts.HTTP11, rh.protocol) assert.True(t, rh.IsHTTP11()) h := RequestHeader{} - h.SetNoHTTP11(true) + h.SetProtocol(consts.HTTP10) assert.DeepEqual(t, consts.HTTP10, h.protocol) - h.SetNoHTTP11(false) + h.SetProtocol(consts.HTTP11) assert.DeepEqual(t, consts.HTTP11, h.protocol) assert.True(t, h.IsHTTP11()) } @@ -101,6 +101,17 @@ func TestSetContentLengthBytes(t *testing.T) { assert.DeepEqual(t, rh.contentLengthBytes, []byte("foo")) } +func TestInitContentLengthWithValue(t *testing.T) { + initLength := 100 + h := RequestHeader{} + h.InitContentLengthWithValue(initLength) + assert.DeepEqual(t, h.contentLength, initLength) + + rh := ResponseHeader{} + rh.InitContentLengthWithValue(initLength) + assert.DeepEqual(t, rh.contentLength, initLength) +} + func TestSetContentEncoding(t *testing.T) { rh := ResponseHeader{} rh.SetContentEncoding("gzip") @@ -171,24 +182,71 @@ func TestResponseHeaderGet(t *testing.T) { assert.DeepEqual(t, val, rightVal) } +func TestRequestHeaderGetAll(t *testing.T) { + h := RequestHeader{} + h.Set("Foo-Bar", "foo") + h.Add("Foo-Bar", "bar") + h.Add("Foo-Bar", "foo-bar") + values := h.GetAll("Foo-Bar") + assert.DeepEqual(t, values, []string{"foo", "bar", "foo-bar"}) +} + +func TestResponseHeaderGetAll(t *testing.T) { + h := ResponseHeader{} + h.Set("Foo-Bar", "foo") + h.Add("Foo-Bar", "bar") + h.Add("Foo-Bar", "foo-bar") + values := h.GetAll("Foo-Bar") + assert.DeepEqual(t, values, []string{"foo", "bar", "foo-bar"}) +} + func TestRequestHeaderVisitAll(t *testing.T) { h := RequestHeader{} h.Set("xxx", "yyy") h.Set("xxx2", "yyy2") - h.VisitAll( - func(k, v []byte) { - key := string(k) - value := string(v) - if key != "Xxx" && key != "Xxx2" { - t.Fatalf("Unexpected %v. Expected %v", key, "xxx or yyy") - } - if key == "Xxx" && value != "yyy" { - t.Fatalf("Unexpected %v. Expected %v", value, "yyy") - } - if key == "Xxx2" && value != "yyy2" { - t.Fatalf("Unexpected %v. Expected %v", value, "yyy2") - } - }) + h.SetHost("host") + h.SetContentLengthBytes([]byte("content-length")) + h.Set(consts.HeaderContentType, "content-type") + h.Set(consts.HeaderUserAgent, "user-agent") + err := h.Trailer().SetTrailers([]byte("foo, bar")) + if err != nil { + t.Fatalf("Set trailer err %v", err) + } + h.SetCookie("foo", "bar") + h.Set(consts.HeaderConnection, "close") + h.VisitAll(func(k, v []byte) { + key := string(k) + value := string(v) + switch key { + case consts.HeaderHost: + assert.DeepEqual(t, value, "host") + case consts.HeaderContentLength: + assert.DeepEqual(t, value, "content-length") + case consts.HeaderContentType: + assert.DeepEqual(t, value, "content-type") + case consts.HeaderUserAgent: + assert.DeepEqual(t, value, "user-agent") + case consts.HeaderTrailer: + assert.DeepEqual(t, value, "Foo, Bar") + case consts.HeaderCookie: + assert.DeepEqual(t, value, "foo=bar") + case consts.HeaderConnection: + assert.DeepEqual(t, value, "close") + case "Xxx": + assert.DeepEqual(t, value, "yyy") + case "Xxx2": + assert.DeepEqual(t, value, "yyy2") + default: + t.Fatalf("Unexpected key %v", key) + } + }) +} + +func TestRequestHeaderCookie(t *testing.T) { + var h RequestHeader + h.SetCookie("foo", "bar") + cookie := h.Cookie("foo") + assert.DeepEqual(t, []byte("bar"), cookie) } func TestRequestHeaderCookies(t *testing.T) { @@ -215,6 +273,7 @@ func TestRequestHeaderDel(t *testing.T) { h.Set(consts.HeaderServer, "aaabbb") h.Set(consts.HeaderContentLength, "1123") h.Set(consts.HeaderTrailer, "foo, bar") + h.Set(consts.HeaderUserAgent, "foo-bar") h.SetHost("foobar") h.SetCookie("foo", "bar") @@ -226,6 +285,7 @@ func TestRequestHeaderDel(t *testing.T) { h.del([]byte("Set-Cookie")) h.del([]byte("Host")) h.del([]byte(consts.HeaderTrailer)) + h.del([]byte(consts.HeaderUserAgent)) h.DelCookie("foo") h.Del("ccc") @@ -269,6 +329,10 @@ func TestRequestHeaderDel(t *testing.T) { if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } + hv = h.Peek(consts.HeaderUserAgent) + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } if h.ContentLength() != 0 { t.Fatalf("unexpected content-length: %d. Expecting 0", h.ContentLength()) } @@ -584,6 +648,16 @@ func TestRequestHeaderDelAllCookies(t *testing.T) { } } +func TestResponseHeaderDelAllCookies(t *testing.T) { + var h ResponseHeader + h.SetCanonical([]byte(consts.HeaderSetCookie), []byte("foo")) + h.DelAllCookies() + hv := h.FullCookie() + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } +} + func TestRequestHeaderSetNoDefaultContentType(t *testing.T) { var h RequestHeader h.SetMethod(http.MethodPost) @@ -732,3 +806,13 @@ func TestResponseHeaderDateEmpty(t *testing.T) { t.Fatalf("ResponseDateNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", h, headers) //nolint:govet } } + +func TestSetTrailerWithROString(t *testing.T) { + h := &RequestHeader{} + h.Add(consts.HeaderTrailer, "foo,bar,hertz") + assert.DeepEqual(t, "Foo, Bar, Hertz", h.Get(consts.HeaderTrailer)) + + h1 := &ResponseHeader{} + h1.Add(consts.HeaderTrailer, "foo,bar,hertz") + assert.DeepEqual(t, "Foo, Bar, Hertz", h1.Get(consts.HeaderTrailer)) +} diff --git a/pkg/protocol/http1/proxy/proxy.go b/pkg/protocol/http1/proxy/proxy.go index f8bae7608..2b243ff04 100644 --- a/pkg/protocol/http1/proxy/proxy.go +++ b/pkg/protocol/http1/proxy/proxy.go @@ -81,13 +81,11 @@ func SetupProxy(conn network.Conn, addr string, proxyURI *protocol.URI, tlsConfi defer close(didReadResponse) err = reqI.Write(connectReq, conn) - if err != nil { return } err = conn.Flush() - if err != nil { return } diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index a5d33f13d..3ea659603 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -114,9 +114,6 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { defer func() { if s.EnableTrace { - if shouldRecordInTraceError(err) { - ctx.GetTraceInfo().Stats().SetError(err) - } // in case of error, we need to trigger all events if eventsToTrigger != nil { for last := eventsToTrigger.pop(); last != nil; last = eventsToTrigger.pop() { @@ -124,8 +121,11 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { } s.eventStackPool.Put(eventsToTrigger) } - - traceCtl.DoFinish(cc, ctx, err) + if shouldRecordInTraceError(err) { + traceCtl.DoFinish(cc, ctx, err) + } else { + traceCtl.DoFinish(cc, ctx, nil) + } } // Hijack may release and close the connection already @@ -133,6 +133,11 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { zr.Release() //nolint:errcheck zr = nil } + + if ctx.IsExiled() { + return + } + ctx.Reset() s.Core.GetCtxPool().Put(ctx) }() @@ -384,7 +389,11 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { } // general case if s.EnableTrace { - traceCtl.DoFinish(cc, ctx, err) + if shouldRecordInTraceError(err) { + traceCtl.DoFinish(cc, ctx, err) + } else { + traceCtl.DoFinish(cc, ctx, nil) + } } ctx.ResetWithoutConn() diff --git a/pkg/protocol/http1/server_test.go b/pkg/protocol/http1/server_test.go index dc8790a97..2263ece77 100644 --- a/pkg/protocol/http1/server_test.go +++ b/pkg/protocol/http1/server_test.go @@ -70,6 +70,7 @@ func TestTraceEventCompleted(t *testing.T) { assert.False(t, traceInfo.Stats().GetEvent(stats.WriteStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.WriteFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil()) + assert.Nil(t, traceInfo.Stats().Error()) } func TestTraceEventReadHeaderError(t *testing.T) { diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 8b65e6e7a..30b899dac 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -726,6 +726,7 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { // align with https://datatracker.ietf.org/doc/html/rfc2616#section-5.2 if len(ctx.Request.Host()) == 0 && ctx.Request.Header.IsHTTP11() && bytesconv.B2s(ctx.Request.Method()) != consts.MethodConnect { + ctx.SetHandlers(engine.Handlers) serveError(c, ctx, consts.StatusBadRequest, requiredHostBody) return } @@ -743,6 +744,7 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { // Follow RFC7230#section-5.3 if rPath == "" || rPath[0] != '/' { + ctx.SetHandlers(engine.Handlers) serveError(c, ctx, consts.StatusBadRequest, default400Body) return } diff --git a/version.go b/version.go index b23641ecd..fe052d1c5 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" - Version = "v0.8.1" + Version = "v0.9.0" )