Skip to content

Commit

Permalink
expression, executor, plan: rewrite builtin function trim. (#3936)
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveZhangBit authored and shenli committed Aug 1, 2017
1 parent 58dca67 commit 3527750
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 91 deletions.
11 changes: 11 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,17 @@ func (s *testSuite) TestStringBuiltin(c *C) {
result.Check(testutil.RowsWithSep(",", "bar ,bar,,<nil>"))
result = tk.MustQuery(`select rtrim(' bar '), rtrim('bar'), rtrim(''), rtrim(null)`)
result.Check(testutil.RowsWithSep(",", " bar,bar,,<nil>"))

// for trim
result = tk.MustQuery(`select trim(' bar '), trim(leading 'x' from 'xxxbarxxx'), trim(trailing 'xyz' from 'barxxyz'), trim(both 'x' from 'xxxbarxxx')`)
result.Check(testkit.Rows("bar barxxx barx bar"))
result = tk.MustQuery(`select trim(leading from ' bar'), trim('x' from 'xxxbarxxx'), trim('x' from 'bar'), trim('' from ' bar ')`)
result.Check(testutil.RowsWithSep(",", "bar,bar,bar, bar "))
result = tk.MustQuery(`select trim(''), trim('x' from '')`)
result.Check(testutil.RowsWithSep(",", ","))
result = tk.MustQuery(`select trim(null from 'bar'), trim('x' from null), trim(null), trim(leading null from 'bar')`)
// FIXME: the result for trim(leading null from 'bar') should be <nil>, current is 'bar'
result.Check(testkit.Rows("<nil> <nil> <nil> bar"))
}

func (s *testSuite) TestEncryptionBuiltin(c *C) {
Expand Down
172 changes: 123 additions & 49 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ var (
_ builtinFunc = &builtinHexStrArgSig{}
_ builtinFunc = &builtinHexIntArgSig{}
_ builtinFunc = &builtinUnHexSig{}
_ builtinFunc = &builtinTrimSig{}
_ builtinFunc = &builtinTrim1ArgSig{}
_ builtinFunc = &builtinTrim2ArgsSig{}
_ builtinFunc = &builtinTrim3ArgsSig{}
_ builtinFunc = &builtinLTrimSig{}
_ builtinFunc = &builtinRTrimSig{}
_ builtinFunc = &builtinRpadSig{}
Expand Down Expand Up @@ -996,8 +998,6 @@ func (b *builtinLocateSig) eval(row []types.Datum) (d types.Datum, err error) {
return d, nil
}

const spaceChars = "\n\t\r "

type hexFunctionClass struct {
baseFunctionClass
}
Expand Down Expand Up @@ -1118,75 +1118,149 @@ func (b *builtinUnHexSig) evalString(row []types.Datum) (string, bool, error) {
return string(bs), false, nil
}

const spaceChars = "\n\t\r "

type trimFunctionClass struct {
baseFunctionClass
}

// The syntax of trim in mysql is 'TRIM([{BOTH | LEADING | TRAILING} [remstr] FROM] str), TRIM([remstr FROM] str)',
// but we wil convert it into trim(str), trim(str, remstr) and trim(str, remstr, direction) in AST.
func (c *trimFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
sig := &builtinTrimSig{newBaseBuiltinFunc(args, ctx)}
return sig.setSelf(sig), errors.Trace(c.verifyArgs(args))
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}

switch len(args) {
case 1:
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString)
if err != nil {
return nil, errors.Trace(err)
}
argType := args[0].GetType()
bf.tp.Flen = argType.Flen
if mysql.HasBinaryFlag(argType.Flag) {
types.SetBinChsClnFlag(bf.tp)
}
sig := &builtinTrim1ArgSig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), nil

case 2:
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpString)
if err != nil {
return nil, errors.Trace(err)
}
argType := args[0].GetType()
bf.tp.Flen = argType.Flen
if mysql.HasBinaryFlag(argType.Flag) {
types.SetBinChsClnFlag(bf.tp)
}
sig := &builtinTrim2ArgsSig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), nil

case 3:
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpString, tpInt)
if err != nil {
return nil, errors.Trace(err)
}
argType := args[0].GetType()
bf.tp.Flen = argType.Flen
if mysql.HasBinaryFlag(argType.Flag) {
types.SetBinChsClnFlag(bf.tp)
}
sig := &builtinTrim3ArgsSig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), nil

default:
return nil, errors.Trace(c.verifyArgs(args))
}
}

