Skip to content

Commit

Permalink
sql: add GROUP BY support. Closes #52. (#86)
Browse files Browse the repository at this point in the history
* sql: add AggregationExpression interface.
* sql: add function registry to Catalog.
* sql/expression: add Count and First implementations.
* sql/plan: add GroupBy node.
  • Loading branch information
smola authored Dec 23, 2016
1 parent a2cbdbc commit 14cd44c
Show file tree
Hide file tree
Showing 20 changed files with 957 additions and 43 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ gitql supports a subset of the SQL standard, currently including:
* `WHERE`
* `ORDER BY` (with `ASC` and `DESC`)
* `LIMIT`
* `GROUP BY` (with `COUNT` and `FIRST`)
* `SHOW TABLES`
* `DESCRIBE TABLE`

Expand Down
8 changes: 7 additions & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/gitql/gitql/sql"
"github.com/gitql/gitql/sql/analyzer"
"github.com/gitql/gitql/sql/parse"
"github.com/gitql/gitql/sql/expression"
)

type Engine struct {
Expand All @@ -12,7 +13,12 @@ type Engine struct {
}

func New() *Engine {
c := &sql.Catalog{}
c := sql.NewCatalog()
err := expression.RegisterDefaults(c)
if err != nil {
panic(err)
}

a := analyzer.New(c)
return &Engine{c, a}
}
Expand Down
7 changes: 7 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ func TestEngine_Query(t *testing.T) {
sql.NewMemoryRow(int64(1)),
},
)

testQuery(t, e,
"SELECT COUNT(*) FROM mytable;",
[]sql.Row{
sql.NewMemoryRow(int32(3)),
},
)
}

func testQuery(t *testing.T, e *gitql.Engine, q string, r []sql.Row) {
Expand Down
27 changes: 27 additions & 0 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ var DefaultRules = []Rule{
{"resolve_columns", resolveColumns},
{"resolve_database", resolveDatabase},
{"resolve_star", resolveStar},
{"resolve_functions", resolveFunctions},
}

func resolveDatabase(a *Analyzer, n sql.Node) sql.Node {
Expand Down Expand Up @@ -102,3 +103,29 @@ func resolveColumns(a *Analyzer, n sql.Node) sql.Node {
return gf
})
}

func resolveFunctions(a *Analyzer, n sql.Node) sql.Node {
if n.Resolved() {
return n
}

return n.TransformExpressionsUp(func(e sql.Expression) sql.Expression {
uf, ok := e.(*expression.UnresolvedFunction)
if !ok {
return e
}

n := uf.Name()
f, err := a.Catalog.Function(n)
if err != nil {
return e
}

rf, err := f.Build(uf.Children...)
if err != nil {
return e
}

return rf
})
}
107 changes: 107 additions & 0 deletions sql/catalog.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
package sql

import (
"errors"
"fmt"
"reflect"
)

type Catalog struct {
Databases []Database
Functions map[string]*FunctionEntry
}

func NewCatalog() *Catalog {
return &Catalog{
Functions: map[string]*FunctionEntry{},
}
}

