Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: introduce cast(... as ... array) in expression index #39992

Merged
merged 5 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/kv/sql2kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func collectGeneratedColumns(se *session, meta *model.TableInfo, cols []*table.C
var genCols []genCol
for i, col := range cols {
if col.GeneratedExpr != nil {
expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names)
expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names, false)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6164,7 +6164,7 @@ func (d *ddl) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName m
// After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic.
// The recover step causes DDL wait a few seconds, makes the unit test painfully slow.
// For same reason, decide whether index is global here.
indexColumns, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications)
indexColumns, _, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -6274,7 +6274,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as
if err != nil {
return nil, errors.Trace(err)
}
expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr)
expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr, true)
if err != nil {
// TODO: refine the error message.
return nil, err
Expand Down Expand Up @@ -6389,7 +6389,7 @@ func (d *ddl) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast.Inde
// After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic.
// The recover step causes DDL wait a few seconds, makes the unit test painfully slow.
// For same reason, decide whether index is global here.
indexColumns, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications)
indexColumns, _, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications)
if err != nil {
return errors.Trace(err)
}
Expand Down
24 changes: 18 additions & 6 deletions ddl/generated_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,14 @@ func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol
}

type illegalFunctionChecker struct {
hasIllegalFunc bool
hasAggFunc bool
hasRowVal bool // hasRowVal checks whether the functional index refers to a row value
hasWindowFunc bool
hasNotGAFunc4ExprIdx bool
otherErr error
hasIllegalFunc bool
hasAggFunc bool
hasRowVal bool // hasRowVal checks whether the functional index refers to a row value
hasWindowFunc bool
hasNotGAFunc4ExprIdx bool
hasCastArrayFunc bool
disallowCastArrayFunc bool
otherErr error
}

func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) {
Expand Down Expand Up @@ -308,7 +310,14 @@ func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipC
case *ast.WindowFuncExpr:
c.hasWindowFunc = true
return inNode, true
case *ast.FuncCastExpr:
c.hasCastArrayFunc = c.hasCastArrayFunc || node.Tp.IsArray()
if c.disallowCastArrayFunc && node.Tp.IsArray() {
c.otherErr = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
return inNode, true
}
}
c.disallowCastArrayFunc = true
return inNode, false
}

Expand Down Expand Up @@ -355,6 +364,9 @@ func checkIllegalFn4Generated(name string, genType int, expr ast.ExprNode) error
if genType == typeIndex && c.hasNotGAFunc4ExprIdx && !config.GetGlobalConfig().Experimental.AllowsExpressionIndex {
return dbterror.ErrUnsupportedExpressionIndex
}
if genType == typeColumn && c.hasCastArrayFunc {
return expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
}
return nil
}

Expand Down
21 changes: 12 additions & 9 deletions ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,28 @@ var (
telemetryAddIndexIngestUsage = metrics.TelemetryAddIndexIngestCnt
)

func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, error) {
func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) {
// Build offsets.
idxParts := make([]*model.IndexColumn, 0, len(indexPartSpecifications))
var col *model.ColumnInfo
var mvIndex bool
maxIndexLength := config.GetGlobalConfig().MaxIndexLength
// The sum of length of all index columns.
sumLength := 0
for _, ip := range indexPartSpecifications {
col = model.FindColumnInfo(columns, ip.Column.Name.L)
if col == nil {
return nil, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name)
return nil, false, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name)
}

if err := checkIndexColumn(ctx, col, ip.Length); err != nil {
return nil, err
return nil, false, err
}
mvIndex = mvIndex || col.FieldType.IsArray()
indexColLen := ip.Length
indexColumnLength, err := getIndexColumnLength(col, ip.Length)
if err != nil {
return nil, err
return nil, false, err
}
sumLength += indexColumnLength

Expand All @@ -92,12 +94,12 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde
// The multiple column index and the unique index in which the length sum exceeds the maximum size
// will return an error instead produce a warning.
if ctx == nil || ctx.GetSessionVars().StrictSQLMode || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 {
return nil, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength)
return nil, false, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength)
}
// truncate index length and produce warning message in non-restrict sql mode.
colLenPerUint, err := getIndexColumnLength(col, 1)
if err != nil {
return nil, err
return nil, false, err
}
indexColLen = maxIndexLength / colLenPerUint
// produce warning message
Expand All @@ -111,7 +113,7 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde
})
}

