diff --git a/unmarshaler.go b/unmarshaler.go index 857f3cf5..98231bae 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -35,6 +35,9 @@ type Decoder struct { // global settings strict bool + + // toggles unmarshaler interface + unmarshalerInterface bool } // NewDecoder creates a new Decoder that will read from r. @@ -54,6 +57,24 @@ func (d *Decoder) DisallowUnknownFields() *Decoder { return d } +// EnableUnmarshalerInterface allows to enable unmarshaler interface. +// +// With this feature enabled, types implementing the unstable/Unmarshaler +// interface can be decoded from any structure of the document. It allows types +// that don't have a straightfoward TOML representation to provide their own +// decoding logic. +// +// Currently, types can only decode from a single value. Tables and array tables +// are not supported. +// +// *Unstable:* This method does not follow the compatibility guarantees of +// semver. It can be changed or removed without a new major version being +// issued. +func (d *Decoder) EnableUnmarshalerInterface() *Decoder { + d.unmarshalerInterface = true + return d +} + // Decode the whole content of r into v. // // By default, values in the document that don't exist in the target Go value @@ -108,6 +129,7 @@ func (d *Decoder) Decode(v interface{}) error { strict: strict{ Enabled: d.strict, }, + unmarshalerInterface: d.unmarshalerInterface, } return dec.FromParser(v) @@ -143,6 +165,9 @@ type decoder struct { // Strict mode strict strict + // Flag that enables/disables unmarshaler interface. + unmarshalerInterface bool + // Current context for the error. errorContext *errorContext } @@ -648,6 +673,14 @@ func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error { v = initAndDereferencePointer(v) } + if d.unmarshalerInterface { + if v.CanAddr() && v.Addr().CanInterface() { + if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok { + return outi.UnmarshalTOML(value) + } + } + } + ok, err := d.tryTextUnmarshaler(value, v) if ok || err != nil { return err diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 78e06895..c1833fb0 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/pelletier/go-toml/v2" + "github.com/pelletier/go-toml/v2/unstable" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -3772,3 +3773,95 @@ func TestUnmarshal_Nil(t *testing.T) { }) } } + +type CustomUnmarshalerKey struct { + A int64 +} + +func (k *CustomUnmarshalerKey) UnmarshalTOML(value *unstable.Node) error { + item, err := strconv.ParseInt(string(value.Data), 10, 64) + if err != nil { + return fmt.Errorf("error converting to int64, %v", err) + } + k.A = item + return nil + +} + +func TestUnmarshal_CustomUnmarshaler(t *testing.T) { + type MyConfig struct { + Unmarshalers []CustomUnmarshalerKey `toml:"unmarshalers"` + Foo *string `toml:"foo,omitempty"` + } + + examples := []struct { + desc string + disableUnmarshalerInterface bool + input string + expected MyConfig + err bool + }{ + { + desc: "empty", + input: ``, + expected: MyConfig{Unmarshalers: []CustomUnmarshalerKey{}, Foo: nil}, + }, + { + desc: "simple", + input: `unmarshalers = [1,2,3]`, + expected: MyConfig{ + Unmarshalers: []CustomUnmarshalerKey{ + {A: 1}, + {A: 2}, + {A: 3}, + }, + Foo: nil, + }, + }, + { + desc: "unmarshal string and custom unmarshaler", + input: `unmarshalers = [1,2,3] +foo = "bar"`, + expected: MyConfig{ + Unmarshalers: []CustomUnmarshalerKey{ + {A: 1}, + {A: 2}, + {A: 3}, + }, + Foo: func(v string) *string { + return &v + }("bar"), + }, + }, + { + desc: "simple example, but unmarshaler interface disabled", + disableUnmarshalerInterface: true, + input: `unmarshalers = [1,2,3]`, + err: true, + }, + } + + for _, ex := range examples { + e := ex + t.Run(e.desc, func(t *testing.T) { + foo := MyConfig{} + + decoder := toml.NewDecoder(bytes.NewReader([]byte(e.input))) + if !ex.disableUnmarshalerInterface { + decoder.EnableUnmarshalerInterface() + } + err := decoder.Decode(&foo) + + if e.err { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, len(foo.Unmarshalers), len(e.expected.Unmarshalers)) + for i := 0; i < len(foo.Unmarshalers); i++ { + require.Equal(t, foo.Unmarshalers[i], e.expected.Unmarshalers[i]) + } + require.Equal(t, foo.Foo, e.expected.Foo) + } + }) + } +} diff --git a/unstable/unmarshaler.go b/unstable/unmarshaler.go new file mode 100644 index 00000000..00cfd6de --- /dev/null +++ b/unstable/unmarshaler.go @@ -0,0 +1,7 @@ +package unstable + +// The Unmarshaler interface may be implemented by types to customize their +// behavior when being unmarshaled from a TOML document. +type Unmarshaler interface { + UnmarshalTOML(value *Node) error +}