type builtinTrimSig struct {
baseBuiltinFunc
type builtinTrim1ArgSig struct {
baseStringBuiltinFunc
}

// eval evals a builtinTrimSig.
// evalString evals a builtinTrim1ArgSig, corresponding to trim(str)
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim
func (b *builtinTrimSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
func (b *builtinTrim1ArgSig) evalString(row []types.Datum) (d string, isNull bool, err error) {
d, isNull, err = b.args[0].EvalString(row, b.ctx.GetSessionVars().StmtCtx)
if isNull || err != nil {
return d, isNull, errors.Trace(err)
}
// args[0] -> Str
// args[1] -> RemStr
// args[2] -> Direction
// eval str
if args[0].IsNull() {
return d, nil
return strings.Trim(d, spaceChars), false, nil
}

type builtinTrim2ArgsSig struct {
baseStringBuiltinFunc
}

// evalString evals a builtinTrim2ArgsSig, corresponding to trim(str, remstr)
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim
func (b *builtinTrim2ArgsSig) evalString(row []types.Datum) (d string, isNull bool, err error) {
var str, remstr string

sc := b.ctx.GetSessionVars().StmtCtx
str, isNull, err = b.args[0].EvalString(row, sc)
if isNull || err != nil {
return d, isNull, errors.Trace(err)
}
str, err := args[0].ToString()
if err != nil {
return d, errors.Trace(err)
remstr, isNull, err = b.args[1].EvalString(row, sc)
if isNull || err != nil {
return d, isNull, errors.Trace(err)
}
remstr := ""
// eval remstr
if len(args) > 1 {
if args[1].Kind() != types.KindNull {
remstr, err = args[1].ToString()
if err != nil {
return d, errors.Trace(err)
}
}
d = trimLeft(str, remstr)
d = trimRight(d, remstr)
return d, false, nil
}

type builtinTrim3ArgsSig struct {
baseStringBuiltinFunc
}

// evalString evals a builtinTrim3ArgsSig, corresponding to trim(str, remstr, direction)
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim
func (b *builtinTrim3ArgsSig) evalString(row []types.Datum) (d string, isNull bool, err error) {
var (
str, remstr string
x int64
direction ast.TrimDirectionType
isRemStrNull bool
)
sc := b.ctx.GetSessionVars().StmtCtx
str, isNull, err = b.args[0].EvalString(row, sc)
if isNull || err != nil {
return d, isNull, errors.Trace(err)
}
// do trim
var result string
var direction ast.TrimDirectionType
if len(args) > 2 {
direction = args[2].GetValue().(ast.TrimDirectionType)
} else {
direction = ast.TrimBothDefault
remstr, isRemStrNull, err = b.args[1].EvalString(row, sc)
if err != nil {
return d, isNull, errors.Trace(err)
}
x, isNull, err = b.args[2].EvalInt(row, sc)
if isNull || err != nil {
return d, isNull, errors.Trace(err)
}
direction = ast.TrimDirectionType(x)
if direction == ast.TrimLeading {
if len(remstr) > 0 {
result = trimLeft(str, remstr)
if isRemStrNull {
d = strings.TrimLeft(str, spaceChars)
} else {
result = strings.TrimLeft(str, spaceChars)
d = trimLeft(str, remstr)
}
} else if direction == ast.TrimTrailing {
if len(remstr) > 0 {
result = trimRight(str, remstr)
if isRemStrNull {
d = strings.TrimRight(str, spaceChars)
} else {
result = strings.TrimRight(str, spaceChars)
d = trimRight(str, remstr)
}
} else if len(remstr) > 0 {
x := trimLeft(str, remstr)
result = trimRight(x, remstr)
} else {
result = strings.Trim(str, spaceChars)
if isRemStrNull {
d = strings.Trim(str, spaceChars)
} else {
d = trimLeft(str, remstr)
d = trimRight(d, remstr)
}
}
d.SetString(result)
return d, nil
return d, false, nil
}

type lTrimFunctionClass struct {
Expand Down
84 changes: 44 additions & 40 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,51 +867,55 @@ func (s *testEvaluatorSuite) TestLocate(c *C) {

func (s *testEvaluatorSuite) TestTrim(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
str interface{}
remstr interface{}
dir ast.TrimDirectionType
result interface{}
cases := []struct {
args []interface{}
isNil bool
getErr bool
res string
}{
{" bar ", nil, ast.TrimBothDefault, "bar"},
{"xxxbarxxx", "x", ast.TrimLeading, "barxxx"},
{"xxxbarxxx", "x", ast.TrimBoth, "bar"},
{"barxxyz", "xyz", ast.TrimTrailing, "barx"},
{nil, "xyz", ast.TrimBoth, nil},
{1, 2, ast.TrimBoth, "1"},
{" \t\rbar\n ", nil, ast.TrimBothDefault, "bar"},
{[]interface{}{" bar "}, false, false, "bar"},
{[]interface{}{""}, false, false, ""},
{[]interface{}{nil}, true, false, ""},
{[]interface{}{"xxxbarxxx", "x"}, false, false, "bar"},
{[]interface{}{"bar", "x"}, false, false, "bar"},
{[]interface{}{" bar ", ""}, false, false, " bar "},
{[]interface{}{"", "x"}, false, false, ""},
{[]interface{}{"bar", nil}, true, false, ""},
{[]interface{}{nil, "x"}, true, false, ""},
{[]interface{}{"xxxbarxxx", "x", int(ast.TrimLeading)}, false, false, "barxxx"},
{[]interface{}{"barxxyz", "xyz", int(ast.TrimTrailing)}, false, false, "barx"},
{[]interface{}{"xxxbarxxx", "x", int(ast.TrimBoth)}, false, false, "bar"},
// FIXME: the result for this test shuold be nil, current is "bar"
{[]interface{}{"bar", nil, int(ast.TrimLeading)}, false, false, "bar"},
{[]interface{}{errors.New("must error")}, false, true, ""},
}
for _, v := range tbl {
fc := funcs[ast.Trim]
f, err := fc.getFunction(datumsToConstants(types.MakeDatums(v.str, v.remstr, v.dir)), s.ctx)
c.Assert(err, IsNil)
r, err := f.eval(nil)
for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.Trim, primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum(v.result))
d, err := f.Eval(nil)
if t.getErr {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
if t.isNil {
c.Assert(d.Kind(), Equals, types.KindNull)
} else {
c.Assert(d.GetString(), Equals, t.res)
}
}
}

for _, v := range []struct {
str, result interface{}
fn string
}{
{" ", "", ast.LTrim},
{" ", "", ast.RTrim},
{"foo0", "foo0", ast.LTrim},
{"bar0", "bar0", ast.RTrim},
{" foo1", "foo1", ast.LTrim},
{"bar1 ", "bar1", ast.RTrim},
{spaceChars + "foo2 ", "foo2 ", ast.LTrim},
{" bar2" + spaceChars, " bar2", ast.RTrim},
{nil, nil, ast.LTrim},
{nil, nil, ast.RTrim},
} {
fc := funcs[v.fn]
f, err := fc.getFunction(datumsToConstants(types.MakeDatums(v.str)), s.ctx)
c.Assert(err, IsNil)
r, err := f.eval(nil)
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum(v.result))
}
f, err := funcs[ast.Trim].getFunction([]Expression{Zero}, s.ctx)
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)

f, err = funcs[ast.Trim].getFunction([]Expression{Zero, Zero}, s.ctx)
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)

f, err = funcs[ast.Trim].getFunction([]Expression{Zero, Zero, Zero}, s.ctx)
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)
}

func (s *testEvaluatorSuite) TestLTrim(c *C) {
Expand Down
4 changes: 2 additions & 2 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -3523,15 +3523,15 @@ FunctionCallNonKeyword:
| "TRIM" '(' TrimDirection "FROM" Expression ')'
{
nilVal := ast.NewValueExpr(nil)
direction := ast.NewValueExpr($3)
direction := ast.NewValueExpr(int($3.(ast.TrimDirectionType)))
$$ = &ast.FuncCallExpr{
FnName: model.NewCIStr($1),
Args: []ast.ExprNode{$5.(ast.ExprNode), nilVal, direction},
}
}
| "TRIM" '(' TrimDirection Expression "FROM" Expression ')'
{
direction := ast.NewValueExpr($3)
direction := ast.NewValueExpr(int($3.(ast.TrimDirectionType)))
$$ = &ast.FuncCallExpr{
FnName: model.NewCIStr($1),
Args: []ast.ExprNode{$6.(ast.ExprNode),$4.(ast.ExprNode), direction},
Expand Down
2 changes: 2 additions & 0 deletions plan/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ func (s *testPlanSuite) TestInferType(c *C) {
{"ltrim(c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"rtrim(c_char)", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"rtrim(c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"trim(c_char)", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"trim(c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},

{"cot(c_int)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"cot(c_float)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
Expand Down

0 comments on commit 3527750

Please sign in to comment.