Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evalengine: Improve weight string support #13658

Merged
merged 2 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions go/mysql/datetime/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package datetime

import (
"encoding/binary"
"time"

"vitess.io/vitess/go/mysql/decimal"
Expand Down Expand Up @@ -641,6 +642,20 @@ func (dt *DateTime) addInterval(itv *Interval) bool {
}
}

func (dt DateTime) WeightString(dst []byte) []byte {
// This logic does the inverse of what we do in the binlog parser for the datetime2 type.
year, month, day := dt.Date.Year(), dt.Date.Month(), dt.Date.Day()
ymd := uint64(year*13+month)<<5 | uint64(day)
hms := uint64(dt.Time.Hour())<<12 | uint64(dt.Time.Minute())<<6 | uint64(dt.Time.Second())
raw := (ymd<<17|hms)<<24 + uint64(dt.Time.Nanosecond()/1000)
if dt.Time.Neg() {
raw = -raw
}

raw = raw ^ (1 << 63)
return binary.BigEndian.AppendUint64(dst, raw)
}

func NewDateFromStd(t time.Time) Date {
year, month, day := t.Date()
return Date{
Expand Down
5 changes: 5 additions & 0 deletions go/mysql/datetime/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,5 +399,10 @@ func ParseTimeDecimal(d decimal.Decimal, l int32, prec int) (Time, int, bool) {
} else {
t = t.Round(prec)
}
// We only support a maximum of nanosecond precision,
// so if the decimal has any larger precision we truncate it.
if prec > 9 {
prec = 9
}
return t, prec, ok
}
56 changes: 56 additions & 0 deletions go/mysql/decimal/weights.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
Copyright 2023 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package decimal

// Our weight string format is normalizing the weight string to a fixed length,
// so it becomes byte-ordered. The byte lengths are pre-computed based on
// https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
// and generated empirically with a manual loop:
//
// for i := 1; i <= 65; i++ {
// dec, err := NewFromMySQL(bytes.Repeat([]byte("9"), i))
// if err != nil {
// t.Fatal(err)
// }
//
// byteLengths = append(byteLengths, len(dec.value.Bytes()))
// }
var weightStringLengths = []int{
0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5, 6, 6, 7, 7, 8, 8, 8,
9, 9, 10, 10, 10, 11, 11, 12, 12, 13, 13, 13, 14, 14, 15, 15, 15,
16, 16, 17, 17, 18, 18, 18, 19, 19, 20, 20, 20, 21, 21, 22, 22,
23, 23, 23, 24, 24, 25, 25, 25, 26, 26, 27, 27, 27,
}

func (d Decimal) WeightString(dst []byte, length, precision int32) []byte {
dec := d.rescale(-precision)
dec = dec.Clamp(length-precision, precision)

buf := make([]byte, weightStringLengths[length]+1)
dec.value.FillBytes(buf[:])

if dec.value.Sign() < 0 {
for i := range buf {
buf[i] ^= 0xff
}
}
// Use the same trick as used for signed numbers on the first byte.
buf[0] ^= 0x80

dst = append(dst, buf[:]...)
return dst
}
20 changes: 13 additions & 7 deletions go/vt/vtgate/evalengine/api_hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,19 @@ func randTime() time.Time {
return time.Unix(sec, 0)
}

func randomNull() sqltypes.Value { return sqltypes.NULL }
func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) }
func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) }
func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) }
func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) }
func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) }
func randomDecimal() sqltypes.Value { return sqltypes.NewDecimal(fmt.Sprintf("%d", rand.Int63())) }
func randomNull() sqltypes.Value { return sqltypes.NULL }
func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) }
func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) }
func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) }
func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) }
func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) }
func randomDecimal() sqltypes.Value {
dec := fmt.Sprintf("%d.%d", rand.Intn(9999999999), rand.Intn(9999999999))
if rand.Int()&0x1 == 1 {
dec = "-" + dec
}
return sqltypes.NewDecimal(dec)
}
func randomVarChar() sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf("%d", rand.Int63())) }
func randomDate() sqltypes.Value { return sqltypes.NewDate(randTime().Format(time.DateOnly)) }
func randomDatetime() sqltypes.Value { return sqltypes.NewDatetime(randTime().Format(time.DateTime)) }
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 8 additions & 13 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2732,22 +2732,17 @@ func (asm *assembler) Fn_TO_BASE64(t sqltypes.Type, col collations.TypedCollatio
}, "FN TO_BASE64 VARCHAR(SP-1)")
}

func (asm *assembler) Fn_WEIGHT_STRING_b(length int) {
func (asm *assembler) Fn_WEIGHT_STRING(length int) {
asm.emit(func(env *ExpressionEnv) int {
str := env.vm.stack[env.vm.sp-1].(*evalBytes)
w := collations.Binary.WeightString(make([]byte, 0, length), str.bytes, collations.PadToMax)
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalBinary(w)
return 1
}, "FN WEIGHT_STRING VARBINARY(SP-1)")
}

func (asm *assembler) Fn_WEIGHT_STRING_c(col collations.Collation, length int) {
asm.emit(func(env *ExpressionEnv) int {
str := env.vm.stack[env.vm.sp-1].(*evalBytes)
w := col.WeightString(nil, str.bytes, length)
input := env.vm.stack[env.vm.sp-1]
w, _, err := evalWeightString(nil, input, length, 0)
if err != nil {
env.vm.err = err
return 1
}
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalBinary(w)
return 1
}, "FN WEIGHT_STRING VARCHAR(SP-1)")
}, "FN WEIGHT_STRING (SP-1)")
}

func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) {
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,10 @@ func TestCompilerSingle(t *testing.T) {
expression: `concat('test', _latin1 0xff)`,
result: `VARCHAR("testÿ")`,
},
{
expression: `WEIGHT_STRING('foobar' as char(3))`,
result: `VARBINARY("\x1c\xe5\x1d\xdd\x1d\xdd")`,
},
}

for _, tc := range testCases {
Expand Down
20 changes: 19 additions & 1 deletion go/vt/vtgate/evalengine/expr_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) {
case "JSON":
return evalToJSON(e)
case "DATETIME":
switch p := c.Length; {
case p > 6:
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p)
}
if dt := evalToDateTime(e, c.Length); dt != nil {
return dt, nil
}
Expand All @@ -130,6 +134,10 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) {
}
return nil, nil
case "TIME":
switch p := c.Length; {
case p > 6:
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p)
}
if t := evalToTime(e, c.Length); t != nil {
return t, nil
}
Expand Down Expand Up @@ -227,6 +235,9 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) {
case "DOUBLE", "REAL":
convt = c.compileToFloat(arg, 1)

case "FLOAT":
return ctype{}, c.unsupported(conv)

case "SIGNED", "SIGNED INTEGER":
convt = c.compileToInt64(arg, 1)

Expand All @@ -244,9 +255,17 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) {
convt = c.compileToDate(arg, 1)

case "DATETIME":
switch p := conv.Length; {
case p > 6:
return ctype{}, c.unsupported(conv)
}
convt = c.compileToDateTime(arg, 1, conv.Length)

case "TIME":
switch p := conv.Length; {
case p > 6:
return ctype{}, c.unsupported(conv)
}
convt = c.compileToTime(arg, 1, conv.Length)

default:
Expand All @@ -256,7 +275,6 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) {
c.asm.jumpDestination(skip)
convt.Flag = arg.Flag | flagNullable
return convt, nil

}

