From 7154635629bb67cacfa59617ea086de94a0f36a2 Mon Sep 17 00:00:00 2001 From: Antonio Navarro Perez Date: Fri, 10 Feb 2017 17:27:47 +0100 Subject: [PATCH] expression: Add regexp support (#105) --- sql/expression/comparison.go | 49 ++++++++++++++++++++++++--- sql/expression/comparison_test.go | 55 +++++++++++++++++++++++++++++++ sql/parse/parse.go | 2 ++ sql/parse/parse_test.go | 12 +++++++ 4 files changed, 114 insertions(+), 4 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 79a9a4751..e82c38cf0 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -2,6 +2,7 @@ package expression import ( "fmt" + "regexp" "github.com/gitql/gitql/sql" ) @@ -42,6 +43,50 @@ func (c *Equals) TransformUp(f func(sql.Expression) sql.Expression) sql.Expressi return f(NewEquals(lc, rc)) } +func (e Equals) Name() string { + return e.Left.Name() + "==" + e.Right.Name() +} + +type Regexp struct { + Comparison +} + +func NewRegexp(left sql.Expression, right sql.Expression) *Regexp { + // FIXME: enable this again + // checkEqualTypes(left, right) + return &Regexp{Comparison{BinaryExpression{left, right}, left.Type()}} +} + +func (e Regexp) Eval(row sql.Row) interface{} { + l := e.Left.Eval(row) + r := e.Right.Eval(row) + + sl, okl := l.(string) + sr, okr := r.(string) + + if !okl || !okr { + return e.ChildType.Compare(l, r) == 0 + } + + reg, err := regexp.Compile(sr) + if err != nil { + return false + } + + return reg.MatchString(sl) +} + +func (c *Regexp) TransformUp(f func(sql.Expression) sql.Expression) sql.Expression { + lc := c.BinaryExpression.Left.TransformUp(f) + rc := c.BinaryExpression.Right.TransformUp(f) + + return f(NewRegexp(lc, rc)) +} + +func (e Regexp) Name() string { + return e.Left.Name() + " REGEXP " + e.Right.Name() +} + type GreaterThan struct { Comparison } @@ -139,7 +184,3 @@ func checkEqualTypes(a sql.Expression, b sql.Expression) { panic(fmt.Errorf("both types should be equal: %v and %v\n", a, b)) } } - -func (e Equals) Name() string { - return e.Left.Name() + "==" + e.Right.Name() -} diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index 63d90ce2c..8ca2ab693 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -12,6 +12,8 @@ const ( testEqual = 1 testLess = 2 testGreater = 3 + testRegexp = 4 + testNotRegexp = 5 ) var comparisonCases = map[sql.Type]map[int][][]interface{}{ @@ -45,6 +47,33 @@ var comparisonCases = map[sql.Type]map[int][][]interface{}{ }, } +var likeComparisonCases = map[sql.Type]map[int][][]interface{}{ + sql.String: { + testRegexp: { + {"foobar", ".*bar"}, + {"foobarfoo", ".*bar.*"}, + {"bar", "bar"}, + {"barfoo", "bar.*"}, + }, + testNotRegexp: { + {"foobara", ".*bar$"}, + {"foofoo", ".*bar.*"}, + {"bara", "bar$"}, + {"abarfoo", "^bar.*"}, + }, + }, + sql.Integer: { + testRegexp: { + {int32(1), int32(1)}, + {int32(0), int32(0)}, + }, + testNotRegexp: { + {int32(-1), int32(0)}, + {int32(1), int32(2)}, + }, + }, +} + func TestComparisons_Equals(t *testing.T) { assert := require.New(t) for resultType, cmpCase := range comparisonCases { @@ -122,3 +151,29 @@ func TestComparisons_GreaterThan(t *testing.T) { } } } + +func TestComparisons_Regexp(t *testing.T) { + assert := require.New(t) + for resultType, cmpCase := range likeComparisonCases { + get0 := NewGetField(0, resultType, "col1") + assert.NotNil(get0) + get1 := NewGetField(1, resultType, "col2") + assert.NotNil(get1) + eq := NewRegexp(get0, get1) + assert.NotNil(eq) + assert.Equal(sql.Boolean, eq.Type()) + for cmpResult, cases := range cmpCase { + for _, pair := range cases { + row := sql.NewRow(pair[0], pair[1]) + assert.NotNil(row) + cmp := eq.Eval(row) + assert.NotNil(cmp) + if cmpResult == testRegexp { + assert.Equal(true, cmp) + } else { + assert.Equal(false, cmp) + } + } + } + } +} diff --git a/sql/parse/parse.go b/sql/parse/parse.go index bf3c12b2b..f51856b64 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -306,6 +306,8 @@ func comparisonExprToExpression(c *sqlparser.ComparisonExpr) (sql.Expression, switch c.Operator { default: return nil, errUnsupportedFeature(c.Operator) + case sqlparser.RegexpStr: + return expression.NewRegexp(left, right), nil case sqlparser.EqualStr: return expression.NewEquals(left, right), nil case sqlparser.LessThanStr: diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 56acf42fc..41f29bfe9 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -161,6 +161,18 @@ var fixtures = map[string]sql.Node{ []sql.Expression{}, plan.NewUnresolvedTable("t1"), ), + `SELECT a FROM t1 where a regexp '.*test.*';`: plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("a"), + }, + plan.NewFilter( + expression.NewRegexp( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(".*test.*", sql.String), + ), + plan.NewUnresolvedTable("t1"), + ), + ), } func TestParse(t *testing.T) {