Skip to content

Commit

Permalink
executor: add window function NTILE (#9682)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros authored Mar 15, 2019
1 parent 5f8c4c7 commit 3a48d9c
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 0 deletions.
11 changes: 11 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
return buildCumeDist(ordinal, orderByCols)
case ast.WindowFuncNthValue:
return buildNthValue(windowFuncDesc, ordinal)
case ast.WindowFuncNtile:
return buildNtile(windowFuncDesc, ordinal)
case ast.WindowFuncPercentRank:
return buildPercenRank(ordinal, orderByCols)
case ast.WindowFuncLead:
Expand Down Expand Up @@ -393,6 +395,15 @@ func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth}
}

func buildNtile(aggFuncDes *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDes.Args,
ordinal: ordinal,
}
n, _, _ := expression.GetUint64FromConstant(aggFuncDes.Args[0])
return &ntile{baseAggFunc: base, n: n}
}

func buildPercenRank(ordinal int, orderByCols []*expression.Column) AggFunc {
base := baseAggFunc{
ordinal: ordinal,
Expand Down
80 changes: 80 additions & 0 deletions executor/aggfuncs/func_ntile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

// ntile divides the partition into n ranked groups and returns the group number a row belongs to.
// e.g. We have 11 rows and n = 3. They will be divided into 3 groups.
// First 4 rows belongs to group 1. Following 4 rows belongs to group 2. The last 3 rows belongs to group 3.
type ntile struct {
n uint64
baseAggFunc
}

type partialResult4Ntile struct {
curIdx uint64
curGroupIdx uint64
remainder uint64
quotient uint64
rows []chunk.Row
}

func (n *ntile) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4Ntile{curGroupIdx: 1})
}

func (n *ntile) ResetPartialResult(pr PartialResult) {
p := (*partialResult4Ntile)(pr)
p.curIdx = 0
p.curGroupIdx = 1
p.rows = p.rows[:0]
}

func (n *ntile) UpdatePartialResult(_ sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Ntile)(pr)
p.rows = append(p.rows, rowsInGroup...)
// Update the quotient and remainder.
if n.n != 0 {
p.quotient = uint64(len(p.rows)) / n.n
p.remainder = uint64(len(p.rows)) % n.n
}
return nil
}

func (n *ntile) AppendFinalResult2Chunk(_ sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4Ntile)(pr)

// If the divisor is 0, the arg of NTILE would be NULL. So we just return NULL.
if n.n == 0 {
chk.AppendNull(n.ordinal)
return nil
}

chk.AppendUint64(n.ordinal, p.curGroupIdx)

p.curIdx++
curMaxIdx := p.quotient
if p.curGroupIdx <= p.remainder {
curMaxIdx++
}
if p.curIdx == curMaxIdx {
p.curIdx = 0
p.curGroupIdx++
}
return nil
}
7 changes: 7 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result = tk.MustQuery("select a, nth_value(a, 5) over() from t")
result.Check(testkit.Rows("1 <nil>", "1 <nil>", "2 <nil>", "2 <nil>"))

result = tk.MustQuery("select ntile(3) over() from t")
result.Check(testkit.Rows("1", "1", "2", "3"))
result = tk.MustQuery("select ntile(2) over() from t")
result.Check(testkit.Rows("1", "1", "2", "2"))
result = tk.MustQuery("select ntile(null) over() from t")
result.Check(testkit.Rows("<nil>", "<nil>", "<nil>", "<nil>"))

result = tk.MustQuery("select a, percent_rank() over() from t")
result.Check(testkit.Rows("1 0", "1 0", "2 0", "2 0"))
result = tk.MustQuery("select a, percent_rank() over(order by a) from t")
Expand Down
9 changes: 9 additions & 0 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
a.typeInfer4NumberFuncs()
case ast.WindowFuncCumeDist:
a.typeInfer4CumeDist()
case ast.WindowFuncNtile:
a.typeInfer4Ntile()
case ast.WindowFuncPercentRank:
a.typeInfer4PercentRank()
case ast.WindowFuncLead, ast.WindowFuncLag:
Expand Down Expand Up @@ -204,6 +206,13 @@ func (a *baseFuncDesc) typeInfer4CumeDist() {
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec
}

func (a *baseFuncDesc) typeInfer4Ntile() {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.Flen = 21
types.SetBinChsClnFlag(a.RetTp)
a.RetTp.Flag |= mysql.UnsignedFlag
}

func (a *baseFuncDesc) typeInfer4PercentRank() {
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flag, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec
Expand Down
6 changes: 6 additions & 0 deletions expression/aggregation/window_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex
if !ok || (val == 0 && !isNull) {
return nil
}
case ast.WindowFuncNtile:
val, isNull, ok := expression.GetUint64FromConstant(args[0])
// ntile does not allow `0`, but allows `null`.
if !ok || (val == 0 && !isNull) {
return nil
}
case ast.WindowFuncLead, ast.WindowFuncLag:
if len(args) < 2 {
break
Expand Down
8 changes: 8 additions & 0 deletions planner/core/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2211,6 +2211,14 @@ func (s *testPlanSuite) TestWindowFunction(c *C) {
sql: "select nth_value(a, 0) over() from t",
result: "[planner:1210]Incorrect arguments to nth_value",
},
{
sql: "select ntile(0) over() from t",
result: "[planner:1210]Incorrect arguments to ntile",
},
{
sql: "select ntile(null) over() from t",
result: "TableReader(Table(t))->Window(ntile(<nil>) over())->Projection",
},
{
sql: "select avg(a) over w from t window w as(partition by b)",
result: "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a)) over(partition by test.t.b))->Projection",
Expand Down

0 comments on commit 3a48d9c

Please sign in to comment.