Skip to content

Commit

Permalink
[parser] parser: support window function ast (pingcap#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored and xhebox committed Oct 8, 2021
1 parent 08af09a commit c583f72
Show file tree
Hide file tree
Showing 6 changed files with 585 additions and 138 deletions.
174 changes: 174 additions & 0 deletions parser/ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ var (
_ Node = &TableSource{}
_ Node = &UnionSelectList{}
_ Node = &WildCardField{}
_ Node = &WindowSpec{}
_ Node = &PartitionByClause{}
_ Node = &FrameClause{}
_ Node = &FrameBound{}
)

// JoinType is join type, including cross/left/right/full.
Expand Down Expand Up @@ -468,6 +472,8 @@ type SelectStmt struct {
GroupBy *GroupByClause
// Having is the having condition.
Having *HavingClause
// WindowSpecs is the window specification list.
WindowSpecs []WindowSpec
// OrderBy is the ordering expression list.
OrderBy *OrderByClause
// Limit is the limit clause.
Expand Down Expand Up @@ -542,6 +548,14 @@ func (n *SelectStmt) Accept(v Visitor) (Node, bool) {
n.Having = node.(*HavingClause)
}

for i, spec := range n.WindowSpecs {
node, ok := spec.Accept(v)
if !ok {
return n, false
}
n.WindowSpecs[i] = *node.(*WindowSpec)
}

if n.OrderBy != nil {
node, ok := n.OrderBy.Accept(v)
if !ok {
Expand Down Expand Up @@ -1035,3 +1049,163 @@ func (n *ShowStmt) Accept(v Visitor) (Node, bool) {
}
return v.Leave(n)
}

// WindowSpec is the specification of a window.
type WindowSpec struct {
node

Name model.CIStr
// Ref is the reference window of this specification. For example, in `w2 as (w1 order by a)`,
// the definition of `w2` references `w1`.
Ref model.CIStr

PartitionBy *PartitionByClause
OrderBy *OrderByClause
Frame *FrameClause
}

// Accept implements Node Accept interface.
func (n *WindowSpec) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*WindowSpec)
if n.PartitionBy != nil {
node, ok := n.PartitionBy.Accept(v)
if !ok {
return n, false
}
n.PartitionBy = node.(*PartitionByClause)
}
if n.OrderBy != nil {
node, ok := n.OrderBy.Accept(v)
if !ok {
return n, false
}
n.OrderBy = node.(*OrderByClause)
}
if n.Frame != nil {
node, ok := n.Frame.Accept(v)
if !ok {
return n, false
}
n.Frame = node.(*FrameClause)
}
return v.Leave(n)
}

// PartitionByClause represents partition by clause.
type PartitionByClause struct {
node

Items []*ByItem
}

// Accept implements Node Accept interface.
func (n *PartitionByClause) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*PartitionByClause)
for i, val := range n.Items {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Items[i] = node.(*ByItem)
}
return v.Leave(n)
}

// FrameType is the type of window function frame.
type FrameType int

// Window function frame types.
// MySQL only supports `ROWS` and `RANGES`.
const (
Rows = iota
Ranges
Groups
)

// FrameClause represents frame clause.
type FrameClause struct {
node

Type FrameType
Extent FrameExtent
}

// Accept implements Node Accept interface.
func (n *FrameClause) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*FrameClause)
node, ok := n.Extent.Start.Accept(v)
if !ok {
return n, false
}
n.Extent.Start = *node.(*FrameBound)
node, ok = n.Extent.End.Accept(v)
if !ok {
return n, false
}
n.Extent.End = *node.(*FrameBound)
return v.Leave(n)
}

// FrameExtent represents frame extent.
type FrameExtent struct {
Start FrameBound
End FrameBound
}

// FrameType is the type of window function frame bound.
type BoundType int

// Frame bound types.
const (
Following = iota
Preceding
CurrentRow
)

// FrameBound represents frame bound.
type FrameBound struct {
node

Type BoundType
UnBounded bool
Expr ExprNode
// `Unit` is used to indicate the units in which the `Expr` should be interpreted.
// For example: '2:30' MINUTE_SECOND.
Unit ExprNode
}

// Accept implements Node Accept interface.
func (n *FrameBound) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*FrameBound)
if n.Expr != nil {
node, ok := n.Expr.Accept(v)
if !ok {
return n, false
}
n.Expr = node.(ExprNode)
}
if n.Unit != nil {
node, ok := n.Expr.Accept(v)
if !ok {
return n, false
}
n.Unit = node.(ExprNode)
}
return v.Leave(n)
}
4 changes: 4 additions & 0 deletions parser/ast/dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ func (ts *testDMLSuite) TestDMLVisitorCover(c *C) {
{&SelectStmt{}, 0, 0},
{&FieldList{}, 0, 0},
{&UnionSelectList{}, 0, 0},
{&WindowSpec{}, 0, 0},
{&PartitionByClause{}, 0, 0},
{&FrameClause{}, 0, 0},
{&FrameBound{}, 0, 0},
}

for _, v := range stmts {
Expand Down
74 changes: 74 additions & 0 deletions parser/ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var (
_ FuncNode = &AggregateFuncExpr{}
_ FuncNode = &FuncCallExpr{}
_ FuncNode = &FuncCastExpr{}
_ FuncNode = &WindowFuncExpr{}
)

// List scalar function names.
Expand Down Expand Up @@ -522,3 +523,76 @@ func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) {
}
return v.Leave(n)
}

const (
// WindowFuncRowNumber is the name of row_number function.
WindowFuncRowNumber = "row_number"
// WindowFuncRank is the name of rank function.
WindowFuncRank = "rank"
// WindowFuncDenseRank is the name of dense_rank function.
WindowFuncDenseRank = "dense_rank"
// WindowFuncCumeDist is the name of cume_dist function.
WindowFuncCumeDist = "cume_dist"
// WindowFuncPercentRank is the name of percent_rank function.
WindowFuncPercentRank = "percent_rank"
// WindowFuncNtile is the name of ntile function.
WindowFuncNtile = "ntile"
// WindowFuncLead is the name of lead function.
WindowFuncLead = "lead"
// WindowFuncLag is the name of lag function.
WindowFuncLag = "lag"
// WindowFuncFirstValue is the name of first_value function.
WindowFuncFirstValue = "first_value"
// WindowFuncLastValue is the name of last_value function.
WindowFuncLastValue = "last_value"
// WindowFuncNthValue is the name of nth_value function.
WindowFuncNthValue = "nth_value"
)

// WindowFuncExpr represents window function expression.
type WindowFuncExpr struct {
funcNode

// F is the function name.
F string
// Args is the function args.
Args []ExprNode
// Distinct cannot be true for most window functions, except `max` and `min`.
// We need to raise error if it is not allowed to be true.
Distinct bool
// IgnoreNull indicates how to handle null value.
// MySQL only supports `RESPECT NULLS`, so we need to raise error if it is true.
IgnoreNull bool
// FromLast indicates the calculation direction of this window function.
// MySQL only supports calculation from first, so we need to raise error if it is true.
FromLast bool
// Spec is the specification of this window.
Spec WindowSpec
}

// Format formats the window function expression into a Writer.
func (n *WindowFuncExpr) Format(w io.Writer) {
panic("Not implemented")
}

// Accept implements Node Accept interface.
func (n *WindowFuncExpr) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*WindowFuncExpr)
for i, val := range n.Args {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Args[i] = node.(ExprNode)
}
node, ok := n.Spec.Accept(v)
if !ok {
return n, false
}
n.Spec = *node.(*WindowSpec)
return v.Leave(n)
}
1 change: 1 addition & 0 deletions parser/ast/functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func (ts *testFunctionsSuite) TestFunctionsVisitorCover(c *C) {
&AggregateFuncExpr{Args: []ExprNode{valueExpr}},
&FuncCallExpr{Args: []ExprNode{valueExpr}},
&FuncCastExpr{Expr: valueExpr},
&WindowFuncExpr{Spec: WindowSpec{}},
}

for _, stmt := range stmts {
Expand Down
Loading

0 comments on commit c583f72

Please sign in to comment.