From 1da72008b2d1dcbfe09cf85e4e1f499d75fd6362 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Mon, 6 Feb 2023 20:59:01 +1100 Subject: [PATCH] feat: support decoding query parameters into pointers Also fix incorrect error message and bad reference to "matches". --- cmd/happy/main.go | 29 ++++++++++++++++++----------- testdata/main.go | 2 +- testdata/main_api.go | 11 ++++++----- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/cmd/happy/main.go b/cmd/happy/main.go index 73c596c..54f9c95 100644 --- a/cmd/happy/main.go +++ b/cmd/happy/main.go @@ -351,24 +351,31 @@ func genQueryDecoderFunc(gctx *genContext, paramType types.Type) (name string, e w = w.Push() for i := 0; i < strct.NumFields(); i++ { field := strct.Field(i) - fieldName := lcFirst(strings.ReplaceAll(field.Name(), ".", "")) + fieldName := lcFirst(field.Name()) w.L("if q, ok := p[%q]; ok {", fieldName) w = w.Push() - switch field.Type().String() { + fieldType := field.Type() + strctRef := "out" + if _, ptr := fieldType.(*types.Pointer); ptr { + fieldType = fieldType.(*types.Pointer).Elem() + w.L("out.%s = new(%s)", field.Name(), fieldType) + strctRef = "*" + strctRef + } + switch fieldType.String() { case "bool": gctx.Import("strconv") - w.L("if out.%s, err = strconv.ParseBool(q[len(q)-1]); err != nil {", field.Name()) - w.L(` return fmt.Errorf("failed to decode query parameter \"%s\" as %s: %%w", err)`, fieldName, field.Type()) + w.L("if %s.%s, err = strconv.ParseBool(q[len(q)-1]); err != nil {", strctRef, field.Name()) + w.L(` return fmt.Errorf("failed to decode query parameter \"%s\" as %s: %%w", err)`, fieldName, fieldType) w.L("}") case "int": gctx.Import("strconv") - w.L("if out.%s, err = strconv.Atoi(q[len(q)-1]); err != nil {", field.Name()) - w.L(` return fmt.Errorf("failed to decode query parameter \"%s\" as %s: %%w", err)`, fieldName, field.Type()) + w.L("if %s.%s, err = strconv.Atoi(q[len(q)-1]); err != nil {", strctRef, field.Name()) + w.L(` return fmt.Errorf("failed to decode query parameter \"%s\" as %s: %%w", err)`, fieldName, fieldType) w.L("}") case "string": - w.L("out.%s = q[len(q)-1]", field.Name()) + w.L("%s.%s = q[len(q)-1]", strctRef, field.Name()) default: - return "", fmt.Errorf("can't decode query parameter into field %s.%s of type %s", paramType, field.Name(), field.Type()) + return "", fmt.Errorf("can't decode query parameter into field %s.%s of type %s, only int, string and bool are supported", paramType, field.Name(), field.Type()) } w = w.Pop() w.L("}") @@ -465,9 +472,9 @@ func genEndpoint(gctx *genContext, w *codewriter.Writer, ep endpoint) error { case bt == "string": if bt != ref { - args = append(args, fmt.Sprintf("%s(matches[%d])", ref, index)) + args = append(args, fmt.Sprintf("%s(params[%d])", ref, index)) } else { - args = append(args, fmt.Sprintf("matches[%d]", index)) + args = append(args, fmt.Sprintf("params[%d]", index)) } case bt == "int": @@ -505,7 +512,7 @@ func genEndpoint(gctx *genContext, w *codewriter.Writer, ep endpoint) error { return fmt.Errorf("%s: %w", pos, err) } w.L("if err := %s(r.URL.Query(), ¶m%d); err != nil {", decoderFn, i) - w.L(` http.Error(w, fmt.Sprintf("Failed to decode query parameters: %s", err), http.StatusBadRequest)`) + w.L(` http.Error(w, fmt.Sprintf("Failed to decode query parameters: %%s", err), http.StatusBadRequest)`) w.L(" return") w.L("}") } else { diff --git a/testdata/main.go b/testdata/main.go index bb14e6a..33103ef 100644 --- a/testdata/main.go +++ b/testdata/main.go @@ -92,7 +92,7 @@ func (s *Service) CreateUser(r *http.Request, user User) error { type Paginate struct { Page int Size int - Sparse bool + Sparse *bool } //happy:api GET /users diff --git a/testdata/main_api.go b/testdata/main_api.go index 7dea701..1cac3aa 100644 --- a/testdata/main_api.go +++ b/testdata/main_api.go @@ -61,7 +61,7 @@ func (h *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "GET": var param0 Paginate if err := decodePaginate(r.URL.Query(), ¶m0); err != nil { - http.Error(w, fmt.Sprintf("Failed to decode query parameters: %!s(MISSING)", err), http.StatusBadRequest) + http.Error(w, fmt.Sprintf("Failed to decode query parameters: %s", err), http.StatusBadRequest) return } res, err = h.ListUsers(param0) @@ -169,17 +169,18 @@ matched: func decodePaginate(p url.Values, out *Paginate) (err error) { if q, ok := p["page"]; ok { if out.Page, err = strconv.Atoi(q[len(q)-1]); err != nil { - return fmt.Errorf("failed to decode query parameter \"page\" into type int: %w", err) + return fmt.Errorf("failed to decode query parameter \"page\" as int: %w", err) } } if q, ok := p["size"]; ok { if out.Size, err = strconv.Atoi(q[len(q)-1]); err != nil { - return fmt.Errorf("failed to decode query parameter \"size\" into type int: %w", err) + return fmt.Errorf("failed to decode query parameter \"size\" as int: %w", err) } } if q, ok := p["sparse"]; ok { - if out.Sparse, err = strconv.ParseBool(q[len(q)-1]); err != nil { - return fmt.Errorf("failed to decode query parameter \"sparse\" into type bool: %w", err) + out.Sparse = new(bool) + if *out.Sparse, err = strconv.ParseBool(q[len(q)-1]); err != nil { + return fmt.Errorf("failed to decode query parameter \"sparse\" as bool: %w", err) } } return nil