Skip to content

Commit

Permalink
executor: support window function row number (#9098)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored Feb 18, 2019
1 parent a8664ef commit cc08569
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 3 deletions.
11 changes: 11 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal
return buildBitXor(aggFuncDesc, ordinal)
case ast.AggFuncBitAnd:
return buildBitAnd(aggFuncDesc, ordinal)
case ast.WindowFuncRowNumber:
return buildRowNumber(aggFuncDesc, ordinal)
}
return nil
}
Expand Down Expand Up @@ -313,3 +315,12 @@ func buildBitAnd(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
}
return &bitAndUint64{baseBitAggFunc{base}}
}

// buildRowNumber builds the AggFunc implementation for function "ROW_NUMBER".
func buildRowNumber(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &rowNumber{base}
}
47 changes: 47 additions & 0 deletions executor/aggfuncs/row_number.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

type rowNumber struct {
baseAggFunc
}

type partialResult4RowNumber struct {
curIdx int64
}

func (rn *rowNumber) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4RowNumber{})
}

func (rn *rowNumber) ResetPartialResult(pr PartialResult) {
p := (*partialResult4RowNumber)(pr)
p.curIdx = 0
}

func (rn *rowNumber) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
return nil
}

func (rn *rowNumber) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4RowNumber)(pr)
p.curIdx++
chk.AppendInt64(rn.ordinal, p.curIdx)
return nil
}
5 changes: 5 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,9 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result.Check(testkit.Rows("21", "21", "21", "21", "21", "21", "21", "21", "21"))
result = tk.MustQuery("select _tidb_rowid, sum(t.a) over() from t")
result.Check(testkit.Rows("1 7", "2 7", "3 7"))

result = tk.MustQuery("select a, row_number() over() from t")
result.Check(testkit.Rows("1 1", "4 2", "2 3"))
result = tk.MustQuery("select a, row_number() over(partition by a) from t")
result.Check(testkit.Rows("1 1", "2 1", "4 1"))
}
22 changes: 19 additions & 3 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
a.typeInfer4MaxMin(ctx)
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
a.typeInfer4BitFuncs(ctx)
case ast.WindowFuncRowNumber:
a.typeInfer4RowNumber()
default:
panic("unsupported agg function: " + a.Name)
}
Expand Down Expand Up @@ -184,6 +186,12 @@ func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) {
// TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0])
}

func (a *baseFuncDesc) typeInfer4RowNumber() {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.Flen = 21
types.SetBinChsClnFlag(a.RetTp)
}

// GetDefaultValue gets the default value when the function's input is null.
// According to MySQL, default values of the function are listed as follows:
// e.g.
Expand Down Expand Up @@ -213,11 +221,19 @@ func (a *baseFuncDesc) GetDefaultValue() (v types.Datum) {
return
}

// We do not need to wrap cast upon these functions,
// since the EvalXXX method called by the arg is determined by the corresponding arg type.
var noNeedCastAggFuncs = map[string]struct{}{
ast.AggFuncCount: {},
ast.AggFuncMax: {},
ast.AggFuncMin: {},
ast.AggFuncFirstRow: {},
ast.WindowFuncRowNumber: {},
}

// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
// We do not need to wrap cast upon these functions,
// since the EvalXXX method called by the arg is determined by the corresponding arg type.
if a.Name == ast.AggFuncCount || a.Name == ast.AggFuncMin || a.Name == ast.AggFuncMax || a.Name == ast.AggFuncFirstRow {
if _, ok := noNeedCastAggFuncs[a.Name]; ok {
return
}
var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression
Expand Down

0 comments on commit cc08569

Please sign in to comment.