return idxParts, nil
return idxParts, mvIndex, nil
}

// CheckPKOnGeneratedColumn checks the specification of PK is valid.
Expand Down Expand Up @@ -154,7 +156,7 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn
}

// JSON column cannot index.
if col.FieldType.GetType() == mysql.TypeJSON {
if col.FieldType.GetType() == mysql.TypeJSON && !col.FieldType.IsArray() {
if col.Hidden {
return dbterror.ErrFunctionalIndexOnJSONOrGeometryFunction
}
Expand Down Expand Up @@ -263,7 +265,7 @@ func BuildIndexInfo(
return nil, errors.Trace(err)
}

idxColumns, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications)
idxColumns, mvIndex, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -276,6 +278,7 @@ func BuildIndexInfo(
Primary: isPrimary,
Unique: isUnique,
Global: isGlobal,
MVIndex: mvIndex,
}

if indexOption != nil {
Expand Down
2 changes: 1 addition & 1 deletion ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, tblInfo *
return nil
}

e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr)
e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr, false)
if err != nil {
return errors.Trace(err)
}
Expand Down
1 change: 1 addition & 0 deletions expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ go_test(
"integration_serial_test.go",
"integration_test.go",
"main_test.go",
"multi_valued_index_test.go",
"scalar_function_test.go",
"schema_test.go",
"typeinfer_test.go",
Expand Down
83 changes: 79 additions & 4 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package expression

import (
"fmt"
"math"
"strconv"
"strings"
Expand Down Expand Up @@ -407,6 +408,70 @@ func (c *castAsDurationFunctionClass) getFunction(ctx sessionctx.Context, args [
return sig, nil
}

type castAsArrayFunctionClass struct {
baseFunctionClass

tp *types.FieldType
}

func (c *castAsArrayFunctionClass) verifyArgs(args []Expression) error {
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}

if args[0].GetType().EvalType() != types.ETJson {
return types.ErrInvalidJSONData.GenWithStackByArgs("1", "cast_as_array")
}

return nil
}

func (c *castAsArrayFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (sig builtinFunc, err error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
arrayType := c.tp.ArrayType()
switch arrayType.GetType() {
case mysql.TypeYear, mysql.TypeJSON:
return nil, ErrNotSupportedYet.GenWithStackByArgs(fmt.Sprintf("CAST-ing data to array of %s", arrayType.String()))
}
if arrayType.EvalType() == types.ETString && arrayType.GetCharset() != charset.CharsetUTF8MB4 && arrayType.GetCharset() != charset.CharsetBin {
return nil, ErrNotSupportedYet.GenWithStackByArgs("specifying charset for multi-valued index", arrayType.String())
}

bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp)
if err != nil {
return nil, err
}
sig = &castJSONAsArrayFunctionSig{bf}
return sig, nil
}

type castJSONAsArrayFunctionSig struct {
baseBuiltinFunc
}

func (b *castJSONAsArrayFunctionSig) Clone() builtinFunc {
newSig := &castJSONAsArrayFunctionSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJSON, isNull bool, err error) {
val, isNull, err := b.args[0].EvalJSON(b.ctx, row)
if isNull || err != nil {
return res, isNull, err
}

if val.TypeCode != types.JSONTypeCodeArray {
return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAST-ing Non-JSON Array type to array")
}

// TODO: impl the cast(... as ... array) function

return types.BinaryJSON{}, false, nil
}

type castAsJSONFunctionClass struct {
baseFunctionClass

Expand Down Expand Up @@ -1914,6 +1979,13 @@ func BuildCastCollationFunction(ctx sessionctx.Context, expr Expression, ec *Exp

// BuildCastFunction builds a CAST ScalarFunction from the Expression.
func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
res, err := BuildCastFunctionWithCheck(ctx, expr, tp)
terror.Log(err)
return
}

// BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any.
func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression, err error) {
argType := expr.GetType()
// If source argument's nullable, then target type should be nullable
if !mysql.HasNotNullFlag(argType.GetFlag()) {
Expand All @@ -1933,15 +2005,18 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT
case types.ETDuration:
fc = &castAsDurationFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ETJson:
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
if tp.IsArray() {
fc = &castAsArrayFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
} else {
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
}
case types.ETString:
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
if expr.GetType().GetType() == mysql.TypeBit {
tp.SetFlen((expr.GetType().GetFlen() + 7) / 8)
}
}
f, err := fc.getFunction(ctx, []Expression{expr})
terror.Log(err)
res = &ScalarFunction{
FuncName: model.NewCIStr(ast.Cast),
RetType: tp,
Expand All @@ -1950,10 +2025,10 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT
// We do not fold CAST if the eval type of this scalar function is ETJson
// since we may reset the flag of the field type of CastAsJson later which
// would affect the evaluation of it.
if tp.EvalType() != types.ETJson {
if tp.EvalType() != types.ETJson && err == nil {
res = FoldConstant(res)
}
return res
return res, err
}

// WrapWithCastAsInt wraps `expr` with `cast` if the return type of expr is not
Expand Down
1 change: 1 addition & 0 deletions expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var (
ErrInvalidTableSample = dbterror.ClassExpression.NewStd(mysql.ErrInvalidTableSample)
ErrInternal = dbterror.ClassOptimizer.NewStd(mysql.ErrInternal)
ErrNoDB = dbterror.ClassOptimizer.NewStd(mysql.ErrNoDB)
ErrNotSupportedYet = dbterror.ClassExpression.NewStd(mysql.ErrNotSupportedYet)

// All the un-exported errors are defined here:
errFunctionNotExists = dbterror.ClassExpression.NewStd(mysql.ErrSpDoesNotExist)
Expand Down
4 changes: 2 additions & 2 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ var EvalAstExpr func(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, e
// RewriteAstExpr rewrites ast expression directly.
// Note: initialized in planner/core
// import expression and planner/core together to use EvalAstExpr
var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice) (Expression, error)
var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice, allowCastArray bool) (Expression, error)

// VecExpr contains all vectorized evaluation methods.
type VecExpr interface {
Expand Down Expand Up @@ -998,7 +998,7 @@ func ColumnInfos2ColumnsAndNames(ctx sessionctx.Context, dbName, tblName model.C
if err != nil {
return nil, nil, errors.Trace(err)
}
e, err := RewriteAstExpr(ctx, expr, mockSchema, names)
e, err := RewriteAstExpr(ctx, expr, mockSchema, names, false)
if err != nil {
return nil, nil, errors.Trace(err)
}
Expand Down
47 changes: 47 additions & 0 deletions expression/multi_valued_index_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression_test

import (
"testing"

"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/testkit"
)

func TestMultiValuedIndexDDL(t *testing.T) {
store := testkit.CreateMockStore(t)

tk := testkit.NewTestKit(t, store)
tk.MustExec("USE test;")

tk.MustExec("create table t(a json);")
tk.MustGetErrCode("select cast(a as signed array) from t", errno.ErrNotSupportedYet)
tk.MustGetErrCode("select json_extract(cast(a as signed array), '$[0]') from t", errno.ErrNotSupportedYet)
tk.MustGetErrCode("select * from t where cast(a as signed array)", errno.ErrNotSupportedYet)
tk.MustGetErrCode("select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet)

tk.MustExec("drop table t")
tk.MustGetErrCode("CREATE TABLE t(x INT, KEY k ((1 AND CAST(JSON_ARRAY(x) AS UNSIGNED ARRAY))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(cast(f1 as unsigned array) as unsigned array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->>'$[*]' as unsigned array))));", errno.ErrInvalidJSONData)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as year array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as json array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as char(10) charset gbk array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("create table t(j json, gc json as ((concat(cast(j->'$[*]' as unsigned array),\"x\"))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("create table t(j json, gc json as (cast(j->'$[*]' as unsigned array)));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("create view v as select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet)
tk.MustExec("create table t(a json, index idx((cast(a as signed array))));")
}
Loading