From 3a48d9c87700fbde3f9efdff7ae220b5d64933b3 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Fri, 15 Mar 2019 16:01:56 +0800 Subject: [PATCH] executor: add window function NTILE (#9682) --- executor/aggfuncs/builder.go | 11 ++++ executor/aggfuncs/func_ntile.go | 80 +++++++++++++++++++++++++++ executor/window_test.go | 7 +++ expression/aggregation/base_func.go | 9 +++ expression/aggregation/window_func.go | 6 ++ planner/core/logical_plan_test.go | 8 +++ 6 files changed, 121 insertions(+) create mode 100644 executor/aggfuncs/func_ntile.go diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 77f08de8515b6..0253f86624cfa 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -73,6 +73,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag return buildCumeDist(ordinal, orderByCols) case ast.WindowFuncNthValue: return buildNthValue(windowFuncDesc, ordinal) + case ast.WindowFuncNtile: + return buildNtile(windowFuncDesc, ordinal) case ast.WindowFuncPercentRank: return buildPercenRank(ordinal, orderByCols) case ast.WindowFuncLead: @@ -393,6 +395,15 @@ func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth} } +func buildNtile(aggFuncDes *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseAggFunc{ + args: aggFuncDes.Args, + ordinal: ordinal, + } + n, _, _ := expression.GetUint64FromConstant(aggFuncDes.Args[0]) + return &ntile{baseAggFunc: base, n: n} +} + func buildPercenRank(ordinal int, orderByCols []*expression.Column) AggFunc { base := baseAggFunc{ ordinal: ordinal, diff --git a/executor/aggfuncs/func_ntile.go b/executor/aggfuncs/func_ntile.go new file mode 100644 index 0000000000000..1adbb326d7609 --- /dev/null +++ b/executor/aggfuncs/func_ntile.go @@ -0,0 +1,80 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +// ntile divides the partition into n ranked groups and returns the group number a row belongs to. +// e.g. We have 11 rows and n = 3. They will be divided into 3 groups. +// First 4 rows belongs to group 1. Following 4 rows belongs to group 2. The last 3 rows belongs to group 3. +type ntile struct { + n uint64 + baseAggFunc +} + +type partialResult4Ntile struct { + curIdx uint64 + curGroupIdx uint64 + remainder uint64 + quotient uint64 + rows []chunk.Row +} + +func (n *ntile) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4Ntile{curGroupIdx: 1}) +} + +func (n *ntile) ResetPartialResult(pr PartialResult) { + p := (*partialResult4Ntile)(pr) + p.curIdx = 0 + p.curGroupIdx = 1 + p.rows = p.rows[:0] +} + +func (n *ntile) UpdatePartialResult(_ sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Ntile)(pr) + p.rows = append(p.rows, rowsInGroup...) + // Update the quotient and remainder. + if n.n != 0 { + p.quotient = uint64(len(p.rows)) / n.n + p.remainder = uint64(len(p.rows)) % n.n + } + return nil +} + +func (n *ntile) AppendFinalResult2Chunk(_ sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4Ntile)(pr) + + // If the divisor is 0, the arg of NTILE would be NULL. So we just return NULL. + if n.n == 0 { + chk.AppendNull(n.ordinal) + return nil + } + + chk.AppendUint64(n.ordinal, p.curGroupIdx) + + p.curIdx++ + curMaxIdx := p.quotient + if p.curGroupIdx <= p.remainder { + curMaxIdx++ + } + if p.curIdx == curMaxIdx { + p.curIdx = 0 + p.curGroupIdx++ + } + return nil +} diff --git a/executor/window_test.go b/executor/window_test.go index dc17af8a3e870..5ea4e6345c6f2 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -120,6 +120,13 @@ func (s *testSuite2) TestWindowFunctions(c *C) { result = tk.MustQuery("select a, nth_value(a, 5) over() from t") result.Check(testkit.Rows("1 ", "1 ", "2 ", "2 ")) + result = tk.MustQuery("select ntile(3) over() from t") + result.Check(testkit.Rows("1", "1", "2", "3")) + result = tk.MustQuery("select ntile(2) over() from t") + result.Check(testkit.Rows("1", "1", "2", "2")) + result = tk.MustQuery("select ntile(null) over() from t") + result.Check(testkit.Rows("", "", "", "")) + result = tk.MustQuery("select a, percent_rank() over() from t") result.Check(testkit.Rows("1 0", "1 0", "2 0", "2 0")) result = tk.MustQuery("select a, percent_rank() over(order by a) from t") diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 21b986702a496..ba0e716853476 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -100,6 +100,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { a.typeInfer4NumberFuncs() case ast.WindowFuncCumeDist: a.typeInfer4CumeDist() + case ast.WindowFuncNtile: + a.typeInfer4Ntile() case ast.WindowFuncPercentRank: a.typeInfer4PercentRank() case ast.WindowFuncLead, ast.WindowFuncLag: @@ -204,6 +206,13 @@ func (a *baseFuncDesc) typeInfer4CumeDist() { a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec } +func (a *baseFuncDesc) typeInfer4Ntile() { + a.RetTp = types.NewFieldType(mysql.TypeLonglong) + a.RetTp.Flen = 21 + types.SetBinChsClnFlag(a.RetTp) + a.RetTp.Flag |= mysql.UnsignedFlag +} + func (a *baseFuncDesc) typeInfer4PercentRank() { a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flag, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec diff --git a/expression/aggregation/window_func.go b/expression/aggregation/window_func.go index 41c17f16682fb..28ccfed44e98d 100644 --- a/expression/aggregation/window_func.go +++ b/expression/aggregation/window_func.go @@ -35,6 +35,12 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex if !ok || (val == 0 && !isNull) { return nil } + case ast.WindowFuncNtile: + val, isNull, ok := expression.GetUint64FromConstant(args[0]) + // ntile does not allow `0`, but allows `null`. + if !ok || (val == 0 && !isNull) { + return nil + } case ast.WindowFuncLead, ast.WindowFuncLag: if len(args) < 2 { break diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index ef8b0d1f8e2e5..9088421ffddf6 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -2211,6 +2211,14 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { sql: "select nth_value(a, 0) over() from t", result: "[planner:1210]Incorrect arguments to nth_value", }, + { + sql: "select ntile(0) over() from t", + result: "[planner:1210]Incorrect arguments to ntile", + }, + { + sql: "select ntile(null) over() from t", + result: "TableReader(Table(t))->Window(ntile() over())->Projection", + }, { sql: "select avg(a) over w from t window w as(partition by b)", result: "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a)) over(partition by test.t.b))->Projection",