diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 7d5c559d26ffc..a3f1e8256b449 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -69,6 +69,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag return buildFirstValue(windowFuncDesc, ordinal) case ast.WindowFuncLastValue: return buildLastValue(windowFuncDesc, ordinal) + case ast.WindowFuncCumeDist: + return buildCumeDist(ordinal, orderByCols) default: return Build(ctx, windowFuncDesc, ordinal) } @@ -345,11 +347,7 @@ func buildRank(ordinal int, orderByCols []*expression.Column, isDense bool) AggF base := baseAggFunc{ ordinal: ordinal, } - r := &rank{baseAggFunc: base, isDense: isDense} - for _, col := range orderByCols { - r.cmpFuncs = append(r.cmpFuncs, chunk.GetCompareFunc(col.RetType)) - r.colIdx = append(r.colIdx, col.Index) - } + r := &rank{baseAggFunc: base, isDense: isDense, rowComparer: buildRowComparer(orderByCols)} return r } @@ -368,3 +366,11 @@ func buildLastValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { } return &lastValue{baseAggFunc: base, tp: aggFuncDesc.RetTp} } + +func buildCumeDist(ordinal int, orderByCols []*expression.Column) AggFunc { + base := baseAggFunc{ + ordinal: ordinal, + } + r := &cumeDist{baseAggFunc: base, rowComparer: buildRowComparer(orderByCols)} + return r +} diff --git a/executor/aggfuncs/func_cume_dist.go b/executor/aggfuncs/func_cume_dist.go new file mode 100644 index 0000000000000..37e1ffb1636a5 --- /dev/null +++ b/executor/aggfuncs/func_cume_dist.go @@ -0,0 +1,58 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +type cumeDist struct { + baseAggFunc + rowComparer +} + +type partialResult4CumeDist struct { + curIdx int + lastRank int + rows []chunk.Row +} + +func (r *cumeDist) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4Rank{}) +} + +func (r *cumeDist) ResetPartialResult(pr PartialResult) { + p := (*partialResult4Rank)(pr) + p.curIdx = 0 + p.lastRank = 0 + p.rows = p.rows[:0] +} + +func (r *cumeDist) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4CumeDist)(pr) + p.rows = append(p.rows, rowsInGroup...) + return nil +} + +func (r *cumeDist) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4CumeDist)(pr) + numRows := len(p.rows) + for p.lastRank < numRows && r.compareRows(p.rows[p.curIdx], p.rows[p.lastRank]) == 0 { + p.lastRank++ + } + p.curIdx++ + chk.AppendFloat64(r.ordinal, float64(p.lastRank)/float64(numRows)) + return nil +} diff --git a/executor/aggfuncs/func_rank.go b/executor/aggfuncs/func_rank.go index e73c46e3c45ec..448760e605bb4 100644 --- a/executor/aggfuncs/func_rank.go +++ b/executor/aggfuncs/func_rank.go @@ -14,15 +14,15 @@ package aggfuncs import ( + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/chunk" ) type rank struct { baseAggFunc - isDense bool - cmpFuncs []chunk.CompareFunc - colIdx []int + isDense bool + rowComparer } type partialResult4Rank struct { @@ -48,16 +48,6 @@ func (r *rank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk. return nil } -func (r *rank) compareRows(prev, curr chunk.Row) int { - for i, idx := range r.colIdx { - res := r.cmpFuncs[i](prev, idx, curr, idx) - if res != 0 { - return res - } - } - return 0 -} - func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { p := (*partialResult4Rank)(pr) p.curIdx++ @@ -78,3 +68,29 @@ func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult chk.AppendInt64(r.ordinal, p.lastRank) return nil } + +type rowComparer struct { + cmpFuncs []chunk.CompareFunc + colIdx []int +} + +func buildRowComparer(cols []*expression.Column) rowComparer { + rc := rowComparer{} + rc.colIdx = make([]int, 0, len(cols)) + rc.cmpFuncs = make([]chunk.CompareFunc, 0, len(cols)) + for _, col := range cols { + rc.cmpFuncs = append(rc.cmpFuncs, chunk.GetCompareFunc(col.RetType)) + rc.colIdx = append(rc.colIdx, col.Index) + } + return rc +} + +func (rc *rowComparer) compareRows(prev, curr chunk.Row) int { + for i, idx := range rc.colIdx { + res := rc.cmpFuncs[i](prev, idx, curr, idx) + if res != 0 { + return res + } + } + return 0 +} diff --git a/executor/window_test.go b/executor/window_test.go index 4870251fa3b40..462afca70f414 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -103,4 +103,11 @@ func (s *testSuite2) TestWindowFunctions(c *C) { result = tk.MustQuery("select a, first_value(rand(0)) over(), last_value(rand(0)) over() from t") result.Check(testkit.Rows("1 0.9451961492941164 0.05434383959970039", "1 0.9451961492941164 0.05434383959970039", "2 0.9451961492941164 0.05434383959970039", "2 0.9451961492941164 0.05434383959970039")) + + result = tk.MustQuery("select a, b, cume_dist() over() from t") + result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 1", "2 2 1")) + result = tk.MustQuery("select a, b, cume_dist() over(order by a) from t") + result.Check(testkit.Rows("1 1 0.5", "1 2 0.5", "2 1 1", "2 2 1")) + result = tk.MustQuery("select a, b, cume_dist() over(order by a, b) from t") + result.Check(testkit.Rows("1 1 0.25", "1 2 0.5", "2 1 0.75", "2 2 1")) } diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index cfd1f16b74e39..306e47ec9569d 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -98,6 +98,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { a.typeInfer4BitFuncs(ctx) case ast.WindowFuncRowNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank: a.typeInfer4NumberFuncs() + case ast.WindowFuncCumeDist: + a.typeInfer4CumeDist() default: panic("unsupported agg function: " + a.Name) } @@ -193,6 +195,11 @@ func (a *baseFuncDesc) typeInfer4NumberFuncs() { types.SetBinChsClnFlag(a.RetTp) } +func (a *baseFuncDesc) typeInfer4CumeDist() { + a.RetTp = types.NewFieldType(mysql.TypeDouble) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec +} + // GetDefaultValue gets the default value when the function's input is null. // According to MySQL, default values of the function are listed as follows: // e.g.