func (c *ConvertUsingExpr) eval(env *ExpressionEnv) (eval, error) {
Expand Down
93 changes: 55 additions & 38 deletions go/vt/vtgate/evalengine/fn_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type (
}

builtinWeightString struct {
String Expr
Expr Expr
Cast string
Len int
HasLen bool
Expand Down Expand Up @@ -455,76 +455,93 @@ func (expr *builtinCollation) compile(c *compiler) (ctype, error) {
}

func (c *builtinWeightString) callable() []Expr {
return []Expr{c.String}
return []Expr{c.Expr}
}

func (c *builtinWeightString) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) {
_, f := c.String.typeof(env, fields)
_, f := c.Expr.typeof(env, fields)
return sqltypes.VarBinary, f
}

func (c *builtinWeightString) eval(env *ExpressionEnv) (eval, error) {
var (
tc collations.TypedCollation
text []byte
weights []byte
length = c.Len
)

str, err := c.String.eval(env)
var weights []byte

input, err := c.Expr.eval(env)
if err != nil {
return nil, err
}

switch str := str.(type) {
case *evalInt64, *evalUint64:
// when calling WEIGHT_STRING with an integral value, MySQL returns the
// internal sort key that would be used in an InnoDB table... we do not
// support that
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%s: %s", ErrEvaluatedExprNotSupported, FormatExpr(c))
if c.Cast == "binary" {
weights, _, err = evalWeightString(weights, evalToBinary(input), c.Len, 0)
if err != nil {
return nil, err
}
return newEvalBinary(weights), nil
}

switch val := input.(type) {
case *evalInt64, *evalUint64, *evalTemporal:
weights, _, err = evalWeightString(weights, val, 0, 0)
case *evalBytes:
text = str.bytes
tc = str.col
if val.isBinary() {
weights, _, err = evalWeightString(weights, val, 0, 0)
} else {
var strLen int
if c.Cast == "char" {
strLen = c.Len
}
weights, _, err = evalWeightString(weights, val, strLen, 0)
}
default:
return nil, nil
}

if c.Cast == "binary" {
tc = collationBinary
weights = make([]byte, 0, c.Len)
length = collations.PadToMax
if err != nil {
return nil, err
}

collation := tc.Collation.Get()
weights = collation.WeightString(weights, text, length)
return newEvalBinary(weights), nil
}

func (call *builtinWeightString) compile(c *compiler) (ctype, error) {
str, err := call.String.compile(c)
str, err := call.Expr.compile(c)
if err != nil {
return ctype{}, err
}

switch str.Type {
case sqltypes.Int64, sqltypes.Uint64:
return ctype{}, c.unsupported(call)

case sqltypes.VarChar, sqltypes.VarBinary:
skip := c.compileNullCheck1(str)
var flag typeFlag
if str.Flag&flagNullable != 0 {
flag = flag | flagNullable
}

if call.Cast == "binary" {
c.asm.Fn_WEIGHT_STRING_b(call.Len)
} else {
c.asm.Fn_WEIGHT_STRING_c(str.Col.Collation.Get(), call.Len)
skip := c.compileNullCheck1(str)
if call.Cast == "binary" {
if !sqltypes.IsBinary(str.Type) {
c.asm.Convert_xb(1, sqltypes.VarBinary, 0, false)
}
c.asm.Fn_WEIGHT_STRING(call.Len)
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil
return ctype{Type: sqltypes.VarBinary, Flag: flagNullable | flagNull, Col: collationBinary}, nil
}

switch str.Type {
case sqltypes.Int64, sqltypes.Uint64, sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp, sqltypes.Time, sqltypes.VarBinary, sqltypes.Binary, sqltypes.Blob:
c.asm.Fn_WEIGHT_STRING(0)

case sqltypes.VarChar, sqltypes.Char, sqltypes.Text:
var strLen int
if call.Cast == "char" {
strLen = call.Len
}
c.asm.Fn_WEIGHT_STRING(strLen)

default:
c.asm.SetNull(1)
return ctype{Type: sqltypes.VarBinary, Flag: flagNullable | flagNull, Col: collationBinary}, nil
flag = flag | flagNull | flagNullable
}

c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.VarBinary, Flag: flag, Col: collationBinary}, nil
}

func (call builtinLeftRight) eval(env *ExpressionEnv) (eval, error) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (c *CallExpr) format(w *formatter, depth int) {

func (c *builtinWeightString) format(w *formatter, depth int) {
w.WriteString("WEIGHT_STRING(")
c.String.format(w, depth)
c.Expr.format(w, depth)

if c.Cast != "" {
fmt.Fprintf(w, " AS %s(%d)", strings.ToUpper(c.Cast), c.Len)
Expand Down
Loading