Skip to content

Commit

Permalink
expression, parser: fix issue #3691, cast compatibility (#3894)
Browse files Browse the repository at this point in the history
  • Loading branch information
winkyao authored and jackysp committed Aug 1, 2017
1 parent 01c1d4c commit d0dcb5b
Show file tree
Hide file tree
Showing 15 changed files with 297 additions and 50 deletions.
73 changes: 72 additions & 1 deletion executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,9 @@ func (s *testSuite) TestUnion(c *C) {
tk.MustExec("CREATE TABLE t (a DECIMAL(4,2))")
tk.MustExec("INSERT INTO t VALUE(12.34)")
r = tk.MustQuery("SELECT 1 AS c UNION select a FROM t")
r.Check(testkit.Rows("1.00", "12.34"))
r.Sort().Check(testkit.Rows("1.00", "12.34"))

// #issue3771
r = tk.MustQuery("SELECT 'a' UNION SELECT CONCAT('a', -4)")
r.Sort().Check(testkit.Rows("a", "a-4"))
Expand Down Expand Up @@ -1189,6 +1191,75 @@ func (s *testSuite) TestBuiltin(c *C) {
result = tk.MustQuery("select cast(-1 as unsigned)")
result.Check(testkit.Rows("18446744073709551615"))

// Fix issue #3691, cast compability.
result = tk.MustQuery("select cast('18446744073709551616' as unsigned);")
result.Check(testkit.Rows("18446744073709551615"))
result = tk.MustQuery("select cast('18446744073709551616' as signed);")
result.Check(testkit.Rows("-1"))
result = tk.MustQuery("select cast('9223372036854775808' as signed);")
result.Check(testkit.Rows("-9223372036854775808"))
result = tk.MustQuery("select cast('9223372036854775809' as signed);")
result.Check(testkit.Rows("-9223372036854775807"))
result = tk.MustQuery("select cast('9223372036854775807' as signed);")
result.Check(testkit.Rows("9223372036854775807"))
result = tk.MustQuery("select cast('18446744073709551615' as signed);")
result.Check(testkit.Rows("-1"))
result = tk.MustQuery("select cast('18446744073709551614' as signed);")
result.Check(testkit.Rows("-2"))
result = tk.MustQuery("select cast(18446744073709551615 as unsigned);")
result.Check(testkit.Rows("18446744073709551615"))
result = tk.MustQuery("select cast(18446744073709551616 as unsigned);")
result.Check(testkit.Rows("18446744073709551615"))
result = tk.MustQuery("select cast(18446744073709551616 as signed);")
result.Check(testkit.Rows("9223372036854775807"))
result = tk.MustQuery("select cast(18446744073709551617 as signed);")
result.Check(testkit.Rows("9223372036854775807"))
result = tk.MustQuery("select cast(18446744073709551615 as signed);")
result.Check(testkit.Rows("-1"))
result = tk.MustQuery("select cast(18446744073709551614 as signed);")
result.Check(testkit.Rows("-2"))
result = tk.MustQuery("select cast(-18446744073709551616 as signed);")
result.Check(testkit.Rows("-9223372036854775808"))
result = tk.MustQuery("select cast(18446744073709551614.9 as unsigned);") // Round up
result.Check(testkit.Rows("18446744073709551615"))
result = tk.MustQuery("select cast(18446744073709551614.4 as unsigned);") // Round down
result.Check(testkit.Rows("18446744073709551614"))
result = tk.MustQuery("select cast(-9223372036854775809 as signed);")
result.Check(testkit.Rows("-9223372036854775808"))
result = tk.MustQuery("select cast(-9223372036854775809 as unsigned);")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select cast(-9223372036854775808 as unsigned);")
result.Check(testkit.Rows("9223372036854775808"))
result = tk.MustQuery("select cast('-9223372036854775809' as unsigned);")
result.Check(testkit.Rows("9223372036854775808"))
result = tk.MustQuery("select cast('-9223372036854775807' as unsigned);")
result.Check(testkit.Rows("9223372036854775809"))
result = tk.MustQuery("select cast('-2' as unsigned);")
result.Check(testkit.Rows("18446744073709551614"))
result = tk.MustQuery("select cast(cast(1-2 as unsigned) as signed integer);")
result.Check(testkit.Rows("-1"))
result = tk.MustQuery("select cast(1 as signed int)")
result.Check(testkit.Rows("1"))

// test cast time as decimal overflow
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(s1 time);")
tk.MustExec("insert into t1 values('11:11:11');")
result = tk.MustQuery("select cast(s1 as decimal(7, 2)) from t1;")
result.Check(testkit.Rows("99999.99"))
result = tk.MustQuery("select cast(s1 as decimal(8, 2)) from t1;")
result.Check(testkit.Rows("111111.00"))
_, err := tk.Exec("insert into t1 values(cast('111111.00' as decimal(7, 2)));")
c.Assert(err, NotNil)

result = tk.MustQuery(`select CAST(0x8fffffffffffffff as signed) a,
CAST(0xfffffffffffffffe as signed) b,
CAST(0xffffffffffffffff as unsigned) c;`)
result.Check(testkit.Rows("-8070450532247928833 -2 18446744073709551615"))

result = tk.MustQuery(`select cast("1:2:3" as TIME) = "1:02:03"`)
result.Check(testkit.Rows("0"))

// fixed issue #3471
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a time(6));")
Expand All @@ -1206,7 +1277,7 @@ func (s *testSuite) TestBuiltin(c *C) {

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a bigint(30));")
_, err := tk.Exec("insert into t values(-9223372036854775809)")
_, err = tk.Exec("insert into t values(-9223372036854775809)")
c.Assert(err, NotNil)

// test unhex and hex
Expand Down
19 changes: 13 additions & 6 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,34 +333,41 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) {
sessVars := ctx.GetSessionVars()
sc := new(variable.StatementContext)
sc.TimeZone = sessVars.GetTimeZone()

switch s.(type) {
case *ast.UpdateStmt, *ast.DeleteStmt:
sc.IgnoreTruncate = false
sc.IgnoreOverflow = false
sc.OverflowAsWarning = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode
sc.InUpdateOrDeleteStmt = true
case *ast.InsertStmt:
sc.IgnoreTruncate = false
sc.IgnoreOverflow = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode
sc.InInsertStmt = true
case *ast.CreateTableStmt, *ast.AlterTableStmt:
// Make sure the sql_mode is strict when checking column default value.
sc.IgnoreTruncate = false
sc.IgnoreOverflow = false
sc.OverflowAsWarning = false
sc.TruncateAsWarning = false
case *ast.LoadDataStmt:
sc.IgnoreTruncate = false
sc.IgnoreOverflow = false
sc.OverflowAsWarning = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode
case *ast.SelectStmt:
sc.IgnoreOverflow = true
sc.InSelectStmt = true

// see https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html#sql-mode-strict
// said "For statements such as SELECT that do not change data, invalid values
// generate a warning in strict mode, not an error."
// and https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
sc.OverflowAsWarning = true

// Return warning for truncate error in selection.
sc.IgnoreTruncate = false
sc.TruncateAsWarning = true
default:
sc.IgnoreTruncate = true
sc.IgnoreOverflow = false
sc.OverflowAsWarning = false
if show, ok := s.(*ast.ShowStmt); ok {
if show.Tp == ast.ShowWarnings {
sc.InShowWarning = true
Expand Down
76 changes: 61 additions & 15 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
package expression

import (
"math"
"strconv"
"strings"

"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
)
Expand Down Expand Up @@ -544,22 +547,24 @@ func (b *builtinCastDecimalAsIntSig) evalInt(row []types.Datum) (res int64, isNu
if isNull || err != nil {
return res, isNull, errors.Trace(err)
}

// Round is needed for both unsigned and signed.
var to types.MyDecimal
val.Round(&to, 0, types.ModeHalfEven)

if mysql.HasUnsignedFlag(b.tp.Flag) {
var (
floatVal float64
uintRes uint64
)
floatVal, err = val.ToFloat64()
if err != nil {
return res, false, errors.Trace(err)
}
uintRes, err = types.ConvertFloatToUint(sc, floatVal, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble)
var uintRes uint64
uintRes, err = to.ToUint()
res = int64(uintRes)
} else {
var to types.MyDecimal
val.Round(&to, 0, types.ModeHalfEven)
res, err = to.ToInt()
}

if terror.ErrorEqual(err, types.ErrOverflow) {
warnErr := types.ErrTruncatedWrongVal.GenByArgs("DECIMAL", val)
err = sc.HandleOverflow(err, warnErr)
}

return res, false, errors.Trace(err)
}

Expand Down Expand Up @@ -638,6 +643,30 @@ type builtinCastStringAsIntSig struct {
baseIntBuiltinFunc
}

// handleOverflow handles the overflow caused by cast string as int,
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html.
// When an out-of-range value is assigned to an integer column, MySQL stores the value representing the corresponding endpoint of the column data type range. If it is in select statement, it will return the
// endpoint value with a warning.
func (b *builtinCastStringAsIntSig) handleOverflow(origRes int64, origStr string, origErr error, isNegative bool) (res int64, err error) {
res, err = origRes, origErr
if err == nil {
return
}

sc := b.getCtx().GetSessionVars().StmtCtx
if sc.InSelectStmt && terror.ErrorEqual(origErr, types.ErrOverflow) {
if isNegative {
res = math.MinInt64
} else {
uval := uint64(math.MaxUint64)
res = int64(uval)
}
warnErr := types.ErrTruncatedWrongVal.GenByArgs("INTEGER", origStr)
err = sc.HandleOverflow(origErr, warnErr)
}
return
}

func (b *builtinCastStringAsIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) {
sc := b.getCtx().GetSessionVars().StmtCtx
if IsHybridType(b.args[0]) {
Expand All @@ -647,13 +676,30 @@ func (b *builtinCastStringAsIntSig) evalInt(row []types.Datum) (res int64, isNul
if isNull || err != nil {
return res, isNull, errors.Trace(err)
}
if mysql.HasUnsignedFlag(b.tp.Flag) {
var ures uint64

val = strings.TrimSpace(val)
isNegative := false
if len(val) > 1 && val[0] == '-' { // negative number
isNegative = true
}

var ures uint64
if isNegative {
res, err = types.StrToInt(sc, val)
if err == nil {
// If overflow, don't append this warnings
sc.AppendWarning(types.ErrCastNegIntAsUnsigned)
}
} else {
ures, err = types.StrToUint(sc, val)
res = int64(ures)
} else {
res, err = types.StrToInt(sc, val)

if err == nil && !mysql.HasUnsignedFlag(b.tp.Flag) && ures > uint64(math.MaxInt64) {
sc.AppendWarning(types.ErrCastAsSignedOverflow)
}
}

res, err = b.handleOverflow(res, val, err, isNegative)
return res, false, errors.Trace(err)
}

Expand Down
87 changes: 87 additions & 0 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

. "github.com/pingcap/check"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/types"
Expand Down Expand Up @@ -75,6 +76,92 @@ func (s *testEvaluatorSuite) TestCast(c *C) {
c.Assert(len(res.GetString()), Equals, 5)
c.Assert(res.GetString(), Equals, string([]byte{'a', 0x00, 0x00, 0x00, 0x00}))

origSc := sc
sc.InSelectStmt = true
sc.OverflowAsWarning = true

// cast('18446744073709551616' as unsigned);
tp1 := &types.FieldType{
Tp: mysql.TypeLonglong,
Flag: mysql.BinaryFlag,
Charset: charset.CharsetBin,
Collate: charset.CollationBin,
Flen: mysql.MaxIntWidth,
}
f = NewCastFunc(tp1, &Constant{Value: types.NewDatum("18446744073709551616"), RetType: types.NewFieldType(mysql.TypeString)}, ctx)
res, err = f.Eval(nil)
c.Assert(err, IsNil)
c.Assert(res.GetUint64() == math.MaxUint64, IsTrue)

warnings := sc.GetWarnings()
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn), IsTrue)

f = NewCastFunc(tp1, &Constant{Value: types.NewDatum("-1"), RetType: types.NewFieldType(mysql.TypeString)}, ctx)
res, err = f.Eval(nil)
c.Assert(err, IsNil)
c.Assert(res.GetUint64() == 18446744073709551615, IsTrue)

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrCastNegIntAsUnsigned, lastWarn), IsTrue)

f = NewCastFunc(tp1, &Constant{Value: types.NewDatum("-18446744073709551616"), RetType: types.NewFieldType(mysql.TypeString)}, ctx)
res, err = f.Eval(nil)
c.Assert(err, IsNil)
t := math.MinInt64
// 9223372036854775808
c.Assert(res.GetUint64() == uint64(t), IsTrue)

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn), IsTrue)

// cast('18446744073709551616' as signed);
mask := ^mysql.UnsignedFlag
tp1.Flag &= uint(mask)
f = NewCastFunc(tp1, &Constant{Value: types.NewDatum("18446744073709551616"), RetType: types.NewFieldType(mysql.TypeString)}, ctx)
res, err = f.Eval(nil)
c.Assert(err, IsNil)
c.Check(res.GetInt64(), Equals, int64(-1))

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn), IsTrue)

// cast('18446744073709551614' as signed);
f = NewCastFunc(tp1, &Constant{Value: types.NewDatum("18446744073709551614"), RetType: types.NewFieldType(mysql.TypeString)}, ctx)
res, err = f.Eval(nil)
c.Assert(err, IsNil)
c.Check(res.GetInt64(), Equals, int64(-2))

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrCastAsSignedOverflow, lastWarn), IsTrue)

// create table t1(s1 time);
// insert into t1 values('11:11:11');
// select cast(s1 as decimal(7, 2)) from t1;
tpDecimal := &types.FieldType{
Tp: mysql.TypeNewDecimal,
Flag: mysql.BinaryFlag | mysql.UnsignedFlag,
Charset: charset.CharsetBin,
Collate: charset.CollationBin,
Flen: 7,
Decimal: 2,
}
f = NewCastFunc(tpDecimal, &Constant{Value: timeDatum, RetType: types.NewFieldType(mysql.TypeDatetime)}, ctx)
res, err = f.Eval(nil)
c.Assert(err, IsNil)
resDecimal := new(types.MyDecimal)
resDecimal.FromString([]byte("99999.99"))
c.Assert(res.GetMysqlDecimal().Compare(resDecimal), Equals, 0)

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrOverflow, lastWarn), IsTrue)
sc = origSc

