diff --git a/executor/replace.go b/executor/replace.go index c4bf39db0596d..8f35be4d05dbd 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -20,8 +20,10 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -72,7 +74,7 @@ func (e *ReplaceExec) removeRow(ctx context.Context, txn kv.Transaction, handle return false, err } - rowUnchanged, err := types.EqualDatums(e.ctx.GetSessionVars().StmtCtx, oldRow, newRow) + rowUnchanged, err := e.EqualDatumsAsBinary(e.ctx.GetSessionVars().StmtCtx, oldRow, newRow) if err != nil { return false, err } @@ -89,6 +91,27 @@ func (e *ReplaceExec) removeRow(ctx context.Context, txn kv.Transaction, handle return false, nil } +// EqualDatumsAsBinary compare if a and b contains the same datum values in binary collation. +func (e *ReplaceExec) EqualDatumsAsBinary(sc *stmtctx.StatementContext, a []types.Datum, b []types.Datum) (bool, error) { + if len(a) != len(b) { + return false, nil + } + for i, ai := range a { + collation := ai.Collation() + // We should use binary collation to compare datum, otherwise the result will be incorrect + ai.SetCollation(charset.CollationBin) + v, err := ai.CompareDatum(sc, &b[i]) + ai.SetCollation(collation) + if err != nil { + return false, errors.Trace(err) + } + if v != 0 { + return false, nil + } + } + return true, nil +} + // replaceRow removes all duplicate rows for one row, then inserts it. func (e *ReplaceExec) replaceRow(ctx context.Context, r toBeCheckedRow) error { txn, err := e.ctx.Txn(true) diff --git a/executor/write_test.go b/executor/write_test.go index fdf8510065ee7..e1f7d8a1691c1 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" @@ -2978,3 +2979,55 @@ func (s *testSerialSuite) TestIssue20724(c *C) { tk.MustQuery("select * from t1").Check(testkit.Rows("A")) tk.MustExec("drop table t1") } + +func (s *testSerialSuite) TestIssue20840(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("set tidb_enable_clustered_index = 0") + tk.MustExec("create table t1 (i varchar(20) unique key) collate=utf8mb4_general_ci") + tk.MustExec("insert into t1 values ('a')") + tk.MustExec("replace into t1 values ('A')") + tk.MustQuery("select * from t1").Check(testkit.Rows("A")) + tk.MustExec("drop table t1") +} + +func (s *testSuite) TestEqualDatumsAsBinary(c *C) { + tests := []struct { + a []interface{} + b []interface{} + same bool + }{ + // Positive cases + {[]interface{}{1}, []interface{}{1}, true}, + {[]interface{}{1, "aa"}, []interface{}{1, "aa"}, true}, + {[]interface{}{1, "aa", 1}, []interface{}{1, "aa", 1}, true}, + + // negative cases + {[]interface{}{1}, []interface{}{2}, false}, + {[]interface{}{1, "a"}, []interface{}{1, "aaaaaa"}, false}, + {[]interface{}{1, "aa", 3}, []interface{}{1, "aa", 2}, false}, + + // Corner cases + {[]interface{}{}, []interface{}{}, true}, + {[]interface{}{nil}, []interface{}{nil}, true}, + {[]interface{}{}, []interface{}{1}, false}, + {[]interface{}{1}, []interface{}{1, 1}, false}, + {[]interface{}{nil}, []interface{}{1}, false}, + } + for _, tt := range tests { + testEqualDatumsAsBinary(c, tt.a, tt.b, tt.same) + } +} + +func testEqualDatumsAsBinary(c *C, a []interface{}, b []interface{}, same bool) { + sc := new(stmtctx.StatementContext) + re := new(executor.ReplaceExec) + sc.IgnoreTruncate = true + res, err := re.EqualDatumsAsBinary(sc, types.MakeDatums(a...), types.MakeDatums(b...)) + c.Assert(err, IsNil) + c.Assert(res, Equals, same, Commentf("a: %v, b: %v", a, b)) +} diff --git a/types/datum.go b/types/datum.go index d9d89fd2034c1..4190df25e9e4d 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1962,29 +1962,6 @@ func MaxValueDatum() Datum { return Datum{k: KindMaxValue} } -// EqualDatums compare if a and b contains the same datum values. -func EqualDatums(sc *stmtctx.StatementContext, a []Datum, b []Datum) (bool, error) { - if len(a) != len(b) { - return false, nil - } - if a == nil && b == nil { - return true, nil - } - if a == nil || b == nil { - return false, nil - } - for i, ai := range a { - v, err := ai.CompareDatum(sc, &b[i]) - if err != nil { - return false, errors.Trace(err) - } - if v != 0 { - return false, nil - } - } - return true, nil -} - // SortDatums sorts a slice of datum. func SortDatums(sc *stmtctx.StatementContext, datums []Datum) error { sorter := datumsSorter{datums: datums, sc: sc} diff --git a/types/datum_test.go b/types/datum_test.go index c5b3ea491a48a..cdb11b279d2f6 100644 --- a/types/datum_test.go +++ b/types/datum_test.go @@ -114,42 +114,6 @@ func (ts *testDatumSuite) TestToBool(c *C) { c.Assert(err, NotNil) } -func (ts *testDatumSuite) TestEqualDatums(c *C) { - tests := []struct { - a []interface{} - b []interface{} - same bool - }{ - // Positive cases - {[]interface{}{1}, []interface{}{1}, true}, - {[]interface{}{1, "aa"}, []interface{}{1, "aa"}, true}, - {[]interface{}{1, "aa", 1}, []interface{}{1, "aa", 1}, true}, - - // negative cases - {[]interface{}{1}, []interface{}{2}, false}, - {[]interface{}{1, "a"}, []interface{}{1, "aaaaaa"}, false}, - {[]interface{}{1, "aa", 3}, []interface{}{1, "aa", 2}, false}, - - // Corner cases - {[]interface{}{}, []interface{}{}, true}, - {[]interface{}{nil}, []interface{}{nil}, true}, - {[]interface{}{}, []interface{}{1}, false}, - {[]interface{}{1}, []interface{}{1, 1}, false}, - {[]interface{}{nil}, []interface{}{1}, false}, - } - for _, tt := range tests { - testEqualDatums(c, tt.a, tt.b, tt.same) - } -} - -func testEqualDatums(c *C, a []interface{}, b []interface{}, same bool) { - sc := new(stmtctx.StatementContext) - sc.IgnoreTruncate = true - res, err := EqualDatums(sc, MakeDatums(a...), MakeDatums(b...)) - c.Assert(err, IsNil) - c.Assert(res, Equals, same, Commentf("a: %v, b: %v", a, b)) -} - func testDatumToInt64(c *C, val interface{}, expect int64) { d := NewDatum(val) sc := new(stmtctx.StatementContext) diff --git a/util/profile/flamegraph_test.go b/util/profile/flamegraph_test.go index d02d1912a25e7..5517ccd54a7c7 100644 --- a/util/profile/flamegraph_test.go +++ b/util/profile/flamegraph_test.go @@ -89,7 +89,14 @@ func (s *profileInternalSuite) TestProfileToDatum(c *C) { c.Assert(err, IsNil, comment) comment = Commentf("row %2d, actual (%s), expected (%s)", i, rowStr, expectStr) - equal, err := types.EqualDatums(nil, row, datums[i]) + equal := true + for j, r := range row { + v, err := r.CompareDatum(nil, &datums[i][j]) + if v != 0 || err != nil { + equal = false + break + } + } c.Assert(err, IsNil, comment) c.Assert(equal, IsTrue, comment) }