Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pointer receiver directive #357

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions _generated/pointer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package _generated

import (
"fmt"
"time"

"github.com/tinylib/msgp/msgp"
)

//go:generate msgp $GOFILE$

// Generate only pointer receivers:

//msgp:pointer

var mustNoInterf = []interface{}{
Pointer0{},
NamedBoolPointer(true),
NamedIntPointer(0),
NamedFloat64Pointer(0),
NamedStringPointer(""),
NamedMapStructPointer(nil),
NamedMapStructPointer2(nil),
NamedMapStringPointer(nil),
NamedMapStringPointer2(nil),
EmbeddableStructPointer{},
EmbeddableStruct2Pointer{},
PointerHalfFull{},
PointerNoName{},
}

var mustHaveInterf = []interface{}{
&Pointer0{},
mustPtr(NamedBoolPointer(true)),
mustPtr(NamedIntPointer(0)),
mustPtr(NamedFloat64Pointer(0)),
mustPtr(NamedStringPointer("")),
mustPtr(NamedMapStructPointer(nil)),
mustPtr(NamedMapStructPointer2(nil)),
mustPtr(NamedMapStringPointer(nil)),
mustPtr(NamedMapStringPointer2(nil)),
&EmbeddableStructPointer{},
&EmbeddableStruct2Pointer{},
&PointerHalfFull{},
&PointerNoName{},
}

func mustPtr[T any](v T) *T {
return &v
}

func init() {
for _, v := range mustNoInterf {
if _, ok := v.(msgp.Marshaler); ok {
panic(fmt.Sprintf("type %T supports interface", v))
}
if _, ok := v.(msgp.Encodable); ok {
panic(fmt.Sprintf("type %T supports interface", v))
}
}
for _, v := range mustHaveInterf {
if _, ok := v.(msgp.Marshaler); !ok {
panic(fmt.Sprintf("type %T does not support interface", v))
}
if _, ok := v.(msgp.Encodable); !ok {
panic(fmt.Sprintf("type %T does not support interface", v))
}
}
}

type Pointer0 struct {
ABool bool `msg:"abool"`
AInt int `msg:"aint"`
AInt8 int8 `msg:"aint8"`
AInt16 int16 `msg:"aint16"`
AInt32 int32 `msg:"aint32"`
AInt64 int64 `msg:"aint64"`
AUint uint `msg:"auint"`
AUint8 uint8 `msg:"auint8"`
AUint16 uint16 `msg:"auint16"`
AUint32 uint32 `msg:"auint32"`
AUint64 uint64 `msg:"auint64"`
AFloat32 float32 `msg:"afloat32"`
AFloat64 float64 `msg:"afloat64"`
AComplex64 complex64 `msg:"acomplex64"`
AComplex128 complex128 `msg:"acomplex128"`

ANamedBool bool `msg:"anamedbool"`
ANamedInt int `msg:"anamedint"`
ANamedFloat64 float64 `msg:"anamedfloat64"`

AMapStrStr map[string]string `msg:"amapstrstr"`

APtrNamedStr *NamedString `msg:"aptrnamedstr"`

AString string `msg:"astring"`
ANamedString string `msg:"anamedstring"`
AByteSlice []byte `msg:"abyteslice"`

ASliceString []string `msg:"aslicestring"`
ASliceNamedString []NamedString `msg:"aslicenamedstring"`

ANamedStruct NamedStruct `msg:"anamedstruct"`
APtrNamedStruct *NamedStruct `msg:"aptrnamedstruct"`

AUnnamedStruct struct {
A string `msg:"a"`
} `msg:"aunnamedstruct"` // omitempty not supported on unnamed struct

EmbeddableStruct `msg:",flatten"` // embed flat

EmbeddableStruct2 `msg:"embeddablestruct2"` // embed non-flat

AArrayInt [5]int `msg:"aarrayint"` // not supported

ATime time.Time `msg:"atime"`
}

type (
NamedBoolPointer bool
NamedIntPointer int
NamedFloat64Pointer float64
NamedStringPointer string
NamedMapStructPointer map[string]Pointer0
NamedMapStructPointer2 map[string]*Pointer0
NamedMapStringPointer map[string]NamedStringPointer
NamedMapStringPointer2 map[string]*NamedStringPointer
)

type EmbeddableStructPointer struct {
SomeEmbed string `msg:"someembed"`
}

type EmbeddableStruct2Pointer struct {
SomeEmbed2 string `msg:"someembed2"`
}

type NamedStructPointer struct {
A string `msg:"a"`
B string `msg:"b"`
}

type PointerHalfFull struct {
Field00 string `msg:"field00"`
Field01 string `msg:"field01"`
Field02 string `msg:"field02"`
Field03 string `msg:"field03"`
}

type PointerNoName struct {
ABool bool `msg:""`
AInt int `msg:""`
AInt8 int8 `msg:""`
AInt16 int16 `msg:""`
AInt32 int32 `msg:""`
AInt64 int64 `msg:""`
AUint uint `msg:""`
AUint8 uint8 `msg:""`
AUint16 uint16 `msg:""`
AUint32 uint32 `msg:""`
AUint64 uint64 `msg:""`
AFloat32 float32 `msg:""`
AFloat64 float64 `msg:""`
AComplex64 complex64 `msg:""`
AComplex128 complex128 `msg:""`

ANamedBool bool `msg:""`
ANamedInt int `msg:""`
ANamedFloat64 float64 `msg:""`

AMapStrF map[string]NamedFloat64Pointer `msg:""`
AMapStrStruct map[string]PointerHalfFull `msg:""`
AMapStrStruct2 map[string]*PointerHalfFull `msg:""`

APtrNamedStr *NamedStringPointer `msg:""`

AString string `msg:""`
AByteSlice []byte `msg:""`

ASliceString []string `msg:""`
ASliceNamedString []NamedStringPointer `msg:""`

ANamedStruct NamedStructPointer `msg:""`
APtrNamedStruct *NamedStructPointer `msg:""`

AUnnamedStruct struct {
A string `msg:""`
} `msg:""` // omitempty not supported on unnamed struct

EmbeddableStructPointer `msg:",flatten"` // embed flat

EmbeddableStruct2Pointer `msg:""` // embed non-flat

AArrayInt [5]int `msg:""` // not supported

ATime time.Time `msg:""`
ADur time.Duration `msg:""`
}
14 changes: 13 additions & 1 deletion gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,22 @@ var builtins = map[string]struct{}{
}

