From da03da081f3aff5414d574c1d1092768b5c074c3 Mon Sep 17 00:00:00 2001 From: "binbin.zhang" Date: Sun, 9 Apr 2023 21:43:35 +0800 Subject: [PATCH] composite_sharding (#665) --- conf/config.yaml | 12 +++-- pkg/boot/discovery.go | 2 +- pkg/boot/misc.go | 14 ++--- pkg/config/equals.go | 8 +-- pkg/config/model.go | 14 +++-- pkg/proto/rule/rule.go | 29 +++++++++++ pkg/runtime/optimize/dml/insert.go | 84 ++++++++++++++++++------------ pkg/runtime/rule/shard_script.go | 31 ++++++++--- 8 files changed, 135 insertions(+), 59 deletions(-) diff --git a/conf/config.yaml b/conf/config.yaml index 84b7ea78..beb866f1 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -64,13 +64,17 @@ data: type: snowflake option: db_rules: - - column: uid + - columns: + - uid + - id type: scriptExpr - expr: parseInt($value % 32 / 8) + expr: parseInt(($uid % 32 +$id % 32) / 8) tbl_rules: - - column: uid + - columns: + - uid + - id type: scriptExpr - expr: $value % 32 + expr: $uid % 32 + $id % 32 step: 32 topology: db_pattern: employees_${0000..0003} diff --git a/pkg/boot/discovery.go b/pkg/boot/discovery.go index 31a6989c..83716df4 100644 --- a/pkg/boot/discovery.go +++ b/pkg/boot/discovery.go @@ -504,7 +504,7 @@ func toSharder(input *config.Rule) (rule.ShardComputer, error) { case rrule.HashCrc32Shard: computer = rrule.NewHashCrc32Shard(mod) case rrule.ScriptExpr: - computer, err = rrule.NewJavascriptShardComputer(input.Expr) + computer, err = rrule.NewJavascriptShardComputer(input.Expr, input.Columns) default: panic(fmt.Errorf("error config, unsupport shard type: %s", input.Type)) } diff --git a/pkg/boot/misc.go b/pkg/boot/misc.go index 9a3850e7..7f34d521 100644 --- a/pkg/boot/misc.go +++ b/pkg/boot/misc.go @@ -75,9 +75,10 @@ func makeVTable(tableName string, table *config.Table) (*rule.VTable, error) { if dbSteps == nil { dbSteps = make(map[string]int) } - dbSharder[it.Column] = shd - keys[it.Column] = struct{}{} - dbSteps[it.Column] = it.Step + columnKey := it.ColumnKey() + dbSharder[columnKey] = shd + keys[columnKey] = struct{}{} + dbSteps[columnKey] = it.Step } for _, it := range table.TblRules { @@ -94,9 +95,10 @@ func makeVTable(tableName string, table *config.Table) (*rule.VTable, error) { if tbSteps == nil { tbSteps = make(map[string]int) } - tbSharder[it.Column] = shd - keys[it.Column] = struct{}{} - tbSteps[it.Column] = it.Step + columnKey := it.ColumnKey() + tbSharder[columnKey] = shd + keys[columnKey] = struct{}{} + tbSteps[columnKey] = it.Step } for k := range keys { diff --git a/pkg/config/equals.go b/pkg/config/equals.go index 6a11858f..aae3622a 100644 --- a/pkg/config/equals.go +++ b/pkg/config/equals.go @@ -75,20 +75,20 @@ func (r Rules) Equals(o Rules) bool { oldTmp := map[string]*Rule{} for i := range r { - newTmp[r[i].Column] = r[i] + newTmp[r[i].ColumnKey()] = r[i] } for i := range o { - oldTmp[o[i].Column] = o[i] + oldTmp[o[i].ColumnKey()] = o[i] } for i := range r { - if _, ok := oldTmp[o[i].Column]; !ok { + if _, ok := oldTmp[o[i].ColumnKey()]; !ok { newT = append(newT, o[i]) } } for i := range o { - val, ok := newTmp[o[i].Column] + val, ok := newTmp[o[i].ColumnKey()] if !ok { deleteT = append(deleteT, o[i]) continue diff --git a/pkg/config/model.go b/pkg/config/model.go index 65546091..8b2f20f4 100644 --- a/pkg/config/model.go +++ b/pkg/config/model.go @@ -24,6 +24,7 @@ import ( "io" "os" "regexp" + "sort" "strconv" "strings" "time" @@ -191,10 +192,10 @@ type ( } Rule struct { - Column string `validate:"required" yaml:"column" json:"column"` - Type string `validate:"required" yaml:"type" json:"type"` - Expr string `validate:"required" yaml:"expr" json:"expr"` - Step int `yaml:"step" json:"step"` + Columns []string `validate:"required" yaml:"columns" json:"columns"` + Type string `validate:"required" yaml:"type" json:"type"` + Expr string `validate:"required" yaml:"expr" json:"expr"` + Step int `yaml:"step" json:"step"` } Topology struct { @@ -427,3 +428,8 @@ func (l *Listener) String() string { socketAddr := fmt.Sprintf("%s:%d", l.SocketAddress.Address, l.SocketAddress.Port) return fmt.Sprintf("Listener protocol_type:%s, socket_address:%s, server_version:%s", l.ProtocolType, socketAddr, l.ServerVersion) } + +func (l *Rule) ColumnKey() string { + sort.Strings(l.Columns) + return strings.Join(l.Columns, ",") +} diff --git a/pkg/proto/rule/rule.go b/pkg/proto/rule/rule.go index 77d6ef3b..cb269054 100644 --- a/pkg/proto/rule/rule.go +++ b/pkg/proto/rule/rule.go @@ -59,6 +59,26 @@ type VTable struct { autoIncrement *AutoIncrement topology *Topology shards map[string][2]*ShardMetadata // column -> [db shard metadata,table shard metadata] + shardsC []*VShards // todo use shardsC replace shards +} + +type VShards struct { + columns []string + shardMetadata [2]*ShardMetadata + key string +} + +func (vt *VShards) HasColumns(columns []string) []int { + var bingoList []int + for _, v := range vt.columns { + for i, column := range columns { + if v != column { + return []int{} + } + bingoList = append(bingoList, i) + } + } + return bingoList } func (vt *VTable) HasColumn(column string) bool { @@ -245,3 +265,12 @@ func (ru *Rule) Range(f func(table string, vt *VTable) bool) { } } } + +func (vt *VTable) GetShardColumnIndex(columns []string) (bingoList []int) { + for _, v := range vt.shardsC { + if bingoList = v.HasColumns(columns); len(bingoList) > 0 { + return bingoList + } + } + return []int{} +} diff --git a/pkg/runtime/optimize/dml/insert.go b/pkg/runtime/optimize/dml/insert.go index b12fdf12..08b15dd6 100644 --- a/pkg/runtime/optimize/dml/insert.go +++ b/pkg/runtime/optimize/dml/insert.go @@ -31,6 +31,7 @@ import ( "github.com/arana-db/arana/pkg/runtime/ast" "github.com/arana-db/arana/pkg/runtime/cmp" rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/runtime/logical" "github.com/arana-db/arana/pkg/runtime/optimize" "github.com/arana-db/arana/pkg/runtime/plan/dml" ) @@ -57,52 +58,36 @@ func optimizeInsert(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err return ret, nil } - // TODO: handle multiple shard keys. + bingoList := vt.GetShardColumnIndex(stmt.Columns) - bingo := -1 - // check existing shard columns - for i, col := range stmt.Columns { - if _, _, ok = vt.GetShardMetadata(col); ok { - bingo = i - break - } - } - - if bingo < 0 { + if len(bingoList) == 0 { return nil, errors.Wrap(optimize.ErrNoShardKeyFound, "failed to insert") } // check on duplicated key update for _, upd := range stmt.DuplicatedUpdates { - if upd.Column.Suffix() == stmt.Columns[bingo] { - return nil, errors.New("do not support update sharding key") + for bingo := range bingoList { + if upd.Column.Suffix() == stmt.Columns[bingo] { + return nil, errors.New("do not support update sharding key") + } } } var ( sharder = optimize.NewXSharder(ctx, o.Rule, o.Args) - left = ast.ColumnNameExpressionAtom(make([]string, 1)) - filter = &ast.PredicateExpressionNode{ - P: &ast.BinaryComparisonPredicateNode{ - Left: &ast.AtomPredicateNode{ - A: left, - }, - Op: cmp.Ceq, - }, - } - slots = make(map[string]map[string][]int) // (db,table,valuesIndex) + slots = make(map[string]map[string][]int) // (db,table,valuesIndex) ) - // reset filter - resetFilter := func(column string, value ast.ExpressionNode) { - left[0] = column - filter.P.(*ast.BinaryComparisonPredicateNode).Right = value.(*ast.PredicateExpressionNode).P - } - for i, values := range stmt.Values { - var shards rule.DatabaseTables - value := values[bingo] - resetFilter(stmt.Columns[bingo], value) + var ( + shards rule.DatabaseTables + filter ast.ExpressionNode + ) + if len(bingoList) == 1 { + filter = buildFilter(stmt.Columns[bingoList[0]], values[bingoList[0]]) + } else { + filter = buildLogicalFilter(stmt.Columns, values, bingoList) + } if len(o.Hints) > 0 { if shards, err = optimize.Hints(tableName, o.Hints, o.Rule); err != nil { @@ -287,3 +272,38 @@ func createSequenceIfAbsent(ctx context.Context, vtab *rule.VTable, metadata *pr } return nil } + +func buildFilter(column string, value ast.ExpressionNode) ast.ExpressionNode { + // reset filter + return &ast.PredicateExpressionNode{ + P: &ast.BinaryComparisonPredicateNode{ + Left: &ast.AtomPredicateNode{ + A: ast.ColumnNameExpressionAtom([]string{column}), + }, + Op: cmp.Ceq, + Right: value.(*ast.PredicateExpressionNode).P, + }, + } +} + +func buildLogicalFilter(columns []string, values []ast.ExpressionNode, bingoList []int) ast.ExpressionNode { + filter := &ast.LogicalExpressionNode{ + Op: logical.Land, + Left: buildFilter(columns[bingoList[0]], values[bingoList[0]]), + Right: buildFilter(columns[bingoList[1]], values[bingoList[1]]), + } + return appendLogicalFilter(columns, values, bingoList, 2, filter) +} + +func appendLogicalFilter(columns []string, values []ast.ExpressionNode, bingoList []int, index int, filter ast.ExpressionNode) ast.ExpressionNode { + if index == len(bingoList) { + return filter + } + newFilter := &ast.LogicalExpressionNode{ + Op: logical.Land, + Left: filter, + Right: buildFilter(columns[bingoList[index]], values[bingoList[index]]), + } + appendLogicalFilter(columns, values, bingoList, index, newFilter) + return filter +} diff --git a/pkg/runtime/rule/shard_script.go b/pkg/runtime/rule/shard_script.go index e0061259..120bb639 100644 --- a/pkg/runtime/rule/shard_script.go +++ b/pkg/runtime/rule/shard_script.go @@ -35,8 +35,8 @@ import ( var _ rule.ShardComputer = (*jsShardComputer)(nil) const ( - _jsEntrypoint = "__compute__" // shard method name - _jsValueName = "$value" // variable name of column in sharding script + _jsEntrypoint = "__compute__" // shard method name + _jsVariableName = "$" // variable name of column in sharding script ) type jsShardComputer struct { @@ -47,8 +47,8 @@ type jsShardComputer struct { } // NewJavascriptShardComputer returns a shard computer which is based on Javascript. -func NewJavascriptShardComputer(script string) (rule.ShardComputer, error) { - script = wrapScript(script) +func NewJavascriptShardComputer(script string, columns []string) (rule.ShardComputer, error) { + script = wrapScript(script, columns) vm, err := createVM(script) if err != nil { @@ -74,8 +74,17 @@ func (j *jsShardComputer) Compute(value interface{}) (int, error) { j.putVM(vm) }() + params, ok := value.([]interface{}) + if !ok { + return 0, errors.Wrapf(err, "javascript shard computer params type not is []string") + } + gojaValues := make([]goja.Value, 0, len(params)) + for _, v := range params { + gojaValues = append(gojaValues, vm.ToValue(v)) + } + fn, _ := goja.AssertFunction(vm.Get(_jsEntrypoint)) - res, err := fn(goja.Undefined(), vm.ToValue(value)) + res, err := fn(goja.Undefined(), gojaValues...) if err != nil { return 0, errors.WithStack(err) } @@ -99,15 +108,21 @@ func (j *jsShardComputer) putVM(vm *goja.Runtime) { } } -func wrapScript(script string) string { +func wrapScript(script string, columns []string) string { + params := make([]string, 0, len(columns)) + for i := range columns { + params = append(params, _jsVariableName+columns[i]) + } + jsFuncParams := strings.Join(params, ",") + var sb strings.Builder - sb.Grow(32 + len(_jsEntrypoint) + len(_jsValueName) + len(script)) + sb.Grow(32 + len(_jsEntrypoint) + len(jsFuncParams) + len(script)) sb.WriteString("function ") sb.WriteString(_jsEntrypoint) sb.WriteString("(") - sb.WriteString(_jsValueName) + sb.WriteString(jsFuncParams) sb.WriteString(") {\n") if strings.Contains(script, "return ") {