Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

util: replace compareDatum by compare, point part #30575

Merged
merged 10 commits into from
Dec 13, 2021
2 changes: 1 addition & 1 deletion util/ranger/detacher.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func ExtractEqAndInCondition(sctx sessionctx.Context, conditions []expression.Ex
mergedAccesses[offset] = accesses[offset]
points[offset] = rb.build(accesses[offset])
}
points[offset] = rb.intersection(points[offset], rb.build(cond))
points[offset] = rb.intersection(points[offset], rb.build(cond), collate.GetCollator(cols[offset].GetType().Collate))
if len(points[offset]) == 0 { // Early termination if false expression found
if expression.MaybeOverOptimized4PlanCache(sctx, conditions) {
// cannot return an empty-range for plan-cache since the range may become non-empty as parameters change
Expand Down
37 changes: 20 additions & 17 deletions util/ranger/points.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ func (rp *point) Clone(value types.Datum) *point {
}

type pointSorter struct {
points []*point
err error
sc *stmtctx.StatementContext
points []*point
err error
sc *stmtctx.StatementContext
collator collate.Collator
}

func (r *pointSorter) Len() int {
Expand All @@ -95,18 +96,18 @@ func (r *pointSorter) Len() int {
func (r *pointSorter) Less(i, j int) bool {
a := r.points[i]
b := r.points[j]
less, err := rangePointLess(r.sc, a, b)
less, err := rangePointLess(r.sc, a, b, r.collator)
if err != nil {
r.err = err
}
return less
}

func rangePointLess(sc *stmtctx.StatementContext, a, b *point) (bool, error) {
func rangePointLess(sc *stmtctx.StatementContext, a, b *point, collator collate.Collator) (bool, error) {
if a.value.Kind() == types.KindMysqlEnum && b.value.Kind() == types.KindMysqlEnum {
return rangePointEnumLess(sc, a, b)
}
cmp, err := a.value.CompareDatum(sc, &b.value)
cmp, err := a.value.Compare(sc, &b.value, collator)
if cmp != 0 {
return cmp < 0, nil
}
Expand Down Expand Up @@ -604,7 +605,7 @@ func (r *builder) buildFromIn(expr *expression.ScalarFunction) ([]*point, bool)
endPoint := &point{value: endValue}
rangePoints = append(rangePoints, startPoint, endPoint)
}
sorter := pointSorter{points: rangePoints, sc: r.sc}
sorter := pointSorter{points: rangePoints, sc: r.sc, collator: collate.GetCollator(colCollate)}
sort.Sort(&sorter)
if sorter.err != nil {
r.err = sorter.err
Expand Down Expand Up @@ -765,13 +766,15 @@ func (r *builder) buildFromNot(expr *expression.ScalarFunction) []*point {
}

func (r *builder) buildFromScalarFunc(expr *expression.ScalarFunction) []*point {
_, coll := expr.CharsetAndCollation()
collator := collate.GetCollator(coll)
switch op := expr.FuncName.L; op {
case ast.GE, ast.GT, ast.LT, ast.LE, ast.EQ, ast.NE, ast.NullEQ:
return r.buildFromBinOp(expr)
case ast.LogicAnd:
return r.intersection(r.build(expr.GetArgs()[0]), r.build(expr.GetArgs()[1]))
return r.intersection(r.build(expr.GetArgs()[0]), r.build(expr.GetArgs()[1]), collator)
case ast.LogicOr:
return r.union(r.build(expr.GetArgs()[0]), r.build(expr.GetArgs()[1]))
return r.union(r.build(expr.GetArgs()[0]), r.build(expr.GetArgs()[1]), collator)
case ast.IsTruthWithoutNull:
return r.buildFromIsTrue(expr, 0, false)
case ast.IsTruthWithNull:
Expand All @@ -794,19 +797,19 @@ func (r *builder) buildFromScalarFunc(expr *expression.ScalarFunction) []*point
return nil
}

func (r *builder) intersection(a, b []*point) []*point {
return r.merge(a, b, false)
func (r *builder) intersection(a, b []*point, collator collate.Collator) []*point {
return r.merge(a, b, false, collator)
}

func (r *builder) union(a, b []*point) []*point {
return r.merge(a, b, true)
func (r *builder) union(a, b []*point, collator collate.Collator) []*point {
return r.merge(a, b, true, collator)
}

func (r *builder) mergeSorted(a, b []*point) []*point {
func (r *builder) mergeSorted(a, b []*point, collator collate.Collator) []*point {
ret := make([]*point, 0, len(a)+len(b))
i, j := 0, 0
for i < len(a) && j < len(b) {
less, err := rangePointLess(r.sc, a[i], b[j])
less, err := rangePointLess(r.sc, a[i], b[j], collator)
if err != nil {
r.err = err
return nil
Expand All @@ -827,8 +830,8 @@ func (r *builder) mergeSorted(a, b []*point) []*point {
return ret
}

func (r *builder) merge(a, b []*point, union bool) []*point {
mergedPoints := r.mergeSorted(a, b)
func (r *builder) merge(a, b []*point, union bool, collator collate.Collator) []*point {
mergedPoints := r.mergeSorted(a, b, collator)
if r.err != nil {
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions util/ranger/ranger.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ func buildColumnRange(accessConditions []expression.Expression, sctx sessionctx.
rb := builder{sc: sctx.GetSessionVars().StmtCtx}
rangePoints := getFullRange()
for _, cond := range accessConditions {
rangePoints = rb.intersection(rangePoints, rb.build(cond))
rangePoints = rb.intersection(rangePoints, rb.build(cond), collate.GetCollator(tp.Collate))
if rb.err != nil {
return nil, errors.Trace(rb.err)
}
Expand Down Expand Up @@ -375,7 +375,7 @@ func (d *rangeDetacher) buildCNFIndexRange(newTp []*types.FieldType,
rangePoints := getFullRange()
// Build rangePoints for non-equal access conditions.
for i := eqAndInCount; i < len(accessCondition); i++ {
rangePoints = rb.intersection(rangePoints, rb.build(accessCondition[i]))
rangePoints = rb.intersection(rangePoints, rb.build(accessCondition[i]), collate.GetCollator(newTp[eqAndInCount].Collate))
if rb.err != nil {
return nil, errors.Trace(rb.err)
}
Expand Down