diff --git a/expression/builtin_cast_vec.go b/expression/builtin_cast_vec.go index 0358bfe094cba..cd70ecd65df97 100644 --- a/expression/builtin_cast_vec.go +++ b/expression/builtin_cast_vec.go @@ -530,11 +530,41 @@ func (b *builtinCastRealAsTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk } func (b *builtinCastDecimalAsDecimalSig) vectorized() bool { - return false + return true } func (b *builtinCastDecimalAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, result *chunk.Column) error { - return errors.Errorf("not implemented") + n := input.NumRows() + buf, err := b.bufAllocator.get(types.ETDecimal, n) + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalDecimal(b.ctx, input, buf); err != nil { + return err + } + + result.ResizeDecimal(n, false) + result.MergeNulls(buf) + oldDecs := buf.Decimals() + newDecs := result.Decimals() + sc := b.ctx.GetSessionVars().StmtCtx + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + + dec := &types.MyDecimal{} + if !(b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && oldDecs[i].IsNegative()) { + *dec = oldDecs[i] + } + dec, err = types.ProduceDecWithSpecifiedTp(dec, b.tp, sc) + if err != nil { + return err + } + newDecs[i] = *dec + } + return nil } func (b *builtinCastDurationAsTimeSig) vectorized() bool { diff --git a/expression/builtin_cast_vec_test.go b/expression/builtin_cast_vec_test.go index 5033506fd4a96..8844f5f2b9493 100644 --- a/expression/builtin_cast_vec_test.go +++ b/expression/builtin_cast_vec_test.go @@ -79,6 +79,7 @@ var vecBuiltinCastCases = map[string][]vecExprBenchCase{ geners: []dataGenerator{ &jsonTimeGener{}, }}, + {retEvalType: types.ETDecimal, childrenTypes: []types.EvalType{types.ETDecimal}}, }, }