From 14cd44c91ff99f1a4a6f59b2304340104d4336cc Mon Sep 17 00:00:00 2001 From: "Santiago M. Mola" Date: Fri, 23 Dec 2016 16:12:42 +0100 Subject: [PATCH] sql: add GROUP BY support. Closes #52. (#86) * sql: add AggregationExpression interface. * sql: add function registry to Catalog. * sql/expression: add Count and First implementations. * sql/plan: add GroupBy node. --- README.md | 1 + engine.go | 8 +- engine_test.go | 7 ++ sql/analyzer/rules.go | 27 +++++ sql/catalog.go | 107 ++++++++++++++++ sql/catalog_test.go | 146 +++++++++++++++++++++- sql/core.go | 26 +++- sql/expression/aggregation.go | 115 ++++++++++++++++++ sql/expression/aggregation_test.go | 96 +++++++++++++++ sql/expression/common.go | 15 +++ sql/expression/unresolved.go | 36 ++++++ sql/parse/parse.go | 64 ++++++++-- sql/parse/parse_test.go | 19 +++ sql/plan/common.go | 26 ++++ sql/plan/filter.go | 4 - sql/plan/group_by.go | 189 +++++++++++++++++++++++++++++ sql/plan/group_by_test.go | 88 ++++++++++++++ sql/plan/limit.go | 4 - sql/plan/project.go | 18 +-- sql/plan/sort.go | 4 - 20 files changed, 957 insertions(+), 43 deletions(-) create mode 100644 sql/expression/aggregation.go create mode 100644 sql/expression/aggregation_test.go create mode 100644 sql/plan/group_by.go create mode 100644 sql/plan/group_by_test.go diff --git a/README.md b/README.md index 9d2c00cef..b670fb852 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ gitql supports a subset of the SQL standard, currently including: * `WHERE` * `ORDER BY` (with `ASC` and `DESC`) * `LIMIT` +* `GROUP BY` (with `COUNT` and `FIRST`) * `SHOW TABLES` * `DESCRIBE TABLE` diff --git a/engine.go b/engine.go index f72c1b88d..388e3564f 100644 --- a/engine.go +++ b/engine.go @@ -4,6 +4,7 @@ import ( "github.com/gitql/gitql/sql" "github.com/gitql/gitql/sql/analyzer" "github.com/gitql/gitql/sql/parse" + "github.com/gitql/gitql/sql/expression" ) type Engine struct { @@ -12,7 +13,12 @@ type Engine struct { } func New() *Engine { - c := &sql.Catalog{} + c := sql.NewCatalog() + err := expression.RegisterDefaults(c) + if err != nil { + panic(err) + } + a := analyzer.New(c) return &Engine{c, a} } diff --git a/engine_test.go b/engine_test.go index 408ada82e..be0a03a81 100644 --- a/engine_test.go +++ b/engine_test.go @@ -51,6 +51,13 @@ func TestEngine_Query(t *testing.T) { sql.NewMemoryRow(int64(1)), }, ) + + testQuery(t, e, + "SELECT COUNT(*) FROM mytable;", + []sql.Row{ + sql.NewMemoryRow(int32(3)), + }, + ) } func testQuery(t *testing.T, e *gitql.Engine, q string, r []sql.Row) { diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index cee432dea..2b9773b49 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -11,6 +11,7 @@ var DefaultRules = []Rule{ {"resolve_columns", resolveColumns}, {"resolve_database", resolveDatabase}, {"resolve_star", resolveStar}, + {"resolve_functions", resolveFunctions}, } func resolveDatabase(a *Analyzer, n sql.Node) sql.Node { @@ -102,3 +103,29 @@ func resolveColumns(a *Analyzer, n sql.Node) sql.Node { return gf }) } + +func resolveFunctions(a *Analyzer, n sql.Node) sql.Node { + if n.Resolved() { + return n + } + + return n.TransformExpressionsUp(func(e sql.Expression) sql.Expression { + uf, ok := e.(*expression.UnresolvedFunction) + if !ok { + return e + } + + n := uf.Name() + f, err := a.Catalog.Function(n) + if err != nil { + return e + } + + rf, err := f.Build(uf.Children...) + if err != nil { + return e + } + + return rf + }) +} diff --git a/sql/catalog.go b/sql/catalog.go index d61d60d01..b9efa9021 100644 --- a/sql/catalog.go +++ b/sql/catalog.go @@ -1,11 +1,20 @@ package sql import ( + "errors" "fmt" + "reflect" ) type Catalog struct { Databases []Database + Functions map[string]*FunctionEntry +} + +func NewCatalog() *Catalog { + return &Catalog{ + Functions: map[string]*FunctionEntry{}, + } } func (c Catalog) Database(name string) (Database, error) { @@ -32,3 +41,101 @@ func (c Catalog) Table(dbName string, tableName string) (Table, error) { return table, nil } + +func (c Catalog) RegisterFunction(name string, f interface{}) error { + e, err := inspectFunction(f) + if err != nil { + return err + } + + c.Functions[name] = e + return nil +} + +func (c Catalog) Function(name string) (*FunctionEntry, error) { + e, ok := c.Functions[name] + if !ok { + return nil, fmt.Errorf("function not found: %s", name) + } + + return e, nil +} + +type FunctionEntry struct { + v reflect.Value +} + +func (e *FunctionEntry) Build(args ...Expression) (Expression, error) { + t := e.v.Type() + if !t.IsVariadic() && len(args) != t.NumIn() { + return nil, fmt.Errorf("expected %d args, got %d", + t.NumIn(), len(args)) + } + + if t.IsVariadic() && len(args) < t.NumIn()-1 { + return nil, fmt.Errorf("expected at least %d args, got %d", + t.NumIn(), len(args)) + } + + var in []reflect.Value + for _, arg := range args { + in = append(in, reflect.ValueOf(arg)) + } + + out := e.v.Call(in) + if len(out) != 1 { + return nil, fmt.Errorf("expected 1 return value, got %d: ", len(out)) + } + + expr, ok := out[0].Interface().(Expression) + if !ok { + return nil, errors.New("return value doesn't implement Expression") + } + + return expr, nil +} + +var ( + expressionType = buildExpressionType() + expressionSliceType = buildExpressionSliceType() +) + +func buildExpressionType() reflect.Type { + var v Expression + return reflect.ValueOf(&v).Elem().Type() +} + +func buildExpressionSliceType() reflect.Type { + var v []Expression + return reflect.ValueOf(&v).Elem().Type() +} + +func inspectFunction(f interface{}) (*FunctionEntry, error) { + v := reflect.ValueOf(f) + t := v.Type() + if t.Kind() != reflect.Func { + return nil, fmt.Errorf("expected function, got: %s", t.Kind()) + } + + if t.NumOut() != 1 { + return nil, errors.New("function builders must return a single Expression") + } + + out := t.Out(0) + if !out.Implements(expressionType) { + return nil, fmt.Errorf("return value doesn't implement Expression: %s", out) + } + + for i := 0; i < t.NumIn(); i++ { + in := t.In(i) + if i == t.NumIn()-1 && t.IsVariadic() && in == expressionSliceType { + continue + } + + if in != expressionType { + return nil, fmt.Errorf("input argument %d is not a Expression", i) + } + } + + return &FunctionEntry{v}, nil +} diff --git a/sql/catalog_test.go b/sql/catalog_test.go index bf2155811..0d3d7e928 100644 --- a/sql/catalog_test.go +++ b/sql/catalog_test.go @@ -6,13 +6,14 @@ import ( "github.com/gitql/gitql/mem" "github.com/gitql/gitql/sql" + "github.com/gitql/gitql/sql/expression" "github.com/stretchr/testify/assert" ) func TestCatalog_Database(t *testing.T) { assert := assert.New(t) - c := sql.Catalog{} + c := sql.NewCatalog() db, err := c.Database("foo") assert.EqualError(err, "database not found: foo") assert.Nil(db) @@ -28,7 +29,7 @@ func TestCatalog_Database(t *testing.T) { func TestCatalog_Table(t *testing.T) { assert := assert.New(t) - c := sql.Catalog{} + c := sql.NewCatalog() table, err := c.Table("foo", "bar") assert.EqualError(err, "database not found: foo") @@ -48,3 +49,144 @@ func TestCatalog_Table(t *testing.T) { assert.NoError(err) assert.Equal(mytable, table) } + +func TestCatalog_RegisterFunction_NoArgs(t *testing.T) { + assert := assert.New(t) + + c := sql.NewCatalog() + name := "func" + var expected sql.Expression = expression.NewStar() + err := c.RegisterFunction(name, func() sql.Expression { + return expected + }) + assert.Nil(err) + + f, err := c.Function(name) + assert.Nil(err) + + e, err := f.Build() + assert.Nil(err) + assert.Equal(expected, e) + + e, err = f.Build(expression.NewStar()) + assert.NotNil(err) + assert.Nil(e) + + e, err = f.Build(expression.NewStar(), expression.NewStar()) + assert.NotNil(err) + assert.Nil(e) +} + +func TestCatalog_RegisterFunction_OneArg(t *testing.T) { + assert := assert.New(t) + + c := sql.NewCatalog() + name := "func" + var expected sql.Expression = expression.NewStar() + err := c.RegisterFunction(name, func(sql.Expression) sql.Expression { + return expected + }) + assert.Nil(err) + + f, err := c.Function(name) + assert.Nil(err) + + e, err := f.Build() + assert.NotNil(err) + assert.Nil(e) + + e, err = f.Build(expression.NewStar()) + assert.Nil(err) + assert.Equal(expected, e) + + e, err = f.Build(expression.NewStar(), expression.NewStar()) + assert.NotNil(err) + assert.Nil(e) +} + +func TestCatalog_RegisterFunction_Variadic(t *testing.T) { + assert := assert.New(t) + + c := sql.NewCatalog() + name := "func" + var expected sql.Expression = expression.NewStar() + err := c.RegisterFunction(name, func(...sql.Expression) sql.Expression { + return expected + }) + assert.Nil(err) + + f, err := c.Function(name) + assert.Nil(err) + + e, err := f.Build() + assert.Nil(err) + assert.Equal(expected, e) + + e, err = f.Build(expression.NewStar()) + assert.Nil(err) + assert.Equal(expected, e) + + e, err = f.Build(expression.NewStar(), expression.NewStar()) + assert.Nil(err) + assert.Equal(expected, e) +} + +func TestCatalog_RegisterFunction_OneAndVariadic(t *testing.T) { + assert := assert.New(t) + + c := sql.NewCatalog() + name := "func" + var expected sql.Expression = expression.NewStar() + err := c.RegisterFunction(name, func(sql.Expression, ...sql.Expression) sql.Expression { + return expected + }) + assert.Nil(err) + + f, err := c.Function(name) + assert.Nil(err) + + e, err := f.Build() + assert.NotNil(err) + assert.Nil(e) + + e, err = f.Build(expression.NewStar()) + assert.Nil(err) + assert.Equal(expected, e) + + e, err = f.Build(expression.NewStar(), expression.NewStar()) + assert.Nil(err) + assert.Equal(expected, e) +} + +func TestCatalog_RegisterFunction_Invalid(t *testing.T) { + assert := assert.New(t) + + c := sql.NewCatalog() + name := "func" + err := c.RegisterFunction(name, func(sql.Table) sql.Expression { + return nil + }) + assert.NotNil(err) + + err = c.RegisterFunction(name, func(sql.Expression) sql.Table { + return nil + }) + assert.NotNil(err) + + err = c.RegisterFunction(name, func(sql.Expression) (sql.Table, error) { + return nil, nil + }) + assert.NotNil(err) + + err = c.RegisterFunction(name, 1) + assert.NotNil(err) +} + +func TestCatalog_Function_NotExists(t *testing.T) { + assert := assert.New(t) + + c := sql.NewCatalog() + f, err := c.Function("func") + assert.NotNil(err) + assert.Nil(f) +} diff --git a/sql/core.go b/sql/core.go index cf289428a..353a0894d 100644 --- a/sql/core.go +++ b/sql/core.go @@ -1,6 +1,8 @@ package sql -import "errors" +import ( + "errors" +) type Nameable interface { Name() string @@ -23,6 +25,28 @@ type Expression interface { TransformUp(func(Expression) Expression) Expression } +// AggregationExpression implements an aggregation expression, where an +// aggregation buffer is created for each grouping (NewBuffer) and rows in the +// grouping are fed to the buffer (Update). Multiple buffers can be merged +// (Merge), making partial aggregations possible. +// Note that Eval must be called with the final aggregation buffer in order to +// get the final result. +type AggregationExpression interface { + Expression + // NewBuffer creates a new aggregation buffer and returns it as a Row. + NewBuffer() Row + // Update updates the given buffer with the given row. + Update(buffer, row Row) + // Merge merges a partial buffer into a global one. + Merge(buffer, partial Row) +} + +type Aggregation interface { + Update(Row) (Row, error) + Merge(Row) + Eval() interface{} +} + type Node interface { Resolvable Transformable diff --git a/sql/expression/aggregation.go b/sql/expression/aggregation.go new file mode 100644 index 000000000..1e45e5b89 --- /dev/null +++ b/sql/expression/aggregation.go @@ -0,0 +1,115 @@ +package expression + +import ( + "fmt" + + "github.com/gitql/gitql/sql" +) + +type Count struct { + UnaryExpression +} + +func NewCount(e sql.Expression) *Count { + return &Count{UnaryExpression{e}} +} + +func (c *Count) NewBuffer() sql.Row { + return sql.NewMemoryRow(int32(0)) +} + +func (c *Count) Type() sql.Type { + return sql.Integer +} + +func (c *Count) Resolved() bool { + if _, ok := c.Child.(*Star); ok { + return true + } + + return c.Child.Resolved() +} + +func (c *Count) Name() string { + return fmt.Sprintf("count(%s)", c.Child.Name()) +} + +func (c *Count) TransformUp(f func(sql.Expression) sql.Expression) sql.Expression { + nc := c.UnaryExpression.Child.TransformUp(f) + return f(NewCount(nc)) +} + +func (c *Count) Update(buffer, row sql.Row) { + mr := buffer.(sql.MemoryRow) + var inc bool + if _, ok := c.Child.(*Star); ok { + inc = true + } else { + v := c.Child.Eval(row) + if v != nil { + inc = true + } + } + + if inc { + mr[0] = getInt32At(buffer, 0) + int32(1) + } +} + +func (c *Count) Merge(buffer, partial sql.Row) { + mb := buffer.(sql.MemoryRow) + mb[0] = getInt32At(buffer, 0) + getInt32At(partial, 0) +} + +func (c *Count) Eval(buffer sql.Row) interface{} { + return getInt32At(buffer, 0) +} + +type First struct { + UnaryExpression +} + +func NewFirst(e sql.Expression) *First { + return &First{UnaryExpression{e}} +} + +func (e *First) NewBuffer() sql.Row { + return sql.NewMemoryRow(nil) +} + +func (e *First) Type() sql.Type { + return e.Child.Type() +} + +func (e *First) Name() string { + return fmt.Sprintf("first(%s)", e.Child.Name()) +} + +func (e *First) TransformUp(f func(sql.Expression) sql.Expression) sql.Expression { + nc := e.UnaryExpression.Child.TransformUp(f) + return f(NewFirst(nc)) +} + +func (e *First) Update(buffer, row sql.Row) { + mr := buffer.(sql.MemoryRow) + if mr[0] == nil { + mr[0] = e.Child.Eval(row) + } +} + +func (e *First) Merge(buffer, partial sql.Row) { + mb := buffer.(sql.MemoryRow) + if mb[0] == nil { + mp := partial.(sql.MemoryRow) + mb[0] = mp[0] + } +} + +func (e *First) Eval(buffer sql.Row) interface{} { + return buffer.Fields()[0] +} + +func getInt32At(row sql.Row, i int) int32 { + f := row.Fields() + return f[i].(int32) +} diff --git a/sql/expression/aggregation_test.go b/sql/expression/aggregation_test.go new file mode 100644 index 000000000..ff79becda --- /dev/null +++ b/sql/expression/aggregation_test.go @@ -0,0 +1,96 @@ +package expression + +import ( + "testing" + + "github.com/gitql/gitql/sql" + + "github.com/stretchr/testify/require" +) + +func TestCount_Name(t *testing.T) { + assert := require.New(t) + + c := NewCount(NewLiteral("foo", sql.String)) + assert.Equal("count(literal_string)", c.Name()) +} + +func TestCount_Eval_1(t *testing.T) { + assert := require.New(t) + + c := NewCount(NewLiteral(1, sql.Integer)) + b := c.NewBuffer() + assert.Equal(int32(0), c.Eval(b)) + + c.Update(b, sql.NewMemoryRow()) + c.Update(b, sql.NewMemoryRow("foo")) + c.Update(b, sql.NewMemoryRow(1)) + c.Update(b, sql.NewMemoryRow(1, 2, 3)) + assert.Equal(int32(4), c.Eval(b)) + + b2 := c.NewBuffer() + c.Update(b2, sql.NewMemoryRow()) + c.Update(b2, sql.NewMemoryRow("foo")) + c.Merge(b, b2) + assert.Equal(int32(6), c.Eval(b)) +} + +func TestCount_Eval_Star(t *testing.T) { + assert := require.New(t) + + c := NewCount(NewStar()) + b := c.NewBuffer() + assert.Equal(int32(0), c.Eval(b)) + + c.Update(b, sql.NewMemoryRow()) + c.Update(b, sql.NewMemoryRow("foo")) + c.Update(b, sql.NewMemoryRow(1)) + c.Update(b, sql.NewMemoryRow(1, 2, 3)) + assert.Equal(int32(4), c.Eval(b)) + + b2 := c.NewBuffer() + c.Update(b2, sql.NewMemoryRow()) + c.Update(b2, sql.NewMemoryRow("foo")) + c.Merge(b, b2) + assert.Equal(int32(6), c.Eval(b)) +} + +func TestCount_Eval_String(t *testing.T) { + assert := require.New(t) + + c := NewCount(NewGetField(0, sql.String, "")) + b := c.NewBuffer() + assert.Equal(int32(0), c.Eval(b)) + + c.Update(b, sql.NewMemoryRow("foo")) + assert.Equal(int32(1), c.Eval(b)) + + c.Update(b, sql.NewMemoryRow(nil)) + assert.Equal(int32(1), c.Eval(b)) +} + +func TestFirst_Name(t *testing.T) { + assert := require.New(t) + + c := NewFirst(NewGetField(0, sql.Integer, "field")) + assert.Equal("first(field)", c.Name()) +} + +func TestFirst_Eval(t *testing.T) { + assert := require.New(t) + + c := NewFirst(NewGetField(0, sql.Integer, "field")) + b := c.NewBuffer() + assert.Nil(c.Eval(b)) + + c.Update(b, sql.NewMemoryRow(int32(1))) + assert.Equal(int32(1), c.Eval(b)) + + c.Update(b, sql.NewMemoryRow(int32(2))) + assert.Equal(int32(1), c.Eval(b)) + + b2 := c.NewBuffer() + c.Update(b2, sql.NewMemoryRow(int32(2))) + c.Merge(b, b2) + assert.Equal(int32(1), c.Eval(b)) +} diff --git a/sql/expression/common.go b/sql/expression/common.go index aeca52121..1211ac5f5 100644 --- a/sql/expression/common.go +++ b/sql/expression/common.go @@ -18,3 +18,18 @@ type BinaryExpression struct { func (p BinaryExpression) Resolved() bool { return p.Left.Resolved() && p.Right.Resolved() } + +var defaultFunctions = map[string]interface{}{ + "count": NewCount, + "first": NewFirst, +} + +func RegisterDefaults(c *sql.Catalog) error { + for k, v := range defaultFunctions { + if err := c.RegisterFunction(k, v); err != nil { + return err + } + } + + return nil +} diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index c99fc4d01..5be83bd51 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -30,3 +30,39 @@ func (p *UnresolvedColumn) TransformUp(f func(sql.Expression) sql.Expression) sq n := *p return f(&n) } + +type UnresolvedFunction struct { + name string + IsAggregate bool + Children []sql.Expression +} + +func NewUnresolvedFunction(name string, agg bool, + children ...sql.Expression) *UnresolvedFunction { + return &UnresolvedFunction{name, agg, children} +} + +func (UnresolvedFunction) Resolved() bool { + return false +} + +func (UnresolvedFunction) Type() sql.Type { + return sql.String //FIXME +} + +func (c UnresolvedFunction) Name() string { + return c.name +} + +func (UnresolvedFunction) Eval(r sql.Row) interface{} { + return "FAIL" //FIXME +} + +func (p *UnresolvedFunction) TransformUp(f func(sql.Expression) sql.Expression) sql.Expression { + var rc []sql.Expression + for _, c := range p.Children { + rc = append(rc, f(c)) + } + + return f(NewUnresolvedFunction(p.name, p.IsAggregate, rc...)) +} diff --git a/sql/parse/parse.go b/sql/parse/parse.go index baed55810..4d0cbad1a 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -69,10 +69,6 @@ func convertSelect(s *sqlparser.Select) (sql.Node, error) { return nil, errUnsupportedFeature("DISTINCT") } - if len(s.GroupBy) != 0 { - return nil, errUnsupportedFeature("GROUP BY") - } - if s.Having != nil { return nil, errUnsupportedFeature("HAVING") } @@ -99,12 +95,7 @@ func convertSelect(s *sqlparser.Select) (sql.Node, error) { } } - node, err = selectToProject(s.SelectExprs, node) - if err != nil { - return nil, err - } - - return node, nil + return selectToProjectOrGroupBy(s.SelectExprs, s.GroupBy, node) } func tableExprsToTable(te sqlparser.TableExprs) (sql.Node, error) { @@ -199,7 +190,34 @@ func limitToLimit(o sqlparser.ValExpr, child sql.Node) (*plan.Limit, error) { return plan.NewLimit(n, child), nil } -func selectToProject(se sqlparser.SelectExprs, child sql.Node) (*plan.Project, error) { +func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, child sql.Node) (sql.Node, error) { + selectExprs, err := selectExprsToExpressions(se) + if err != nil { + return nil, err + } + + isAgg := len(g) > 0 + if !isAgg { + for _, e := range selectExprs { + if u, ok := e.(*expression.UnresolvedFunction); ok { + isAgg = u.IsAggregate + } + } + } + + if isAgg { + groupingExprs, err := groupByToExpressions(g) + if err != nil { + return nil, err + } + + return plan.NewGroupBy(selectExprs, groupingExprs, child), nil + } + + return plan.NewProject(selectExprs, child), nil +} + +func selectExprsToExpressions(se sqlparser.SelectExprs) ([]sql.Expression, error) { var exprs []sql.Expression for _, e := range se { pe, err := selectExprToExpression(e) @@ -210,7 +228,7 @@ func selectToProject(se sqlparser.SelectExprs, child sql.Node) (*plan.Project, e exprs = append(exprs, pe) } - return plan.NewProject(exprs, child), nil + return exprs, nil } func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { @@ -281,6 +299,20 @@ func comparisonExprToExpression(c *sqlparser.ComparisonExpr) (sql.Expression, } } +func groupByToExpressions(g sqlparser.GroupBy) ([]sql.Expression, error) { + es := make([]sql.Expression, len(g)) + for i, ve := range g { + e, err := valExprToExpression(ve) + if err != nil { + return nil, err + } + + es[i] = e + } + + return es, nil +} + func valExprToExpression(ve sqlparser.ValExpr) (sql.Expression, error) { switch v := ve.(type) { default: @@ -302,6 +334,14 @@ func valExprToExpression(ve sqlparser.ValExpr) (sql.Expression, error) { case *sqlparser.ColName: //TODO: add handling of case sensitiveness. return expression.NewUnresolvedColumn(v.Name.Lowered()), nil + case *sqlparser.FuncExpr: + exprs, err := selectExprsToExpressions(v.Exprs) + if err != nil { + return nil, err + } + + return expression.NewUnresolvedFunction(v.Name.Lowered(), + v.IsAggregate(), exprs...), nil } } diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 80413f0a0..56acf42fc 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -142,6 +142,25 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("t2"), ), ), + `SELECT foo, bar FROM t1 GROUP BY foo, bar;`: plan.NewGroupBy( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, + plan.NewUnresolvedTable("t1"), + ), + `SELECT COUNT(*) FROM t1;`: plan.NewGroupBy( + []sql.Expression{ + expression.NewUnresolvedFunction("count", true, + expression.NewStar()), + }, + []sql.Expression{}, + plan.NewUnresolvedTable("t1"), + ), } func TestParse(t *testing.T) { diff --git a/sql/plan/common.go b/sql/plan/common.go index 972ec903a..5a857176e 100644 --- a/sql/plan/common.go +++ b/sql/plan/common.go @@ -6,6 +6,10 @@ type UnaryNode struct { Child sql.Node } +func (n *UnaryNode) Schema() sql.Schema { + return n.Child.Schema() +} + func (n UnaryNode) Resolved() bool { return n.Child.Resolved() } @@ -22,3 +26,25 @@ type BinaryNode struct { func (n BinaryNode) Children() []sql.Node { return []sql.Node{n.Left, n.Right} } + +func expressionsResolved(exprs ...sql.Expression) bool { + for _, e := range exprs { + if !e.Resolved() { + return false + } + } + + return true +} + +func transformExpressionsUp(f func(sql.Expression) sql.Expression, + exprs []sql.Expression) []sql.Expression { + + var es []sql.Expression + for _, e := range exprs { + te := e.TransformUp(f) + es = append(es, te) + } + + return es +} diff --git a/sql/plan/filter.go b/sql/plan/filter.go index a9f4df0d3..77dd1c641 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -14,10 +14,6 @@ func NewFilter(expression sql.Expression, child sql.Node) *Filter { } } -func (p *Filter) Schema() sql.Schema { - return p.UnaryNode.Child.Schema() -} - func (p *Filter) Resolved() bool { return p.UnaryNode.Child.Resolved() && p.expression.Resolved() } diff --git a/sql/plan/group_by.go b/sql/plan/group_by.go new file mode 100644 index 000000000..8419c56bf --- /dev/null +++ b/sql/plan/group_by.go @@ -0,0 +1,189 @@ +package plan + +import ( + "fmt" + "io" + "strings" + + "github.com/gitql/gitql/sql" + "github.com/gitql/gitql/sql/expression" +) + +type GroupBy struct { + UnaryNode + aggregate []sql.Expression + grouping []sql.Expression +} + +func NewGroupBy(aggregate []sql.Expression, grouping []sql.Expression, + child sql.Node) *GroupBy { + + return &GroupBy{ + UnaryNode: UnaryNode{Child: child}, + aggregate: aggregate, + grouping: grouping, + } +} + +func (p *GroupBy) Resolved() bool { + return p.UnaryNode.Child.Resolved() && + expressionsResolved(p.aggregate...) && + expressionsResolved(p.grouping...) +} + +func (p *GroupBy) Schema() sql.Schema { + s := sql.Schema{} + for _, e := range p.aggregate { + s = append(s, sql.Field{ + Name: e.Name(), + Type: e.Type(), + }) + } + + return s +} + +func (p *GroupBy) RowIter() (sql.RowIter, error) { + i, err := p.Child.RowIter() + if err != nil { + return nil, err + } + return newGroupByIter(p, i), nil +} + +func (p *GroupBy) TransformUp(f func(sql.Node) sql.Node) sql.Node { + c := p.UnaryNode.Child.TransformUp(f) + n := NewGroupBy(p.aggregate, p.grouping, c) + + return f(n) +} + +func (p *GroupBy) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node { + c := p.UnaryNode.Child.TransformExpressionsUp(f) + aes := transformExpressionsUp(f, p.aggregate) + ges := transformExpressionsUp(f, p.grouping) + n := NewGroupBy(aes, ges, c) + + return n +} + +type groupByIter struct { + p *GroupBy + childIter sql.RowIter + rows []sql.Row + idx int +} + +func newGroupByIter(p *GroupBy, child sql.RowIter) *groupByIter { + return &groupByIter{ + p: p, + childIter: child, + rows: nil, + idx: -1, + } +} + +func (i *groupByIter) Next() (sql.Row, error) { + if i.idx == -1 { + err := i.computeRows() + if err != nil { + return nil, err + } + i.idx = 0 + } + if i.idx >= len(i.rows) { + return nil, io.EOF + } + row := i.rows[i.idx] + i.idx++ + return row, nil +} + +func (i *groupByIter) computeRows() error { + rows := []sql.Row{} + for { + childRow, err := i.childIter.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + rows = append(rows, childRow) + } + + rows, err := groupBy(rows, i.p.aggregate, i.p.grouping) + if err != nil { + return err + } + + i.rows = rows + return nil +} + +func groupBy(rows []sql.Row, aggExpr []sql.Expression, + groupExpr []sql.Expression) ([]sql.Row, error) { + + //TODO: currently, we first group all rows, and then + // compute aggregations in a separate stage. We should + // compute aggregations incrementally instead. + + hrows := map[interface{}][]sql.Row{} + for _, row := range rows { + key := groupingKey(groupExpr, row) + hrows[key] = append(hrows[key], row) + } + + result := make([]sql.Row, 0, len(hrows)) + for _, rows := range hrows { + row := aggregate(aggExpr, rows) + result = append(result, row) + } + + return result, nil +} + +func groupingKey(exprs []sql.Expression, row sql.Row) interface{} { + //TODO: use a more robust/efficient way of calculating grouping keys. + vals := make([]string, 0, len(exprs)) + for _, expr := range exprs { + vals = append(vals, fmt.Sprintf("%#v", expr.Eval(row))) + } + + return strings.Join(vals, ",") +} + +func aggregate(exprs []sql.Expression, rows []sql.Row) sql.Row { + aggs := exprsToAggregateExprs(exprs) + + buffers := make([]sql.Row, len(aggs)) + for i, agg := range aggs { + buffers[i] = agg.NewBuffer() + } + + for _, row := range rows { + for i, agg := range aggs { + agg.Update(buffers[i], row) + } + } + + fields := make([]interface{}, 0, len(exprs)) + for i, agg := range aggs { + fields = append(fields, agg.Eval(buffers[i])) + } + + return sql.NewMemoryRow(fields...) +} + +func exprsToAggregateExprs(exprs []sql.Expression) []sql.AggregationExpression { + var r []sql.AggregationExpression + for _, e := range exprs { + if ae, ok := e.(sql.AggregationExpression); ok { + r = append(r, ae) + } else { + r = append(r, expression.NewFirst(e)) + } + } + + return r +} diff --git a/sql/plan/group_by_test.go b/sql/plan/group_by_test.go new file mode 100644 index 000000000..5c1d893d4 --- /dev/null +++ b/sql/plan/group_by_test.go @@ -0,0 +1,88 @@ +package plan + +import ( + "testing" + + "github.com/gitql/gitql/mem" + "github.com/gitql/gitql/sql" + "github.com/gitql/gitql/sql/expression" + + "github.com/stretchr/testify/assert" +) + +func TestGroupBy_Schema(t *testing.T) { + assert := assert.New(t) + + child := mem.NewTable("test", sql.Schema{}) + agg := []sql.Expression{ + expression.NewAlias(expression.NewLiteral("s", sql.String), "c1"), + expression.NewAlias(expression.NewCount(expression.NewStar()), "c2"), + } + gb := NewGroupBy(agg, nil, child) + assert.Equal(sql.Schema{ + sql.Field{Name: "c1", Type: sql.String}, + sql.Field{Name: "c2", Type: sql.Integer}, + }, gb.Schema()) +} + +func TestGroupBy_Resolved(t *testing.T) { + assert := assert.New(t) + + child := mem.NewTable("test", sql.Schema{}) + agg := []sql.Expression{ + expression.NewAlias(expression.NewCount(expression.NewStar()), "c2"), + } + gb := NewGroupBy(agg, nil, child) + assert.True(gb.Resolved()) + + agg = []sql.Expression{ + expression.NewStar(), + } + gb = NewGroupBy(agg, nil, child) + assert.False(gb.Resolved()) +} + +func TestGroupBy_RowIter(t *testing.T) { + assert := assert.New(t) + childSchema := sql.Schema{ + sql.Field{"col1", sql.String}, + sql.Field{"col2", sql.BigInteger}, + } + child := mem.NewTable("test", childSchema) + child.Insert("col1_1", int64(1111)) + child.Insert("col1_1", int64(1111)) + child.Insert("col1_2", int64(4444)) + child.Insert("col1_1", int64(1111)) + child.Insert("col1_2", int64(4444)) + + p := NewSort( + []SortField{ + { + Column: expression.NewGetField(0, sql.String, "col1"), + Order: Ascending, + }, { + Column: expression.NewGetField(1, sql.BigInteger, "col2"), + Order: Ascending, + }, + }, + NewGroupBy( + []sql.Expression{ + expression.NewGetField(0, sql.String, "col1"), + expression.NewGetField(1, sql.BigInteger, "col2"), + }, + []sql.Expression{ + expression.NewGetField(0, sql.String, "col1"), + expression.NewGetField(1, sql.BigInteger, "col2"), + }, + child, + )) + + assert.Equal(1, len(p.Children())) + + rows, err := sql.NodeToRows(p) + assert.NoError(err) + assert.Len(rows, 2) + + assert.Equal(sql.NewMemoryRow("col1_1", int64(1111)), rows[0]) + assert.Equal(sql.NewMemoryRow("col1_2", int64(4444)), rows[1]) +} diff --git a/sql/plan/limit.go b/sql/plan/limit.go index b8751adb0..ccdefbd6f 100644 --- a/sql/plan/limit.go +++ b/sql/plan/limit.go @@ -18,10 +18,6 @@ func NewLimit(size int64, child sql.Node) *Limit { } } -func (l *Limit) Schema() sql.Schema { - return l.UnaryNode.Child.Schema() -} - func (p *Limit) Resolved() bool { return p.UnaryNode.Child.Resolved() } diff --git a/sql/plan/project.go b/sql/plan/project.go index bf78d8834..60561abed 100644 --- a/sql/plan/project.go +++ b/sql/plan/project.go @@ -29,16 +29,8 @@ func (p *Project) Schema() sql.Schema { } func (p *Project) Resolved() bool { - return p.UnaryNode.Child.Resolved() && p.expressionsResolved() -} - -func (p *Project) expressionsResolved() bool { - for _, e := range p.Expressions { - if !e.Resolved() { - return false - } - } - return true + return p.UnaryNode.Child.Resolved() && + expressionsResolved(p.Expressions...) } func (p *Project) RowIter() (sql.RowIter, error) { @@ -58,11 +50,7 @@ func (p *Project) TransformUp(f func(sql.Node) sql.Node) sql.Node { func (p *Project) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node { c := p.UnaryNode.Child.TransformExpressionsUp(f) - es := []sql.Expression{} - for _, e := range p.Expressions { - te := e.TransformUp(f) - es = append(es, te) - } + es := transformExpressionsUp(f, p.Expressions) n := NewProject(es, c) return n diff --git a/sql/plan/sort.go b/sql/plan/sort.go index c1848f1e5..f4f184c2f 100644 --- a/sql/plan/sort.go +++ b/sql/plan/sort.go @@ -44,10 +44,6 @@ func (p *Sort) expressionsResolved() bool { return true } -func (s *Sort) Schema() sql.Schema { - return s.UnaryNode.Child.Schema() -} - func (s *Sort) RowIter() (sql.RowIter, error) { i, err := s.UnaryNode.Child.RowIter()