Skip to content

Commit d1848b0

Browse files
alivxxxngaut
authored andcommitted
expression: fix data race of rand function (#11168) (#11169)
1 parent 6e0de41 commit d1848b0

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

expression/builtin_math.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"math/rand"
2525
"strconv"
2626
"strings"
27+
"sync"
2728
"time"
2829

2930
"github.com/cznic/mathutil"
@@ -966,7 +967,7 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
966967
bt := bf
967968
if len(args) == 0 {
968969
seed := time.Now().UnixNano()
969-
sig = &builtinRandSig{bt, rand.New(rand.NewSource(seed))}
970+
sig = &builtinRandSig{bt, &sync.Mutex{}, rand.New(rand.NewSource(seed))}
970971
} else if _, isConstant := args[0].(*Constant); isConstant {
971972
// According to MySQL manual:
972973
// If an integer argument N is specified, it is used as the seed value:
@@ -979,7 +980,7 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
979980
if isNull {
980981
seed = time.Now().UnixNano()
981982
}
982-
sig = &builtinRandSig{bt, rand.New(rand.NewSource(seed))}
983+
sig = &builtinRandSig{bt, &sync.Mutex{}, rand.New(rand.NewSource(seed))}
983984
} else {
984985
sig = &builtinRandWithSeedSig{bt}
985986
}
@@ -988,19 +989,23 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
988989

989990
type builtinRandSig struct {
990991
baseBuiltinFunc
992+
mu *sync.Mutex
991993
randGen *rand.Rand
992994
}
993995

994996
func (b *builtinRandSig) Clone() builtinFunc {
995-
newSig := &builtinRandSig{randGen: b.randGen}
997+
newSig := &builtinRandSig{randGen: b.randGen, mu: b.mu}
996998
newSig.cloneFrom(&b.baseBuiltinFunc)
997999
return newSig
9981000
}
9991001

10001002
// evalReal evals RAND().
10011003
// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_rand
10021004
func (b *builtinRandSig) evalReal(row chunk.Row) (float64, bool, error) {
1003-
return b.randGen.Float64(), false, nil
1005+
b.mu.Lock()
1006+
res := b.randGen.Float64()
1007+
b.mu.Unlock()
1008+
return res, false, nil
10041009
}
10051010

10061011
type builtinRandWithSeedSig struct {

expression/integration_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,8 @@ func (s *testIntegrationSuite) TestMathBuiltin(c *C) {
550550
tk.MustExec("drop table if exists t")
551551
tk.MustExec("create table t(a int)")
552552
tk.MustExec("insert into t values(1),(2),(3)")
553+
tk.Se.GetSessionVars().MaxChunkSize = 1
554+
tk.MustQuery("select rand(1) from t").Sort().Check(testkit.Rows("0.6046602879796196", "0.6645600532184904", "0.9405090880450124"))
553555
tk.MustQuery("select rand(a) from t").Check(testkit.Rows("0.6046602879796196", "0.16729663442585624", "0.7199826688373036"))
554556
tk.MustQuery("select rand(1), rand(2), rand(3)").Check(testkit.Rows("0.6046602879796196 0.16729663442585624 0.7199826688373036"))
555557
}

0 commit comments

Comments
 (0)