Skip to content

Commit

Permalink
Fix AttributeValue marshaling and handling in expressions
Browse files Browse the repository at this point in the history
Updates the `attributevalue` and `expression` package's handling of
AttributeValue marshaling fixing several bugs in the packages.

* Fixes aws#1569 `Inconsistent struct field name marshaled`. Fields will now
  be consistent with the EncoderOptions or DecoderOptions the Go struct
  was used with. Previously the Go struct fields would be cached with
  the first options used for the type. Causes subsequent usages to have
  the wrong field names if the encoding options used different TagKeys.

* Fixes aws#645, aws#411 `Support more than string types for map keys`. Updates
  (un)marshaler to support number, bool, and types that implement
  encoding.Text(Un)Marshaler interfaces.

* Fixes Support for expression Names with literal dots in name. Adds new
  function NameNoDotSplit to expression package. This function allows
  you to provide a literal expression Name containing dots. Also adds a
  new method to NameBuilder, AppendName, for joining multiple name path
  components together. Helpful for joining names with literal dots with
  subsequent object path fields.

* Fixes bug with AttributeValue marshaler struct struct tag usage that
  caused TagKey to be ignored if the member had a struct tag with
  `dynamodbav` struct tag. Now both tags will be read as documented,
  with the TagKey struct tag options taking precedence.
  • Loading branch information
jasdel committed Feb 17, 2022
1 parent c363cb8 commit eff88d0
Show file tree
Hide file tree
Showing 15 changed files with 944 additions and 129 deletions.
149 changes: 129 additions & 20 deletions feature/dynamodb/attributevalue/decode.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package attributevalue

import (
"encoding"
"fmt"
"reflect"
"strconv"
Expand Down Expand Up @@ -197,7 +198,7 @@ func UnmarshalListOfMapsWithOptions(l []map[string]types.AttributeValue, out int
}

// DecoderOptions is a collection of options to configure how the decoder
// unmarshalls the value.
// unmarshals the value.
type DecoderOptions struct {
// Support other custom struct tag keys, such as `yaml`, `json`, or `toml`.
// Note that values provided with a custom TagKey must also be supported
Expand All @@ -221,7 +222,7 @@ type Decoder struct {
// NewDecoder creates a new Decoder with default configuration. Use
// the `opts` functional options to override the default configuration.
func NewDecoder(optFns ...func(*DecoderOptions)) *Decoder {
var options DecoderOptions
options := DecoderOptions{TagKey: defaultTagKey}
for _, fn := range optFns {
fn(&options)
}
Expand Down Expand Up @@ -254,14 +255,14 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag)
var u Unmarshaler
_, isNull := av.(*types.AttributeValueMemberNULL)
if av == nil || isNull {
u, v = indirect(v, true)
u, v = indirect(v, indirectOptions{decodeNull: true})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}
return d.decodeNull(v)
}

u, v = indirect(v, false)
u, v = indirect(v, indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}
Expand Down Expand Up @@ -386,7 +387,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), false)
u, elem := indirect(v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberBS{Value: bs})
}
Expand Down Expand Up @@ -513,7 +514,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), false)
u, elem := indirect(v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberNS{Value: ns})
}
Expand Down Expand Up @@ -564,32 +565,48 @@ func (d *Decoder) decodeList(avList []types.AttributeValue, v reflect.Value) err
return nil
}

func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) error {
func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) (err error) {
var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error

switch v.Kind() {
case reflect.Map:
t := v.Type()
if t.Key().Kind() != reflect.String {
return &UnmarshalTypeError{Value: "map string key", Type: t.Key()}
decodeMapKey, err = d.getMapKeyDecoder(v.Type().Key())
if err != nil {
return err
}

if v.IsNil() {
v.Set(reflect.MakeMap(t))
v.Set(reflect.MakeMap(v.Type()))
}
case reflect.Struct:
case reflect.Interface:
v.Set(reflect.MakeMap(stringInterfaceMapType))
decodeMapKey = d.decodeString
v = v.Elem()
default:
return &UnmarshalTypeError{Value: "map", Type: v.Type()}
}

if v.Kind() == reflect.Map {
keyType := v.Type().Key()
valueType := v.Type().Elem()
for k, av := range avMap {
key := reflect.New(v.Type().Key()).Elem()
key.SetString(k)
elem := reflect.New(v.Type().Elem()).Elem()
key := reflect.New(keyType).Elem()
// handle pointer keys
_, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true})
if err := decodeMapKey(k, indirectKey, tag{}); err != nil {
return &UnmarshalTypeError{
Value: fmt.Sprintf("map key %q", k),
Type: keyType,
Err: err,
}
}

elem := reflect.New(valueType).Elem()
if err := d.decode(av, elem, tag{}); err != nil {
return err
}

v.SetMapIndex(key, elem)
}
} else if v.Kind() == reflect.Struct {
Expand All @@ -609,6 +626,50 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val
return nil
}