func (c Catalog) Database(name string) (Database, error) {
Expand All @@ -32,3 +41,101 @@ func (c Catalog) Table(dbName string, tableName string) (Table, error) {

return table, nil
}

func (c Catalog) RegisterFunction(name string, f interface{}) error {
e, err := inspectFunction(f)
if err != nil {
return err
}

c.Functions[name] = e
return nil
}

func (c Catalog) Function(name string) (*FunctionEntry, error) {
e, ok := c.Functions[name]
if !ok {
return nil, fmt.Errorf("function not found: %s", name)
}

return e, nil
}

type FunctionEntry struct {
v reflect.Value
}

func (e *FunctionEntry) Build(args ...Expression) (Expression, error) {
t := e.v.Type()
if !t.IsVariadic() && len(args) != t.NumIn() {
return nil, fmt.Errorf("expected %d args, got %d",
t.NumIn(), len(args))
}

if t.IsVariadic() && len(args) < t.NumIn()-1 {
return nil, fmt.Errorf("expected at least %d args, got %d",
t.NumIn(), len(args))
}

var in []reflect.Value
for _, arg := range args {
in = append(in, reflect.ValueOf(arg))
}

out := e.v.Call(in)
if len(out) != 1 {
return nil, fmt.Errorf("expected 1 return value, got %d: ", len(out))
}

expr, ok := out[0].Interface().(Expression)
if !ok {
return nil, errors.New("return value doesn't implement Expression")
}

return expr, nil
}

var (
expressionType = buildExpressionType()
expressionSliceType = buildExpressionSliceType()
)

func buildExpressionType() reflect.Type {
var v Expression
return reflect.ValueOf(&v).Elem().Type()
}

func buildExpressionSliceType() reflect.Type {
var v []Expression
return reflect.ValueOf(&v).Elem().Type()
}

func inspectFunction(f interface{}) (*FunctionEntry, error) {
v := reflect.ValueOf(f)
t := v.Type()
if t.Kind() != reflect.Func {
return nil, fmt.Errorf("expected function, got: %s", t.Kind())
}

if t.NumOut() != 1 {
return nil, errors.New("function builders must return a single Expression")
}

out := t.Out(0)
if !out.Implements(expressionType) {
return nil, fmt.Errorf("return value doesn't implement Expression: %s", out)
}

for i := 0; i < t.NumIn(); i++ {
in := t.In(i)
if i == t.NumIn()-1 && t.IsVariadic() && in == expressionSliceType {
continue
}

if in != expressionType {
return nil, fmt.Errorf("input argument %d is not a Expression", i)
}
}

return &FunctionEntry{v}, nil
}
146 changes: 144 additions & 2 deletions sql/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ import (
"github.com/gitql/gitql/mem"
"github.com/gitql/gitql/sql"

"github.com/gitql/gitql/sql/expression"
"github.com/stretchr/testify/assert"
)

func TestCatalog_Database(t *testing.T) {
assert := assert.New(t)

c := sql.Catalog{}
c := sql.NewCatalog()
db, err := c.Database("foo")
assert.EqualError(err, "database not found: foo")
assert.Nil(db)
Expand All @@ -28,7 +29,7 @@ func TestCatalog_Database(t *testing.T) {
func TestCatalog_Table(t *testing.T) {
assert := assert.New(t)

c := sql.Catalog{}
c := sql.NewCatalog()

table, err := c.Table("foo", "bar")
assert.EqualError(err, "database not found: foo")
Expand All @@ -48,3 +49,144 @@ func TestCatalog_Table(t *testing.T) {
assert.NoError(err)
assert.Equal(mytable, table)
}

func TestCatalog_RegisterFunction_NoArgs(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
var expected sql.Expression = expression.NewStar()
err := c.RegisterFunction(name, func() sql.Expression {
return expected
})
assert.Nil(err)

f, err := c.Function(name)
assert.Nil(err)

e, err := f.Build()
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar())
assert.NotNil(err)
assert.Nil(e)

e, err = f.Build(expression.NewStar(), expression.NewStar())
assert.NotNil(err)
assert.Nil(e)
}

func TestCatalog_RegisterFunction_OneArg(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
var expected sql.Expression = expression.NewStar()
err := c.RegisterFunction(name, func(sql.Expression) sql.Expression {
return expected
})
assert.Nil(err)

f, err := c.Function(name)
assert.Nil(err)

e, err := f.Build()
assert.NotNil(err)
assert.Nil(e)

e, err = f.Build(expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar(), expression.NewStar())
assert.NotNil(err)
assert.Nil(e)
}

func TestCatalog_RegisterFunction_Variadic(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
var expected sql.Expression = expression.NewStar()
err := c.RegisterFunction(name, func(...sql.Expression) sql.Expression {
return expected
})
assert.Nil(err)

f, err := c.Function(name)
assert.Nil(err)

e, err := f.Build()
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar(), expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)
}

func TestCatalog_RegisterFunction_OneAndVariadic(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
var expected sql.Expression = expression.NewStar()
err := c.RegisterFunction(name, func(sql.Expression, ...sql.Expression) sql.Expression {
return expected
})
assert.Nil(err)

f, err := c.Function(name)
assert.Nil(err)

e, err := f.Build()
assert.NotNil(err)
assert.Nil(e)

e, err = f.Build(expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar(), expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)
}

func TestCatalog_RegisterFunction_Invalid(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
err := c.RegisterFunction(name, func(sql.Table) sql.Expression {
return nil
})
assert.NotNil(err)

err = c.RegisterFunction(name, func(sql.Expression) sql.Table {
return nil
})
assert.NotNil(err)

err = c.RegisterFunction(name, func(sql.Expression) (sql.Table, error) {
return nil, nil
})
assert.NotNil(err)

err = c.RegisterFunction(name, 1)
assert.NotNil(err)
}

func TestCatalog_Function_NotExists(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
f, err := c.Function("func")
assert.NotNil(err)
assert.Nil(f)
}
Loading

0 comments on commit 14cd44c

Please sign in to comment.