// common data/methods for every Elem
type common struct{ vname, alias string }
type common struct {
vname, alias string
ptrRcv bool
}

func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }
func (c *common) Alias(typ string) { c.alias = typ }
func (c *common) hidden() {}
func (c *common) AllowNil() bool { return false }
func (c *common) AlwaysPtr(set *bool) bool {
if c != nil && set != nil {
c.ptrRcv = *set
}
return c.ptrRcv
}

func IsPrintable(e Elem) bool {
if be, ok := e.(*BaseElem); ok && !be.Printable() {
Expand Down Expand Up @@ -191,6 +200,9 @@ type Elem interface {
// This is true for slices and maps.
AllowNil() bool

// AlwaysPtr will return true if receiver should always be a pointer.
AlwaysPtr(set *bool) bool

// IfZeroExpr returns the expression to compare to an empty value
// for this type, per the rules of the `omitempty` feature.
// It is meant to be used in an if statement
Expand Down
11 changes: 9 additions & 2 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,16 @@ func (e *encodeGen) Execute(p Elem) error {
e.ctx = &Context{}

e.p.comment("EncodeMsg implements msgp.Encodable")

e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", p.Varname(), imutMethodReceiver(p))
rcv := imutMethodReceiver(p)
ogVar := p.Varname()
if p.AlwaysPtr(nil) {
rcv = methodReceiver(p)
}
e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", ogVar, rcv)
next(e, p)
if p.AlwaysPtr(nil) {
p.SetVarname(ogVar)
}
e.p.nakedReturn()
return e.p.err
}
Expand Down
13 changes: 10 additions & 3 deletions gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,18 @@ func (m *marshalGen) Execute(p Elem) error {
// calling methodReceiver so
// that z.Msgsize() is printed correctly
c := p.Varname()

m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", p.Varname(), imutMethodReceiver(p))
rcv := imutMethodReceiver(p)
ogVar := p.Varname()
if p.AlwaysPtr(nil) {
rcv = methodReceiver(p)
}
m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", ogVar, rcv)
m.p.printf("\no = msgp.Require(b, %s.Msgsize())", c)
next(m, p)
if p.AlwaysPtr(nil) {
p.SetVarname(ogVar)
}

m.p.nakedReturn()
return m.p.err
}
Expand Down Expand Up @@ -280,7 +288,6 @@ func (m *marshalGen) gBase(b *BaseElem) {
}
m.fuseHook()
vname := b.Varname()

if b.Convert {
if b.ShimMode == Cast {
vname = tobaseConvert(b)
Expand Down
10 changes: 9 additions & 1 deletion gen/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,17 @@ func (s *sizeGen) Execute(p Elem) error {

s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message")

s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", p.Varname(), imutMethodReceiver(p))
rcv := imutMethodReceiver(p)
ogVar := p.Varname()
if p.AlwaysPtr(nil) {
rcv = methodReceiver(p)
}
s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", ogVar, rcv)
s.state = assign
next(s, p)
if p.AlwaysPtr(nil) {
p.SetVarname(ogVar)
}
s.p.nakedReturn()
return s.p.err
}
Expand Down
10 changes: 9 additions & 1 deletion parse/directives.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ var directives = map[string]directive{
// to add an early directive, define a func([]string, *FileSet) error
// and then add it to this list.
var earlyDirectives = map[string]directive{
"tag": tag,
"tag": tag,
"pointer": pointer,
}

var passDirectives = map[string]passDirective{
Expand Down Expand Up @@ -120,6 +121,7 @@ func replace(text []string, f *FileSet) error {
return err
}
e := f.parseExpr(expr)
e.AlwaysPtr(&f.pointerRcv)

if be, ok := e.(*gen.BaseElem); ok {
be.Convert = true
Expand Down Expand Up @@ -178,3 +180,9 @@ func tag(text []string, f *FileSet) error {
f.tagName = strings.TrimSpace(text[1])
return nil
}

//msgp:pointer
func pointer(text []string, f *FileSet) error {
f.pointerRcv = true
return nil
}
2 changes: 2 additions & 0 deletions parse/getast.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type FileSet struct {
Directives []string // raw preprocessor directives
Imports []*ast.ImportSpec // imports
tagName string // tag to read field names from
pointerRcv bool // generate with pointer receivers.
}

// File parses a file at the relative path
Expand Down Expand Up @@ -199,6 +200,7 @@ parse:
popstate()
continue parse
}
el.AlwaysPtr(&f.pointerRcv)
// push unresolved identities into
// the graph of links and resolve after
// we've handled every possible named type.
Expand Down
5 changes: 5 additions & 0 deletions printer/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ func PrintFile(file string, f *parse.FileSet, mode gen.Method) error {
}
err = <-res
if err != nil {
os.WriteFile(file+".broken", out.Bytes(), os.ModePerm)
if Logf != nil {
Logf("Error: %s. Wrote broken output to %s\n", err, file+".broken")
}

return err
}
return nil
Expand Down
Loading