var numberType = reflect.TypeOf(Number(""))
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()

func (d *Decoder) getMapKeyDecoder(keyType reflect.Type) (func(string, reflect.Value, tag) error, error) {
// Test the key type to determine if it implements the TextUnmarshaler interface.
if reflect.PtrTo(keyType).Implements(textUnmarshalerType) || keyType.Implements(textUnmarshalerType) {
return func(v string, k reflect.Value, _ tag) error {
if !k.CanAddr() {
return fmt.Errorf("cannot take address of map key, %v", k.Type())
}
return k.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(v))
}, nil
}

var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error

switch keyType.Kind() {
case reflect.Bool:
decodeMapKey = func(v string, key reflect.Value, fieldTag tag) error {
b, err := strconv.ParseBool(v)
if err != nil {
return err
}
return d.decodeBool(b, key)
}
case reflect.String:
// Number type handled as a string
decodeMapKey = d.decodeString

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
decodeMapKey = d.decodeNumber

default:
return nil, &UnmarshalTypeError{
Value: "map key must be string, number, bool, or TextUnmarshaler",
Type: keyType,
}
}

return decodeMapKey, nil
}

func (d *Decoder) decodeNull(v reflect.Value) error {
if v.IsValid() && v.CanSet() {
v.Set(reflect.Zero(v.Type()))
Expand Down Expand Up @@ -675,7 +736,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), false)
u, elem := indirect(v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberSS{Value: ss})
}
Expand Down Expand Up @@ -713,38 +774,82 @@ func decoderFieldByIndex(v reflect.Value, index []int) reflect.Value {
return v
}

type indirectOptions struct {
decodeNull bool
skipUnmarshaler bool
}

// indirect will walk a value's interface or pointer value types. Returning
// the final value or the value a unmarshaler is defined on.
//
// Based on the enoding/json type reflect value type indirection in Go Stdlib
// https://golang.org/src/encoding/json/decode.go indirect func.
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) {
func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
// unexported embedded struct fields.
//
// The logic below effectively does this when it first addresses the value
// (to satisfy possible pointer methods) and continues to dereference
// subsequent pointers as necessary.
//
// After the first round-trip, we set v back to the original value to
// preserve the original RW flags contained in reflect.Value.
v0 := v
haveAddr := false

// If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods,
// we find them.
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
haveAddr = true
v = v.Addr()
}

for {
// Load value from interface, but only if the result will be
// usefully addressable.
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) {
if e.Kind() == reflect.Ptr && !e.IsNil() && (!opts.decodeNull || e.Elem().Kind() == reflect.Ptr) {
haveAddr = false
v = e
continue
}
if e.Kind() != reflect.Ptr && e.IsValid() {
return nil, e
}
}
if v.Kind() != reflect.Ptr {
break
}
if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() {
if opts.decodeNull && v.CanSet() {
break
}

// Prevent infinite loop if v is an interface pointing to its own address:
// var v interface{}
// v = &v
if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v {
v = v.Elem()
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.Type().NumMethod() > 0 {
if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
return u, reflect.Value{}
}
}
v = v.Elem()

if haveAddr {
v = v0 // restore original value after round-trip Value.Addr().Elem()
haveAddr = false
} else {
v = v.Elem()
}
}

return nil, v
Expand Down Expand Up @@ -782,8 +887,12 @@ func (n Number) String() string {
type UnmarshalTypeError struct {
Value string
Type reflect.Type
Err error
}

// Unwrap returns the underlying error if any.
func (e *UnmarshalTypeError) Unwrap() error { return e.Err }

// Error returns the string representation of the error.
// satisfying the error interface
func (e *UnmarshalTypeError) Error() string {
Expand Down
Loading

0 comments on commit eff88d0

Please sign in to comment.