From cc085693f10ee4a840a36d3b371de62a4f660044 Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Mon, 18 Feb 2019 17:14:37 +0800 Subject: [PATCH] executor: support window function row number (#9098) --- executor/aggfuncs/builder.go | 11 +++++++ executor/aggfuncs/row_number.go | 47 +++++++++++++++++++++++++++++ executor/window_test.go | 5 +++ expression/aggregation/base_func.go | 22 ++++++++++++-- 4 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 executor/aggfuncs/row_number.go diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 0f2162d567b2c..b01dcf8203598 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -52,6 +52,8 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal return buildBitXor(aggFuncDesc, ordinal) case ast.AggFuncBitAnd: return buildBitAnd(aggFuncDesc, ordinal) + case ast.WindowFuncRowNumber: + return buildRowNumber(aggFuncDesc, ordinal) } return nil } @@ -313,3 +315,12 @@ func buildBitAnd(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { } return &bitAndUint64{baseBitAggFunc{base}} } + +// buildRowNumber builds the AggFunc implementation for function "ROW_NUMBER". +func buildRowNumber(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + } + return &rowNumber{base} +} diff --git a/executor/aggfuncs/row_number.go b/executor/aggfuncs/row_number.go new file mode 100644 index 0000000000000..766eb5750b7ea --- /dev/null +++ b/executor/aggfuncs/row_number.go @@ -0,0 +1,47 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +type rowNumber struct { + baseAggFunc +} + +type partialResult4RowNumber struct { + curIdx int64 +} + +func (rn *rowNumber) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4RowNumber{}) +} + +func (rn *rowNumber) ResetPartialResult(pr PartialResult) { + p := (*partialResult4RowNumber)(pr) + p.curIdx = 0 +} + +func (rn *rowNumber) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + return nil +} + +func (rn *rowNumber) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4RowNumber)(pr) + p.curIdx++ + chk.AppendInt64(rn.ordinal, p.curIdx) + return nil +} diff --git a/executor/window_test.go b/executor/window_test.go index 0e966664add41..0e6e5ba4b01ef 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -40,4 +40,9 @@ func (s *testSuite2) TestWindowFunctions(c *C) { result.Check(testkit.Rows("21", "21", "21", "21", "21", "21", "21", "21", "21")) result = tk.MustQuery("select _tidb_rowid, sum(t.a) over() from t") result.Check(testkit.Rows("1 7", "2 7", "3 7")) + + result = tk.MustQuery("select a, row_number() over() from t") + result.Check(testkit.Rows("1 1", "4 2", "2 3")) + result = tk.MustQuery("select a, row_number() over(partition by a) from t") + result.Check(testkit.Rows("1 1", "2 1", "4 1")) } diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 4f43c0bfa4b82..2f241bca2a1ee 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -95,6 +95,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { a.typeInfer4MaxMin(ctx) case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: a.typeInfer4BitFuncs(ctx) + case ast.WindowFuncRowNumber: + a.typeInfer4RowNumber() default: panic("unsupported agg function: " + a.Name) } @@ -184,6 +186,12 @@ func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) { // TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0]) } +func (a *baseFuncDesc) typeInfer4RowNumber() { + a.RetTp = types.NewFieldType(mysql.TypeLonglong) + a.RetTp.Flen = 21 + types.SetBinChsClnFlag(a.RetTp) +} + // GetDefaultValue gets the default value when the function's input is null. // According to MySQL, default values of the function are listed as follows: // e.g. @@ -213,11 +221,19 @@ func (a *baseFuncDesc) GetDefaultValue() (v types.Datum) { return } +// We do not need to wrap cast upon these functions, +// since the EvalXXX method called by the arg is determined by the corresponding arg type. +var noNeedCastAggFuncs = map[string]struct{}{ + ast.AggFuncCount: {}, + ast.AggFuncMax: {}, + ast.AggFuncMin: {}, + ast.AggFuncFirstRow: {}, + ast.WindowFuncRowNumber: {}, +} + // WrapCastForAggArgs wraps the args of an aggregate function with a cast function. func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) { - // We do not need to wrap cast upon these functions, - // since the EvalXXX method called by the arg is determined by the corresponding arg type. - if a.Name == ast.AggFuncCount || a.Name == ast.AggFuncMin || a.Name == ast.AggFuncMax || a.Name == ast.AggFuncFirstRow { + if _, ok := noNeedCastAggFuncs[a.Name]; ok { return } var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression