Skip to content

Commit

Permalink
*: add builtin aggregate function json_objectagg (#11154)
Browse files Browse the repository at this point in the history
  • Loading branch information
hg2990656 authored Feb 4, 2020
1 parent 7720d7d commit ebc6a2d
Show file tree
Hide file tree
Showing 13 changed files with 490 additions and 22 deletions.
1 change: 1 addition & 0 deletions ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ func init() {
mysql.ErrInvalidStoreVersion: mysql.ErrInvalidStoreVersion,
mysql.ErrInvalidUseOfNull: mysql.ErrInvalidUseOfNull,
mysql.ErrJSONUsedAsKey: mysql.ErrJSONUsedAsKey,
mysql.ErrJSONDocumentNULLKey: mysql.ErrJSONDocumentNULLKey,
mysql.ErrKeyColumnDoesNotExits: mysql.ErrKeyColumnDoesNotExits,
mysql.ErrLockWaitTimeout: mysql.ErrLockWaitTimeout,
mysql.ErrNoParts: mysql.ErrNoParts,
Expand Down
172 changes: 172 additions & 0 deletions executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ type aggTest struct {
results []types.Datum
}

type multiArgsAggTest struct {
dataTypes []*types.FieldType
retType *types.FieldType
numRows int
dataGens []func(i int) types.Datum
funcName string
results []types.Datum
}

func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
for i := 0; i < p.numRows; i++ {
Expand Down Expand Up @@ -150,6 +159,99 @@ func buildAggTesterWithFieldType(funcName string, ft *types.FieldType, numRows i
return pt
}

func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
iter := chunk.NewIterator4Chunk(srcChk)

args := make([]expression.Expression, len(p.dataTypes))
for k := 0; k < len(p.dataTypes); k++ {
args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k}
}

desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
c.Assert(err, IsNil)
partialDesc, finalDesc := desc.Split([]int{0, 1})

// build partial func for partial phase.
partialFunc := aggfuncs.Build(s.ctx, partialDesc, 0)
partialResult := partialFunc.AllocPartialResult()

// build final func for final phase.
finalFunc := aggfuncs.Build(s.ctx, finalDesc, 0)
finalPr := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.retType}, 1)

// update partial result.
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, p.retType)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)

err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr)
c.Assert(err, IsNil)
partialFunc.ResetPartialResult(partialResult)

iter.Begin()
iter.Next()
for row := iter.Next(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
resultChk.Reset()
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr)
c.Assert(err, IsNil)

resultChk.Reset()
err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
c.Assert(err, IsNil)

dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[2])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
}

// for multiple args in aggfuncs such as json_objectagg(c1, c2)
func buildMultiArgsAggTester(funcName string, tps []byte, rt byte, numRows int, results ...interface{}) multiArgsAggTest {
fts := make([]*types.FieldType, len(tps))
for i := 0; i < len(tps); i++ {
fts[i] = types.NewFieldType(tps[i])
}
return buildMultiArgsAggTesterWithFieldType(funcName, fts, types.NewFieldType(rt), numRows, results...)
}

func buildMultiArgsAggTesterWithFieldType(funcName string, fts []*types.FieldType, rt *types.FieldType, numRows int, results ...interface{}) multiArgsAggTest {
dataGens := make([]func(i int) types.Datum, len(fts))
for i := 0; i < len(fts); i++ {
dataGens[i] = getDataGenFunc(fts[i])
}
mt := multiArgsAggTest{
dataTypes: fts,
retType: rt,
numRows: numRows,
funcName: funcName,
dataGens: dataGens,
}
for _, result := range results {
mt.results = append(mt.results, types.NewDatum(result))
}
return mt
}

func getDataGenFunc(ft *types.FieldType) func(i int) types.Datum {
switch ft.Tp {
case mysql.TypeLonglong:
Expand Down Expand Up @@ -250,3 +352,73 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
}

func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
srcChk.AppendDatum(0, &types.Datum{})

args := make([]expression.Expression, len(p.dataTypes))
for k := 0; k < len(p.dataTypes); k++ {
args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k}
}

desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
c.Assert(err, IsNil)
finalFunc := aggfuncs.Build(s.ctx, desc, 0)
finalPr := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)

iter := chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)

// test the empty input
resultChk.Reset()
finalFunc.ResetPartialResult(finalPr)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)

// test the agg func with distinct
desc, err = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true)
c.Assert(err, IsNil)
finalFunc = aggfuncs.Build(s.ctx, desc, 0)
finalPr = finalFunc.AllocPartialResult()

resultChk.Reset()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
}
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)

// test the empty input
resultChk.Reset()
finalFunc.ResetPartialResult(finalPr)
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
}
3 changes: 3 additions & 0 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ var (

// All the AggFunc implementations for "BIT_AND" are listed here.
_ AggFunc = (*bitAndUint64)(nil)

// All the AggFunc implementations for "JSON_OBJECTAGG" are listed here
_ AggFunc = (*jsonObjectAgg)(nil)
)

// PartialResult represents data structure to store the partial result for the
Expand Down
16 changes: 16 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal
return buildBitAnd(aggFuncDesc, ordinal)
case ast.AggFuncVarPop:
return buildVarPop(aggFuncDesc, ordinal)
case ast.AggFuncJsonObjectAgg:
return buildJSONObjectAgg(aggFuncDesc, ordinal)
}
return nil
}
Expand Down Expand Up @@ -371,6 +373,20 @@ func buildVarPop(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
}
}

// buildJSONObjectAgg builds the AggFunc implementation for function "json_objectagg".
func buildJSONObjectAgg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
default:
return &jsonObjectAgg{base}
}
}

// buildRowNumber builds the AggFunc implementation for function "ROW_NUMBER".
func buildRowNumber(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
Expand Down
114 changes: 114 additions & 0 deletions executor/aggfuncs/func_json_objectagg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2020 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
)

type jsonObjectAgg struct {
baseAggFunc
}

type partialResult4JsonObjectAgg struct {
entries map[string]interface{}
}

func (e *jsonObjectAgg) AllocPartialResult() PartialResult {
p := partialResult4JsonObjectAgg{}
p.entries = make(map[string]interface{})
return PartialResult(&p)
}

func (e *jsonObjectAgg) ResetPartialResult(pr PartialResult) {
p := (*partialResult4JsonObjectAgg)(pr)
p.entries = make(map[string]interface{})
}

func (e *jsonObjectAgg) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4JsonObjectAgg)(pr)
if len(p.entries) == 0 {
chk.AppendNull(e.ordinal)
return nil
}

// appendBinary does not support some type such as uint8、types.time,so convert is needed here
for key, val := range p.entries {
switch x := val.(type) {
case *types.MyDecimal:
float64Val, err := x.ToFloat64()
if err != nil {
return errors.Trace(err)
}
p.entries[key] = float64Val
case []uint8, types.Time, types.Duration:
strVal, err := types.ToString(x)
if err != nil {
return errors.Trace(err)
}
p.entries[key] = strVal
}
}

chk.AppendJSON(e.ordinal, json.CreateBinary(p.entries))
return nil
}

func (e *jsonObjectAgg) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4JsonObjectAgg)(pr)
for _, row := range rowsInGroup {
key, err := e.args[0].Eval(row)
if err != nil {
return errors.Trace(err)
}

value, err := e.args[1].Eval(row)
if err != nil {
return errors.Trace(err)
}

if key.IsNull() {
return json.ErrJSONDocumentNULLKey
}

// the result json's key is string, so it needs to convert the first arg to string
keyString, err := key.ToString()
if err != nil {
return errors.Trace(err)
}

realVal := value.GetValue()
switch x := realVal.(type) {
case nil, bool, int64, uint64, float64, string, json.BinaryJSON, *types.MyDecimal, []uint8, types.Time, types.Duration:
p.entries[keyString] = realVal
default:
return json.ErrUnsupportedSecondArgumentType.GenWithStackByArgs(x)
}
}
return nil
}

func (e *jsonObjectAgg) MergePartialResult(sctx sessionctx.Context, src PartialResult, dst PartialResult) error {
p1, p2 := (*partialResult4JsonObjectAgg)(src), (*partialResult4JsonObjectAgg)(dst)
// When the result of this function is normalized, values having duplicate keys are discarded,
// and only the last value encountered is used with that key in the returned object
for k, v := range p1.entries {
p2.entries[k] = v
}
return nil
}
Loading

0 comments on commit ebc6a2d

Please sign in to comment.