Skip to content

Commit

Permalink
executor: support window function first_value and last_value (#9560)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored and zz-jason committed Mar 6, 2019
1 parent 2b646cb commit 247777d
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 1 deletion.
20 changes: 20 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
return buildRank(ordinal, orderByCols, true)
case ast.WindowFuncRowNumber:
return buildRowNumber(windowFuncDesc, ordinal)
case ast.WindowFuncFirstValue:
return buildFirstValue(windowFuncDesc, ordinal)
case ast.WindowFuncLastValue:
return buildLastValue(windowFuncDesc, ordinal)
default:
return Build(ctx, windowFuncDesc, ordinal)
}
Expand Down Expand Up @@ -348,3 +352,19 @@ func buildRank(ordinal int, orderByCols []*expression.Column, isDense bool) AggF
}
return r
}

func buildFirstValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &firstValue{baseAggFunc: base, tp: aggFuncDesc.RetTp}
}

func buildLastValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &lastValue{baseAggFunc: base, tp: aggFuncDesc.RetTp}
}
302 changes: 302 additions & 0 deletions executor/aggfuncs/func_value.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
)

// valueEvaluator is used to evaluate values for `first_value`, `last_value`, `nth_value`,
// `lead` and `lag`.
type valueEvaluator interface {
// evaluateRow evaluates the expression using row and stores the result inside.
evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error
// appendResult appends the result to chunk.
appendResult(chk *chunk.Chunk, colIdx int)
}

type value4Int struct {
val int64
isNull bool
}

func (v *value4Int) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalInt(ctx, row)
return err
}

func (v *value4Int) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendInt64(colIdx, v.val)
}
}

type value4Float32 struct {
val float32
isNull bool
}

func (v *value4Float32) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
var val float64
val, v.isNull, err = expr.EvalReal(ctx, row)
v.val = float32(val)
return err
}

func (v *value4Float32) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendFloat32(colIdx, v.val)
}
}

type value4Decimal struct {
val *types.MyDecimal
isNull bool
}

func (v *value4Decimal) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalDecimal(ctx, row)
return err
}

func (v *value4Decimal) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendMyDecimal(colIdx, v.val)
}
}

type value4Float64 struct {
val float64
isNull bool
}

func (v *value4Float64) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalReal(ctx, row)
return err
}

func (v *value4Float64) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendFloat64(colIdx, v.val)
}
}

type value4String struct {
val string
isNull bool
}

func (v *value4String) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalString(ctx, row)
return err
}

func (v *value4String) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendString(colIdx, v.val)
}
}

type value4Time struct {
val types.Time
isNull bool
}

func (v *value4Time) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalTime(ctx, row)
return err
}

func (v *value4Time) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendTime(colIdx, v.val)
}
}

type value4Duration struct {
val types.Duration
isNull bool
}

func (v *value4Duration) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalDuration(ctx, row)
return err
}

func (v *value4Duration) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendDuration(colIdx, v.val)
}
}

type value4JSON struct {
val json.BinaryJSON
isNull bool
}

func (v *value4JSON) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalJSON(ctx, row)
return err
}

func (v *value4JSON) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendJSON(colIdx, v.val)
}
}

func buildValueEvaluator(tp *types.FieldType) valueEvaluator {
evalType := tp.EvalType()
if tp.Tp == mysql.TypeBit {
evalType = types.ETString
}
switch evalType {
case types.ETInt:
return &value4Int{}
case types.ETReal:
switch tp.Tp {
case mysql.TypeFloat:
return &value4Float32{}
case mysql.TypeDouble:
return &value4Float64{}
}
case types.ETDecimal:
return &value4Decimal{}
case types.ETDatetime, types.ETTimestamp:
return &value4Time{}
case types.ETDuration:
return &value4Duration{}
case types.ETString:
return &value4String{}
case types.ETJson:
return &value4JSON{}
}
return nil
}

type firstValue struct {
baseAggFunc

tp *types.FieldType
}

type partialResult4FirstValue struct {
gotFirstValue bool
evaluator valueEvaluator
}

func (v *firstValue) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4FirstValue{evaluator: buildValueEvaluator(v.tp)})
}

func (v *firstValue) ResetPartialResult(pr PartialResult) {
p := (*partialResult4FirstValue)(pr)
p.gotFirstValue = false
}

func (v *firstValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4FirstValue)(pr)
if p.gotFirstValue {
return nil
}
if len(rowsInGroup) > 0 {
p.gotFirstValue = true
err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[0])
if err != nil {
return err
}
}
return nil
}

func (v *firstValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4FirstValue)(pr)
if !p.gotFirstValue {
chk.AppendNull(v.ordinal)
} else {
p.evaluator.appendResult(chk, v.ordinal)
}
return nil
}

type lastValue struct {
baseAggFunc

tp *types.FieldType
}

type partialResult4LastValue struct {
gotLastValue bool
evaluator valueEvaluator
}

func (v *lastValue) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4LastValue{evaluator: buildValueEvaluator(v.tp)})
}

func (v *lastValue) ResetPartialResult(pr PartialResult) {
p := (*partialResult4LastValue)(pr)
p.gotLastValue = false
}

func (v *lastValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4LastValue)(pr)
if len(rowsInGroup) > 0 {
p.gotLastValue = true
err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[len(rowsInGroup)-1])
if err != nil {
return err
}
}
return nil
}

func (v *lastValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4LastValue)(pr)
if !p.gotLastValue {
chk.AppendNull(v.ordinal)
} else {
p.evaluator.appendResult(chk, v.ordinal)
}
return nil
}
10 changes: 10 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,14 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result.Check(testkit.Rows("1 2", "1 2", "2 6", "2 6"))
result = tk.MustQuery("select a, sum(a) over(order by a, b) from t")
result.Check(testkit.Rows("1 1", "1 2", "2 4", "2 6"))

result = tk.MustQuery("select a, first_value(a) over(), last_value(a) over() from t")
result.Check(testkit.Rows("1 1 2", "1 1 2", "2 1 2", "2 1 2"))
result = tk.MustQuery("select a, first_value(a) over(rows between 1 preceding and 1 following), last_value(a) over(rows between 1 preceding and 1 following) from t")
result.Check(testkit.Rows("1 1 1", "1 1 2", "2 1 2", "2 2 2"))
result = tk.MustQuery("select a, first_value(a) over(rows between 1 following and 1 following), last_value(a) over(rows between 1 following and 1 following) from t")
result.Check(testkit.Rows("1 1 1", "1 2 2", "2 2 2", "2 <nil> <nil>"))
result = tk.MustQuery("select a, first_value(rand(0)) over(), last_value(rand(0)) over() from t")
result.Check(testkit.Rows("1 0.9451961492941164 0.05434383959970039", "1 0.9451961492941164 0.05434383959970039",
"2 0.9451961492941164 0.05434383959970039", "2 0.9451961492941164 0.05434383959970039"))
}
3 changes: 2 additions & 1 deletion expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
a.typeInfer4Avg(ctx)
case ast.AggFuncGroupConcat:
a.typeInfer4GroupConcat(ctx)
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow:
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow,
ast.WindowFuncFirstValue, ast.WindowFuncLastValue:
a.typeInfer4MaxMin(ctx)
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
a.typeInfer4BitFuncs(ctx)
Expand Down

0 comments on commit 247777d

Please sign in to comment.