Skip to content

Commit

Permalink
executor,planner: fix update join update unmatched outer row (#23491) (
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Mar 26, 2021
1 parent cb0040c commit 317e1a1
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 79 deletions.
5 changes: 3 additions & 2 deletions ddl/db_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -793,9 +793,10 @@ func (s *testStateChangeSuite) TestWriteOnlyForDropColumn(c *C) {
c.Assert(err, IsNil)
}()

sqls := make([]sqlWithErr, 2)
sqls := make([]sqlWithErr, 3)
sqls[0] = sqlWithErr{"update t set c1='5', c3='2020-03-01';", errors.New("[planner:1054]Unknown column 'c3' in 'field list'")}
sqls[1] = sqlWithErr{"update t t1, tt t2 set t1.c1='5', t1.c3='2020-03-01', t2.c1='10' where t1.c4=t2.c4",
sqls[1] = sqlWithErr{"update t set c1='5', c3='2020-03-01' where c4 = 8;", errors.New("[planner:1054]Unknown column 'c3' in 'field list'")}
sqls[2] = sqlWithErr{"update t t1, tt t2 set t1.c1='5', t1.c3='2020-03-01', t2.c1='10' where t1.c4=t2.c4",
errors.New("[planner:1054]Unknown column 'c3' in 'field list'")}
// TODO: Fix the case of sqls[2].
// sqls[2] = sqlWithErr{"update t set c1='5' where c3='2017-07-01';", errors.New("[planner:1054]Unknown column 'c3' in 'field list'")}
Expand Down
28 changes: 28 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,15 @@ func (b *executorBuilder) buildUpdate(v *plannercore.Update) Executor {
}
base := newBaseExecutor(b.ctx, v.Schema(), v.ID(), selExec)
base.initCap = chunk.ZeroCapacity
var assignFlag []int
assignFlag, b.err = getAssignFlag(b.ctx, v, selExec.Schema().Len())
if b.err != nil {
return nil
}
b.err = plannercore.CheckUpdateList(assignFlag, v)
if b.err != nil {
return nil
}
updateExec := &UpdateExec{
baseExecutor: base,
OrderedList: v.OrderedList,
Expand All @@ -1850,10 +1859,29 @@ func (b *executorBuilder) buildUpdate(v *plannercore.Update) Executor {
multiUpdateOnSameTable: multiUpdateOnSameTable,
tblID2table: tblID2table,
tblColPosInfos: v.TblColPosInfos,
assignFlag: assignFlag,
}
return updateExec
}

func getAssignFlag(ctx sessionctx.Context, v *plannercore.Update, schemaLen int) ([]int, error) {
assignFlag := make([]int, schemaLen)
for i := range assignFlag {
assignFlag[i] = -1
}
for _, assign := range v.OrderedList {
if !ctx.GetSessionVars().AllowWriteRowID && assign.Col.ID == model.ExtraHandleID {
return nil, errors.Errorf("insert, update and replace statements for _tidb_rowid are not supported.")
}
tblIdx, found := v.TblColPosInfos.FindTblIdx(assign.Col.Index)
if found {
colIdx := assign.Col.Index
assignFlag[colIdx] = tblIdx
}
}
return assignFlag, nil
}

func (b *executorBuilder) buildDelete(v *plannercore.Delete) Executor {
tblID2table := make(map[int64]table.Table, len(v.TblColPosInfos))
for _, info := range v.TblColPosInfos {
Expand Down
167 changes: 167 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4129,6 +4129,173 @@ func (s *testSuiteP1) TestUnionAutoSignedCast(c *C) {
Check(testkit.Rows("1 1", "2 -1", "3 -1"))
}

func (s *testSuiteP1) TestUpdateClustered(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

type resultChecker struct {
check string
assert []string
}

for _, clustered := range []string{"", "clustered"} {
tests := []struct {
initSchema []string
initData []string
dml string
resultCheck []resultChecker
}{
{ // left join + update both + match & unmatched + pk
[]string{
"drop table if exists a, b",
"create table a (k1 int, k2 int, v int)",
fmt.Sprintf("create table b (a int not null, k1 int, k2 int, v int, primary key(k1, k2) %s)", clustered),
},
[]string{
"insert into a values (1, 1, 1), (2, 2, 2)", // unmatched + matched
"insert into b values (2, 2, 2, 2)",
},
"update a left join b on a.k1 = b.k1 and a.k2 = b.k2 set a.v = 20, b.v = 100, a.k1 = a.k1 + 1, b.k1 = b.k1 + 1, a.k2 = a.k2 + 2, b.k2 = b.k2 + 2",
[]resultChecker{
{
"select * from b",
[]string{"2 3 4 100"},
},
{
"select * from a",
[]string{"2 3 20", "3 4 20"},
},
},
},
{ // left join + update both + match & unmatched + pk
[]string{
"drop table if exists a, b",
"create table a (k1 int, k2 int, v int)",
fmt.Sprintf("create table b (a int not null, k1 int, k2 int, v int, primary key(k1, k2) %s)", clustered),
},
[]string{
"insert into a values (1, 1, 1), (2, 2, 2)", // unmatched + matched
"insert into b values (2, 2, 2, 2)",
},
"update a left join b on a.k1 = b.k1 and a.k2 = b.k2 set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, b.k1 = b.k1 + 1, b.k2 = b.k2 + 2, a.v = 20, b.v = 100",
[]resultChecker{
{
"select * from b",
[]string{"2 3 4 100"},
},
{
"select * from a",
[]string{"2 3 20", "3 4 20"},
},
},
},
{ // left join + update both + match & unmatched + prefix pk
[]string{
"drop table if exists a, b",
"create table a (k1 varchar(100), k2 varchar(100), v varchar(100))",
fmt.Sprintf("create table b (a varchar(100) not null, k1 varchar(100), k2 varchar(100), v varchar(100), primary key(k1(1), k2(1)) %s, key kk1(k1(1), v(1)))", clustered),
},
[]string{
"insert into a values ('11', '11', '11'), ('22', '22', '22')", // unmatched + matched
"insert into b values ('22', '22', '22', '22')",
},
"update a left join b on a.k1 = b.k1 and a.k2 = b.k2 set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, b.k1 = b.k1 + 1, b.k2 = b.k2 + 2, a.v = 20, b.v = 100",
[]resultChecker{
{
"select * from b",
[]string{"22 23 24 100"},
},
{
"select * from a",
[]string{"12 13 20", "23 24 20"},
},
},
},
{ // right join + update both + match & unmatched + prefix pk
[]string{
"drop table if exists a, b",
"create table a (k1 varchar(100), k2 varchar(100), v varchar(100))",
fmt.Sprintf("create table b (a varchar(100) not null, k1 varchar(100), k2 varchar(100), v varchar(100), primary key(k1(1), k2(1)) %s, key kk1(k1(1), v(1)))", clustered),
},
[]string{
"insert into a values ('11', '11', '11'), ('22', '22', '22')", // unmatched + matched
"insert into b values ('22', '22', '22', '22')",
},
"update b right join a on a.k1 = b.k1 and a.k2 = b.k2 set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, b.k1 = b.k1 + 1, b.k2 = b.k2 + 2, a.v = 20, b.v = 100",
[]resultChecker{
{
"select * from b",
[]string{"22 23 24 100"},
},
{
"select * from a",
[]string{"12 13 20", "23 24 20"},
},
},
},
{ // inner join + update both + match & unmatched + prefix pk
[]string{
"drop table if exists a, b",
"create table a (k1 varchar(100), k2 varchar(100), v varchar(100))",
fmt.Sprintf("create table b (a varchar(100) not null, k1 varchar(100), k2 varchar(100), v varchar(100), primary key(k1(1), k2(1)) %s, key kk1(k1(1), v(1)))", clustered),
},
[]string{
"insert into a values ('11', '11', '11'), ('22', '22', '22')", // unmatched + matched
"insert into b values ('22', '22', '22', '22')",
},
"update b join a on a.k1 = b.k1 and a.k2 = b.k2 set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, b.k1 = b.k1 + 1, b.k2 = b.k2 + 2, a.v = 20, b.v = 100",
[]resultChecker{
{
"select * from b",
[]string{"22 23 24 100"},
},
{
"select * from a",
[]string{"11 11 11", "23 24 20"},
},
},
},
{
[]string{
"drop table if exists a, b",
"create table a (k1 varchar(100), k2 varchar(100), v varchar(100))",
fmt.Sprintf("create table b (a varchar(100) not null, k1 varchar(100), k2 varchar(100), v varchar(100), primary key(k1(1), k2(1)) %s, key kk1(k1(1), v(1)))", clustered),
},
[]string{
"insert into a values ('11', '11', '11'), ('22', '22', '22')", // unmatched + matched
"insert into b values ('22', '22', '22', '22')",
},
"update a set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, a.v = 20 where exists (select 1 from b where a.k1 = b.k1 and a.k2 = b.k2)",
[]resultChecker{
{
"select * from b",
[]string{"22 22 22 22"},
},
{
"select * from a",
[]string{"11 11 11", "23 24 20"},
},
},
},
}

for _, test := range tests {
for _, s := range test.initSchema {
tk.MustExec(s)
}
for _, s := range test.initData {
tk.MustExec(s)
}
tk.MustExec(test.dml)
for _, checker := range test.resultCheck {
tk.MustQuery(checker.check).Check(testkit.Rows(checker.assert...))
}
tk.MustExec("admin check table a")
tk.MustExec("admin check table b")
}
}
}

func (s *testSuite6) TestUpdateJoin(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
Loading

0 comments on commit 317e1a1

Please sign in to comment.