From 128b73100a53d16160593ff02e9ee7dcb2932d4d Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Sat, 16 Nov 2024 11:57:12 -0600 Subject: [PATCH] fix: properly support negative numerator for CBOR rational numbers (#793) --- cbor/tags.go | 49 +++++++++++++++++++++++++++++++++++++---------- cbor/tags_test.go | 8 ++++++++ 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/cbor/tags.go b/cbor/tags.go index cf9af612..9039d910 100644 --- a/cbor/tags.go +++ b/cbor/tags.go @@ -15,6 +15,7 @@ package cbor import ( + "fmt" "math/big" "reflect" @@ -93,16 +94,30 @@ type Rat struct { } func (r *Rat) UnmarshalCBOR(cborData []byte) error { - tmpRat := []uint64{} + tmpRat := []any{} if _, err := Decode(cborData, &tmpRat); err != nil { return err } - // Convert numerator and denominator to big.Int - // It's necessary to do this to support num/denom larger than int64 (up to uint64) + // Convert numerator to big.Int tmpNum := new(big.Int) - tmpNum.SetUint64(tmpRat[0]) + switch v := tmpRat[0].(type) { + case int64: + tmpNum.SetInt64(v) + case uint64: + tmpNum.SetUint64(v) + default: + return fmt.Errorf("unsupported numerator type for cbor.Rat: %T", v) + } + // Convert denominator to big.Int tmpDenom := new(big.Int) - tmpDenom.SetUint64(tmpRat[1]) + switch v := tmpRat[1].(type) { + case int64: + tmpDenom.SetInt64(v) + case uint64: + tmpDenom.SetUint64(v) + default: + return fmt.Errorf("unsupported demoninator type for cbor.Rat: %T", v) + } // Create new big.Rat with num/denom set to big.Int values above r.Rat = new(big.Rat) r.Rat.SetFrac(tmpNum, tmpDenom) @@ -110,12 +125,26 @@ func (r *Rat) UnmarshalCBOR(cborData []byte) error { } func (r *Rat) MarshalCBOR() ([]byte, error) { + tmpContent := make([]any, 2) + // Numerator + if r.Num().IsUint64() { + tmpContent[0] = r.Num().Uint64() + } else if r.Num().IsInt64() { + tmpContent[0] = r.Num().Int64() + } else { + return nil, fmt.Errorf("numerator cannot be represented at int64/uint64") + } + // Denominator + if r.Denom().IsUint64() { + tmpContent[1] = r.Denom().Uint64() + } else if r.Denom().IsInt64() { + tmpContent[1] = r.Denom().Int64() + } else { + return nil, fmt.Errorf("numerator cannot be represented at int64/uint64") + } tmpData := _cbor.Tag{ - Number: CborTagRational, - Content: []uint64{ - r.Num().Uint64(), - r.Denom().Uint64(), - }, + Number: CborTagRational, + Content: tmpContent, } return Encode(&tmpData) } diff --git a/cbor/tags_test.go b/cbor/tags_test.go index 341237c5..d0fa9858 100644 --- a/cbor/tags_test.go +++ b/cbor/tags_test.go @@ -54,6 +54,7 @@ var tagsTestDefs = []struct { }, ), }, + // 30([9223372036854775809, 10000000000000000000]) { cborHex: "d81e821b80000000000000011b8ac7230489e80000", object: cbor.Rat{ @@ -63,6 +64,13 @@ var tagsTestDefs = []struct { ), }, }, + // 30([-1, 2]) + { + cborHex: "d81e822002", + object: cbor.Rat{ + Rat: big.NewRat(-1, 2), + }, + }, } func TestTagsDecode(t *testing.T) {