Skip to content

Commit

Permalink
*: support natural join. (#3861)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zejun Li authored and hanfei1991 committed Aug 1, 2017
1 parent 26f622d commit 7874e98
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 52 deletions.
2 changes: 2 additions & 0 deletions ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ type Join struct {
On *OnCondition
// Using represents join using clause.
Using []*ColumnName
// NaturalJoin represents join is natural join
NaturalJoin bool
}

// Accept implements Node Accept interface.
Expand Down
19 changes: 19 additions & 0 deletions executor/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,25 @@ func (s *testSuite) TestUsing(c *C) {
tk.MustExec("select * from (t1 join t2 using (a)) join (t3 join t4 using (a)) on (t2.a = t4.a and t1.a = t3.a)")
}

func (s *testSuite) TestNaturalJoin(c *C) {
defer func() {
s.cleanEnv(c)
testleak.AfterTest(c)()
}()
tk := testkit.NewTestKit(c, s.store)

tk.MustExec("use test")
tk.MustExec("drop table if exists t1, t2")
tk.MustExec("create table t1 (a int, b int)")
tk.MustExec("create table t2 (a int, c int)")
tk.MustExec("insert t1 values (1, 2), (10, 20)")
tk.MustExec("insert t2 values (1, 3), (100, 200)")

tk.MustQuery("select * from t1 natural join t2").Check(testkit.Rows("1 2 3"))
tk.MustQuery("select * from t1 natural left join t2 order by a").Check(testkit.Rows("1 2 3", "10 20 <nil>"))
tk.MustQuery("select * from t1 natural right join t2 order by a").Check(testkit.Rows("1 3 2", "100 200 <nil>"))
}

func (s *testSuite) TestMultiJoin(c *C) {
defer func() {
s.cleanEnv(c)
Expand Down
1 change: 1 addition & 0 deletions parser/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ var tokenMap = map[string]int{
"UUID": uuid,
"UUID_SHORT": uuidShort,
"KILL": kill,
"NATURAL": natural,
}

func isTokenIdentifier(s string, buf *bytes.Buffer) int {
Expand Down
13 changes: 11 additions & 2 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ import (
xor "XOR"
yearMonth "YEAR_MONTH"
zerofill "ZEROFILL"
natural "NATURAL"

/* the following tokens belong to NotKeywordToken*/
abs "ABS"
Expand Down Expand Up @@ -894,7 +895,7 @@ import (
%precedence lowerThanKey
%precedence key

%left join inner cross left right full
%left join inner cross left right full natural
/* A dummy token to force the priority of TableRef production in a join. */
%left tableRefPriority
%precedence lowerThanOn
Expand Down Expand Up @@ -2434,7 +2435,7 @@ ReservedKeyword:
| "STARTING" | "TABLE" | "STORED" | "TERMINATED" | "THEN" | "TINYBLOB" | "TINYINT" | "TINYTEXT" | "TO"
| "TRAILING" | "TRIGGER" | "TRUE" | "UNION" | "UNIQUE" | "UNLOCK" | "UNSIGNED"
| "UPDATE" | "USE" | "USING" | "UTC_DATE" | "UTC_TIMESTAMP" | "VALUES" | "VARBINARY" | "VARCHAR" | "VIRTUAL"
| "WHEN" | "WHERE" | "WRITE" | "XOR" | "YEAR_MONTH" | "ZEROFILL"
| "WHEN" | "WHERE" | "WRITE" | "XOR" | "YEAR_MONTH" | "ZEROFILL" | "NATURAL"
/*
| "DELAYED" | "HIGH_PRIORITY" | "LOW_PRIORITY"| "WITH"
*/
Expand Down Expand Up @@ -4655,6 +4656,14 @@ JoinTable:
{
$$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), Using: $8.([]*ast.ColumnName)}
}
| TableRef "NATURAL" "JOIN" TableRef
{
$$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $4.(ast.ResultSetNode), NaturalJoin: true}
}
| TableRef "NATURAL" JoinType OuterOpt "JOIN" TableRef
{
$$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $6.(ast.ResultSetNode), Tp: $3.(ast.JoinType), NaturalJoin: true}
}

JoinType:
"LEFT"
Expand Down
5 changes: 5 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ func (s *testParserSuite) TestDMLStmt(c *C) {
{"select * from t1 join t2 left join t3 using (id)", true},
{"select * from t1 right join t2 using (id) left join t3 using (id)", true},
{"select * from t1 right join t2 using (id) left join t3", false},
{"select * from t1 natural join t2", true},
{"select * from t1 natural right join t2", true},
{"select * from t1 natural left outer join t2", true},
{"select * from t1 natural inner join t2", false},
{"select * from t1 natural cross join t2", false},

// for admin
{"admin show ddl;", true},
Expand Down
119 changes: 69 additions & 50 deletions plan/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ package plan

import (
"fmt"
"sort"

"github.com/cznic/mathutil"
"github.com/juju/errors"
Expand Down Expand Up @@ -257,7 +256,12 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan {
}
}

if join.Using != nil {
if join.NaturalJoin {
if err := b.buildNaturalJoin(joinPlan, leftPlan, rightPlan, join); err != nil {
b.err = err
return nil
}
} else if join.Using != nil {
if err := b.buildUsingClause(joinPlan, leftPlan, rightPlan, join); err != nil {
b.err = err
return nil
Expand Down Expand Up @@ -295,72 +299,87 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan {
// Second, columns unique to the first table, in order in which they occur in that table.
// Third, columns unique to the second table, in order in which they occur in that table.
func (b *planBuilder) buildUsingClause(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, join *ast.Join) error {
filter := make(map[string]bool, len(join.Using))
for _, col := range join.Using {
filter[col.Name.L] = true
}
return b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp == ast.RightJoin, filter)
}

// buildNaturalJoin build natural join output schema. It find out all the common columns
// then using the same mechanism as buildUsingClause to eliminate redundant columns and build join conditions.
// According to standard SQL, producing this display order:
// All the common columns
// Every column in the first (left) table that is not a common column
// Every column in the second (right) table that is not a common column
func (b *planBuilder) buildNaturalJoin(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, join *ast.Join) error {
return b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp == ast.RightJoin, nil)
}

// coalesceCommonColumns is used by buildUsingClause and buildNaturalJoin. The filter is used by buildUsingClause.
func (b *planBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, rightJoin bool, filter map[string]bool) error {
lsc := leftPlan.Schema().Clone()
rsc := rightPlan.Schema().Clone()
lColumns, rColumns := lsc.Columns, rsc.Columns
if rightJoin {
lColumns, rColumns = rsc.Columns, lsc.Columns
}

schemaCols := make([]*expression.Column, 0, len(lsc.Columns)+len(rsc.Columns)-len(join.Using))
redundantCols := make([]*expression.Column, 0, len(join.Using))
conds := make([]*expression.ScalarFunction, 0, len(join.Using))
// Find out all the common columns and put them ahead.
commonLen := 0
for i, lCol := range lColumns {
for j := commonLen; j < len(rColumns); j++ {
if lCol.ColName.L != rColumns[j].ColName.L {
continue
}

redundant := make(map[string]bool, len(join.Using))
for _, col := range join.Using {
var (
err error
lc, rc *expression.Column
cond expression.Expression
)
if len(filter) > 0 {
if !filter[lCol.ColName.L] {
break
}
// Mark this column exist.
filter[lCol.ColName.L] = false
}

if lc, err = lsc.FindColumn(col); err != nil {
return errors.Trace(err)
}
if rc, err = rsc.FindColumn(col); err != nil {
return errors.Trace(err)
}
redundant[col.Name.L] = true
if lc == nil || rc == nil {
// Same as MySQL.
return ErrUnknownColumn.GenByArgs(col.Name, "from clause")
}
col := rColumns[i]
copy(rColumns[commonLen+1:i+1], rColumns[commonLen:i])
rColumns[commonLen] = col

if cond, err = expression.NewFunction(b.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc); err != nil {
return errors.Trace(err)
}
conds = append(conds, cond.(*expression.ScalarFunction))
col = lColumns[j]
copy(lColumns[commonLen+1:j+1], lColumns[commonLen:j])
lColumns[commonLen] = col

if join.Tp == ast.RightJoin {
schemaCols = append(schemaCols, rc)
redundantCols = append(redundantCols, lc)
} else {
schemaCols = append(schemaCols, lc)
redundantCols = append(redundantCols, rc)
commonLen++
break
}
}

// Columns in using clause may not ordered in the order in which they occur in the first table, so reorder them.
sort.Slice(schemaCols, func(i, j int) bool {
return schemaCols[i].Position < schemaCols[j].Position
})

if join.Tp == ast.RightJoin {
lsc, rsc = rsc, lsc
}
for _, col := range lsc.Columns {
if !redundant[col.ColName.L] {
schemaCols = append(schemaCols, col)
if len(filter) > 0 && len(filter) != commonLen {
for col, notExist := range filter {
if notExist {
return ErrUnknownColumn.GenByArgs(col, "from clause")
}
}
}
for _, col := range rsc.Columns {
if !redundant[col.ColName.L] {
schemaCols = append(schemaCols, col)

schemaCols := make([]*expression.Column, len(lColumns)+len(rColumns)-commonLen)
copy(schemaCols[:len(lColumns)], lColumns)
copy(schemaCols[len(lColumns):], rColumns[commonLen:])

conds := make([]*expression.ScalarFunction, 0, commonLen)
for i := 0; i < commonLen; i++ {
lc, rc := lsc.Columns[i], rsc.Columns[i]
cond, err := expression.NewFunction(b.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc)
if err != nil {
return errors.Trace(err)
}
conds = append(conds, cond.(*expression.ScalarFunction))
}

p.SetSchema(expression.NewSchema(schemaCols...))
p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(rColumns[:commonLen]...))
p.EqualConditions = append(conds, p.EqualConditions...)

// p.redundantSchema may contains columns which are merged from sub join, so merge it with redundantCols.
p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(redundantCols...))

return nil
}

Expand Down

0 comments on commit 7874e98

Please sign in to comment.