diff --git a/ast/dml.go b/ast/dml.go index 134dc5edbe9ae..f3cd8006d963b 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -71,6 +71,8 @@ type Join struct { On *OnCondition // Using represents join using clause. Using []*ColumnName + // NaturalJoin represents join is natural join + NaturalJoin bool } // Accept implements Node Accept interface. @@ -248,6 +250,19 @@ const ( SelectLockInShareMode ) +// String implements fmt.Stringer. +func (slt SelectLockType) String() string { + switch slt { + case SelectLockNone: + return "none" + case SelectLockForUpdate: + return "for update" + case SelectLockInShareMode: + return "in share mode" + } + return "unsupported select lock type" +} + // WildCardField is a special type of select field content. type WildCardField struct { node diff --git a/executor/analyze_test.go b/executor/analyze_test.go index 018f1ab9186a8..b77f0e6c7c15f 100644 --- a/executor/analyze_test.go +++ b/executor/analyze_test.go @@ -38,11 +38,11 @@ func (s *testSuite) TestAnalyzeTable(c *C) { tk.MustExec("insert into t1 (a) values (1)") result := tk.MustQuery("explain select * from t1 where t1.a = 1") rowStr := fmt.Sprintf("%s", result.Rows()) - c.Check(rowStr, Equals, "[[IndexScan_7 cop ] [IndexReader_8 root ]]") + c.Check(rowStr, Equals, "[[IndexScan_7 cop table:t1, index:a, range:[1,1], out of order:true] [IndexReader_8 root index:IndexScan_7]]") tk.MustExec("analyze table t1") result = tk.MustQuery("explain select * from t1 where t1.a = 1") rowStr = fmt.Sprintf("%s", result.Rows()) - c.Check(rowStr, Equals, "[[TableScan_4 Selection_5 cop ] [Selection_5 cop eq(test.t1.a, 1)] [TableReader_6 root ]]") + c.Check(rowStr, Equals, "[[TableScan_4 Selection_5 cop table:t1, range:(-inf,+inf), keep order:false] [Selection_5 cop eq(test.t1.a, 1)] [TableReader_6 root data:Selection_5]]") tk.MustExec("drop table if exists t1") tk.MustExec("create table t1 (a int)") @@ -51,7 +51,7 @@ func (s *testSuite) TestAnalyzeTable(c *C) { tk.MustExec("analyze table t1 index ind_a") result = tk.MustQuery("explain select * from t1 where t1.a = 1") rowStr = fmt.Sprintf("%s", result.Rows()) - c.Check(rowStr, Equals, "[[TableScan_4 Selection_5 cop ] [Selection_5 cop eq(test.t1.a, 1)] [TableReader_6 root ]]") + c.Check(rowStr, Equals, "[[TableScan_4 Selection_5 cop table:t1, range:(-inf,+inf), keep order:false] [Selection_5 cop eq(test.t1.a, 1)] [TableReader_6 root data:Selection_5]]") } type recordSet struct { diff --git a/executor/executor_test.go b/executor/executor_test.go index 593c1e14e90f6..753aa088b14ae 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -217,6 +217,68 @@ func (s *testSuite) TestSelectWithoutFrom(c *C) { r.Check(testkit.Rows("string")) } +func (s *testSuite) TestSelectBackslashN(c *C) { + defer func() { + s.cleanEnv(c) + testleak.AfterTest(c)() + }() + tk := testkit.NewTestKit(c, s.store) + + sql := `select \N;` + r := tk.MustQuery(sql) + r.Check(testkit.Rows("")) + rs, err := tk.Exec(sql) + c.Check(err, IsNil) + fields, err := rs.Fields() + c.Check(err, IsNil) + c.Check(len(fields), Equals, 1) + c.Check(fields[0].Column.Name.O, Equals, "NULL") + + sql = `select "\N";` + r = tk.MustQuery(sql) + r.Check(testkit.Rows("N")) + rs, err = tk.Exec(sql) + c.Check(err, IsNil) + fields, err = rs.Fields() + c.Check(err, IsNil) + c.Check(len(fields), Equals, 1) + c.Check(fields[0].Column.Name.O, Equals, `"\N"`) + + tk.MustExec("use test;") + tk.MustExec("create table test (`\\N` int);") + tk.MustExec("insert into test values (1);") + tk.CheckExecResult(1, 0) + sql = "select * from test;" + r = tk.MustQuery(sql) + r.Check(testkit.Rows("1")) + rs, err = tk.Exec(sql) + c.Check(err, IsNil) + fields, err = rs.Fields() + c.Check(err, IsNil) + c.Check(len(fields), Equals, 1) + c.Check(fields[0].Column.Name.O, Equals, `\N`) + + sql = "select \\N from test;" + r = tk.MustQuery(sql) + r.Check(testkit.Rows("")) + rs, err = tk.Exec(sql) + c.Check(err, IsNil) + fields, err = rs.Fields() + c.Check(err, IsNil) + c.Check(len(fields), Equals, 1) + c.Check(fields[0].Column.Name.O, Equals, `NULL`) + + sql = "select `\\N` from test;" + r = tk.MustQuery(sql) + r.Check(testkit.Rows("1")) + rs, err = tk.Exec(sql) + c.Check(err, IsNil) + fields, err = rs.Fields() + c.Check(err, IsNil) + c.Check(len(fields), Equals, 1) + c.Check(fields[0].Column.Name.O, Equals, `\N`) +} + func (s *testSuite) TestSelectLimit(c *C) { tk := testkit.NewTestKit(c, s.store) defer func() { @@ -983,6 +1045,17 @@ func (s *testSuite) TestStringBuiltin(c *C) { result.Check(testutil.RowsWithSep(",", "bar ,bar,,")) result = tk.MustQuery(`select rtrim(' bar '), rtrim('bar'), rtrim(''), rtrim(null)`) result.Check(testutil.RowsWithSep(",", " bar,bar,,")) + + // for trim + result = tk.MustQuery(`select trim(' bar '), trim(leading 'x' from 'xxxbarxxx'), trim(trailing 'xyz' from 'barxxyz'), trim(both 'x' from 'xxxbarxxx')`) + result.Check(testkit.Rows("bar barxxx barx bar")) + result = tk.MustQuery(`select trim(leading from ' bar'), trim('x' from 'xxxbarxxx'), trim('x' from 'bar'), trim('' from ' bar ')`) + result.Check(testutil.RowsWithSep(",", "bar,bar,bar, bar ")) + result = tk.MustQuery(`select trim(''), trim('x' from '')`) + result.Check(testutil.RowsWithSep(",", ",")) + result = tk.MustQuery(`select trim(null from 'bar'), trim('x' from null), trim(null), trim(leading null from 'bar')`) + // FIXME: the result for trim(leading null from 'bar') should be , current is 'bar' + result.Check(testkit.Rows(" bar")) } func (s *testSuite) TestEncryptionBuiltin(c *C) { diff --git a/executor/explain_test.go b/executor/explain_test.go deleted file mode 100644 index 5d8d750061136..0000000000000 --- a/executor/explain_test.go +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor_test - -import ( - . "github.com/pingcap/check" - "github.com/pingcap/tidb/util/testkit" - "github.com/pingcap/tidb/util/testleak" -) - -func (s *testSuite) TestExplain(c *C) { - tk := testkit.NewTestKit(c, s.store) - defer func() { - s.cleanEnv(c) - testleak.AfterTest(c)() - tk.MustExec("set @@session.tidb_opt_insubquery_unfold = 0") - }() - tk.MustExec("use test") - tk.MustExec("drop table if exists t1, t2") - tk.MustExec("create table t1 (c1 int primary key, c2 int, c3 int, index c2 (c2))") - tk.MustExec("create table t2 (c1 int unique, c2 int)") - tk.MustExec("insert into t2 values(1, 0), (2, 1)") - - tests := []struct { - sql string - expect []string - }{ - { - "select * from t1", - []string{ - "TableScan_3 cop ", - "TableReader_4 root ", - }, - }, - { - "select * from t1 order by c2", - []string{ - "IndexScan_13 cop ", - "TableScan_14 cop ", - "IndexLookUp_15 root ", - }, - }, - { - "select * from t2 order by c2", - []string{ - "TableScan_4 cop ", - "TableReader_5 Sort_3 root ", - "Sort_3 root t2.c2:asc", - }, - }, - { - "select * from t1 where t1.c1 > 0", - []string{ - "TableScan_4 cop ", - "TableReader_5 root ", - }, - }, - { - "select t1.c1, t1.c2 from t1 where t1.c2 = 1", - []string{ - "IndexScan_7 cop ", - "IndexReader_8 root ", - }, - }, - { - "select * from t1 left join t2 on t1.c2 = t2.c1 where t1.c1 > 1", - []string{ - "TableScan_22 cop ", - "TableReader_23 IndexJoin_7 root ", - "IndexScan_33 cop ", - "TableScan_34 cop ", - "IndexLookUp_35 IndexJoin_7 root ", - "IndexJoin_7 root outer:TableReader_23, outer key:test.t1.c2, inner key:test.t2.c1", - }, - }, - { - "update t1 set t1.c2 = 2 where t1.c1 = 1", - []string{ - "TableScan_4 cop ", - "TableReader_5 Update_3 root ", - "Update_3 root ", - }, - }, - { - "delete from t1 where t1.c2 = 1", - []string{ - "IndexScan_7 cop ", - "TableScan_8 cop ", - "IndexLookUp_9 Delete_3 root ", - "Delete_3 root ", - }, - }, - { - "select count(b.c2) from t1 a, t2 b where a.c1 = b.c2 group by a.c1", - []string{ - "TableScan_17 HashAgg_16 cop ", - "HashAgg_16 cop type:complete, group by:b.c2, funcs:count(b.c2), firstrow(b.c2)", - "TableReader_21 HashAgg_20 root ", - "HashAgg_20 IndexJoin_9 root type:final, group by:, funcs:count(col_0), firstrow(col_1)", - "TableScan_12 cop ", - "TableReader_31 IndexJoin_9 root ", - "IndexJoin_9 Projection_8 root outer:TableReader_31, outer key:b.c2, inner key:a.c1", - "Projection_8 root cast(join_agg_0)", - }, - }, - { - "select * from t2 order by t2.c2 limit 0, 1", - []string{ - "TableScan_7 TopN_5 cop ", - "TopN_5 cop ", - "TableReader_10 TopN_5 root ", - "TopN_5 root ", - }, - }, - { - "select * from t1 where c1 > 1 and c2 = 1 and c3 < 1", - []string{ - "IndexScan_7 Selection_9 cop ", - "Selection_9 cop gt(test.t1.c1, 1)", - "TableScan_8 Selection_10 cop ", - "Selection_10 cop lt(test.t1.c3, 1)", - "IndexLookUp_11 root ", - }, - }, - { - "select * from t1 where c1 =1 and c2 > 1", - []string{ - "TableScan_4 Selection_5 cop ", - "Selection_5 cop gt(test.t1.c2, 1)", - "TableReader_6 root ", - }, - }, - { - "select sum(t1.c1 in (select c1 from t2)) from t1", - []string{ - "TableScan_11 HashAgg_10 cop ", - "HashAgg_10 cop type:complete, funcs:sum(in(test.t1.c1, 1, 2))", - "TableReader_14 HashAgg_13 root ", - "HashAgg_13 root type:final, funcs:sum(col_0)", - }, - }, - { - "select c1 from t1 where c1 in (select c2 from t2)", - []string{ - "TableScan_11 cop ", - "TableReader_12 root ", - }, - }, - { - "select (select count(1) k from t1 s where s.c1 = t1.c1 having k != 0) from t1", - []string{ - "TableScan_13 cop ", - "TableReader_14 Apply_12 root ", - "TableScan_18 cop ", - "TableReader_19 Selection_4 root ", - "Selection_4 HashAgg_17 root eq(s.c1, test.t1.c1)", - "HashAgg_17 Selection_10 root type:complete, funcs:count(1)", - "Selection_10 Apply_12 root ne(k, 0)", - "Apply_12 Projection_2 root left outer join, small:Selection_10, right:Selection_10", - "Projection_2 root k", - }, - }, - { - "select * from information_schema.columns", - []string{ - "MemTableScan_3 root ", - }, - }, - } - tk.MustExec("set @@session.tidb_opt_insubquery_unfold = 1") - for _, tt := range tests { - result := tk.MustQuery("explain " + tt.sql) - result.Check(testkit.Rows(tt.expect...)) - } -} diff --git a/executor/join_test.go b/executor/join_test.go index 03937228ed130..b9c5eadc0eb73 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -324,6 +324,25 @@ func (s *testSuite) TestUsing(c *C) { tk.MustExec("select * from (t1 join t2 using (a)) join (t3 join t4 using (a)) on (t2.a = t4.a and t1.a = t3.a)") } +func (s *testSuite) TestNaturalJoin(c *C) { + defer func() { + s.cleanEnv(c) + testleak.AfterTest(c)() + }() + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (a int, b int)") + tk.MustExec("create table t2 (a int, c int)") + tk.MustExec("insert t1 values (1, 2), (10, 20)") + tk.MustExec("insert t2 values (1, 3), (100, 200)") + + tk.MustQuery("select * from t1 natural join t2").Check(testkit.Rows("1 2 3")) + tk.MustQuery("select * from t1 natural left join t2 order by a").Check(testkit.Rows("1 2 3", "10 20 ")) + tk.MustQuery("select * from t1 natural right join t2 order by a").Check(testkit.Rows("1 3 2", "100 200 ")) +} + func (s *testSuite) TestMultiJoin(c *C) { defer func() { s.cleanEnv(c) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 6f75331ebc32f..f07d8130d1535 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -104,7 +104,9 @@ var ( _ builtinFunc = &builtinHexStrArgSig{} _ builtinFunc = &builtinHexIntArgSig{} _ builtinFunc = &builtinUnHexSig{} - _ builtinFunc = &builtinTrimSig{} + _ builtinFunc = &builtinTrim1ArgSig{} + _ builtinFunc = &builtinTrim2ArgsSig{} + _ builtinFunc = &builtinTrim3ArgsSig{} _ builtinFunc = &builtinLTrimSig{} _ builtinFunc = &builtinRTrimSig{} _ builtinFunc = &builtinRpadSig{} @@ -996,8 +998,6 @@ func (b *builtinLocateSig) eval(row []types.Datum) (d types.Datum, err error) { return d, nil } -const spaceChars = "\n\t\r " - type hexFunctionClass struct { baseFunctionClass } @@ -1118,75 +1118,149 @@ func (b *builtinUnHexSig) evalString(row []types.Datum) (string, bool, error) { return string(bs), false, nil } +const spaceChars = "\n\t\r " + type trimFunctionClass struct { baseFunctionClass } +// The syntax of trim in mysql is 'TRIM([{BOTH | LEADING | TRAILING} [remstr] FROM] str), TRIM([remstr FROM] str)', +// but we wil convert it into trim(str), trim(str, remstr) and trim(str, remstr, direction) in AST. func (c *trimFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) { - sig := &builtinTrimSig{newBaseBuiltinFunc(args, ctx)} - return sig.setSelf(sig), errors.Trace(c.verifyArgs(args)) + if err := c.verifyArgs(args); err != nil { + return nil, errors.Trace(err) + } + + switch len(args) { + case 1: + bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString) + if err != nil { + return nil, errors.Trace(err) + } + argType := args[0].GetType() + bf.tp.Flen = argType.Flen + if mysql.HasBinaryFlag(argType.Flag) { + types.SetBinChsClnFlag(bf.tp) + } + sig := &builtinTrim1ArgSig{baseStringBuiltinFunc{bf}} + return sig.setSelf(sig), nil + + case 2: + bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpString) + if err != nil { + return nil, errors.Trace(err) + } + argType := args[0].GetType() + bf.tp.Flen = argType.Flen + if mysql.HasBinaryFlag(argType.Flag) { + types.SetBinChsClnFlag(bf.tp) + } + sig := &builtinTrim2ArgsSig{baseStringBuiltinFunc{bf}} + return sig.setSelf(sig), nil + + case 3: + bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpString, tpInt) + if err != nil { + return nil, errors.Trace(err) + } + argType := args[0].GetType() + bf.tp.Flen = argType.Flen + if mysql.HasBinaryFlag(argType.Flag) { + types.SetBinChsClnFlag(bf.tp) + } + sig := &builtinTrim3ArgsSig{baseStringBuiltinFunc{bf}} + return sig.setSelf(sig), nil + + default: + return nil, errors.Trace(c.verifyArgs(args)) + } } -type builtinTrimSig struct { - baseBuiltinFunc +type builtinTrim1ArgSig struct { + baseStringBuiltinFunc } -// eval evals a builtinTrimSig. +// evalString evals a builtinTrim1ArgSig, corresponding to trim(str) // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim -func (b *builtinTrimSig) eval(row []types.Datum) (d types.Datum, err error) { - args, err := b.evalArgs(row) - if err != nil { - return types.Datum{}, errors.Trace(err) +func (b *builtinTrim1ArgSig) evalString(row []types.Datum) (d string, isNull bool, err error) { + d, isNull, err = b.args[0].EvalString(row, b.ctx.GetSessionVars().StmtCtx) + if isNull || err != nil { + return d, isNull, errors.Trace(err) } - // args[0] -> Str - // args[1] -> RemStr - // args[2] -> Direction - // eval str - if args[0].IsNull() { - return d, nil + return strings.Trim(d, spaceChars), false, nil +} + +type builtinTrim2ArgsSig struct { + baseStringBuiltinFunc +} + +// evalString evals a builtinTrim2ArgsSig, corresponding to trim(str, remstr) +// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim +func (b *builtinTrim2ArgsSig) evalString(row []types.Datum) (d string, isNull bool, err error) { + var str, remstr string + + sc := b.ctx.GetSessionVars().StmtCtx + str, isNull, err = b.args[0].EvalString(row, sc) + if isNull || err != nil { + return d, isNull, errors.Trace(err) } - str, err := args[0].ToString() - if err != nil { - return d, errors.Trace(err) + remstr, isNull, err = b.args[1].EvalString(row, sc) + if isNull || err != nil { + return d, isNull, errors.Trace(err) } - remstr := "" - // eval remstr - if len(args) > 1 { - if args[1].Kind() != types.KindNull { - remstr, err = args[1].ToString() - if err != nil { - return d, errors.Trace(err) - } - } + d = trimLeft(str, remstr) + d = trimRight(d, remstr) + return d, false, nil +} + +type builtinTrim3ArgsSig struct { + baseStringBuiltinFunc +} + +// evalString evals a builtinTrim3ArgsSig, corresponding to trim(str, remstr, direction) +// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim +func (b *builtinTrim3ArgsSig) evalString(row []types.Datum) (d string, isNull bool, err error) { + var ( + str, remstr string + x int64 + direction ast.TrimDirectionType + isRemStrNull bool + ) + sc := b.ctx.GetSessionVars().StmtCtx + str, isNull, err = b.args[0].EvalString(row, sc) + if isNull || err != nil { + return d, isNull, errors.Trace(err) } - // do trim - var result string - var direction ast.TrimDirectionType - if len(args) > 2 { - direction = args[2].GetValue().(ast.TrimDirectionType) - } else { - direction = ast.TrimBothDefault + remstr, isRemStrNull, err = b.args[1].EvalString(row, sc) + if err != nil { + return d, isNull, errors.Trace(err) } + x, isNull, err = b.args[2].EvalInt(row, sc) + if isNull || err != nil { + return d, isNull, errors.Trace(err) + } + direction = ast.TrimDirectionType(x) if direction == ast.TrimLeading { - if len(remstr) > 0 { - result = trimLeft(str, remstr) + if isRemStrNull { + d = strings.TrimLeft(str, spaceChars) } else { - result = strings.TrimLeft(str, spaceChars) + d = trimLeft(str, remstr) } } else if direction == ast.TrimTrailing { - if len(remstr) > 0 { - result = trimRight(str, remstr) + if isRemStrNull { + d = strings.TrimRight(str, spaceChars) } else { - result = strings.TrimRight(str, spaceChars) + d = trimRight(str, remstr) } - } else if len(remstr) > 0 { - x := trimLeft(str, remstr) - result = trimRight(x, remstr) } else { - result = strings.Trim(str, spaceChars) + if isRemStrNull { + d = strings.Trim(str, spaceChars) + } else { + d = trimLeft(str, remstr) + d = trimRight(d, remstr) + } } - d.SetString(result) - return d, nil + return d, false, nil } type lTrimFunctionClass struct { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 66563420ef08f..51f06552a7533 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -867,51 +867,55 @@ func (s *testEvaluatorSuite) TestLocate(c *C) { func (s *testEvaluatorSuite) TestTrim(c *C) { defer testleak.AfterTest(c)() - tbl := []struct { - str interface{} - remstr interface{} - dir ast.TrimDirectionType - result interface{} + cases := []struct { + args []interface{} + isNil bool + getErr bool + res string }{ - {" bar ", nil, ast.TrimBothDefault, "bar"}, - {"xxxbarxxx", "x", ast.TrimLeading, "barxxx"}, - {"xxxbarxxx", "x", ast.TrimBoth, "bar"}, - {"barxxyz", "xyz", ast.TrimTrailing, "barx"}, - {nil, "xyz", ast.TrimBoth, nil}, - {1, 2, ast.TrimBoth, "1"}, - {" \t\rbar\n ", nil, ast.TrimBothDefault, "bar"}, + {[]interface{}{" bar "}, false, false, "bar"}, + {[]interface{}{""}, false, false, ""}, + {[]interface{}{nil}, true, false, ""}, + {[]interface{}{"xxxbarxxx", "x"}, false, false, "bar"}, + {[]interface{}{"bar", "x"}, false, false, "bar"}, + {[]interface{}{" bar ", ""}, false, false, " bar "}, + {[]interface{}{"", "x"}, false, false, ""}, + {[]interface{}{"bar", nil}, true, false, ""}, + {[]interface{}{nil, "x"}, true, false, ""}, + {[]interface{}{"xxxbarxxx", "x", int(ast.TrimLeading)}, false, false, "barxxx"}, + {[]interface{}{"barxxyz", "xyz", int(ast.TrimTrailing)}, false, false, "barx"}, + {[]interface{}{"xxxbarxxx", "x", int(ast.TrimBoth)}, false, false, "bar"}, + // FIXME: the result for this test shuold be nil, current is "bar" + {[]interface{}{"bar", nil, int(ast.TrimLeading)}, false, false, "bar"}, + {[]interface{}{errors.New("must error")}, false, true, ""}, } - for _, v := range tbl { - fc := funcs[ast.Trim] - f, err := fc.getFunction(datumsToConstants(types.MakeDatums(v.str, v.remstr, v.dir)), s.ctx) - c.Assert(err, IsNil) - r, err := f.eval(nil) + for _, t := range cases { + f, err := newFunctionForTest(s.ctx, ast.Trim, primitiveValsToConstants(t.args)...) c.Assert(err, IsNil) - c.Assert(r, testutil.DatumEquals, types.NewDatum(v.result)) + d, err := f.Eval(nil) + if t.getErr { + c.Assert(err, NotNil) + } else { + c.Assert(err, IsNil) + if t.isNil { + c.Assert(d.Kind(), Equals, types.KindNull) + } else { + c.Assert(d.GetString(), Equals, t.res) + } + } } - for _, v := range []struct { - str, result interface{} - fn string - }{ - {" ", "", ast.LTrim}, - {" ", "", ast.RTrim}, - {"foo0", "foo0", ast.LTrim}, - {"bar0", "bar0", ast.RTrim}, - {" foo1", "foo1", ast.LTrim}, - {"bar1 ", "bar1", ast.RTrim}, - {spaceChars + "foo2 ", "foo2 ", ast.LTrim}, - {" bar2" + spaceChars, " bar2", ast.RTrim}, - {nil, nil, ast.LTrim}, - {nil, nil, ast.RTrim}, - } { - fc := funcs[v.fn] - f, err := fc.getFunction(datumsToConstants(types.MakeDatums(v.str)), s.ctx) - c.Assert(err, IsNil) - r, err := f.eval(nil) - c.Assert(err, IsNil) - c.Assert(r, testutil.DatumEquals, types.NewDatum(v.result)) - } + f, err := funcs[ast.Trim].getFunction([]Expression{Zero}, s.ctx) + c.Assert(err, IsNil) + c.Assert(f.isDeterministic(), IsTrue) + + f, err = funcs[ast.Trim].getFunction([]Expression{Zero, Zero}, s.ctx) + c.Assert(err, IsNil) + c.Assert(f.isDeterministic(), IsTrue) + + f, err = funcs[ast.Trim].getFunction([]Expression{Zero, Zero, Zero}, s.ctx) + c.Assert(err, IsNil) + c.Assert(f.isDeterministic(), IsTrue) } func (s *testEvaluatorSuite) TestLTrim(c *C) { diff --git a/expression/explain.go b/expression/explain.go index 764670ca95b9e..58716562bb235 100644 --- a/expression/explain.go +++ b/expression/explain.go @@ -16,6 +16,8 @@ package expression import ( "bytes" "fmt" + + "github.com/pingcap/tidb/util/types" ) // ExplainInfo implements the Expression interface. @@ -40,7 +42,11 @@ func (expr *Column) ExplainInfo() string { func (expr *Constant) ExplainInfo() string { valStr, err := expr.Value.ToString() if err != nil { - valStr = "not recognized const value" + if expr.Value.Kind() == types.KindNull { + valStr = "null" + } else { + valStr = "not recognized const value" + } } return valStr } diff --git a/parser/lexer_test.go b/parser/lexer_test.go index d7917a6506b60..c6dd0cea3f467 100644 --- a/parser/lexer_test.go +++ b/parser/lexer_test.go @@ -120,6 +120,7 @@ func (s *testLexerSuite) TestLiteral(c *C) { {".1_t_1_x", int('.')}, {"N'some text'", underscoreCS}, {"n'some text'", underscoreCS}, + {"\\N", null}, } runTest(c, table) } diff --git a/parser/misc.go b/parser/misc.go index e341f9eb08aec..57db72a8668ad 100644 --- a/parser/misc.go +++ b/parser/misc.go @@ -119,6 +119,7 @@ func init() { initTokenString("<>", neqSynonym) initTokenString("<<", lsh) initTokenString(">>", rsh) + initTokenString("\\N", null) initTokenFunc("@", startWithAt) initTokenFunc("/", startWithSlash) @@ -614,6 +615,7 @@ var tokenMap = map[string]int{ "UUID": uuid, "UUID_SHORT": uuidShort, "KILL": kill, + "NATURAL": natural, } func isTokenIdentifier(s string, buf *bytes.Buffer) int { diff --git a/parser/parser.y b/parser/parser.y index 8c503ac0ddfba..3cff9dadce816 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -230,6 +230,7 @@ import ( xor "XOR" yearMonth "YEAR_MONTH" zerofill "ZEROFILL" + natural "NATURAL" /* the following tokens belong to NotKeywordToken*/ abs "ABS" @@ -894,7 +895,7 @@ import ( %precedence lowerThanKey %precedence key -%left join inner cross left right full +%left join inner cross left right full natural /* A dummy token to force the priority of TableRef production in a join. */ %left tableRefPriority %precedence lowerThanOn @@ -2434,7 +2435,7 @@ ReservedKeyword: | "STARTING" | "TABLE" | "STORED" | "TERMINATED" | "THEN" | "TINYBLOB" | "TINYINT" | "TINYTEXT" | "TO" | "TRAILING" | "TRIGGER" | "TRUE" | "UNION" | "UNIQUE" | "UNLOCK" | "UNSIGNED" | "UPDATE" | "USE" | "USING" | "UTC_DATE" | "UTC_TIMESTAMP" | "VALUES" | "VARBINARY" | "VARCHAR" | "VIRTUAL" -| "WHEN" | "WHERE" | "WRITE" | "XOR" | "YEAR_MONTH" | "ZEROFILL" +| "WHEN" | "WHERE" | "WRITE" | "XOR" | "YEAR_MONTH" | "ZEROFILL" | "NATURAL" /* | "DELAYED" | "HIGH_PRIORITY" | "LOW_PRIORITY"| "WITH" */ @@ -3523,7 +3524,7 @@ FunctionCallNonKeyword: | "TRIM" '(' TrimDirection "FROM" Expression ')' { nilVal := ast.NewValueExpr(nil) - direction := ast.NewValueExpr($3) + direction := ast.NewValueExpr(int($3.(ast.TrimDirectionType))) $$ = &ast.FuncCallExpr{ FnName: model.NewCIStr($1), Args: []ast.ExprNode{$5.(ast.ExprNode), nilVal, direction}, @@ -3531,7 +3532,7 @@ FunctionCallNonKeyword: } | "TRIM" '(' TrimDirection Expression "FROM" Expression ')' { - direction := ast.NewValueExpr($3) + direction := ast.NewValueExpr(int($3.(ast.TrimDirectionType))) $$ = &ast.FuncCallExpr{ FnName: model.NewCIStr($1), Args: []ast.ExprNode{$6.(ast.ExprNode),$4.(ast.ExprNode), direction}, @@ -4655,6 +4656,14 @@ JoinTable: { $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), Using: $8.([]*ast.ColumnName)} } +| TableRef "NATURAL" "JOIN" TableRef + { + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $4.(ast.ResultSetNode), NaturalJoin: true} + } +| TableRef "NATURAL" JoinType OuterOpt "JOIN" TableRef + { + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $6.(ast.ResultSetNode), Tp: $3.(ast.JoinType), NaturalJoin: true} + } JoinType: "LEFT" diff --git a/parser/parser_test.go b/parser/parser_test.go index f33954883ceff..14217515e9f53 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -312,6 +312,11 @@ func (s *testParserSuite) TestDMLStmt(c *C) { {"select * from t1 join t2 left join t3 using (id)", true}, {"select * from t1 right join t2 using (id) left join t3 using (id)", true}, {"select * from t1 right join t2 using (id) left join t3", false}, + {"select * from t1 natural join t2", true}, + {"select * from t1 natural right join t2", true}, + {"select * from t1 natural left outer join t2", true}, + {"select * from t1 natural inner join t2", false}, + {"select * from t1 natural cross join t2", false}, // for admin {"admin show ddl;", true}, diff --git a/plan/explain.go b/plan/explain.go index 8baa31b9b6d17..9fa530df96783 100644 --- a/plan/explain.go +++ b/plan/explain.go @@ -56,6 +56,85 @@ func setParents4FinalPlan(plan PhysicalPlan) { } } +// ExplainInfo implements PhysicalPlan interface. +func (p *SelectLock) ExplainInfo() string { + return p.Lock.String() +} + +// ExplainInfo implements PhysicalPlan interface. +func (p *PhysicalIndexScan) ExplainInfo() string { + buffer := bytes.NewBufferString("") + tblName := p.Table.Name.O + if p.TableAsName != nil && p.TableAsName.O != "" { + tblName = p.TableAsName.O + } + buffer.WriteString(fmt.Sprintf("table:%s", tblName)) + if len(p.Index.Columns) > 0 { + buffer.WriteString(", index:") + for i, idxCol := range p.Index.Columns { + buffer.WriteString(idxCol.Name.O) + if i+1 < len(p.Index.Columns) { + buffer.WriteString(", ") + } + } + } + if len(p.Ranges) > 0 { + buffer.WriteString(", range:") + for i, idxRange := range p.Ranges { + buffer.WriteString(idxRange.String()) + if i+1 < len(p.Ranges) { + buffer.WriteString(", ") + } + } + } + buffer.WriteString(fmt.Sprintf(", out of order:%v", p.OutOfOrder)) + return buffer.String() +} + +// ExplainInfo implements PhysicalPlan interface. +func (p *PhysicalTableScan) ExplainInfo() string { + buffer := bytes.NewBufferString("") + tblName := p.Table.Name.O + if p.TableAsName != nil && p.TableAsName.O != "" { + tblName = p.TableAsName.O + } + buffer.WriteString(fmt.Sprintf("table:%s", tblName)) + if p.pkCol != nil { + buffer.WriteString(fmt.Sprintf(", pk col:%s", p.pkCol.ExplainInfo())) + } + if len(p.Ranges) > 0 { + buffer.WriteString(", range:") + for i, idxRange := range p.Ranges { + buffer.WriteString(idxRange.String()) + if i+1 < len(p.Ranges) { + buffer.WriteString(", ") + } + } + } + buffer.WriteString(fmt.Sprintf(", keep order:%v", p.KeepOrder)) + return buffer.String() +} + +// ExplainInfo implements PhysicalPlan interface. +func (p *PhysicalTableReader) ExplainInfo() string { + return fmt.Sprintf("data:%s", p.tablePlan.ID()) +} + +// ExplainInfo implements PhysicalPlan interface. +func (p *PhysicalIndexReader) ExplainInfo() string { + return fmt.Sprintf("index:%s", p.indexPlan.ID()) +} + +// ExplainInfo implements PhysicalPlan interface. +func (p *PhysicalIndexLookUpReader) ExplainInfo() string { + return fmt.Sprintf("index:%s, table:%s", p.indexPlan.ID(), p.tablePlan.ID()) +} + +// ExplainInfo implements PhysicalPlan interface. +func (p *PhysicalUnionScan) ExplainInfo() string { + return string(expression.ExplainExpressionList(p.Conditions)) +} + // ExplainInfo implements PhysicalPlan interface. func (p *Selection) ExplainInfo() string { return string(expression.ExplainExpressionList(p.Conditions)) @@ -208,16 +287,16 @@ func (p *PhysicalMergeJoin) ExplainInfo() string { expression.ExplainExpressionList(p.OtherConditions))) } if p.Desc { - buffer.WriteString("desc") + buffer.WriteString(", desc") } else { - buffer.WriteString("asc") + buffer.WriteString(", asc") } if len(p.leftKeys) > 0 { - buffer.WriteString(fmt.Sprintf("left key:%s", + buffer.WriteString(fmt.Sprintf(", left key:%s", expression.ExplainColumnList(p.leftKeys))) } if len(p.rightKeys) > 0 { - buffer.WriteString(fmt.Sprintf("right key:%s", + buffer.WriteString(fmt.Sprintf(", right key:%s", expression.ExplainColumnList(p.rightKeys))) } return buffer.String() diff --git a/plan/explain_test.go b/plan/explain_test.go new file mode 100644 index 0000000000000..eee228c926d7b --- /dev/null +++ b/plan/explain_test.go @@ -0,0 +1,207 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" +) + +var _ = Suite(&testExplainSuite{}) + +type testExplainSuite struct { +} + +func (s *testExplainSuite) TestExplain(c *C) { + store, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + testleak.AfterTest(c)() + tk.MustExec("set @@session.tidb_opt_insubquery_unfold = 0") + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2, t3") + tk.MustExec("create table t1 (c1 int primary key, c2 int, c3 int, index c2 (c2))") + tk.MustExec("create table t2 (c1 int unique, c2 int)") + tk.MustExec("insert into t2 values(1, 0), (2, 1)") + tk.MustExec("create table t3 (a bigint, b bigint, c bigint, d bigint)") + + tests := []struct { + sql string + expect []string + }{ + { + "select * from t3 where exists (select s.a from t3 s having sum(s.a) = t3.a )", + []string{ + "TableScan_15 cop table:t3, range:(-inf,+inf), keep order:false", + "TableReader_16 Projection_12 root data:TableScan_15", + "Projection_12 HashSemiJoin_14 root test.t3.a, test.t3.b, test.t3.c, test.t3.d, cast(test.t3.a)", + "TableScan_18 HashAgg_17 cop table:s, range:(-inf,+inf), keep order:false", + "HashAgg_17 cop type:complete, funcs:sum(s.a)", + "TableReader_20 HashAgg_19 root data:HashAgg_17", + "HashAgg_19 HashSemiJoin_14 root type:final, funcs:sum(col_0)", + "HashSemiJoin_14 Projection_11 root right:HashAgg_19, equal:[eq(cast(test.t3.a), sel_agg_1)]", + "Projection_11 root test.t3.a, test.t3.b, test.t3.c, test.t3.d", + }, + }, + { + "select * from t1", + []string{ + "TableScan_3 cop table:t1, range:(-inf,+inf), keep order:false", + "TableReader_4 root data:TableScan_3", + }, + }, + { + "select * from t1 order by c2", + []string{ + "IndexScan_13 cop table:t1, index:c2, range:[,+inf], out of order:false", + "TableScan_14 cop table:t1, keep order:false", + "IndexLookUp_15 root index:IndexScan_13, table:TableScan_14", + }, + }, + { + "select * from t2 order by c2", + []string{ + "TableScan_4 cop table:t2, range:(-inf,+inf), keep order:false", + "TableReader_5 Sort_3 root data:TableScan_4", + "Sort_3 root t2.c2:asc", + }, + }, + { + "select * from t1 where t1.c1 > 0", + []string{ + "TableScan_4 cop table:t1, range:[1,+inf), keep order:false", + "TableReader_5 root data:TableScan_4", + }, + }, + { + "select t1.c1, t1.c2 from t1 where t1.c2 = 1", + []string{ + "IndexScan_7 cop table:t1, index:c2, range:[1,1], out of order:true", + "IndexReader_8 root index:IndexScan_7", + }, + }, + { + "select * from t1 left join t2 on t1.c2 = t2.c1 where t1.c1 > 1", + []string{ + "TableScan_22 cop table:t1, range:[2,+inf), keep order:false", + "TableReader_23 IndexJoin_7 root data:TableScan_22", + "IndexScan_33 cop table:t2, index:c1, range:[,+inf], out of order:false", + "TableScan_34 cop table:t2, keep order:false", + "IndexLookUp_35 IndexJoin_7 root index:IndexScan_33, table:TableScan_34", + "IndexJoin_7 root outer:TableReader_23, outer key:test.t1.c2, inner key:test.t2.c1", + }, + }, + { + "update t1 set t1.c2 = 2 where t1.c1 = 1", + []string{ + "TableScan_4 cop table:t1, range:[1,1], keep order:false", + "TableReader_5 Update_3 root data:TableScan_4", + "Update_3 root ", + }, + }, + { + "delete from t1 where t1.c2 = 1", + []string{ + "IndexScan_7 cop table:t1, index:c2, range:[1,1], out of order:true", + "TableScan_8 cop table:t1, keep order:false", + "IndexLookUp_9 Delete_3 root index:IndexScan_7, table:TableScan_8", + "Delete_3 root ", + }, + }, + { + "select count(b.c2) from t1 a, t2 b where a.c1 = b.c2 group by a.c1", + []string{ + "TableScan_17 HashAgg_16 cop table:b, range:(-inf,+inf), keep order:false", + "HashAgg_16 cop type:complete, group by:b.c2, funcs:count(b.c2), firstrow(b.c2)", + "TableReader_21 HashAgg_20 root data:HashAgg_16", + "HashAgg_20 IndexJoin_9 root type:final, group by:, funcs:count(col_0), firstrow(col_1)", + "TableScan_12 cop table:a, range:(-inf,+inf), keep order:true", + "TableReader_31 IndexJoin_9 root data:TableScan_12", + "IndexJoin_9 Projection_8 root outer:TableReader_31, outer key:b.c2, inner key:a.c1", + "Projection_8 root cast(join_agg_0)", + }, + }, + { + "select * from t2 order by t2.c2 limit 0, 1", + []string{ + "TableScan_7 TopN_5 cop table:t2, range:(-inf,+inf), keep order:false", + "TopN_5 cop ", + "TableReader_10 TopN_5 root data:TopN_5", + "TopN_5 root ", + }, + }, + { + "select * from t1 where c1 > 1 and c2 = 1 and c3 < 1", + []string{ + "IndexScan_7 Selection_9 cop table:t1, index:c2, range:[1,1], out of order:true", + "Selection_9 cop gt(test.t1.c1, 1)", + "TableScan_8 Selection_10 cop table:t1, keep order:false", + "Selection_10 cop lt(test.t1.c3, 1)", + "IndexLookUp_11 root index:Selection_9, table:Selection_10", + }, + }, + { + "select * from t1 where c1 =1 and c2 > 1", + []string{ + "TableScan_4 Selection_5 cop table:t1, range:[1,1], keep order:false", + "Selection_5 cop gt(test.t1.c2, 1)", + "TableReader_6 root data:Selection_5", + }, + }, + { + "select sum(t1.c1 in (select c1 from t2)) from t1", + []string{ + "TableScan_11 HashAgg_10 cop table:t1, range:(-inf,+inf), keep order:false", + "HashAgg_10 cop type:complete, funcs:sum(in(test.t1.c1, 1, 2))", + "TableReader_14 HashAgg_13 root data:HashAgg_10", + "HashAgg_13 root type:final, funcs:sum(col_0)", + }, + }, + { + "select c1 from t1 where c1 in (select c2 from t2)", + []string{ + "TableScan_11 cop table:t1, range:[0,0], [1,1], keep order:false", + "TableReader_12 root data:TableScan_11", + }, + }, + { + "select (select count(1) k from t1 s where s.c1 = t1.c1 having k != 0) from t1", + []string{ + "TableScan_13 cop table:t1, range:(-inf,+inf), keep order:false", + "TableReader_14 Apply_12 root data:TableScan_13", + "TableScan_18 cop table:s, range:(-inf,+inf), keep order:false", + "TableReader_19 Selection_4 root data:TableScan_18", + "Selection_4 HashAgg_17 root eq(s.c1, test.t1.c1)", + "HashAgg_17 Selection_10 root type:complete, funcs:count(1)", + "Selection_10 Apply_12 root ne(k, 0)", + "Apply_12 Projection_2 root left outer join, small:Selection_10, right:Selection_10", + "Projection_2 root k", + }, + }, + { + "select * from information_schema.columns", + []string{ + "MemTableScan_3 root ", + }, + }, + } + tk.MustExec("set @@session.tidb_opt_insubquery_unfold = 1") + for _, tt := range tests { + result := tk.MustQuery("explain " + tt.sql) + result.Check(testkit.Rows(tt.expect...)) + } +} diff --git a/plan/logical_plan_builder.go b/plan/logical_plan_builder.go index 7ea3528e2b2d7..2fcddca8a65d6 100644 --- a/plan/logical_plan_builder.go +++ b/plan/logical_plan_builder.go @@ -15,7 +15,6 @@ package plan import ( "fmt" - "sort" "github.com/cznic/mathutil" "github.com/juju/errors" @@ -257,7 +256,12 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan { } } - if join.Using != nil { + if join.NaturalJoin { + if err := b.buildNaturalJoin(joinPlan, leftPlan, rightPlan, join); err != nil { + b.err = err + return nil + } + } else if join.Using != nil { if err := b.buildUsingClause(joinPlan, leftPlan, rightPlan, join); err != nil { b.err = err return nil @@ -295,72 +299,87 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan { // Second, columns unique to the first table, in order in which they occur in that table. // Third, columns unique to the second table, in order in which they occur in that table. func (b *planBuilder) buildUsingClause(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, join *ast.Join) error { + filter := make(map[string]bool, len(join.Using)) + for _, col := range join.Using { + filter[col.Name.L] = true + } + return b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp == ast.RightJoin, filter) +} + +// buildNaturalJoin build natural join output schema. It find out all the common columns +// then using the same mechanism as buildUsingClause to eliminate redundant columns and build join conditions. +// According to standard SQL, producing this display order: +// All the common columns +// Every column in the first (left) table that is not a common column +// Every column in the second (right) table that is not a common column +func (b *planBuilder) buildNaturalJoin(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, join *ast.Join) error { + return b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp == ast.RightJoin, nil) +} + +// coalesceCommonColumns is used by buildUsingClause and buildNaturalJoin. The filter is used by buildUsingClause. +func (b *planBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, rightJoin bool, filter map[string]bool) error { lsc := leftPlan.Schema().Clone() rsc := rightPlan.Schema().Clone() + lColumns, rColumns := lsc.Columns, rsc.Columns + if rightJoin { + lColumns, rColumns = rsc.Columns, lsc.Columns + } - schemaCols := make([]*expression.Column, 0, len(lsc.Columns)+len(rsc.Columns)-len(join.Using)) - redundantCols := make([]*expression.Column, 0, len(join.Using)) - conds := make([]*expression.ScalarFunction, 0, len(join.Using)) + // Find out all the common columns and put them ahead. + commonLen := 0 + for i, lCol := range lColumns { + for j := commonLen; j < len(rColumns); j++ { + if lCol.ColName.L != rColumns[j].ColName.L { + continue + } - redundant := make(map[string]bool, len(join.Using)) - for _, col := range join.Using { - var ( - err error - lc, rc *expression.Column - cond expression.Expression - ) + if len(filter) > 0 { + if !filter[lCol.ColName.L] { + break + } + // Mark this column exist. + filter[lCol.ColName.L] = false + } - if lc, err = lsc.FindColumn(col); err != nil { - return errors.Trace(err) - } - if rc, err = rsc.FindColumn(col); err != nil { - return errors.Trace(err) - } - redundant[col.Name.L] = true - if lc == nil || rc == nil { - // Same as MySQL. - return ErrUnknownColumn.GenByArgs(col.Name, "from clause") - } + col := rColumns[i] + copy(rColumns[commonLen+1:i+1], rColumns[commonLen:i]) + rColumns[commonLen] = col - if cond, err = expression.NewFunction(b.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc); err != nil { - return errors.Trace(err) - } - conds = append(conds, cond.(*expression.ScalarFunction)) + col = lColumns[j] + copy(lColumns[commonLen+1:j+1], lColumns[commonLen:j]) + lColumns[commonLen] = col - if join.Tp == ast.RightJoin { - schemaCols = append(schemaCols, rc) - redundantCols = append(redundantCols, lc) - } else { - schemaCols = append(schemaCols, lc) - redundantCols = append(redundantCols, rc) + commonLen++ + break } } - // Columns in using clause may not ordered in the order in which they occur in the first table, so reorder them. - sort.Slice(schemaCols, func(i, j int) bool { - return schemaCols[i].Position < schemaCols[j].Position - }) - - if join.Tp == ast.RightJoin { - lsc, rsc = rsc, lsc - } - for _, col := range lsc.Columns { - if !redundant[col.ColName.L] { - schemaCols = append(schemaCols, col) + if len(filter) > 0 && len(filter) != commonLen { + for col, notExist := range filter { + if notExist { + return ErrUnknownColumn.GenByArgs(col, "from clause") + } } } - for _, col := range rsc.Columns { - if !redundant[col.ColName.L] { - schemaCols = append(schemaCols, col) + + schemaCols := make([]*expression.Column, len(lColumns)+len(rColumns)-commonLen) + copy(schemaCols[:len(lColumns)], lColumns) + copy(schemaCols[len(lColumns):], rColumns[commonLen:]) + + conds := make([]*expression.ScalarFunction, 0, commonLen) + for i := 0; i < commonLen; i++ { + lc, rc := lsc.Columns[i], rsc.Columns[i] + cond, err := expression.NewFunction(b.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc) + if err != nil { + return errors.Trace(err) } + conds = append(conds, cond.(*expression.ScalarFunction)) } p.SetSchema(expression.NewSchema(schemaCols...)) + p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(rColumns[:commonLen]...)) p.EqualConditions = append(conds, p.EqualConditions...) - // p.redundantSchema may contains columns which are merged from sub join, so merge it with redundantCols. - p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(redundantCols...)) - return nil } @@ -426,8 +445,14 @@ func (b *planBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, if _, ok := innerExpr.(*ast.ValueExpr); ok && innerExpr.Text() != "" { colName = model.NewCIStr(innerExpr.Text()) } else { + //Change column name \N to NULL, just when original sql contains \N column + fieldText := field.Text() + if fieldText == "\\N" { + fieldText = "NULL" + } + // Remove special comment code for field part, see issue #3739 for detail. - colName = model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment)) + colName = model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(fieldText, parser.TrimComment)) } } } diff --git a/plan/physical_plans.go b/plan/physical_plans.go index f30139baecbcb..4f1f12f4de72a 100644 --- a/plan/physical_plans.go +++ b/plan/physical_plans.go @@ -58,6 +58,9 @@ var ( _ PhysicalPlan = &Insert{} _ PhysicalPlan = &PhysicalIndexScan{} _ PhysicalPlan = &PhysicalTableScan{} + _ PhysicalPlan = &PhysicalTableReader{} + _ PhysicalPlan = &PhysicalIndexReader{} + _ PhysicalPlan = &PhysicalIndexLookUpReader{} _ PhysicalPlan = &PhysicalAggregation{} _ PhysicalPlan = &PhysicalApply{} _ PhysicalPlan = &PhysicalIndexJoin{} diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index 7f20e35eb5103..769f34a20e072 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -116,6 +116,8 @@ func (s *testPlanSuite) TestInferType(c *C) { {"ltrim(c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, {"rtrim(c_char)", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength}, {"rtrim(c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, + {"trim(c_char)", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength}, + {"trim(c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, {"cot(c_int)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"cot(c_float)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},