diff --git a/expression/column.go b/expression/column.go index 52557377c34d0..f63cf44a117a5 100644 --- a/expression/column.go +++ b/expression/column.go @@ -375,3 +375,18 @@ func IndexInfo2Cols(cols []*Column, index *model.IndexInfo) ([]*Column, []int) { } return retCols, lengths } + +// FindColumnsByUniqueIDs will find columns by checking the unique id. +// Note: `ids` must be a subset of the column slice. +func FindColumnsByUniqueIDs(cols []*Column, ids []int) []*Column { + retCols := make([]*Column, 0, len(ids)) + for _, id := range ids { + for _, col := range cols { + if col.UniqueID == id { + retCols = append(retCols, col) + break + } + } + } + return retCols +} diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 5a5b4450b5c1e..6d752fca78ff6 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -56,6 +56,33 @@ func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo return rewriter.pop(), nil } +func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema *Schema) ([]Expression, error) { + exprStr = "select " + exprStr + stmts, err := parser.New().Parse(exprStr, "", "") + if err != nil { + return nil, errors.Trace(err) + } + fields := stmts[0].(*ast.SelectStmt).Fields.Fields + exprs := make([]Expression, 0, len(fields)) + for _, field := range fields { + expr, err := RewriteSimpleExprWithSchema(ctx, field.Expr, schema) + if err != nil { + return nil, errors.Trace(err) + } + exprs = append(exprs, expr) + } + return exprs, nil +} + +func RewriteSimpleExprWithSchema(ctx sessionctx.Context, expr ast.ExprNode, schema *Schema) (Expression, error) { + rewriter := &simpleRewriter{ctx: ctx, schema: schema} + expr.Accept(rewriter) + if rewriter.err != nil { + return nil, errors.Trace(rewriter.err) + } + return rewriter.pop(), nil +} + func (sr *simpleRewriter) rewriteColumn(nodeColName *ast.ColumnNameExpr) (*Column, error) { col := sr.schema.FindColumnByName(nodeColName.Name.Name.L) if col != nil { diff --git a/plan/cbo_test.go b/plan/cbo_test.go index 875709e3c54e4..832ac06ede355 100644 --- a/plan/cbo_test.go +++ b/plan/cbo_test.go @@ -623,7 +623,7 @@ func (s *testAnalyzeSuite) TestCorrelatedEstimation(c *C) { store.Close() }() tk.MustExec("use test") - tk.MustExec("create table t(a int, b int, c int)") + tk.MustExec("create table t(a int, b int, c int, index idx(c))") tk.MustExec("insert into t values(1,1,1), (2,2,2), (3,3,3), (4,4,4), (5,5,5), (6,6,6), (7,7,7), (8,8,8), (9,9,9),(10,10,10)") tk.MustExec("analyze table t") tk.MustQuery("explain select t.c in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t;"). @@ -640,6 +640,19 @@ func (s *testAnalyzeSuite) TestCorrelatedEstimation(c *C) { " └─TableReader_27 10.00 root data:TableScan_26", " └─TableScan_26 10.00 cop table:t1, range:[-inf,+inf], keep order:false", )) + tk.MustQuery("explain select (select concat(t1.a, \",\", t1.b) from t t1 where t1.a=t.a and t1.c=t.c) from t"). + Check(testkit.Rows( + "Projection_8 10.00 root concat(t1.a, \",\", t1.b)", + "└─Apply_10 10.00 root left outer join, inner:MaxOneRow_13", + " ├─TableReader_12 10.00 root data:TableScan_11", + " │ └─TableScan_11 10.00 cop table:t, range:[-inf,+inf], keep order:false", + " └─MaxOneRow_13 1.00 root ", + " └─Projection_14 0.80 root concat(cast(t1.a), \",\", cast(t1.b))", + " └─IndexLookUp_21 0.80 root ", + " ├─IndexScan_18 1.00 cop table:t1, index:c, range: decided by [eq(t1.c, test.t.c)], keep order:false", + " └─Selection_20 0.80 cop eq(t1.a, test.t.a)", + " └─TableScan_19 1.00 cop table:t, keep order:false", + )) } func (s *testAnalyzeSuite) TestInconsistentEstimation(c *C) { diff --git a/plan/logical_plans.go b/plan/logical_plans.go index 1e4ad4a859292..12d2ff433106d 100644 --- a/plan/logical_plans.go +++ b/plan/logical_plans.go @@ -484,14 +484,15 @@ func (path *accessPath) splitCorColAccessCondFromFilters() (access, remained []e for i := path.eqCondCount; i < len(path.idxCols); i++ { matched := false for j, filter := range path.tableFilters { - if !isColEqCorColOrConstant(filter, path.idxCols[i]) { - break + if used[j] || !isColEqCorColOrConstant(filter, path.idxCols[i]) { + continue } matched = true access[i-path.eqCondCount] = filter if path.idxColLens[i] == types.UnspecifiedLength { used[j] = true } + break } if !matched { access = access[:i-path.eqCondCount] diff --git a/plan/logical_plans_test.go b/plan/logical_plans_test.go new file mode 100644 index 0000000000000..6ec6894d51381 --- /dev/null +++ b/plan/logical_plans_test.go @@ -0,0 +1,195 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain col1 copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "fmt" + + "github.com/juju/errors" + . "github.com/pingcap/check" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/testleak" +) + +var _ = Suite(&testUnitTestSuit{}) + +type testUnitTestSuit struct { + ctx sessionctx.Context +} + +func (s *testUnitTestSuit) SetUpSuite(c *C) { + s.ctx = mockContext() +} + +func (s *testUnitTestSuit) newTypeWithFlen(typeByte byte, flen int) *types.FieldType { + tp := types.NewFieldType(typeByte) + tp.Flen = flen + return tp +} + +func (s *testUnitTestSuit) SubstituteCol2CorCol(expr expression.Expression, colIDs map[int]struct{}) (expression.Expression, error) { + switch x := expr.(type) { + case *expression.ScalarFunction: + newArgs := make([]expression.Expression, 0, len(x.GetArgs())) + for _, arg := range x.GetArgs() { + newArg, err := s.SubstituteCol2CorCol(arg, colIDs) + if err != nil { + return nil, errors.Trace(err) + } + newArgs = append(newArgs, newArg) + } + newSf, err := expression.NewFunction(x.GetCtx(), x.FuncName.L, x.GetType(), newArgs...) + return newSf, errors.Trace(err) + case *expression.Column: + if _, ok := colIDs[x.UniqueID]; ok { + return &expression.CorrelatedColumn{Column: *x}, nil + } + } + return expr, nil +} + +func (s *testUnitTestSuit) TestIndexPathSplitCorColCond(c *C) { + defer testleak.AfterTest(c)() + totalSchema := expression.NewSchema() + totalSchema.Append(&expression.Column{ + ColName: model.NewCIStr("col1"), + UniqueID: 1, + RetType: types.NewFieldType(mysql.TypeLonglong), + }) + totalSchema.Append(&expression.Column{ + ColName: model.NewCIStr("col2"), + UniqueID: 2, + RetType: types.NewFieldType(mysql.TypeLonglong), + }) + totalSchema.Append(&expression.Column{ + ColName: model.NewCIStr("col3"), + UniqueID: 3, + RetType: s.newTypeWithFlen(mysql.TypeVarchar, 10), + }) + totalSchema.Append(&expression.Column{ + ColName: model.NewCIStr("col4"), + UniqueID: 4, + RetType: s.newTypeWithFlen(mysql.TypeVarchar, 10), + }) + totalSchema.Append(&expression.Column{ + ColName: model.NewCIStr("col5"), + UniqueID: 5, + RetType: types.NewFieldType(mysql.TypeLonglong), + }) + testCases := []struct { + expr string + corColIDs []int + idxColIDs []int + idxColLens []int + access string + remained string + }{ + { + expr: "col1 = col2", + corColIDs: []int{2}, + idxColIDs: []int{1}, + idxColLens: []int{types.UnspecifiedLength}, + access: "[eq(col1, col2)]", + remained: "[]", + }, + { + expr: "col1 = col5 and col2 = 1", + corColIDs: []int{5}, + idxColIDs: []int{1, 2}, + idxColLens: []int{types.UnspecifiedLength, types.UnspecifiedLength}, + access: "[eq(col1, col5) eq(col2, 1)]", + remained: "[]", + }, + { + expr: "col1 = col5 and col2 = 1", + corColIDs: []int{5}, + idxColIDs: []int{2, 1}, + idxColLens: []int{types.UnspecifiedLength, types.UnspecifiedLength}, + access: "[eq(col2, 1) eq(col1, col5)]", + remained: "[]", + }, + { + expr: "col1 = col5 and col2 = 1", + corColIDs: []int{5}, + idxColIDs: []int{1}, + idxColLens: []int{types.UnspecifiedLength}, + access: "[eq(col1, col5)]", + remained: "[eq(col2, 1)]", + }, + { + expr: "col2 = 1 and col1 = col5", + corColIDs: []int{5}, + idxColIDs: []int{1}, + idxColLens: []int{types.UnspecifiedLength}, + access: "[eq(col1, col5)]", + remained: "[eq(col2, 1)]", + }, + { + expr: "col1 = col2 and col3 = col4 and col5 = 1", + corColIDs: []int{2, 4}, + idxColIDs: []int{1, 3}, + idxColLens: []int{types.UnspecifiedLength, types.UnspecifiedLength}, + access: "[eq(col1, col2) eq(col3, col4)]", + remained: "[eq(col5, 1)]", + }, + { + expr: "col1 = col2 and col3 = col4 and col5 = 1", + corColIDs: []int{2, 4}, + idxColIDs: []int{1, 3}, + idxColLens: []int{types.UnspecifiedLength, 2}, + access: "[eq(col1, col2) eq(col3, col4)]", + remained: "[eq(col3, col4) eq(col5, 1)]", + }, + { + expr: `col1 = col5 and col3 = "col1" and col2 = col5`, + corColIDs: []int{5}, + idxColIDs: []int{1, 2, 3}, + idxColLens: []int{types.UnspecifiedLength, types.UnspecifiedLength, types.UnspecifiedLength}, + access: "[eq(col1, col5) eq(col2, col5) eq(col3, col1)]", + remained: "[]", + }, + } + for _, tt := range testCases { + comment := Commentf("failed at case:\nexpr: %v\ncorColIDs: %v\nidxColIDs: %v\nidxColLens: %v\naccess: %v\nremained: %v\n", tt.expr, tt.corColIDs, tt.idxColIDs, tt.idxColLens, tt.access, tt.remained) + filters, err := expression.ParseSimpleExprsWithSchema(s.ctx, tt.expr, totalSchema) + if sf, ok := filters[0].(*expression.ScalarFunction); ok && sf.FuncName.L == ast.LogicAnd { + filters = expression.FlattenCNFConditions(sf) + } + c.Assert(err, IsNil, comment) + trueFilters := make([]expression.Expression, 0, len(filters)) + idMap := make(map[int]struct{}) + for _, id := range tt.corColIDs { + idMap[id] = struct{}{} + } + for _, filter := range filters { + trueFilter, err := s.SubstituteCol2CorCol(filter, idMap) + c.Assert(err, IsNil, comment) + trueFilters = append(trueFilters, trueFilter) + } + path := accessPath{ + eqCondCount: 0, + tableFilters: trueFilters, + idxCols: expression.FindColumnsByUniqueIDs(totalSchema.Columns, tt.idxColIDs), + idxColLens: tt.idxColLens, + } + access, remained := path.splitCorColAccessCondFromFilters() + c.Assert(fmt.Sprintf("%s", access), Equals, tt.access, comment) + c.Assert(fmt.Sprintf("%s", remained), Equals, tt.remained, comment) + } +}