Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plan: refactor the code of building Insert. #7068

Merged
merged 9 commits into from
Jul 25, 2018
35 changes: 23 additions & 12 deletions ddl/db_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ func (s *testStateChangeSuite) TestTwoStates(c *C) {
testInfo.sqlInfos[0].sql = "insert into t (c1, c2, c3, c4) value(2, 'b', 'N', '2017-07-02')"
testInfo.sqlInfos[1].sql = "insert into t (c1, c2, c3, d3, c4) value(3, 'b', 'N', 'a', '2017-07-03')"
unknownColErr := errors.New("unknown column d3")
testInfo.sqlInfos[1].cases[0].expectedErr = unknownColErr
testInfo.sqlInfos[1].cases[1].expectedErr = unknownColErr
testInfo.sqlInfos[1].cases[2].expectedErr = unknownColErr
testInfo.sqlInfos[1].cases[3].expectedErr = unknownColErr
testInfo.sqlInfos[1].cases[0].expectedCompileErr = unknownColErr
testInfo.sqlInfos[1].cases[1].expectedCompileErr = unknownColErr
testInfo.sqlInfos[1].cases[2].expectedCompileErr = unknownColErr
testInfo.sqlInfos[1].cases[3].expectedCompileErr = unknownColErr
testInfo.sqlInfos[2].sql = "update t set c2 = 'c2_update'"
testInfo.sqlInfos[3].sql = "replace into t values(5, 'e', 'N', '2017-07-05')'"
testInfo.sqlInfos[3].cases[4].expectedErr = errors.New("Column count doesn't match value count at row 1")
testInfo.sqlInfos[3].cases[4].expectedCompileErr = errors.New("Column count doesn't match value count at row 1")
alterTableSQL := "alter table t add column d3 enum('a', 'b') not null default 'a' after c3"
s.test(c, "", alterTableSQL, testInfo)
// TODO: Add more DDL statements.
Expand Down Expand Up @@ -227,10 +227,11 @@ func (s *testStateChangeSuite) test(c *C, tableName, alterTableSQL string, testI
}

type stateCase struct {
session session.Session
rawStmt ast.StmtNode
stmt ast.Statement
expectedErr error
session session.Session
rawStmt ast.StmtNode
stmt ast.Statement
expectedExecErr error
expectedCompileErr error
}

type sqlInfo struct {
Expand Down Expand Up @@ -299,6 +300,13 @@ func (t *testExecInfo) compileSQL(idx int) (err error) {
return errors.Trace(err)
}
c.stmt, err = compiler.Compile(ctx, c.rawStmt)
if c.expectedCompileErr != nil {
if err == nil {
err = errors.Errorf("expected error %s but got nil", c.expectedCompileErr)
} else if strings.Contains(err.Error(), c.expectedCompileErr.Error()) {
err = nil
}
}
if err != nil {
return errors.Trace(err)
}
Expand All @@ -309,11 +317,14 @@ func (t *testExecInfo) compileSQL(idx int) (err error) {
func (t *testExecInfo) execSQL(idx int) error {
for _, sqlInfo := range t.sqlInfos {
c := sqlInfo.cases[idx]
if c.expectedCompileErr != nil {
continue
}
_, err := c.stmt.Exec(context.TODO())
if c.expectedErr != nil {
if c.expectedExecErr != nil {
if err == nil {
err = errors.Errorf("expected error %s but got nil", c.expectedErr)
} else if strings.Contains(err.Error(), c.expectedErr.Error()) {
err = errors.Errorf("expected error %s but got nil", c.expectedExecErr)
} else if strings.Contains(err.Error(), c.expectedExecErr.Error()) {
err = nil
}
}
Expand Down
1 change: 0 additions & 1 deletion executor/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ var (

ErrPasswordNoMatch = terror.ClassExecutor.New(mysql.ErrPasswordNoMatch, mysql.MySQLErrName[mysql.ErrPasswordNoMatch])
ErrCannotUser = terror.ClassExecutor.New(mysql.ErrCannotUser, mysql.MySQLErrName[mysql.ErrCannotUser])
ErrWrongValueCountOnRow = terror.ClassExecutor.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow])
ErrPasswordFormat = terror.ClassExecutor.New(mysql.ErrPasswordFormat, mysql.MySQLErrName[mysql.ErrPasswordFormat])
ErrCantChangeTxCharacteristics = terror.ClassExecutor.New(mysql.ErrCantChangeTxCharacteristics, mysql.MySQLErrName[mysql.ErrCantChangeTxCharacteristics])
)
Expand Down
9 changes: 6 additions & 3 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ func (s *testSuite) TestMultiUpdate(c *C) {
func (s *testSuite) TestGeneratedColumnWrite(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec(`CREATE TABLE test_gc_write (a int primary key, b int, c int as (a+8) virtual)`)
tk.MustExec(`CREATE TABLE test_gc_write (a int primary key auto_increment, b int, c int as (a+8) virtual)`)
tk.MustExec(`CREATE TABLE test_gc_write_1 (a int primary key, b int, c int)`)

tests := []struct {
Expand All @@ -1336,20 +1336,23 @@ func (s *testSuite) TestGeneratedColumnWrite(c *C) {
// Can insert without generated columns.
{`insert into test_gc_write (a, b) values (1, 1)`, 0},
{`insert into test_gc_write set a = 2, b = 2`, 0},
{`insert into test_gc_write (b) select c from test_gc_write`, 0},
// Can update without generated columns.
{`update test_gc_write set b = 2 where a = 2`, 0},
{`update test_gc_write t1, test_gc_write_1 t2 set t1.b = 3, t2.b = 4`, 0},

// But now we can't do this, just as same with MySQL 5.7:
{`insert into test_gc_write values (1, 1)`, mysql.ErrWrongValueCountOnRow},
{`insert into test_gc_write select 1, 1`, mysql.ErrWrongValueCountOnRow},
{`insert into test_gc_write (c) select a, b from test_gc_write`, mysql.ErrWrongValueCountOnRow},
{`insert into test_gc_write (b, c) select a, b from test_gc_write`, mysql.ErrBadGeneratedColumn},
}
for _, tt := range tests {
_, err := tk.Exec(tt.stmt)
if tt.err != 0 {
c.Assert(err, NotNil)
c.Assert(err, NotNil, Commentf("sql is `%v`", tt.stmt))
terr := errors.Trace(err).(*errors.Err).Cause().(*terror.Error)
c.Assert(terr.Code(), Equals, terror.ErrCode(tt.err))
c.Assert(terr.Code(), Equals, terror.ErrCode(tt.err), Commentf("sql is %v", tt.stmt))
} else {
c.Assert(err, IsNil)
}
Expand Down
36 changes: 0 additions & 36 deletions executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,47 +148,14 @@ func (e *InsertValues) fillValueList() error {
return nil
}

func (e *InsertValues) checkValueCount(insertValueCount, valueCount, genColsCount, num int, cols []*table.Column) error {
// TODO: This check should be done in plan builder.
if insertValueCount != valueCount {
// "insert into t values (), ()" is valid.
// "insert into t values (), (1)" is not valid.
// "insert into t values (1), ()" is not valid.
// "insert into t values (1,2), (1)" is not valid.
// So the value count must be same for all insert list.
return ErrWrongValueCountOnRow.GenByArgs(num + 1)
}
if valueCount == 0 && len(e.Columns) > 0 {
// "insert into t (c1) values ()" is not valid.
return ErrWrongValueCountOnRow.GenByArgs(num + 1)
} else if valueCount > 0 {
explicitSetLen := 0
if len(e.Columns) != 0 {
explicitSetLen = len(e.Columns)
} else {
explicitSetLen = len(e.SetList)
}
if explicitSetLen > 0 && valueCount+genColsCount != len(cols) {
return ErrWrongValueCountOnRow.GenByArgs(num + 1)
} else if explicitSetLen == 0 && valueCount != len(cols) {
return ErrWrongValueCountOnRow.GenByArgs(num + 1)
}
}
return nil
}

func (e *InsertValues) insertRows(cols []*table.Column, exec func(rows []types.DatumRow) error) (err error) {
// process `insert|replace ... set x=y...`
if err = e.fillValueList(); err != nil {
return errors.Trace(err)
}

rows := make([]types.DatumRow, len(e.Lists))
length := len(e.Lists[0])
for i, list := range e.Lists {
if err = e.checkValueCount(length, len(list), len(e.GenColumns), i, cols); err != nil {
return errors.Trace(err)
}
e.rowCount = uint64(i)
rows[i], err = e.getRow(cols, list, i)
if err != nil {
Expand Down Expand Up @@ -277,9 +244,6 @@ func (e *InsertValues) fillDefaultValues(row types.DatumRow, hasValue []bool) er
func (e *InsertValues) insertRowsFromSelect(ctx context.Context, cols []*table.Column, exec func(rows []types.DatumRow) error) error {
// process `insert|replace into ... select ... from ...`
selectExec := e.children[0]
if selectExec.Schema().Len() != len(cols) {
return ErrWrongValueCountOnRow.GenByArgs(1)
}
fields := selectExec.retTypes()
chk := selectExec.newChunk()
iter := chunk.NewIterator4Chunk(chk)
Expand Down
21 changes: 21 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ func (s *testSuite) TestInsert(c *C) {
c.Assert(err, NotNil)
tk.MustExec("rollback")

errInsertSelectSQL = `insert insert_test_1 values(default, default, default, default, default)`
tk.MustExec("begin")
_, err = tk.Exec(errInsertSelectSQL)
c.Assert(err, NotNil)
tk.MustExec("rollback")

// Updating column is PK handle.
// Make sure the record is "1, 1, nil, 1".
r := tk.MustQuery("select * from insert_test where id = 1;")
Expand Down Expand Up @@ -240,6 +246,21 @@ func (s *testSuite) TestInsert(c *C) {
Check(testkit.Rows("Warning 1690 constant -1.1 overflows float", "Warning 1690 constant -1.1 overflows double",
"Warning 1690 constant -2.1 overflows float", "Warning 1690 constant -2.1 overflows double"))
tk.MustQuery("select * from t").Check(testkit.Rows("0 0", "0 0", "0 0", "1.1 1.1"))

// issue 7061
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int default 1, b int default 2)")
tk.MustExec("insert into t values(default, default)")
tk.MustQuery("select * from t").Check(testkit.Rows("1 2"))
tk.MustExec("truncate table t")
tk.MustExec("insert into t values(default(b), default(a))")
tk.MustQuery("select * from t").Check(testkit.Rows("2 1"))
tk.MustExec("truncate table t")
tk.MustExec("insert into t (b) values(default)")
tk.MustQuery("select * from t").Check(testkit.Rows("1 2"))
tk.MustExec("truncate table t")
tk.MustExec("insert into t (b) values(default(a))")
tk.MustQuery("select * from t").Check(testkit.Rows("1 1"))
}

func (s *testSuite) TestInsertAutoInc(c *C) {
Expand Down
3 changes: 3 additions & 0 deletions plan/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const (
codeMixOfGroupFuncAndFields = mysql.ErrMixOfGroupFuncAndFields
codeNonUniqTable = mysql.ErrNonuniqTable
codeWrongNumberOfColumnsInSelect = mysql.ErrWrongNumberOfColumnsInSelect
codeWrongValueCountOnRow = mysql.ErrWrongValueCountOnRow
)

// error definitions.
Expand Down Expand Up @@ -81,6 +82,7 @@ var (
ErrInternal = terror.ClassOptimizer.New(codeInternal, mysql.MySQLErrName[mysql.ErrInternal])
ErrMixOfGroupFuncAndFields = terror.ClassOptimizer.New(codeMixOfGroupFuncAndFields, "In aggregated query without GROUP BY, expression #%d of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by")
ErrNonUniqTable = terror.ClassOptimizer.New(codeNonUniqTable, mysql.MySQLErrName[mysql.ErrNonuniqTable])
ErrWrongValueCountOnRow = terror.ClassOptimizer.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow])
)

func init() {
Expand All @@ -107,6 +109,7 @@ func init() {
codeMixOfGroupFuncAndFields: mysql.ErrMixOfGroupFuncAndFields,
codeNonUniqTable: mysql.ErrNonuniqTable,
codeWrongNumberOfColumnsInSelect: mysql.ErrWrongNumberOfColumnsInSelect,
codeWrongValueCountOnRow: mysql.ErrWrongValueCountOnRow,
}
terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mysqlErrCodeMap
}
2 changes: 1 addition & 1 deletion plan/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ func (s *testPlanSuite) TestVisitInfo(c *C) {
ans []visitInfo
}{
{
sql: "insert into t values (1)",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the old test result in an error after this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this the length of value_list is not same with the number of table's columns.

sql: "insert into t (a) values (1)",
ans: []visitInfo{
{mysql.InsertPriv, "test", "t", ""},
},
Expand Down
2 changes: 1 addition & 1 deletion plan/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderBasePhysicalPlan(c *C) {
},
// Test simple insert.
{
sql: "insert into t values(0,0,0,0,0,0,0)",
sql: "insert into t (a, b, c, e, f, g) values(0,0,0,0,0,0)",
best: "Insert",
},
// Test dual.
Expand Down
Loading