Skip to content

Commit 01c9d82

Browse files
crazycs520zz-jason
authored andcommitted
expression: fix cast json to decimal bug. (#8030) (#8109)
1 parent 6999f64 commit 01c9d82

File tree

5 files changed

+97
-4
lines changed

5 files changed

+97
-4
lines changed

executor/executor_test.go

+15
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,21 @@ func (s *testSuite) TestJSON(c *C) {
13321332
// check CAST AS JSON.
13331333
result = tk.MustQuery(`select CAST('3' AS JSON), CAST('{}' AS JSON), CAST(null AS JSON)`)
13341334
result.Check(testkit.Rows(`3 {} <nil>`))
1335+
1336+
// Check cast json to decimal.
1337+
tk.MustExec("drop table if exists test_json")
1338+
tk.MustExec("create table test_json ( a decimal(60,2) as (JSON_EXTRACT(b,'$.c')), b json );")
1339+
tk.MustExec(`insert into test_json (b) values
1340+
('{"c": "1267.1"}'),
1341+
('{"c": "1267.01"}'),
1342+
('{"c": "1267.1234"}'),
1343+
('{"c": "1267.3456"}'),
1344+
('{"c": "1234567890123456789012345678901234567890123456789012345"}'),
1345+
('{"c": "1234567890123456789012345678901234567890123456789012345.12345"}');`)
1346+
1347+
tk.MustQuery("select a from test_json;").Check(testkit.Rows("1267.10", "1267.01", "1267.12",
1348+
"1267.35", "1234567890123456789012345678901234567890123456789012345.00",
1349+
"1234567890123456789012345678901234567890123456789012345.12"))
13351350
}
13361351

13371352
func (s *testSuite) TestMultiUpdate(c *C) {

expression/builtin_cast.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -1541,11 +1541,11 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD
15411541
return res, isNull, errors.Trace(err)
15421542
}
15431543
sc := b.ctx.GetSessionVars().StmtCtx
1544-
f64, err := types.ConvertJSONToFloat(sc, val)
1545-
if err == nil {
1546-
res = new(types.MyDecimal)
1547-
err = res.FromFloat64(f64)
1544+
res, err = types.ConvertJSONToDecimal(sc, val)
1545+
if err != nil {
1546+
return res, false, errors.Trace(err)
15481547
}
1548+
res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, sc)
15491549
return res, false, errors.Trace(err)
15501550
}
15511551

expression/builtin_cast_test.go

+42
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,48 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
10821082
c.Assert(iRes, Equals, int64(0))
10831083
}
10841084

1085+
func (s *testEvaluatorSuite) TestCastJSONAsDecimalSig(c *C) {
1086+
ctx, sc := s.ctx, s.ctx.GetSessionVars().StmtCtx
1087+
originIgnoreTruncate := sc.IgnoreTruncate
1088+
sc.IgnoreTruncate = true
1089+
defer func() {
1090+
sc.IgnoreTruncate = originIgnoreTruncate
1091+
}()
1092+
1093+
col := &Column{RetType: types.NewFieldType(mysql.TypeJSON), Index: 0}
1094+
decFunc := newBaseBuiltinCastFunc(newBaseBuiltinFunc(ctx, []Expression{col}), false)
1095+
decFunc.tp = types.NewFieldType(mysql.TypeNewDecimal)
1096+
decFunc.tp.Flen = 60
1097+
decFunc.tp.Decimal = 2
1098+
sig := &builtinCastJSONAsDecimalSig{decFunc}
1099+
1100+
var tests = []struct {
1101+
In string
1102+
Out *types.MyDecimal
1103+
}{
1104+
{`{}`, types.NewDecFromStringForTest("0")},
1105+
{`[]`, types.NewDecFromStringForTest("0")},
1106+
{`3`, types.NewDecFromStringForTest("3")},
1107+
{`-3`, types.NewDecFromStringForTest("-3")},
1108+
{`4.5`, types.NewDecFromStringForTest("4.5")},
1109+
{`"1234"`, types.NewDecFromStringForTest("1234")},
1110+
// test truncate
1111+
{`"1234.1234"`, types.NewDecFromStringForTest("1234.12")},
1112+
{`"1234.4567"`, types.NewDecFromStringForTest("1234.46")},
1113+
// test big decimal
1114+
{`"1234567890123456789012345678901234567890123456789012345"`, types.NewDecFromStringForTest("1234567890123456789012345678901234567890123456789012345")},
1115+
}
1116+
for _, tt := range tests {
1117+
j, err := json.ParseBinaryFromString(tt.In)
1118+
c.Assert(err, IsNil)
1119+
row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(j)})
1120+
res, isNull, err := sig.evalDecimal(row.ToRow())
1121+
c.Assert(isNull, Equals, false)
1122+
c.Assert(err, IsNil)
1123+
c.Assert(res.Compare(tt.Out), Equals, 0)
1124+
}
1125+
}
1126+
10851127
// TestWrapWithCastAsTypesClasses tests WrapWithCastAsInt/Real/String/Decimal.
10861128
func (s *testEvaluatorSuite) TestWrapWithCastAsTypesClasses(c *C) {
10871129
ctx := s.ctx

types/convert.go

+15
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,21 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6
378378
return 0, errors.New("Unknown type code in JSON")
379379
}
380380

381+
// ConvertJSONToDecimal casts JSON into decimal.
382+
func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyDecimal, error) {
383+
res := new(MyDecimal)
384+
if j.TypeCode != json.TypeCodeString {
385+
f64, err := ConvertJSONToFloat(sc, j)
386+
if err != nil {
387+
return res, errors.Trace(err)
388+
}
389+
err = res.FromFloat64(f64)
390+
return res, errors.Trace(err)
391+
}
392+
err := sc.HandleTruncate(res.FromString([]byte(j.GetString())))
393+
return res, errors.Trace(err)
394+
}
395+
381396
// getValidFloatPrefix gets prefix of string which can be successfully parsed as float.
382397
func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) {
383398
var (

types/convert_test.go

+21
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,27 @@ func (s *testTypeConvertSuite) TestConvertJSONToFloat(c *C) {
804804
}
805805
}
806806

807+
func (s *testTypeConvertSuite) TestConvertJSONToDecimal(c *C) {
808+
var tests = []struct {
809+
In string
810+
Out *MyDecimal
811+
}{
812+
{`{}`, NewDecFromStringForTest("0")},
813+
{`[]`, NewDecFromStringForTest("0")},
814+
{`3`, NewDecFromStringForTest("3")},
815+
{`-3`, NewDecFromStringForTest("-3")},
816+
{`4.5`, NewDecFromStringForTest("4.5")},
817+
{`"1234"`, NewDecFromStringForTest("1234")},
818+
{`"1234567890123456789012345678901234567890123456789012345"`, NewDecFromStringForTest("1234567890123456789012345678901234567890123456789012345")},
819+
}
820+
for _, tt := range tests {
821+
j, err := json.ParseBinaryFromString(tt.In)
822+
c.Assert(err, IsNil)
823+
casted, _ := ConvertJSONToDecimal(new(stmtctx.StatementContext), j)
824+
c.Assert(casted.Compare(tt.Out), Equals, 0)
825+
}
826+
}
827+
807828
func (s *testTypeConvertSuite) TestNumberToDuration(c *C) {
808829
var testCases = []struct {
809830
number int64

0 commit comments

Comments
 (0)