Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More alias handling in Unwrap functionality of Value #792

Merged
merged 5 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions pkg/values/bytes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ func Test_BytesUnwrapTo(t *testing.T) {

assert.Equal(t, hs, got)

var s string
err = tr.UnwrapTo(&s)
var b bool
err = tr.UnwrapTo(&b)
require.Error(t, err)

str := ""
err = tr.UnwrapTo(&str)
require.NoError(t, err)
assert.Equal(t, []byte(str), tr.Underlying)

gotB := (*[]byte)(nil)
err = tr.UnwrapTo(gotB)
assert.ErrorContains(t, err, "cannot unwrap to nil pointer")
Expand Down Expand Up @@ -58,4 +63,15 @@ func Test_BytesUnwrapToAlias(t *testing.T) {
got = append(got, byte(b))
}
assert.Equal(t, underlying, got)

var oracleIDs [5]alias
underlying = []byte("hello")
bn = &Bytes{Underlying: underlying}
err = bn.UnwrapTo(&oracleIDs)
require.NoError(t, err)
got = []byte{}
for _, b := range oracleIDs {
got = append(got, byte(b))
}
assert.Equal(t, underlying, got)
}
80 changes: 16 additions & 64 deletions pkg/values/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package values
import (
"errors"
"fmt"
"math"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/values/pb"
Expand Down Expand Up @@ -42,74 +41,27 @@ func (i *Int64) UnwrapTo(to any) error {
return fmt.Errorf("cannot unwrap to nil pointer: %+v", to)
}

switch tv := to.(type) {
case *int64:
*tv = i.Underlying
return nil
case *int:
if i.Underlying > math.MaxInt {
return fmt.Errorf("cannot unwrap int64 to int: number would overflow %d", i)
}

if i.Underlying < math.MinInt {
return fmt.Errorf("cannot unwrap int64 to int: number would underflow %d", i)
}

*tv = int(i.Underlying)
return nil
case *uint:
if i.Underlying > math.MaxInt {
return fmt.Errorf("cannot unwrap int64 to int: number would overflow %d", i)
}

if i.Underlying < 0 {
return fmt.Errorf("cannot unwrap int64 to uint: number would underflow %d", i)
}

*tv = uint(i.Underlying)
return nil
case *uint32:
if i.Underlying > math.MaxInt {
return fmt.Errorf("cannot unwrap int64 to uint32: number would overflow %d", i)
}
if reflect.ValueOf(to).Kind() != reflect.Pointer {
return fmt.Errorf("cannot unwrap to non-pointer value: %+v", to)
}

if i.Underlying < 0 {
return fmt.Errorf("cannot unwrap int64 to uint32: number would underflow %d", i)
rToVal := reflect.Indirect(reflect.ValueOf(to))
switch rToVal.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if rToVal.OverflowInt(i.Underlying) {
return fmt.Errorf("cannot unwrap int64 to %T: overflow", to)
}

*tv = uint32(i.Underlying)
return nil
case *uint64:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if i.Underlying < 0 {
return fmt.Errorf("cannot unwrap int64 to uint: number would underflow %d", i)
return fmt.Errorf("cannot unwrap int64 to %T: underflow", to)
}

*tv = uint64(i.Underlying)
return nil
case *any:
*tv = i.Underlying
return nil
}

rv := reflect.ValueOf(to)
if rv.Kind() == reflect.Ptr {
switch rv.Elem().Kind() {
case reflect.Int64:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int64(0)))).Interface())
case reflect.Int32:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int32(0)))).Interface())
case reflect.Int:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int(0)))).Interface())
case reflect.Uint64:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint64(0)))).Interface())
case reflect.Uint32:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint32(0)))).Interface())
case reflect.Uint:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint(0)))).Interface())
default:
// fall through to the error, default is required by lint
if rToVal.OverflowUint(uint64(i.Underlying)) {
return fmt.Errorf("cannot unwrap int64 to %T: overflow", to)
}
case reflect.Interface:
default:
return fmt.Errorf("cannot unwrap to type %T", to)
}

