Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GetBSON() method usage #40

Merged
merged 2 commits into from
Oct 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions bson/bson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"reflect"
"testing"
"time"
"strings"

"github.com/globalsign/mgo/bson"
. "gopkg.in/check.v1"
Expand Down Expand Up @@ -381,8 +382,54 @@ func (s *S) Test64bitInt(c *C) {
// --------------------------------------------------------------------------
// Generic two-way struct marshaling tests.

type prefixPtr string
type prefixVal string

func (t *prefixPtr) GetBSON() (interface{}, error) {
if t == nil {
return nil, nil
}
return "foo-" + string(*t), nil
}

func (t *prefixPtr) SetBSON(raw bson.Raw) error {
var s string
if raw.Kind == 0x0A {
return bson.ErrSetZero
}
if err := raw.Unmarshal(&s); err != nil {
return err
}
if !strings.HasPrefix(s, "foo-") {
return errors.New("Prefix not found: " + s)
}
*t = prefixPtr(s[4:])
return nil
}

func (t prefixVal) GetBSON() (interface{}, error) {
return "foo-" + string(t), nil
}

func (t *prefixVal) SetBSON(raw bson.Raw) error {
var s string
if raw.Kind == 0x0A {
return bson.ErrSetZero
}
if err := raw.Unmarshal(&s); err != nil {
return err
}
if !strings.HasPrefix(s, "foo-") {
return errors.New("Prefix not found: " + s)
}
*t = prefixVal(s[4:])
return nil
}

var bytevar = byte(8)
var byteptr = &bytevar
var prefixptr = prefixPtr("bar")
var prefixval = prefixVal("bar")

var structItems = []testItemType{
{&struct{ Ptr *byte }{nil},
Expand Down Expand Up @@ -419,6 +466,24 @@ var structItems = []testItemType{
// Byte arrays.
{&struct{ V [2]byte }{[2]byte{'y', 'o'}},
"\x05v\x00\x02\x00\x00\x00\x00yo"},

{&struct{ V prefixPtr }{prefixPtr("buzz")},
"\x02v\x00\x09\x00\x00\x00foo-buzz\x00"},

{&struct{ V *prefixPtr }{&prefixptr},
"\x02v\x00\x08\x00\x00\x00foo-bar\x00"},

{&struct{ V *prefixPtr }{nil},
"\x0Av\x00"},

{&struct{ V prefixVal }{prefixVal("buzz")},
"\x02v\x00\x09\x00\x00\x00foo-buzz\x00"},

{&struct{ V *prefixVal }{&prefixval},
"\x02v\x00\x08\x00\x00\x00foo-bar\x00"},

{&struct{ V *prefixVal }{nil},
"\x0Av\x00"},
}

func (s *S) TestMarshalStructItems(c *C) {
Expand Down
24 changes: 13 additions & 11 deletions bson/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,20 @@ func setterStyle(outt reflect.Type) int {
setterMutex.RLock()
style := setterStyles[outt]
setterMutex.RUnlock()
if style == setterUnknown {
setterMutex.Lock()
defer setterMutex.Unlock()
if outt.Implements(setterIface) {
setterStyles[outt] = setterType
} else if reflect.PtrTo(outt).Implements(setterIface) {
setterStyles[outt] = setterAddr
} else {
setterStyles[outt] = setterNone
}
style = setterStyles[outt]
if style != setterUnknown {
return style
}

setterMutex.Lock()
defer setterMutex.Unlock()
if outt.Implements(setterIface) {
style = setterType
} else if reflect.PtrTo(outt).Implements(setterIface) {
style = setterAddr
} else {
style = setterNone
}
setterStyles[outt] = style
return style
}

Expand Down
64 changes: 63 additions & 1 deletion bson/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"reflect"
"sort"
"strconv"
"sync"
"time"
)

Expand All @@ -60,13 +61,28 @@ var (

const itoaCacheSize = 32

const (
getterUnknown = iota
getterNone
getterTypeVal
getterTypePtr
getterAddr
)

var itoaCache []string

var getterStyles map[reflect.Type]int
var getterIface reflect.Type
var getterMutex sync.RWMutex

func init() {
itoaCache = make([]string, itoaCacheSize)
for i := 0; i != itoaCacheSize; i++ {
itoaCache[i] = strconv.Itoa(i)
}
var iface Getter
getterIface = reflect.TypeOf(&iface).Elem()
getterStyles = make(map[reflect.Type]int)
}

func itoa(i int) string {
Expand All @@ -76,6 +92,52 @@ func itoa(i int) string {
return strconv.Itoa(i)
}

func getterStyle(outt reflect.Type) int {
getterMutex.RLock()
style := getterStyles[outt]
getterMutex.RUnlock()
if style != getterUnknown {
return style
}

getterMutex.Lock()
defer getterMutex.Unlock()
if outt.Implements(getterIface) {
vt := outt
for vt.Kind() == reflect.Ptr {
vt = vt.Elem()
}
if vt.Implements(getterIface) {
style = getterTypeVal
} else {
style = getterTypePtr
}
} else if reflect.PtrTo(outt).Implements(getterIface) {
style = getterAddr
} else {
style = getterNone
}
getterStyles[outt] = style
return style
}

func getGetter(outt reflect.Type, out reflect.Value) Getter {
style := getterStyle(outt)
if style == getterNone {
return nil
}
if style == getterAddr {
if !out.CanAddr() {
return nil
}
return out.Addr().Interface().(Getter)
}
if style == getterTypeVal && out.Kind() == reflect.Ptr && out.IsNil() {
return nil
}
return out.Interface().(Getter)
}

// --------------------------------------------------------------------------
// Marshaling of the document value itself.

Expand Down Expand Up @@ -253,7 +315,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) {
return
}

if getter, ok := v.Interface().(Getter); ok {
if getter := getGetter(v.Type(), v); getter != nil {
getv, err := getter.GetBSON()
if err != nil {
panic(err)
Expand Down