// cast(bad_string as decimal)
for _, s := range []string{"hello", ""} {
f = NewCastFunc(tp, &Constant{Value: types.NewDatum(s), RetType: types.NewFieldType(mysql.TypeDecimal)}, ctx)
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, ctx context.Cont
sc := ctx.GetSessionVars().StmtCtx
overflow := false
// TODO: Handle float overflow.
if arg, ok := argExpr.(*Constant); sc.IgnoreOverflow && ok &&
if arg, ok := argExpr.(*Constant); sc.InSelectStmt && ok &&
arg.GetTypeClass() == types.ClassInt {
overflow = b.handleIntOverflow(arg)
if overflow {
Expand Down
6 changes: 3 additions & 3 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ func (s *testEvaluatorSuite) TestUnary(c *C) {
{int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808
}
sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.IgnoreOverflow
sc.IgnoreOverflow = true
origin := sc.InSelectStmt
sc.InSelectStmt = true
defer func() {
sc.IgnoreOverflow = origin
sc.InSelectStmt = origin
}()

for _, t := range cases {
Expand Down
1 change: 0 additions & 1 deletion expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ func (b *builtinTimeDiffSig) eval(row []types.Datum) (d types.Datum, err error)
if err != nil {
return d, errors.Trace(err)
}

t := t1.Sub(&t2)
d.SetMysqlDuration(t)
return
Expand Down
Loading

0 comments on commit d0dcb5b

Please sign in to comment.