return fmt.Errorf("cannot unwrap to type %T", to)
return unwrapTo(i.Underlying, to)
}
61 changes: 61 additions & 0 deletions pkg/values/int_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package values

import (
"math"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -36,3 +37,63 @@ func Test_IntUnwrapTo(t *testing.T) {
err = in.UnwrapTo(&i)
assert.ErrorContains(t, err, "cannot unwrap nil")
}

func Test_IntUnwrapping(t *testing.T) {
t.Run("int64 -> int32", func(st *testing.T) {
expected := int64(100)
v := NewInt64(expected)
got := int32(0)
err := v.UnwrapTo(&got)
require.NoError(t, err)
assert.Equal(t, int32(expected), got)
})

t.Run("int64 -> int32; overflow", func(st *testing.T) {
expected := int64(math.MaxInt64)
v := NewInt64(expected)
got := int32(0)
err := v.UnwrapTo(&got)
assert.NotNil(t, err)
})

t.Run("int64 -> int32; underflow", func(st *testing.T) {
expected := int64(math.MinInt64)
v := NewInt64(expected)
got := int32(0)
err := v.UnwrapTo(&got)
assert.NotNil(t, err)
})

t.Run("int64 -> uint32", func(st *testing.T) {
expected := int64(100)
v := NewInt64(expected)
got := uint32(0)
err := v.UnwrapTo(&got)
require.NoError(t, err)
assert.Equal(t, uint32(expected), got)
})

t.Run("int64 -> uint32; overflow", func(st *testing.T) {
expected := int64(math.MaxInt64)
v := NewInt64(expected)
got := uint32(0)
err := v.UnwrapTo(&got)
assert.NotNil(t, err)
})

t.Run("int64 -> uint32; underflow", func(st *testing.T) {
expected := int64(math.MinInt64)
v := NewInt64(expected)
got := uint32(0)
err := v.UnwrapTo(&got)
assert.NotNil(t, err)
})

t.Run("int64 -> uint64; underflow", func(st *testing.T) {
expected := int64(math.MinInt64)
v := NewInt64(expected)
got := uint64(0)
err := v.UnwrapTo(&got)
assert.NotNil(t, err)
})
}
19 changes: 14 additions & 5 deletions pkg/values/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,28 @@ func unwrapTo[T any](underlying T, to any) error {
// eg: type FeedId string allows verification of FeedId's shape while unmarshalling
rTo := reflect.ValueOf(to)
rUnderlying := reflect.ValueOf(underlying)
underlyingPtr := reflect.PointerTo(rUnderlying.Type())
if rTo.Kind() != reflect.Pointer {
return fmt.Errorf("cannot unwrap to value of type: %T", to)
}

if rTo.CanConvert(underlyingPtr) {
reflect.Indirect(rTo.Convert(underlyingPtr)).Set(rUnderlying)
if rUnderlying.CanConvert(reflect.Indirect(rTo).Type()) {
conv := rUnderlying.Convert(reflect.Indirect(rTo).Type())
reflect.Indirect(rTo).Set(conv)
return nil
}

rToVal := reflect.Indirect(rTo)
if rToVal.Kind() == reflect.Slice && rUnderlying.Kind() == reflect.Slice {
newList := reflect.MakeSlice(rToVal.Type(), rUnderlying.Len(), rUnderlying.Len())
if rUnderlying.Kind() == reflect.Slice {
var newList reflect.Value
if rToVal.Kind() == reflect.Array {
newListPtr := reflect.New(reflect.ArrayOf(rUnderlying.Len(), rToVal.Type().Elem()))
newList = reflect.Indirect(newListPtr)
} else if rToVal.Kind() == reflect.Slice {
newList = reflect.MakeSlice(rToVal.Type(), rUnderlying.Len(), rUnderlying.Len())
} else {
return fmt.Errorf("cannot unwrap slice to value of type: %T", to)
}

for i := 0; i < rUnderlying.Len(); i++ {
el := rUnderlying.Index(i)
toEl := newList.Index(i)
Expand Down
Loading