Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: support window function first_value and last_value #9560

Merged
merged 2 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
return buildRank(ordinal, orderByCols, true)
case ast.WindowFuncRowNumber:
return buildRowNumber(windowFuncDesc, ordinal)
case ast.WindowFuncFirstValue:
return buildFirstValue(windowFuncDesc, ordinal)
case ast.WindowFuncLastValue:
return buildLastValue(windowFuncDesc, ordinal)
default:
return Build(ctx, windowFuncDesc, ordinal)
}
Expand Down Expand Up @@ -348,3 +352,19 @@ func buildRank(ordinal int, orderByCols []*expression.Column, isDense bool) AggF
}
return r
}

func buildFirstValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &firstValue{baseAggFunc: base, tp: aggFuncDesc.RetTp}
}

func buildLastValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return &lastValue{baseAggFunc: base, tp: aggFuncDesc.RetTp}
}
302 changes: 302 additions & 0 deletions executor/aggfuncs/func_value.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
)

// valueEvaluator is used to evaluate values for `first_value`, `last_value`, `nth_value`,
// `lead` and `lag`.
type valueEvaluator interface {
// evaluateRow evaluates the expression using row and stores the result inside.
evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error
// appendResult appends the result to chunk.
appendResult(chk *chunk.Chunk, colIdx int)
}

type value4Int struct {
val int64
isNull bool
}

func (v *value4Int) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalInt(ctx, row)
return err
}

func (v *value4Int) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendInt64(colIdx, v.val)
}
}

type value4Float32 struct {
val float32
isNull bool
}

func (v *value4Float32) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
var val float64
val, v.isNull, err = expr.EvalReal(ctx, row)
v.val = float32(val)
return err
}

func (v *value4Float32) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendFloat32(colIdx, v.val)
}
}

type value4Decimal struct {
val *types.MyDecimal
isNull bool
}

func (v *value4Decimal) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalDecimal(ctx, row)
return err
}

func (v *value4Decimal) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendMyDecimal(colIdx, v.val)
}
}

type value4Float64 struct {
val float64
isNull bool
}

func (v *value4Float64) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalReal(ctx, row)
return err
}

func (v *value4Float64) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendFloat64(colIdx, v.val)
}
}

type value4String struct {
val string
isNull bool
}

func (v *value4String) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalString(ctx, row)
return err
}

func (v *value4String) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendString(colIdx, v.val)
}
}

type value4Time struct {
val types.Time
isNull bool
}

func (v *value4Time) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalTime(ctx, row)
return err
}

func (v *value4Time) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendTime(colIdx, v.val)
}
}

type value4Duration struct {
val types.Duration
isNull bool
}

func (v *value4Duration) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalDuration(ctx, row)
return err
}

func (v *value4Duration) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendDuration(colIdx, v.val)
}
}

type value4JSON struct {
val json.BinaryJSON
isNull bool
}

func (v *value4JSON) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error {
var err error
v.val, v.isNull, err = expr.EvalJSON(ctx, row)
return err
}

func (v *value4JSON) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendJSON(colIdx, v.val)
}
}

func buildValueEvaluator(tp *types.FieldType) valueEvaluator {
evalType := tp.EvalType()
if tp.Tp == mysql.TypeBit {
evalType = types.ETString
}
switch evalType {
case types.ETInt:
return &value4Int{}
case types.ETReal:
switch tp.Tp {
case mysql.TypeFloat:
return &value4Float32{}
case mysql.TypeDouble:
return &value4Float64{}
}
case types.ETDecimal:
return &value4Decimal{}
case types.ETDatetime, types.ETTimestamp:
return &value4Time{}
case types.ETDuration:
return &value4Duration{}
case types.ETString:
return &value4String{}
case types.ETJson:
return &value4JSON{}
}
return nil
}

type firstValue struct {
baseAggFunc

tp *types.FieldType
}

type partialResult4FirstValue struct {
gotFirstValue bool
evaluator valueEvaluator
}

func (v *firstValue) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4FirstValue{evaluator: buildValueEvaluator(v.tp)})
}

func (v *firstValue) ResetPartialResult(pr PartialResult) {
p := (*partialResult4FirstValue)(pr)
p.gotFirstValue = false
}

func (v *firstValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4FirstValue)(pr)
if p.gotFirstValue {
return nil
}
if len(rowsInGroup) > 0 {
p.gotFirstValue = true
err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[0])
if err != nil {
return err
}
}
return nil
}

func (v *firstValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4FirstValue)(pr)
if !p.gotFirstValue {
chk.AppendNull(v.ordinal)
} else {
p.evaluator.appendResult(chk, v.ordinal)
}
return nil
}

type lastValue struct {
baseAggFunc

tp *types.FieldType
}

type partialResult4LastValue struct {
gotLastValue bool
evaluator valueEvaluator
}

func (v *lastValue) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4LastValue{evaluator: buildValueEvaluator(v.tp)})
}

func (v *lastValue) ResetPartialResult(pr PartialResult) {
p := (*partialResult4LastValue)(pr)
p.gotLastValue = false
}

func (v *lastValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4LastValue)(pr)
if len(rowsInGroup) > 0 {
p.gotLastValue = true
err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[len(rowsInGroup)-1])
if err != nil {
return err
}
}
return nil
}

func (v *lastValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4LastValue)(pr)
if !p.gotLastValue {
chk.AppendNull(v.ordinal)
} else {
p.evaluator.appendResult(chk, v.ordinal)
}
return nil
}
10 changes: 10 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,14 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result.Check(testkit.Rows("1 2", "1 2", "2 6", "2 6"))
result = tk.MustQuery("select a, sum(a) over(order by a, b) from t")
result.Check(testkit.Rows("1 1", "1 2", "2 4", "2 6"))

result = tk.MustQuery("select a, first_value(a) over(), last_value(a) over() from t")
result.Check(testkit.Rows("1 1 2", "1 1 2", "2 1 2", "2 1 2"))
result = tk.MustQuery("select a, first_value(a) over(rows between 1 preceding and 1 following), last_value(a) over(rows between 1 preceding and 1 following) from t")
result.Check(testkit.Rows("1 1 1", "1 1 2", "2 1 2", "2 2 2"))
result = tk.MustQuery("select a, first_value(a) over(rows between 1 following and 1 following), last_value(a) over(rows between 1 following and 1 following) from t")
result.Check(testkit.Rows("1 1 1", "1 2 2", "2 2 2", "2 <nil> <nil>"))
result = tk.MustQuery("select a, first_value(rand(0)) over(), last_value(rand(0)) over() from t")
result.Check(testkit.Rows("1 0.9451961492941164 0.05434383959970039", "1 0.9451961492941164 0.05434383959970039",
"2 0.9451961492941164 0.05434383959970039", "2 0.9451961492941164 0.05434383959970039"))
}
3 changes: 2 additions & 1 deletion expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
a.typeInfer4Avg(ctx)
case ast.AggFuncGroupConcat:
a.typeInfer4GroupConcat(ctx)
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow:
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow,
ast.WindowFuncFirstValue, ast.WindowFuncLastValue:
a.typeInfer4MaxMin(ctx)
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
a.typeInfer4BitFuncs(ctx)
Expand Down