From 0eb6d29a4c6503ec52bb012b3c427978ab81fa64 Mon Sep 17 00:00:00 2001 From: Xuyuan Pang Date: Sun, 11 Aug 2019 14:59:22 +0800 Subject: [PATCH 1/2] Added Window Function support --- dialect/mysql/mysql.go | 8 ++ dialect/mysql/mysql_test.go | 37 ++++++ dialect/sqlite3/sqlite3.go | 1 + docs/selecting.md | 24 +++- exp/col.go | 2 +- exp/exp.go | 27 +++++ exp/select_clauses.go | 29 +++++ exp/select_clauses_test.go | 48 ++++++++ exp/window.go | 74 ++++++++++++ exp/window_func.go | 106 ++++++++++++++++ exp/window_func_test.go | 186 ++++++++++++++++++++++++++++ exp/window_test.go | 90 ++++++++++++++ expressions.go | 81 ++++++++++++- expressions_example_test.go | 43 +++++++ select_dataset.go | 5 + select_dataset_test.go | 86 +++++++++++++ sql_dialect.go | 81 +++++++++++++ sql_dialect_options.go | 22 +++- sql_dialect_test.go | 233 ++++++++++++++++++++++++++++++++++++ 19 files changed, 1178 insertions(+), 5 deletions(-) create mode 100644 exp/window.go create mode 100644 exp/window_func.go create mode 100644 exp/window_func_test.go create mode 100644 exp/window_test.go diff --git a/dialect/mysql/mysql.go b/dialect/mysql/mysql.go index eab00cbe..f0e99551 100644 --- a/dialect/mysql/mysql.go +++ b/dialect/mysql/mysql.go @@ -19,6 +19,7 @@ func DialectOptions() *goqu.SQLDialectOptions { opts.SupportsWithCTE = false opts.SupportsWithCTERecursive = false opts.SupportsDistinctOn = false + opts.SupportsWindowFunction = false opts.UseFromClauseForMultipleUpdateTables = false @@ -65,6 +66,13 @@ func DialectOptions() *goqu.SQLDialectOptions { return opts } +func DialectOptionsV8() *goqu.SQLDialectOptions { + opts := DialectOptions() + opts.SupportsWindowFunction = true + return opts +} + func init() { goqu.RegisterDialect("mysql", DialectOptions()) + goqu.RegisterDialect("mysql8", DialectOptionsV8()) } diff --git a/dialect/mysql/mysql_test.go b/dialect/mysql/mysql_test.go index 876f95d7..59b1c757 100644 --- a/dialect/mysql/mysql_test.go +++ b/dialect/mysql/mysql_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "os" + "strconv" + "strings" "testing" "time" @@ -393,6 +395,41 @@ func (mt *mysqlTest) TestInsert_OnConflict() { mt.EqualError(err, "goqu: dialect does not support upsert with where clause [dialect=mysql]") } +func (mt *mysqlTest) TestWindowFunction() { + var version string + ok, err := mt.db.Select(goqu.Func("version")).ScanVal(&version) + mt.NoError(err) + mt.True(ok) + + fields := strings.Split(version, ".") + mt.True(len(fields) > 0) + major, err := strconv.Atoi(fields[0]) + mt.NoError(err) + if major < 8 { + return + } + + ds := mt.db.From("entry").Select("int", goqu.ROW_NUMBER().OverName("w").As("id")).Windows(goqu.W("w").OrderBy(goqu.I("int").Desc())) + + var entries []entry + mt.NoError(ds.WithDialect("mysql8").ScanStructs(&entries)) + + mt.Equal([]entry{ + {Int: 9, ID: 1}, + {Int: 8, ID: 2}, + {Int: 7, ID: 3}, + {Int: 6, ID: 4}, + {Int: 5, ID: 5}, + {Int: 4, ID: 6}, + {Int: 3, ID: 7}, + {Int: 2, ID: 8}, + {Int: 1, ID: 9}, + {Int: 0, ID: 10}, + }, entries) + + mt.Error(ds.WithDialect("mysql").ScanStructs(&entries), "goqu: adapter does not support window function clause") +} + func TestMysqlSuite(t *testing.T) { suite.Run(t, new(mysqlTest)) } diff --git a/dialect/sqlite3/sqlite3.go b/dialect/sqlite3/sqlite3.go index 32f3b1f2..5318645f 100644 --- a/dialect/sqlite3/sqlite3.go +++ b/dialect/sqlite3/sqlite3.go @@ -19,6 +19,7 @@ func DialectOptions() *goqu.SQLDialectOptions { opts.SupportsMultipleUpdateTables = false opts.WrapCompoundsInParens = false opts.SupportsDistinctOn = false + opts.SupportsWindowFunction = false opts.PlaceHolderRune = '?' opts.IncludePlaceholderNum = false diff --git a/docs/selecting.md b/docs/selecting.md index e6941297..09ba14c5 100644 --- a/docs/selecting.md +++ b/docs/selecting.md @@ -11,6 +11,7 @@ * [`Offset`](#offset) * [`GroupBy`](#group_by) * [`Having`](#having) + * [`Window`](#window) * Executing Queries * [`ScanStructs`](#scan-structs) - Scans rows into a slice of structs * [`ScanStruct`](#scan-struct) - Scans a row into a slice a struct, returns false if a row wasnt found @@ -610,6 +611,27 @@ Output: SELECT * FROM "test" GROUP BY "age" HAVING (SUM("income") > 1000) ``` + + +**[`Window Function`](https://godoc.org/github.com/doug-martin/goqu/#SelectDataset.Windows)** + +```go +sql, _, _ = goqu.From("test").Select(goqu.ROW_NUMBER().Over(goqu.W().PartitionBy("a").OrderBy("b"))) +fmt.Println(sql) + +sql, _, _ = goqu.From("test").Select(goqu.ROW_NUMBER().OverName("w")).Windows(goqu.W("w").PartitionBy("a").OrderBy("b")) +fmt.Println(sql) +``` + +Output: + +``` +SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "b") FROM "test" +SELECT ROW_NUMBER() OVER "w" FROM "test" WINDOW "w" AS (PARTITION BY "a" ORDER BY "b") +``` + +**NOTE** currently only the `postgres`, `mysql8`(NOT `mysql`) and the default dialect support `Window Function` + ## Executing Queries To execute your query use [`goqu.Database#From`](https://godoc.org/github.com/doug-martin/goqu/#Database.From) to create your dataset @@ -748,4 +770,4 @@ if err := db.From("user").Pluck(&ids, "id"); err != nil{ return } fmt.Printf("\nIds := %+v", ids) -``` \ No newline at end of file +``` diff --git a/exp/col.go b/exp/col.go index cd6f686c..dfd5597b 100644 --- a/exp/col.go +++ b/exp/col.go @@ -12,7 +12,7 @@ type columnList struct { } func NewColumnListExpression(vals ...interface{}) ColumnListExpression { - var cols []Expression + cols := []Expression{} for _, val := range vals { switch t := val.(type) { case string: diff --git a/exp/exp.go b/exp/exp.go index 71eaa127..afb72110 100644 --- a/exp/exp.go +++ b/exp/exp.go @@ -360,6 +360,33 @@ type ( Col() IdentifierExpression Val() interface{} } + + SQLWindowFunctionExpression interface { + SQLFunctionExpression + + Window() WindowExpression + WindowName() string + + Over(WindowExpression) SQLWindowFunctionExpression + OverName(string) SQLWindowFunctionExpression + + HasWindow() bool + HasWindowName() bool + } + + WindowExpression interface { + Expression + + Name() string + + Parent() string + PartitionCols() ColumnListExpression + OrderCols() ColumnListExpression + + Inherit(parent string) WindowExpression + PartitionBy(cols ...interface{}) WindowExpression + OrderBy(cols ...interface{}) WindowExpression + } ) const ( diff --git a/exp/select_clauses.go b/exp/select_clauses.go index 400e3544..1555120e 100644 --- a/exp/select_clauses.go +++ b/exp/select_clauses.go @@ -58,6 +58,11 @@ type ( CommonTables() []CommonTableExpression CommonTablesAppend(cte CommonTableExpression) SelectClauses + + Windows() []WindowExpression + SetWindows(ws []WindowExpression) SelectClauses + WindowsAppend(ws []WindowExpression) SelectClauses + ClearWindows() SelectClauses } selectClauses struct { commonTables []CommonTableExpression @@ -74,6 +79,7 @@ type ( offset uint compounds []CompoundExpression lock Lock + windows []WindowExpression } ) @@ -116,6 +122,7 @@ func (c *selectClauses) clone() *selectClauses { offset: c.offset, compounds: c.compounds, lock: c.lock, + windows: c.windows, } } @@ -331,3 +338,25 @@ func (c *selectClauses) CompoundsAppend(ce CompoundExpression) SelectClauses { ret.compounds = append(ret.compounds, ce) return ret } + +func (c *selectClauses) Windows() []WindowExpression { + return c.windows +} + +func (c *selectClauses) SetWindows(ws []WindowExpression) SelectClauses { + ret := c.clone() + ret.windows = ws + return ret +} + +func (c *selectClauses) WindowsAppend(ws []WindowExpression) SelectClauses { + ret := c.clone() + ret.windows = append(ret.windows, ws...) + return ret +} + +func (c *selectClauses) ClearWindows() SelectClauses { + ret := c.clone() + ret.windows = nil + return ret +} diff --git a/exp/select_clauses_test.go b/exp/select_clauses_test.go index f6a52d47..e6869de1 100644 --- a/exp/select_clauses_test.go +++ b/exp/select_clauses_test.go @@ -255,6 +255,54 @@ func (scs *selectClausesSuite) TestHavingAppend() { scs.Equal(NewExpressionList(AndType, w, w2), c4.Having()) } +func (scs *selectClausesSuite) TestWindow() { + w := NewWindowExpression("w", "", nil, nil) + + c := NewSelectClauses() + c2 := c.WindowsAppend([]WindowExpression{w}) + + scs.Nil(c.Windows()) + + scs.Equal([]WindowExpression{w}, c2.Windows()) +} + +func (scs *selectClausesSuite) TestSetWindows() { + w := NewWindowExpression("w", "", nil, nil) + + c := NewSelectClauses() + c2 := c.SetWindows([]WindowExpression{w}) + + scs.Nil(c.Windows()) + + scs.Equal([]WindowExpression{w}, c2.Windows()) +} + +func (scs *selectClausesSuite) TestWindowsAppend() { + w1 := NewWindowExpression("w1", "", nil, nil) + w2 := NewWindowExpression("w2", "", nil, nil) + + c := NewSelectClauses() + c2 := c.WindowsAppend([]WindowExpression{w1}).WindowsAppend([]WindowExpression{w2}) + + scs.Nil(c.Windows()) + + scs.Equal([]WindowExpression{w1, w2}, c2.Windows()) +} + +func (scs *selectClausesSuite) TestClearWindows() { + w := NewWindowExpression("w", "", nil, nil) + + c := NewSelectClauses() + c2 := c.SetWindows([]WindowExpression{w}) + + scs.Nil(c.Windows()) + + scs.Equal([]WindowExpression{w}, c2.Windows()) + + c3 := c.ClearWindows() + scs.Nil(c3.Windows()) +} + func (scs *selectClausesSuite) TestOrder() { oe := NewIdentifierExpression("", "", "a").Desc() diff --git a/exp/window.go b/exp/window.go new file mode 100644 index 00000000..03926c05 --- /dev/null +++ b/exp/window.go @@ -0,0 +1,74 @@ +package exp + +type sqlWindowExpression struct { + name string + parent string + partitionCols ColumnListExpression + orderCols ColumnListExpression +} + +func NewWindowExpression(window string, parent string, partitionCols, orderCols ColumnListExpression) WindowExpression { + if partitionCols == nil { + partitionCols = NewColumnListExpression() + } + if orderCols == nil { + orderCols = NewColumnListExpression() + } + return sqlWindowExpression{ + name: window, + parent: parent, + partitionCols: partitionCols, + orderCols: orderCols, + } +} + +func (we sqlWindowExpression) clone() sqlWindowExpression { + return sqlWindowExpression{ + name: we.name, + parent: we.parent, + partitionCols: we.partitionCols.Clone().(ColumnListExpression), + orderCols: we.orderCols.Clone().(ColumnListExpression), + } +} + +func (we sqlWindowExpression) Clone() Expression { + return we.clone() +} + +func (we sqlWindowExpression) Expression() Expression { + return we +} + +func (we sqlWindowExpression) Name() string { + return we.name +} + +func (we sqlWindowExpression) Parent() string { + return we.parent +} + +func (we sqlWindowExpression) PartitionCols() ColumnListExpression { + return we.partitionCols +} + +func (we sqlWindowExpression) OrderCols() ColumnListExpression { + return we.orderCols +} + +func (we sqlWindowExpression) PartitionBy(cols ...interface{}) WindowExpression { + ret := we.clone() + ret.partitionCols = NewColumnListExpression(cols...) + return ret +} + +func (we sqlWindowExpression) OrderBy(cols ...interface{}) WindowExpression { + ret := we.clone() + ret.orderCols = NewColumnListExpression(cols...) + return ret +} + +func (we sqlWindowExpression) Inherit(parent string) WindowExpression { + ret := we.clone() + ret.parent = parent + return ret +} diff --git a/exp/window_func.go b/exp/window_func.go new file mode 100644 index 00000000..83366171 --- /dev/null +++ b/exp/window_func.go @@ -0,0 +1,106 @@ +package exp + +type sqlWindowFunctionExpression struct { + name string + args []interface{} + windowName string + window WindowExpression +} + +func NewSQLWindowFunctionExpression(name string, args ...interface{}) SQLWindowFunctionExpression { + return sqlWindowFunctionExpression{ + name: name, + args: args, + } +} + +func (swfe sqlWindowFunctionExpression) clone() sqlWindowFunctionExpression { + return sqlWindowFunctionExpression{ + name: swfe.name, + args: swfe.args, + windowName: swfe.windowName, + window: swfe.window, + } +} + +func (swfe sqlWindowFunctionExpression) Clone() Expression { + return swfe.clone() +} +func (swfe sqlWindowFunctionExpression) Expression() Expression { + return swfe +} +func (swfe sqlWindowFunctionExpression) As(val interface{}) AliasedExpression { + return aliased(swfe, val) +} +func (swfe sqlWindowFunctionExpression) Eq(val interface{}) BooleanExpression { return eq(swfe, val) } +func (swfe sqlWindowFunctionExpression) Neq(val interface{}) BooleanExpression { return neq(swfe, val) } +func (swfe sqlWindowFunctionExpression) Gt(val interface{}) BooleanExpression { return gt(swfe, val) } +func (swfe sqlWindowFunctionExpression) Gte(val interface{}) BooleanExpression { return gte(swfe, val) } +func (swfe sqlWindowFunctionExpression) Lt(val interface{}) BooleanExpression { return lt(swfe, val) } +func (swfe sqlWindowFunctionExpression) Lte(val interface{}) BooleanExpression { return lte(swfe, val) } +func (swfe sqlWindowFunctionExpression) Between(val RangeVal) RangeExpression { + return between(swfe, val) +} +func (swfe sqlWindowFunctionExpression) NotBetween(val RangeVal) RangeExpression { + return notBetween(swfe, val) +} +func (swfe sqlWindowFunctionExpression) Like(val interface{}) BooleanExpression { + return like(swfe, val) +} +func (swfe sqlWindowFunctionExpression) NotLike(val interface{}) BooleanExpression { + return notLike(swfe, val) +} +func (swfe sqlWindowFunctionExpression) ILike(val interface{}) BooleanExpression { + return iLike(swfe, val) +} +func (swfe sqlWindowFunctionExpression) NotILike(val interface{}) BooleanExpression { + return notILike(swfe, val) +} +func (swfe sqlWindowFunctionExpression) In(vals ...interface{}) BooleanExpression { + return in(swfe, vals...) +} +func (swfe sqlWindowFunctionExpression) NotIn(vals ...interface{}) BooleanExpression { + return notIn(swfe, vals...) +} +func (swfe sqlWindowFunctionExpression) Is(val interface{}) BooleanExpression { return is(swfe, val) } +func (swfe sqlWindowFunctionExpression) IsNot(val interface{}) BooleanExpression { + return isNot(swfe, val) +} +func (swfe sqlWindowFunctionExpression) IsNull() BooleanExpression { return is(swfe, nil) } +func (swfe sqlWindowFunctionExpression) IsNotNull() BooleanExpression { return isNot(swfe, nil) } +func (swfe sqlWindowFunctionExpression) IsTrue() BooleanExpression { return is(swfe, true) } +func (swfe sqlWindowFunctionExpression) IsNotTrue() BooleanExpression { return isNot(swfe, true) } +func (swfe sqlWindowFunctionExpression) IsFalse() BooleanExpression { return is(swfe, false) } +func (swfe sqlWindowFunctionExpression) IsNotFalse() BooleanExpression { return isNot(swfe, false) } + +func (swfe sqlWindowFunctionExpression) Name() string { return swfe.name } + +func (swfe sqlWindowFunctionExpression) Args() []interface{} { return swfe.args } + +func (swfe sqlWindowFunctionExpression) Window() WindowExpression { + return swfe.window +} + +func (swfe sqlWindowFunctionExpression) WindowName() string { + return swfe.windowName +} + +func (swfe sqlWindowFunctionExpression) Over(we WindowExpression) SQLWindowFunctionExpression { + ret := swfe.clone() + ret.window = we + return ret +} + +func (swfe sqlWindowFunctionExpression) OverName(name string) SQLWindowFunctionExpression { + ret := swfe.clone() + ret.windowName = name + return ret +} + +func (swfe sqlWindowFunctionExpression) HasWindow() bool { + return swfe.window != nil +} + +func (swfe sqlWindowFunctionExpression) HasWindowName() bool { + return swfe.windowName != "" +} diff --git a/exp/window_func_test.go b/exp/window_func_test.go new file mode 100644 index 00000000..071213ec --- /dev/null +++ b/exp/window_func_test.go @@ -0,0 +1,186 @@ +package exp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type sqlWindowFunctionExpressionTest struct { + suite.Suite +} + +func TestSQLWindowFunctionExpressionSuite(t *testing.T) { + suite.Run(t, new(sqlWindowFunctionExpressionTest)) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestClone() { + t := swfet.T() + wf := NewSQLWindowFunctionExpression("f1", "a") + wf2 := wf.Clone() + assert.Equal(t, wf, wf2) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestExpression() { + t := swfet.T() + wf := NewSQLWindowFunctionExpression("f1", "a") + wf2 := wf.Expression() + assert.Equal(t, wf, wf2) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestName() { + t := swfet.T() + wf := NewSQLWindowFunctionExpression("f1", "a") + assert.Equal(t, wf.Name(), "f1") +} + +func (swfet *sqlWindowFunctionExpressionTest) TestArgs() { + t := swfet.T() + wf := NewSQLWindowFunctionExpression("f1", "a") + assert.Equal(t, wf.Args(), []interface{}{"a"}) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestWindow() { + t := swfet.T() + w := NewWindowExpression("w", "", nil, nil) + wf := NewSQLWindowFunctionExpression("f1", "a") + assert.False(t, wf.HasWindow()) + + wf = wf.Over(w) + assert.True(t, wf.HasWindow()) + assert.Equal(t, wf.Window(), w) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestWindowName() { + t := swfet.T() + windowName := "w" + wf := NewSQLWindowFunctionExpression("f1", "a") + assert.False(t, wf.HasWindowName()) + + wf = wf.OverName(windowName) + assert.True(t, wf.HasWindowName()) + assert.Equal(t, wf.WindowName(), windowName) +} + +func (swfet *sqlWindowFunctionExpressionTest) TestAllOthers() { + t := swfet.T() + wf := NewSQLWindowFunctionExpression("f1", "a") + + expAs := wf.As("a") + assert.Equal(t, expAs.Aliased(), wf) + + expEq := wf.Eq(1) + assert.Equal(t, expEq.LHS(), wf) + assert.Equal(t, expEq.Op(), EqOp) + assert.Equal(t, expEq.RHS(), 1) + + expNeq := wf.Neq(1) + assert.Equal(t, expNeq.LHS(), wf) + assert.Equal(t, expNeq.Op(), NeqOp) + assert.Equal(t, expNeq.RHS(), 1) + + expGt := wf.Gt(1) + assert.Equal(t, expGt.LHS(), wf) + assert.Equal(t, expGt.Op(), GtOp) + assert.Equal(t, expGt.RHS(), 1) + + expGte := wf.Gte(1) + assert.Equal(t, expGte.LHS(), wf) + assert.Equal(t, expGte.Op(), GteOp) + assert.Equal(t, expGte.RHS(), 1) + + expLt := wf.Lt(1) + assert.Equal(t, expLt.LHS(), wf) + assert.Equal(t, expLt.Op(), LtOp) + assert.Equal(t, expLt.RHS(), 1) + + expLte := wf.Lte(1) + assert.Equal(t, expLte.LHS(), wf) + assert.Equal(t, expLte.Op(), LteOp) + assert.Equal(t, expLte.RHS(), 1) + + rv := NewRangeVal(1, 2) + expBetween := wf.Between(rv) + assert.Equal(t, expBetween.LHS(), wf) + assert.Equal(t, expBetween.Op(), BetweenOp) + assert.Equal(t, expBetween.RHS(), rv) + + expNotBetween := wf.NotBetween(rv) + assert.Equal(t, expNotBetween.LHS(), wf) + assert.Equal(t, expNotBetween.Op(), NotBetweenOp) + assert.Equal(t, expNotBetween.RHS(), rv) + + pattern := "a%" + expLike := wf.Like(pattern) + assert.Equal(t, expLike.LHS(), wf) + assert.Equal(t, expLike.Op(), LikeOp) + assert.Equal(t, expLike.RHS(), pattern) + + expNotLike := wf.NotLike(pattern) + assert.Equal(t, expNotLike.LHS(), wf) + assert.Equal(t, expNotLike.Op(), NotLikeOp) + assert.Equal(t, expNotLike.RHS(), pattern) + + expILike := wf.ILike(pattern) + assert.Equal(t, expILike.LHS(), wf) + assert.Equal(t, expILike.Op(), ILikeOp) + assert.Equal(t, expILike.RHS(), pattern) + + expNotILike := wf.NotILike(pattern) + assert.Equal(t, expNotILike.LHS(), wf) + assert.Equal(t, expNotILike.Op(), NotILikeOp) + assert.Equal(t, expNotILike.RHS(), pattern) + + vals := []interface{}{1, 2} + expIn := wf.In(vals) + assert.Equal(t, expIn.LHS(), wf) + assert.Equal(t, expIn.Op(), InOp) + assert.Equal(t, expIn.RHS(), vals) + + expNotIn := wf.NotIn(vals) + assert.Equal(t, expNotIn.LHS(), wf) + assert.Equal(t, expNotIn.Op(), NotInOp) + assert.Equal(t, expNotIn.RHS(), vals) + + obj := 1 + expIs := wf.Is(obj) + assert.Equal(t, expIs.LHS(), wf) + assert.Equal(t, expIs.Op(), IsOp) + assert.Equal(t, expIs.RHS(), obj) + + expIsNot := wf.IsNot(obj) + assert.Equal(t, expIsNot.LHS(), wf) + assert.Equal(t, expIsNot.Op(), IsNotOp) + assert.Equal(t, expIsNot.RHS(), obj) + + expIsNull := wf.IsNull() + assert.Equal(t, expIsNull.LHS(), wf) + assert.Equal(t, expIsNull.Op(), IsOp) + assert.Nil(t, expIsNull.RHS()) + + expIsNotNull := wf.IsNotNull() + assert.Equal(t, expIsNotNull.LHS(), wf) + assert.Equal(t, expIsNotNull.Op(), IsNotOp) + assert.Nil(t, expIsNotNull.RHS()) + + expIsTrue := wf.IsTrue() + assert.Equal(t, expIsTrue.LHS(), wf) + assert.Equal(t, expIsTrue.Op(), IsOp) + assert.Equal(t, expIsTrue.RHS(), true) + + expIsNotTrue := wf.IsNotTrue() + assert.Equal(t, expIsNotTrue.LHS(), wf) + assert.Equal(t, expIsNotTrue.Op(), IsNotOp) + assert.Equal(t, expIsNotTrue.RHS(), true) + + expIsFalse := wf.IsFalse() + assert.Equal(t, expIsFalse.LHS(), wf) + assert.Equal(t, expIsFalse.Op(), IsOp) + assert.Equal(t, expIsFalse.RHS(), false) + + expIsNotFalse := wf.IsNotFalse() + assert.Equal(t, expIsNotFalse.LHS(), wf) + assert.Equal(t, expIsNotFalse.Op(), IsNotOp) + assert.Equal(t, expIsNotFalse.RHS(), false) +} diff --git a/exp/window_test.go b/exp/window_test.go new file mode 100644 index 00000000..dccb08c6 --- /dev/null +++ b/exp/window_test.go @@ -0,0 +1,90 @@ +package exp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type windowExpressionTest struct { + suite.Suite +} + +func TestWindowExpressionSuite(t *testing.T) { + suite.Run(t, new(windowExpressionTest)) +} + +func (wet *windowExpressionTest) TestClone() { + t := wet.T() + w := NewWindowExpression("w", "", nil, nil) + w2 := w.Clone() + + assert.Equal(t, w, w2) +} + +func (wet *windowExpressionTest) TestExpression() { + t := wet.T() + w := NewWindowExpression("w", "", nil, nil) + w2 := w.Expression() + + assert.Equal(t, w, w2) +} + +func (wet *windowExpressionTest) TestName() { + t := wet.T() + w := NewWindowExpression("w", "", nil, nil) + + assert.Equal(t, "w", w.Name()) +} + +func (wet *windowExpressionTest) TestPartitionCols() { + t := wet.T() + cols := NewColumnListExpression("a", "b") + w := NewWindowExpression("w", "", cols, nil) + + assert.Equal(t, cols, w.PartitionCols()) + assert.Equal(t, cols, w.Clone().(WindowExpression).PartitionCols()) +} + +func (wet *windowExpressionTest) TestOrderCols() { + t := wet.T() + cols := NewColumnListExpression("a", "b") + w := NewWindowExpression("w", "", nil, cols) + + assert.Equal(t, cols, w.OrderCols()) + assert.Equal(t, cols, w.Clone().(WindowExpression).OrderCols()) +} + +func (wet *windowExpressionTest) TestPartitonBy() { + t := wet.T() + cols := NewColumnListExpression("a", "b") + w := NewWindowExpression("w", "", nil, nil).PartitionBy("a", "b") + + assert.Equal(t, cols, w.PartitionCols()) +} + +func (wet *windowExpressionTest) TestOrderBy() { + t := wet.T() + cols := NewColumnListExpression("a", "b") + w := NewWindowExpression("w", "", nil, nil).OrderBy("a", "b") + + assert.Equal(t, cols, w.OrderCols()) +} + +func (wet *windowExpressionTest) TestParent() { + t := wet.T() + w := NewWindowExpression("w", "w1", nil, nil) + + assert.Equal(t, "w1", w.Parent()) +} + +func (wet *windowExpressionTest) TestInherit() { + t := wet.T() + w := NewWindowExpression("w", "w1", nil, nil) + + assert.Equal(t, "w1", w.Parent()) + + w = w.Inherit("w2") + assert.Equal(t, "w2", w.Parent()) +} diff --git a/expressions.go b/expressions.go index 62dc1ac7..4d3e3770 100644 --- a/expressions.go +++ b/expressions.go @@ -15,6 +15,9 @@ type ( TruncateOptions = exp.TruncateOptions ) +// emptyWindow is an empty WINDOW clause without name +var emptyWindow = exp.NewWindowExpression("", "", nil, nil) + const ( Wait = exp.Wait NoWait = exp.NoWait @@ -69,6 +72,19 @@ func newIdentifierFunc(name string, col interface{}) exp.SQLFunctionExpression { return Func(name, col) } +// Create a new SQLWindowFunctionExpression with the given name and arguments +func WFunc(name string, args ...interface{}) exp.SQLWindowFunctionExpression { + return exp.NewSQLWindowFunctionExpression(name, args...) +} + +// used internally to normalize the column name if passed in as a string it should be turned into an identifier +func newIdentifierWinFunc(name string, col interface{}) exp.SQLWindowFunctionExpression { + if s, ok := col.(string); ok { + col = I(s) + } + return WFunc(name, col) +} + // Creates a new DISTINCT sql function // DISTINCT("a") -> DISTINCT("a") // DISTINCT(I("a")) -> DISTINCT("a") @@ -117,6 +133,45 @@ func COALESCE(vals ...interface{}) exp.SQLFunctionExpression { return exp.NewSQLFunctionExpression("COALESCE", vals...) } +func ROW_NUMBER() exp.SQLWindowFunctionExpression { + return WFunc("ROW_NUMBER") +} + +func RANK() exp.SQLWindowFunctionExpression { + return WFunc("RANK") +} + +func DENSE_RANK() exp.SQLWindowFunctionExpression { + return WFunc("DENSE_RANK") +} + +func PERCENT_RANK() exp.SQLWindowFunctionExpression { + return WFunc("PERCENT_RANK") +} + +func CUME_DIST() exp.SQLWindowFunctionExpression { + return WFunc("CUME_DIST") +} + +func NTILE(n int) exp.SQLWindowFunctionExpression { + return newIdentifierWinFunc("NTILE", n) +} + +func FIRST_VALUE(val interface{}) exp.SQLWindowFunctionExpression { + return newIdentifierWinFunc("FIRST_VALUE", val) +} + +func LAST_VALUE(val interface{}) exp.SQLWindowFunctionExpression { + return newIdentifierWinFunc("LAST_VALUE", val) +} + +func NTH_VALUE(val interface{}, nth int) exp.SQLWindowFunctionExpression { + if s, ok := val.(string); ok { + val = I(s) + } + return WFunc("NTH_VALUE", val, nth) +} + // Creates a new Identifier, the generated sql will use adapter specific quoting or '"' by default, this ensures case // sensitivity and in certain databases allows for special characters, (e.g. "curr-table", "my table"). // @@ -163,6 +218,28 @@ func T(table string) exp.IdentifierExpression { return exp.NewIdentifierExpression("", table, "") } +// Create a new WINDOW clause +// W() -> () +// W().PartitionBy("a") -> (PARTITION BY "a") +// W().PartitionBy("a").OrderBy("b") -> (PARTITION BY "a" ORDER BY "b") +// W().PartitionBy("a").OrderBy("b").Inherit("w1") -> ("w1" PARTITION BY "a" ORDER BY "b") +// W().PartitionBy("a").OrderBy(I("b").Desc()).Inherit("w1") -> ("w1" PARTITION BY "a" ORDER BY "b" DESC) +// W("w") -> "w" AS () +// W("w", "w1") -> "w" AS ("w1") +// W("w").Inherit("w1") -> "w" AS ("w1") +// W("w").PartitionBy("a") -> "w" AS (PARTITION BY "a") +// W("w", "w1").PartitionBy("a") -> "w" AS ("w1" PARTITION BY "a") +// W("w", "w1").PartitionBy("a").OrderBy("b") -> "w" AS ("w1" PARTITION BY "a" ORDER BY "b") +func W(ws ...string) exp.WindowExpression { + if l := len(ws); l > 0 { + if l == 1 { + return exp.NewWindowExpression(ws[0], "", nil, nil) + } + return exp.NewWindowExpression(ws[0], ws[1], nil, nil) + } + return emptyWindow +} + // Creates a new ON clause to be used within a join // ds.Join(goqu.T("my_table"), goqu.On( // goqu.I("my_table.fkey").Eq(goqu.I("other_table.id")), @@ -184,12 +261,12 @@ func Using(columns ...interface{}) exp.JoinCondition { // Literals can also contain placeholders for other expressions // L("(? AND ?) OR (?)", I("a").Eq(1), I("b").Eq("b"), I("c").In([]string{"a", "b", "c"})) func L(sql string, args ...interface{}) exp.LiteralExpression { - return exp.NewLiteralExpression(sql, args...) + return Literal(sql, args...) } // Alias for goqu.L func Literal(sql string, args ...interface{}) exp.LiteralExpression { - return L(sql, args...) + return exp.NewLiteralExpression(sql, args...) } // Create a new SQL value ( alias for goqu.L("?", val) ). The prrimary use case for this would be in selects. diff --git a/expressions_example_test.go b/expressions_example_test.go index 12449574..ee726780 100644 --- a/expressions_example_test.go +++ b/expressions_example_test.go @@ -1710,3 +1710,46 @@ func ExampleVals() { // Output: // INSERT INTO "user" ("first_name", "last_name", "is_verified") VALUES ('Greg', 'Farley', TRUE), ('Jimmy', 'Stewart', TRUE), ('Jeff', 'Jeffers', FALSE) [] } + +func ExampleW() { + ds := goqu.From("test"). + Select( + goqu.ROW_NUMBER().Over(goqu.W().PartitionBy("a").OrderBy(goqu.I("b").Asc())), + ) + query, args, _ := ds.ToSQL() + fmt.Println(query, args) + + ds = goqu.From("test"). + Select( + goqu.ROW_NUMBER().OverName("w"), + ). + Windows( + goqu.W("w").PartitionBy("a").OrderBy(goqu.I("b").Asc()), + ) + query, args, _ = ds.ToSQL() + fmt.Println(query, args) + + ds = goqu.From("test"). + Select( + goqu.ROW_NUMBER().OverName("w1"), + ). + Windows( + goqu.W("w1").PartitionBy("a"), + goqu.W("w").Inherit("w1").OrderBy(goqu.I("b").Asc()), + ) + query, args, _ = ds.ToSQL() + fmt.Println(query, args) + + ds = goqu.From("test").Select( + goqu.ROW_NUMBER().Over(goqu.W().Inherit("w").OrderBy("b")), + ).Windows( + goqu.W("w").PartitionBy("a"), + ) + query, args, _ = ds.ToSQL() + fmt.Println(query, args) + // Output + // SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "b" ASC) FROM "test" [] + // SELECT ROW_NUMBER() OVER "w" FROM "test" WINDOW "w" AS (PARTITION BY "a" ORDER BY "b" ASC) [] + // SELECT ROW_NUMBER() OVER "w" FROM "test" WINDOW "w1" AS (PARTITION BY "a"), "w" AS ("w1" ORDER BY "b" ASC) [] + // SELECT ROW_NUMBER() OVER ("w" ORDER BY "b") FROM "test" WINDOW "w" AS (PARTITION BY "a") [] +} diff --git a/select_dataset.go b/select_dataset.go index 0b3afb22..260d4618 100644 --- a/select_dataset.go +++ b/select_dataset.go @@ -492,6 +492,11 @@ func (sd *SelectDataset) As(alias string) *SelectDataset { return sd.copy(sd.clauses.SetAlias(T(alias))) } +// Sets the WINDOW clauses +func (sd *SelectDataset) Windows(ws ...exp.WindowExpression) *SelectDataset { + return sd.copy(sd.clauses.SetWindows(ws)) +} + // Generates a SELECT sql statement, if Prepared has been called with true then the parameters will not be interpolated. // See examples. // diff --git a/select_dataset_test.go b/select_dataset_test.go index 92436345..55ff1e21 100644 --- a/select_dataset_test.go +++ b/select_dataset_test.go @@ -1242,6 +1242,15 @@ func (sds *selectDatasetSuite) TestHaving() { sds.Equal(dsc, ds.GetClauses()) } +func (sds *selectDatasetSuite) TestWindows() { + ds := From("test") + dsc := ds.GetClauses() + w := W("w").PartitionBy("a").OrderBy("b") + ec := dsc.SetWindows([]exp.WindowExpression{w}) + sds.Equal(ec, ds.Windows(w).GetClauses()) + sds.Equal(dsc, ds.GetClauses()) +} + func (sds *selectDatasetSuite) TestHaving_ToSQL() { ds1 := From("test") @@ -2456,6 +2465,83 @@ func (sds *selectDatasetSuite) TestPluck_WithPreparedStatement() { sds.Equal([]string{"Bob", "Sally", "Billy"}, names) } +func (sds *selectDatasetSuite) TestWindowFunction() { + for _, tt := range []struct { + expectQuery string + returnRows *sqlmock.Rows + fn exp.Expression + expectValue []int32 + }{ + { + expectQuery: `SELECT ROW_NUMBER\(\) OVER \(PARTITION BY "class" ORDER BY "score"\) AS "r" FROM "test"`, + returnRows: sqlmock.NewRows([]string{"r"}).FromCSVString("1\n2\n1"), + fn: ROW_NUMBER().Over(W().PartitionBy("class").OrderBy("score")).As("r"), + expectValue: []int32{1, 2, 1}, + }, + { + expectQuery: `SELECT RANK\(\) OVER \(PARTITION BY "class" ORDER BY "score"\) AS "r" FROM "test"`, + returnRows: sqlmock.NewRows([]string{"r"}).FromCSVString("1\n2\n1"), + fn: RANK().Over(W().PartitionBy("class").OrderBy("score")).As("r"), + expectValue: []int32{1, 2, 1}, + }, + { + expectQuery: `SELECT DENSE_RANK\(\) OVER \(PARTITION BY "class" ORDER BY "score"\) AS "r" FROM "test"`, + returnRows: sqlmock.NewRows([]string{"r"}).FromCSVString("1\n2\n1"), + fn: DENSE_RANK().Over(W().PartitionBy("class").OrderBy("score")).As("r"), + expectValue: []int32{1, 2, 1}, + }, + { + expectQuery: `SELECT PERCENT_RANK\(\) OVER \(PARTITION BY "class" ORDER BY "score"\) AS "r" FROM "test"`, + returnRows: sqlmock.NewRows([]string{"r"}).FromCSVString("1\n2\n1"), + fn: PERCENT_RANK().Over(W().PartitionBy("class").OrderBy("score")).As("r"), + expectValue: []int32{1, 2, 1}, + }, + { + expectQuery: `SELECT CUME_DIST\(\) OVER \(PARTITION BY "class" ORDER BY "score"\) AS "r" FROM "test"`, + returnRows: sqlmock.NewRows([]string{"r"}).FromCSVString("1\n2\n1"), + fn: CUME_DIST().Over(W().PartitionBy("class").OrderBy("score")).As("r"), + expectValue: []int32{1, 2, 1}, + }, + { + expectQuery: `SELECT NTILE\(2\) OVER \(PARTITION BY "class" ORDER BY "score"\) AS "r" FROM "test"`, + returnRows: sqlmock.NewRows([]string{"r"}).FromCSVString("100\n100\n99"), + fn: NTILE(2).Over(W().PartitionBy("class").OrderBy("score")).As("r"), + expectValue: []int32{100, 100, 99}, + }, + { + expectQuery: `SELECT FIRST_VALUE\("score"\) OVER \(PARTITION BY "class" ORDER BY "score"\) AS "r" FROM "test"`, + returnRows: sqlmock.NewRows([]string{"r"}).FromCSVString("100\n100\n99"), + fn: FIRST_VALUE("score").Over(W().PartitionBy("class").OrderBy("score")).As("r"), + expectValue: []int32{100, 100, 99}, + }, + { + expectQuery: `SELECT LAST_VALUE\("score"\) OVER \(PARTITION BY "class" ORDER BY "score"\) AS "r" FROM "test"`, + returnRows: sqlmock.NewRows([]string{"r"}).FromCSVString("100\n100\n99"), + fn: LAST_VALUE("score").Over(W().PartitionBy("class").OrderBy("score")).As("r"), + expectValue: []int32{100, 100, 99}, + }, + { + expectQuery: `SELECT NTH_VALUE\("score", 3\) OVER \(PARTITION BY "class" ORDER BY "score"\) AS "r" FROM "test"`, + returnRows: sqlmock.NewRows([]string{"r"}).FromCSVString("100\n100\n99"), + fn: NTH_VALUE("score", 3).Over(W().PartitionBy("class").OrderBy("score")).As("r"), + expectValue: []int32{100, 100, 99}, + }, + } { + mDb, sqlMock, err := sqlmock.New() + sds.NoError(err) + qf := exec.NewQueryFactory(mDb) + ds := newDataset("mock", qf) + sqlMock.ExpectQuery(tt.expectQuery). + WillReturnRows(tt.returnRows) + var actualValue []int32 + sds.NoError(ds.Prepared(false). + Select(tt.fn). + From("test"). + ScanVals(&actualValue)) + sds.Equal(tt.expectValue, actualValue) + } +} + func TestSelectDataset(t *testing.T) { suite.Run(t, new(selectDatasetSuite)) } diff --git a/sql_dialect.go b/sql_dialect.go index 50b9cc45..096e89b9 100644 --- a/sql_dialect.go +++ b/sql_dialect.go @@ -48,6 +48,8 @@ var ( errNoSourceForTruncate = errors.New("no source found when generating truncate sql") errNoSetValuesForUpdate = errors.New("no set values found when generating UPDATE sql") errEmptyIdentifier = errors.New(`a empty identifier was encountered, please specify a "schema", "table" or "column"`) + errWindowFunctionNotSupported = errors.New("adapter does not support window function clause") + errNoWindowName = errors.New("window expresion has no valid name") ) func errNotSupportedFragment(sqlType string, f SQLFragmentType) error { @@ -166,6 +168,8 @@ func (d *sqlDialect) ToSelectSQL(b sb.SQLBuilder, clauses exp.SelectClauses) { d.GroupBySQL(b, clauses.GroupBy()) case HavingSQLFragment: d.HavingSQL(b, clauses.Having()) + case WindowSQLFragment: + d.WindowsSQL(b, clauses.Windows()...) case CompoundsSQLFragment: d.CompoundsSQL(b, clauses.Compounds()) case OrderSQLFragment: @@ -495,6 +499,72 @@ func (d *sqlDialect) HavingSQL(b sb.SQLBuilder, having exp.ExpressionList) { } } +func (d *sqlDialect) WindowsSQL(b sb.SQLBuilder, windows ...exp.WindowExpression) { + if b.Error() != nil { + return + } + l := len(windows) + if l == 0 { + return + } + if !d.dialectOptions.SupportsWindowFunction { + b.SetError(errWindowFunctionNotSupported) + return + } + b.Write(d.dialectOptions.WindowFragment) + d.WindowSQL(b, windows[0], true) + for _, we := range windows[1:] { + b.WriteRunes(d.dialectOptions.CommaRune, d.dialectOptions.SpaceRune) + d.WindowSQL(b, we, true) + } +} + +func (d *sqlDialect) WindowSQL(b sb.SQLBuilder, we exp.WindowExpression, withName bool) { + if b.Error() != nil { + return + } + if !d.dialectOptions.SupportsWindowFunction { + b.SetError(errWindowFunctionNotSupported) + return + } + if withName { + name := we.Name() + if len(name) == 0 { + b.SetError(errNoWindowName) + return + } + d.Literal(b, I(name)) + b.Write(d.dialectOptions.AsFragment) + } + b.WriteRunes(d.dialectOptions.LeftParenRune) + + parent, partitionCols, orderCols := we.Parent(), we.PartitionCols(), we.OrderCols() + hasParent := len(parent) > 0 + hasPartition := partitionCols != nil && !partitionCols.IsEmpty() + hasOrder := orderCols != nil && !orderCols.IsEmpty() + + if hasParent { + d.Literal(b, I(parent)) + if hasPartition || hasOrder { + b.WriteRunes(d.dialectOptions.SpaceRune) + } + } + + if hasPartition { + b.Write(d.dialectOptions.WindowPartitionByFragment) + d.Literal(b, partitionCols) + if hasOrder { + b.WriteRunes(d.dialectOptions.SpaceRune) + } + } + if hasOrder { + b.Write(d.dialectOptions.WindowOrderByFragment) + d.Literal(b, orderCols) + } + + b.WriteRunes(d.dialectOptions.RightParenRune) +} + // Generates the ORDER BY clause for an SQL statement func (d *sqlDialect) OrderSQL(b sb.SQLBuilder, order exp.ColumnListExpression) { if order != nil && len(order.Columns()) > 0 { @@ -1138,6 +1208,17 @@ func (d *sqlDialect) literalExpressionSQL(b sb.SQLBuilder, literal exp.LiteralEx func (d *sqlDialect) sqlFunctionExpressionSQL(b sb.SQLBuilder, sqlFunc exp.SQLFunctionExpression) { b.WriteStrings(sqlFunc.Name()) d.Literal(b, sqlFunc.Args()) + + if sqlWinFunc, ok := sqlFunc.(exp.SQLWindowFunctionExpression); ok { + b.Write(d.dialectOptions.WindowOverFragment) + if sqlWinFunc.HasWindowName() { + d.Literal(b, I(sqlWinFunc.WindowName())) + } else if sqlWinFunc.HasWindow() { + d.WindowSQL(b, sqlWinFunc.Window(), false) + } else { + d.WindowSQL(b, emptyWindow, false) + } + } } // Generates SQL for a CastExpression diff --git a/sql_dialect_options.go b/sql_dialect_options.go index 1afb8144..7b3baaf8 100644 --- a/sql_dialect_options.go +++ b/sql_dialect_options.go @@ -37,6 +37,9 @@ type ( // Set to false if the dialect does not require expressions to be wrapped in parens (DEFAULT=true) WrapCompoundsInParens bool + // Set to true if window function are supported in SELECT statement. (DEFAULT=true) + SupportsWindowFunction bool + // Set to true if the dialect requires join tables in UPDATE to be in a FROM clause (DEFAULT=true). UseFromClauseForMultipleUpdateTables bool @@ -85,8 +88,16 @@ type ( WhereFragment []byte // The SQL GROUP BY clause fragment(DEFAULT=[]byte(" GROUP BY ")) GroupByFragment []byte - // The SQL HAVING clause fragment(DELiFAULT=[]byte(" HAVING ")) + // The SQL HAVING clause fragment(DEFAULT=[]byte(" HAVING ")) HavingFragment []byte + // The SQL WINDOW clause fragment(DEFAULT=[]byte(" WINDOW ")) + WindowFragment []byte + // The SQL WINDOW clause PARTITION BY fragment(DEFAULT=[]byte("PARTITION BY ")) + WindowPartitionByFragment []byte + // The SQL WINDOW clause ORDER BY fragment(DEFAULT=[]byte("ORDER BY ")) + WindowOrderByFragment []byte + // The SQL WINDOW clause OVER fragment(DEFAULT=[]byte(" OVER ")) + WindowOverFragment []byte // The SQL ORDER BY clause fragment(DEFAULT=[]byte(" ORDER BY ")) OrderByFragment []byte // The SQL LIMIT BY clause fragment(DEFAULT=[]byte(" LIMIT ")) @@ -304,6 +315,7 @@ const ( InsertSQLFragment DeleteBeginSQLFragment TruncateSQLFragment + WindowSQLFragment ) // nolint:gocyclo @@ -351,6 +363,8 @@ func (sf SQLFragmentType) String() string { return "DeleteBeginSQLFragment" case TruncateSQLFragment: return "TruncateSQLFragment" + case WindowSQLFragment: + return "WindowSQLFragment" } return fmt.Sprintf("%d", sf) } @@ -369,6 +383,7 @@ func DefaultDialectOptions() *SQLDialectOptions { SupportsWithCTERecursive: true, SupportsDistinctOn: true, WrapCompoundsInParens: true, + SupportsWindowFunction: true, SupportsMultipleUpdateTables: true, UseFromClauseForMultipleUpdateTables: true, @@ -395,6 +410,10 @@ func DefaultDialectOptions() *SQLDialectOptions { WhereFragment: []byte(" WHERE "), GroupByFragment: []byte(" GROUP BY "), HavingFragment: []byte(" HAVING "), + WindowFragment: []byte(" WINDOW "), + WindowPartitionByFragment: []byte("PARTITION BY "), + WindowOrderByFragment: []byte("ORDER BY "), + WindowOverFragment: []byte(" OVER "), OrderByFragment: []byte(" ORDER BY "), LimitFragment: []byte(" LIMIT "), OffsetFragment: []byte(" OFFSET "), @@ -489,6 +508,7 @@ func DefaultDialectOptions() *SQLDialectOptions { WhereSQLFragment, GroupBySQLFragment, HavingSQLFragment, + WindowSQLFragment, CompoundsSQLFragment, OrderSQLFragment, LimitSQLFragment, diff --git a/sql_dialect_test.go b/sql_dialect_test.go index 1b86fcaa..377157c2 100644 --- a/sql_dialect_test.go +++ b/sql_dialect_test.go @@ -1380,6 +1380,191 @@ func (dts *dialectTestSuite) TestToSelectSQL_withHaving() { dts.assertPreparedSQL(b, `SELECT * FROM "test" having (("a" = ?) AND ("b" = ?))`, []interface{}{"b", "c"}) } +func (dts *dialectTestSuite) TestWindowsSQL() { + opts := DefaultDialectOptions() + + we := W("w").PartitionBy("a", "b").OrderBy("c", "d") + + opts.SupportsWindowFunction = false + d := sqlDialect{dialect: "test", dialectOptions: opts} + b := sb.NewSQLBuilder(false) + d.WindowsSQL(b, we) + dts.assertErrorSQL(b, errWindowFunctionNotSupported.Error()) + + opts.SupportsWindowFunction = true + d = sqlDialect{dialect: "test", dialectOptions: opts} + b = sb.NewSQLBuilder(false) + d.WindowsSQL(b) + dts.assertNotPreparedSQL(b, "") + + b = sb.NewSQLBuilder(false) + anErr := errors.New("something wrong") + b.SetError(anErr) + d.WindowsSQL(b, we) + dts.assertErrorSQL(b, anErr.Error()) + + b = sb.NewSQLBuilder(false) + d.WindowsSQL(b, we) + dts.assertNotPreparedSQL(b, ` WINDOW "w" AS (PARTITION BY "a", "b" ORDER BY "c", "d")`) + + b = sb.NewSQLBuilder(false) + w1 := W("w1").PartitionBy("a").OrderBy("b") + w2 := W("w2").PartitionBy("c").OrderBy("d") + d.WindowsSQL(b, w1, w2) + dts.assertNotPreparedSQL(b, ` WINDOW "w1" AS (PARTITION BY "a" ORDER BY "b"), "w2" AS (PARTITION BY "c" ORDER BY "d")`) + + w1 = W("w1").PartitionBy("a") + w2 = W("w2").Inherit("w1").OrderBy("b") + d.WindowsSQL(b.Clear(), w1, w2) + dts.assertNotPreparedSQL(b, ` WINDOW "w1" AS (PARTITION BY "a"), "w2" AS ("w1" ORDER BY "b")`) +} + +func (dts *dialectTestSuite) TestWindowSQL() { + opts := DefaultDialectOptions() + + opts.SupportsWindowFunction = false + d := sqlDialect{dialect: "test", dialectOptions: opts} + we := W("w").PartitionBy("a", "b").OrderBy("c", "d") + b := sb.NewSQLBuilder(false) + d.WindowSQL(b, we, true) + dts.assertErrorSQL(b, errWindowFunctionNotSupported.Error()) + + opts.SupportsWindowFunction = true + d = sqlDialect{dialect: "test", dialectOptions: opts} + b = sb.NewSQLBuilder(false) + + b = sb.NewSQLBuilder(false) + anErr := errors.New("something wrong") + b.SetError(anErr) + d.WindowSQL(b, we, true) + dts.assertErrorSQL(b, anErr.Error()) + + we = W().PartitionBy("a", "b").OrderBy("c", "d") + b = sb.NewSQLBuilder(false) + d.WindowSQL(b, we, true) + dts.assertErrorSQL(b, errNoWindowName.Error()) + + for _, tt := range []struct { + we exp.WindowExpression + prepared bool + withName bool + expectedSQL string + expectedArgs []interface{} + }{ + { + we: W(), + prepared: false, + withName: false, + expectedSQL: `()`, + }, + { + we: W().Inherit("w"), + prepared: false, + withName: false, + expectedSQL: `("w")`, + }, + { + we: W().PartitionBy("a"), + prepared: false, + withName: false, + expectedSQL: `(PARTITION BY "a")`, + }, + { + we: W().PartitionBy("a", "b"), + prepared: false, + withName: false, + expectedSQL: `(PARTITION BY "a", "b")`, + }, + { + we: W().OrderBy("c"), + prepared: false, + withName: false, + expectedSQL: `(ORDER BY "c")`, + }, + { + we: W().OrderBy("c", "d"), + prepared: false, + withName: false, + expectedSQL: `(ORDER BY "c", "d")`, + }, + { + we: W().PartitionBy("a", "b").OrderBy("c", "d"), + prepared: false, + withName: false, + expectedSQL: `(PARTITION BY "a", "b" ORDER BY "c", "d")`, + }, + { + we: W().Inherit("w1").PartitionBy("a", "b").OrderBy("c", "d"), + prepared: false, + withName: false, + expectedSQL: `("w1" PARTITION BY "a", "b" ORDER BY "c", "d")`, + }, + // withName + { + we: W("w"), + prepared: false, + withName: true, + expectedSQL: `"w" AS ()`, + }, + { + we: W("w1").Inherit("w"), + prepared: false, + withName: true, + expectedSQL: `"w1" AS ("w")`, + }, + { + we: W("w").PartitionBy("a"), + prepared: false, + withName: true, + expectedSQL: `"w" AS (PARTITION BY "a")`, + }, + { + we: W("w").PartitionBy("a", "b"), + prepared: false, + withName: true, + expectedSQL: `"w" AS (PARTITION BY "a", "b")`, + }, + { + we: W("w").OrderBy("c"), + prepared: false, + withName: true, + expectedSQL: `"w" AS (ORDER BY "c")`, + }, + { + we: W("w").OrderBy("c", "d"), + prepared: false, + withName: true, + expectedSQL: `"w" AS (ORDER BY "c", "d")`, + }, + { + we: W("w").PartitionBy("a", "b").OrderBy("c", "d"), + prepared: false, + withName: true, + expectedSQL: `"w" AS (PARTITION BY "a", "b" ORDER BY "c", "d")`, + }, + { + we: W("w").Inherit("w1").PartitionBy("a", "b").OrderBy("c", "d"), + prepared: false, + withName: true, + expectedSQL: `"w" AS ("w1" PARTITION BY "a", "b" ORDER BY "c", "d")`, + }, + { + we: W("w", "w1").PartitionBy("a", "b").OrderBy("c", "d"), + prepared: false, + withName: true, + expectedSQL: `"w" AS ("w1" PARTITION BY "a", "b" ORDER BY "c", "d")`, + }, + } { + b := sb.NewSQLBuilder(tt.prepared) + d.WindowSQL(b, tt.we, tt.withName) + if tt.prepared { + dts.assertPreparedSQL(b, tt.expectedSQL, tt.expectedArgs) + } else { + dts.assertNotPreparedSQL(b, tt.expectedSQL) + } + } +} + func (dts *dialectTestSuite) TestToSelectSQL_withOrder() { opts := DefaultDialectOptions() // override fragments to ensure they are used @@ -2344,6 +2529,21 @@ func (dts *dialectTestSuite) TestLiteral_SQLFunctionExpression() { } +func (dts *dialectTestSuite) TestLiteral_SQLWindowFunctionExpression() { + d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} + + b := sb.NewSQLBuilder(false) + d.Literal(b.Clear(), exp.NewSQLWindowFunctionExpression("RANK")) + dts.assertNotPreparedSQL(b, `RANK() OVER ()`) + + d.Literal(b.Clear(), exp.NewSQLWindowFunctionExpression("RANK").OverName("w")) + dts.assertNotPreparedSQL(b, `RANK() OVER "w"`) + + w := W().PartitionBy("a").OrderBy(I("b").Asc()) + d.Literal(b.Clear(), exp.NewSQLWindowFunctionExpression("RANK").Over(w)) + dts.assertNotPreparedSQL(b, `RANK() OVER (PARTITION BY "a" ORDER BY "b" ASC)`) +} + func (dts *dialectTestSuite) TestLiteral_CastExpression() { d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} b := sb.NewSQLBuilder(false) @@ -2618,6 +2818,39 @@ func (dts *dialectTestSuite) TestLiteral_ExpressionOrMap() { dts.assertPreparedSQL(b, `(("a" = ?) OR ("b" IN (?, ?, ?)))`, []interface{}{int64(1), "a", "b", "c"}) } +func (dts *dialectTestSuite) TestOptions_SQLFragmentType() { + for _, tt := range []struct { + typ SQLFragmentType + expectedStr string + }{ + {typ: CommonTableSQLFragment, expectedStr: "CommonTableSQLFragment"}, + {typ: SelectSQLFragment, expectedStr: "SelectSQLFragment"}, + {typ: FromSQLFragment, expectedStr: "FromSQLFragment"}, + {typ: JoinSQLFragment, expectedStr: "JoinSQLFragment"}, + {typ: WhereSQLFragment, expectedStr: "WhereSQLFragment"}, + {typ: GroupBySQLFragment, expectedStr: "GroupBySQLFragment"}, + {typ: HavingSQLFragment, expectedStr: "HavingSQLFragment"}, + {typ: CompoundsSQLFragment, expectedStr: "CompoundsSQLFragment"}, + {typ: OrderSQLFragment, expectedStr: "OrderSQLFragment"}, + {typ: LimitSQLFragment, expectedStr: "LimitSQLFragment"}, + {typ: OffsetSQLFragment, expectedStr: "OffsetSQLFragment"}, + {typ: ForSQLFragment, expectedStr: "ForSQLFragment"}, + {typ: UpdateBeginSQLFragment, expectedStr: "UpdateBeginSQLFragment"}, + {typ: SourcesSQLFragment, expectedStr: "SourcesSQLFragment"}, + {typ: IntoSQLFragment, expectedStr: "IntoSQLFragment"}, + {typ: UpdateSQLFragment, expectedStr: "UpdateSQLFragment"}, + {typ: UpdateFromSQLFragment, expectedStr: "UpdateFromSQLFragment"}, + {typ: ReturningSQLFragment, expectedStr: "ReturningSQLFragment"}, + {typ: InsertBeingSQLFragment, expectedStr: "InsertBeingSQLFragment"}, + {typ: DeleteBeginSQLFragment, expectedStr: "DeleteBeginSQLFragment"}, + {typ: TruncateSQLFragment, expectedStr: "TruncateSQLFragment"}, + {typ: WindowSQLFragment, expectedStr: "WindowSQLFragment"}, + {typ: SQLFragmentType(10000), expectedStr: "10000"}, + } { + dts.Equal(tt.expectedStr, tt.typ.String()) + } +} + func TestDialectSuite(t *testing.T) { suite.Run(t, new(dialectTestSuite)) } From 6966a186193d4e07382382a4f17e8ee94d6fdf5c Mon Sep 17 00:00:00 2001 From: Xuyuan Pang Date: Thu, 15 Aug 2019 19:21:13 +0800 Subject: [PATCH 2/2] Fixed lint issues --- exp/window.go | 2 +- sql_dialect.go | 23 +++++++++++++---------- sql_dialect_test.go | 1 - 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/exp/window.go b/exp/window.go index 03926c05..c618d70b 100644 --- a/exp/window.go +++ b/exp/window.go @@ -7,7 +7,7 @@ type sqlWindowExpression struct { orderCols ColumnListExpression } -func NewWindowExpression(window string, parent string, partitionCols, orderCols ColumnListExpression) WindowExpression { +func NewWindowExpression(window, parent string, partitionCols, orderCols ColumnListExpression) WindowExpression { if partitionCols == nil { partitionCols = NewColumnListExpression() } diff --git a/sql_dialect.go b/sql_dialect.go index 096e89b9..d0beed0c 100644 --- a/sql_dialect.go +++ b/sql_dialect.go @@ -529,7 +529,7 @@ func (d *sqlDialect) WindowSQL(b sb.SQLBuilder, we exp.WindowExpression, withNam } if withName { name := we.Name() - if len(name) == 0 { + if name == "" { b.SetError(errNoWindowName) return } @@ -1209,15 +1209,18 @@ func (d *sqlDialect) sqlFunctionExpressionSQL(b sb.SQLBuilder, sqlFunc exp.SQLFu b.WriteStrings(sqlFunc.Name()) d.Literal(b, sqlFunc.Args()) - if sqlWinFunc, ok := sqlFunc.(exp.SQLWindowFunctionExpression); ok { - b.Write(d.dialectOptions.WindowOverFragment) - if sqlWinFunc.HasWindowName() { - d.Literal(b, I(sqlWinFunc.WindowName())) - } else if sqlWinFunc.HasWindow() { - d.WindowSQL(b, sqlWinFunc.Window(), false) - } else { - d.WindowSQL(b, emptyWindow, false) - } + sqlWinFunc, ok := sqlFunc.(exp.SQLWindowFunctionExpression) + if !ok { + return + } + b.Write(d.dialectOptions.WindowOverFragment) + switch { + case sqlWinFunc.HasWindowName(): + d.Literal(b, I(sqlWinFunc.WindowName())) + case sqlWinFunc.HasWindow(): + d.WindowSQL(b, sqlWinFunc.Window(), false) + default: + d.WindowSQL(b, emptyWindow, false) } } diff --git a/sql_dialect_test.go b/sql_dialect_test.go index 377157c2..4a9bf43f 100644 --- a/sql_dialect_test.go +++ b/sql_dialect_test.go @@ -1431,7 +1431,6 @@ func (dts *dialectTestSuite) TestWindowSQL() { opts.SupportsWindowFunction = true d = sqlDialect{dialect: "test", dialectOptions: opts} - b = sb.NewSQLBuilder(false) b = sb.NewSQLBuilder(false) anErr := errors.New("something wrong")