From ac5b5fd4b203318fd5b0fbef13401be9c84e991f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mart=C3=AD?= Date: Wed, 19 Jan 2022 17:16:18 +0000 Subject: [PATCH] node/bindnode: improve support for pointer types In particular, write an extensive test that exercises all the edge cases we care about: 1) Each schema type, and its corresponding Go type (including multiple Go types where appropriate, e.g. links) 2) Each modifier: none, "optional", and "nullable" 3) Each data to decode: present, missing (absent), null The fix is to dereference or initialize pointers as needed. Some more edge cases remain, such as other schema types and non-default representation strategies; a TODO tracks those. While here, we improve the errors returned by AssignLink, so that it now reports a bindnode-specific error when the schema type is a Link but we can't assign to the Go value. --- node/bindnode/infer.go | 5 ++ node/bindnode/infer_test.go | 98 +++++++++++++++++++++++++- node/bindnode/node.go | 137 ++++++++++++++++++++---------------- node/bindnode/repr.go | 7 +- 4 files changed, 184 insertions(+), 63 deletions(-) diff --git a/node/bindnode/infer.go b/node/bindnode/infer.go index eac92a46..57529a55 100644 --- a/node/bindnode/infer.go +++ b/node/bindnode/infer.go @@ -39,6 +39,11 @@ type seenEntry struct { } func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaType schema.Type) { + // TODO(mvdan): support **T as well? + if goType.Kind() == reflect.Ptr { + goType = goType.Elem() + } + // Avoid endless loops. // // TODO(mvdan): this is easy but fairly allocation-happy. diff --git a/node/bindnode/infer_test.go b/node/bindnode/infer_test.go index 2641514f..f23b5b62 100644 --- a/node/bindnode/infer_test.go +++ b/node/bindnode/infer_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "html/template" "io/ioutil" "os/exec" "path/filepath" @@ -79,8 +80,8 @@ var prototypeTests = []struct { linkImpl Link }`, ptrType: (*struct { - LinkGeneric datamodel.Link LinkCID cid.Cid + LinkGeneric datamodel.Link LinkImpl cidlink.Link })(nil), prettyDagJSON: `{ @@ -291,6 +292,101 @@ func TestPrototype(t *testing.T) { } } +func TestPrototypePointerCombinations(t *testing.T) { + t.Parallel() + + // TODO: Null + // TODO: cover more schema types and repr strategies. + // Some of them are still using w.val directly without "nonPtr" calls. + kindTests := []struct { + name string + schemaType string + fieldPtrType interface{} + fieldDagJSON string + }{ + {"Bool", "Bool", (*bool)(nil), `true`}, + {"Int", "Int", (*int64)(nil), `23`}, + {"Float", "Float", (*float64)(nil), `34.5`}, + {"String", "String", (*string)(nil), `"foo"`}, + {"Bytes", "Bytes", (*[]byte)(nil), `{"/": {"bytes": "34cd"}}`}, + {"Link_CID", "Link", (*cid.Cid)(nil), `{"/": "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"}`}, + {"Link_Impl", "Link", (*cidlink.Link)(nil), `{"/": "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"}`}, + {"Link_Generic", "Link", (*datamodel.Link)(nil), `{"/": "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"}`}, + {"List_String", "[String]", (*[]string)(nil), `["foo", "bar"]`}, + {"Map_String_Int", "{String:Int}", (*struct { + Keys []string + Values map[string]int64 + })(nil), `{"x":3,"y":4}`}, + } + + for _, kindTest := range kindTests { + for _, modifier := range []string{"", "optional", "nullable"} { + // don't reuse range vars + kindTest := kindTest + modifier := modifier + t.Run(fmt.Sprintf("%s/%s", kindTest.name, modifier), func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + err := template.Must(template.New("").Parse(` + type Root struct { + field {{.Modifier}} {{.Type}} + }`)).Execute(&buf, struct { + Type, Modifier string + }{kindTest.schemaType, modifier}) + qt.Assert(t, err, qt.IsNil) + schemaSrc := buf.String() + + // *struct { Field {{.fieldPtrType}} } + ptrType := reflect.Zero(reflect.PtrTo(reflect.StructOf([]reflect.StructField{ + {Name: "Field", Type: reflect.TypeOf(kindTest.fieldPtrType)}, + }))).Interface() + + ts, err := ipld.LoadSchemaBytes([]byte(schemaSrc)) + qt.Assert(t, err, qt.IsNil) + schemaType := ts.TypeByName("Root") + qt.Assert(t, schemaType, qt.Not(qt.IsNil)) + + proto := bindnode.Prototype(ptrType, schemaType) + wantEncodedBytes, err := json.Marshal(map[string]interface{}{"field": json.RawMessage(kindTest.fieldDagJSON)}) + qt.Assert(t, err, qt.IsNil) + wantEncoded := string(wantEncodedBytes) + + node := dagjsonDecode(t, proto.Representation(), wantEncoded).(schema.TypedNode) + + encoded := dagjsonEncode(t, node.Representation()) + qt.Assert(t, encoded, qt.Equals, wantEncoded) + + // Assigning with the missing field should only work with optional. + nb := proto.NewBuilder() + err = dagjson.Decode(nb, strings.NewReader(`{}`)) + if modifier == "optional" { + qt.Assert(t, err, qt.IsNil) + node := nb.Build() + // The resulting node should be non-nil with a nil field. + nodeVal := reflect.ValueOf(bindnode.Unwrap(node)) + qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue) + } else { + qt.Assert(t, err, qt.Not(qt.IsNil)) + } + + // Assigning with a null field should only work with nullable. + nb = proto.NewBuilder() + err = dagjson.Decode(nb, strings.NewReader(`{"field":null}`)) + if modifier == "nullable" { + qt.Assert(t, err, qt.IsNil) + node := nb.Build() + // The resulting node should be non-nil with a nil field. + nodeVal := reflect.ValueOf(bindnode.Unwrap(node)) + qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue) + } else { + qt.Assert(t, err, qt.Not(qt.IsNil)) + } + }) + } + } +} + type verifyBadType struct { ptrType interface{} panicRegexp string diff --git a/node/bindnode/node.go b/node/bindnode/node.go index 6b9ae7d7..8ac6f32a 100644 --- a/node/bindnode/node.go +++ b/node/bindnode/node.go @@ -118,6 +118,18 @@ func actualKind(schemaType schema.Type) datamodel.Kind { return schemaType.TypeKind().ActsLike() } +func nonPtrVal(val reflect.Value) reflect.Value { + // TODO: support **T as well as *T? + if val.Kind() == reflect.Ptr { + if val.IsNil() { + // TODO: error in this case? + return reflect.Value{} + } + val = val.Elem() + } + return val +} + func (w *_node) LookupByString(key string) (datamodel.Node, error) { switch typ := w.schemaType.(type) { case *schema.TypeStruct: @@ -128,7 +140,7 @@ func (w *_node) LookupByString(key string) (datamodel.Node, error) { Key: basicnode.NewString(key), } } - fval := w.val.FieldByName(fieldNameFromSchema(key)) + fval := nonPtrVal(w.val).FieldByName(fieldNameFromSchema(key)) if !fval.IsValid() { panic("TODO: go-schema mismatch") } @@ -154,7 +166,7 @@ func (w *_node) LookupByString(key string) (datamodel.Node, error) { return node, nil case *schema.TypeMap: var kval reflect.Value - valuesVal := w.val.FieldByName("Values") + valuesVal := nonPtrVal(w.val).FieldByName("Values") switch ktyp := typ.KeyType().(type) { case *schema.TypeString: kval = reflect.ValueOf(key) @@ -204,7 +216,7 @@ func (w *_node) LookupByString(key string) (datamodel.Node, error) { return nil, datamodel.ErrNotExists{Segment: datamodel.PathSegmentOfString(key)} } // TODO: we could look up the right Go field straight away via idx. - haveIdx, mval := unionMember(w.val) + haveIdx, mval := unionMember(nonPtrVal(w.val)) if haveIdx != idx { // mismatching type return nil, datamodel.ErrNotExists{Segment: datamodel.PathSegmentOfString(key)} } @@ -250,10 +262,11 @@ func unionSetMember(val reflect.Value, memberIdx int, memberPtr reflect.Value) { func (w *_node) LookupByIndex(idx int64) (datamodel.Node, error) { switch typ := w.schemaType.(type) { case *schema.TypeList: - if idx < 0 || int(idx) >= w.val.Len() { + val := nonPtrVal(w.val) + if idx < 0 || int(idx) >= val.Len() { return nil, datamodel.ErrNotExists{Segment: datamodel.PathSegmentOfInt(idx)} } - val := w.val.Index(int(idx)) + val = val.Index(int(idx)) if typ.ValueIsNullable() { if val.IsNil() { return datamodel.Null, nil @@ -316,36 +329,32 @@ func (w *_node) LookupByNode(key datamodel.Node) (datamodel.Node, error) { } func (w *_node) MapIterator() datamodel.MapIterator { + val := nonPtrVal(w.val) switch typ := w.schemaType.(type) { case *schema.TypeStruct: return &_structIterator{ schemaType: typ, fields: typ.Fields(), - val: w.val, + val: val, } case *schema.TypeUnion: return &_unionIterator{ schemaType: typ, members: typ.Members(), - val: w.val, + val: val, } case *schema.TypeMap: return &_mapIterator{ schemaType: typ, - keysVal: w.val.FieldByName("Keys"), - valuesVal: w.val.FieldByName("Values"), + keysVal: val.FieldByName("Keys"), + valuesVal: val.FieldByName("Values"), } } return nil } func (w *_node) ListIterator() datamodel.ListIterator { - val := w.val - if val.Type().Kind() == reflect.Ptr { - if !val.IsNil() { - val = val.Elem() - } - } + val := nonPtrVal(w.val) switch typ := w.schemaType.(type) { case *schema.TypeList: return &_listIterator{schemaType: typ, val: val} @@ -354,6 +363,7 @@ func (w *_node) ListIterator() datamodel.ListIterator { } func (w *_node) Length() int64 { + val := nonPtrVal(w.val) switch w.Kind() { case datamodel.Kind_Map: switch typ := w.schemaType.(type) { @@ -362,9 +372,9 @@ func (w *_node) Length() int64 { case *schema.TypeUnion: return 1 } - return int64(w.val.FieldByName("Keys").Len()) + return int64(val.FieldByName("Keys").Len()) case datamodel.Kind_List: - return int64(w.val.Len()) + return int64(val.Len()) } return -1 } @@ -383,42 +393,42 @@ func (w *_node) AsBool() (bool, error) { if err := compatibleKind(w.schemaType, datamodel.Kind_Bool); err != nil { return false, err } - return w.val.Bool(), nil + return nonPtrVal(w.val).Bool(), nil } func (w *_node) AsInt() (int64, error) { if err := compatibleKind(w.schemaType, datamodel.Kind_Int); err != nil { return 0, err } - return w.val.Int(), nil + return nonPtrVal(w.val).Int(), nil } func (w *_node) AsFloat() (float64, error) { if err := compatibleKind(w.schemaType, datamodel.Kind_Float); err != nil { return 0, err } - return w.val.Float(), nil + return nonPtrVal(w.val).Float(), nil } func (w *_node) AsString() (string, error) { if err := compatibleKind(w.schemaType, datamodel.Kind_String); err != nil { return "", err } - return w.val.String(), nil + return nonPtrVal(w.val).String(), nil } func (w *_node) AsBytes() ([]byte, error) { if err := compatibleKind(w.schemaType, datamodel.Kind_Bytes); err != nil { return nil, err } - return w.val.Bytes(), nil + return nonPtrVal(w.val).Bytes(), nil } func (w *_node) AsLink() (datamodel.Link, error) { if err := compatibleKind(w.schemaType, datamodel.Kind_Link); err != nil { return nil, err } - switch val := w.val.Interface().(type) { + switch val := nonPtrVal(w.val).Interface().(type) { case datamodel.Link: return val, nil case cid.Cid: @@ -454,9 +464,14 @@ type _assembler struct { nullable bool // true if field or map value is nullable } -func (w *_assembler) nonPtrVal() reflect.Value { +func (w *_assembler) createNonPtrVal() reflect.Value { val := w.val - if w.nullable { + // TODO: support **T as well as *T? + if val.Kind() == reflect.Ptr { + // TODO: Sometimes we call createNonPtrVal before an assignment actually + // happens. Does that matter? + // If it matters and we only want to modify the destination value on + // success, then we should make use of the "finish" func. val.Set(reflect.New(val.Type().Elem())) val = val.Elem() } @@ -479,7 +494,7 @@ func (w *basicMapAssembler) Finish() error { return err } basicNode := w.builder.Build() - w.parent.nonPtrVal().Set(reflect.ValueOf(basicNode)) + w.parent.createNonPtrVal().Set(reflect.ValueOf(basicNode)) if w.parent.finish != nil { if err := w.parent.finish(); err != nil { return err @@ -498,7 +513,7 @@ func (w *_assembler) BeginMap(sizeHint int64) (datamodel.MapAssembler, error) { } return &basicMapAssembler{MapAssembler: mapAsm, builder: basicBuilder, parent: w}, nil case *schema.TypeStruct: - val := w.nonPtrVal() + val := w.createNonPtrVal() doneFields := make([]bool, val.NumField()) return &_structAssembler{ schemaType: typ, @@ -507,7 +522,7 @@ func (w *_assembler) BeginMap(sizeHint int64) (datamodel.MapAssembler, error) { finish: w.finish, }, nil case *schema.TypeMap: - val := w.nonPtrVal() + val := w.createNonPtrVal() keysVal := val.FieldByName("Keys") valuesVal := val.FieldByName("Values") if valuesVal.IsNil() { @@ -520,7 +535,7 @@ func (w *_assembler) BeginMap(sizeHint int64) (datamodel.MapAssembler, error) { finish: w.finish, }, nil case *schema.TypeUnion: - val := w.nonPtrVal() + val := w.createNonPtrVal() return &_unionAssembler{ schemaType: typ, val: val, @@ -547,7 +562,7 @@ func (w *basicListAssembler) Finish() error { return err } basicNode := w.builder.Build() - w.parent.nonPtrVal().Set(reflect.ValueOf(basicNode)) + w.parent.createNonPtrVal().Set(reflect.ValueOf(basicNode)) if w.parent.finish != nil { if err := w.parent.finish(); err != nil { return err @@ -566,7 +581,7 @@ func (w *_assembler) BeginList(sizeHint int64) (datamodel.ListAssembler, error) } return &basicListAssembler{ListAssembler: listAsm, builder: basicBuilder, parent: w}, nil case *schema.TypeList: - val := w.nonPtrVal() + val := w.createNonPtrVal() return &_listAssembler{ schemaType: typ, val: val, @@ -603,9 +618,9 @@ func (w *_assembler) AssignBool(b bool) error { return err } if _, ok := w.schemaType.(*schema.TypeAny); ok { - w.nonPtrVal().Set(reflect.ValueOf(basicnode.NewBool(b))) + w.createNonPtrVal().Set(reflect.ValueOf(basicnode.NewBool(b))) } else { - w.nonPtrVal().SetBool(b) + w.createNonPtrVal().SetBool(b) } if w.finish != nil { if err := w.finish(); err != nil { @@ -620,9 +635,9 @@ func (w *_assembler) AssignInt(i int64) error { return err } if _, ok := w.schemaType.(*schema.TypeAny); ok { - w.nonPtrVal().Set(reflect.ValueOf(basicnode.NewInt(i))) + w.createNonPtrVal().Set(reflect.ValueOf(basicnode.NewInt(i))) } else { - w.nonPtrVal().SetInt(i) + w.createNonPtrVal().SetInt(i) } if w.finish != nil { if err := w.finish(); err != nil { @@ -637,9 +652,9 @@ func (w *_assembler) AssignFloat(f float64) error { return err } if _, ok := w.schemaType.(*schema.TypeAny); ok { - w.nonPtrVal().Set(reflect.ValueOf(basicnode.NewFloat(f))) + w.createNonPtrVal().Set(reflect.ValueOf(basicnode.NewFloat(f))) } else { - w.nonPtrVal().SetFloat(f) + w.createNonPtrVal().SetFloat(f) } if w.finish != nil { if err := w.finish(); err != nil { @@ -654,9 +669,9 @@ func (w *_assembler) AssignString(s string) error { return err } if _, ok := w.schemaType.(*schema.TypeAny); ok { - w.nonPtrVal().Set(reflect.ValueOf(basicnode.NewString(s))) + w.createNonPtrVal().Set(reflect.ValueOf(basicnode.NewString(s))) } else { - w.nonPtrVal().SetString(s) + w.createNonPtrVal().SetString(s) } if w.finish != nil { if err := w.finish(); err != nil { @@ -671,9 +686,9 @@ func (w *_assembler) AssignBytes(p []byte) error { return err } if _, ok := w.schemaType.(*schema.TypeAny); ok { - w.nonPtrVal().Set(reflect.ValueOf(basicnode.NewBytes(p))) + w.createNonPtrVal().Set(reflect.ValueOf(basicnode.NewBytes(p))) } else { - w.nonPtrVal().SetBytes(p) + w.createNonPtrVal().SetBytes(p) } if w.finish != nil { if err := w.finish(); err != nil { @@ -684,26 +699,30 @@ func (w *_assembler) AssignBytes(p []byte) error { } func (w *_assembler) AssignLink(link datamodel.Link) error { + val := w.createNonPtrVal() if _, ok := w.schemaType.(*schema.TypeAny); ok { - w.nonPtrVal().Set(reflect.ValueOf(basicnode.NewLink(link))) - } else { - newVal := reflect.ValueOf(link) - if !newVal.Type().AssignableTo(w.val.Type()) { - if newVal.Type() == goTypeCidLink && goTypeCid.AssignableTo(w.val.Type()) { - // Unbox a cidlink.Link to assign to a go-cid.Cid value. - newVal = newVal.FieldByName("Cid") - } else { - // The target value cannot be assigned a datamodel.Link or go-cid.Cid. - // TODO: revisit Eric's concerns in https://github.com/ipld/go-ipld-prime/pull/324#discussion_r780785847 - return datamodel.ErrWrongKind{ - TypeName: w.schemaType.Name(), - MethodName: "AssignLink", - AppropriateKind: datamodel.KindSet_JustLink, - ActualKind: actualKind(w.schemaType), - } - } + val.Set(reflect.ValueOf(basicnode.NewLink(link))) + } else if newVal := reflect.ValueOf(link); newVal.Type().AssignableTo(val.Type()) { + // Directly assignable. + val.Set(newVal) + } else if newVal.Type() == goTypeCidLink && goTypeCid.AssignableTo(val.Type()) { + // Unbox a cidlink.Link to assign to a go-cid.Cid value. + newVal = newVal.FieldByName("Cid") + val.Set(newVal) + } else if actual := actualKind(w.schemaType); actual != datamodel.Kind_Link { + // We're assigning a Link to a schema type that isn't a Link. + return datamodel.ErrWrongKind{ + TypeName: w.schemaType.Name(), + MethodName: "AssignLink", + AppropriateKind: datamodel.KindSet_JustLink, + ActualKind: actualKind(w.schemaType), } - w.nonPtrVal().Set(newVal) + } else { + // The schema type is a Link, but we somehow can't assign to the Go value. + // Almost certainly a bug; we should have verified for compatibility upfront. + // fmt.Println(newVal.Type().ConvertibleTo(val.Type())) + return fmt.Errorf("bindnode bug: AssignLink with %s argument can't be used on Go type %s", + newVal.Type(), val.Type()) } if w.finish != nil { if err := w.finish(); err != nil { diff --git a/node/bindnode/repr.go b/node/bindnode/repr.go index dcc3da54..2c46d92d 100644 --- a/node/bindnode/repr.go +++ b/node/bindnode/repr.go @@ -243,10 +243,11 @@ func (w *_nodeRepr) MapIterator() datamodel.MapIterator { case nil: switch st := (w.schemaType).(type) { case *schema.TypeMap: + val := nonPtrVal(w.val) return &_mapIteratorRepr{ schemaType: st, - keysVal: w.val.FieldByName("Keys"), - valuesVal: w.val.FieldByName("Values"), + keysVal: val.FieldByName("Keys"), + valuesVal: val.FieldByName("Values"), } default: panic(fmt.Sprintf("TODO: mapitr.repr for typekind %s", w.schemaType.TypeKind())) @@ -637,7 +638,7 @@ func (w *_assemblerRepr) AssignInt(i int64) error { if int64(reprInt) != i { continue } - val := (*_assembler)(w).nonPtrVal() + val := (*_assembler)(w).createNonPtrVal() switch val.Kind() { case reflect.String: return (*_assembler)(w).AssignString(member)