Skip to content

Commit

Permalink
composite_sharding (#665)
Browse files Browse the repository at this point in the history
  • Loading branch information
binbin0325 authored Apr 9, 2023
1 parent 76c4e80 commit da03da0
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 59 deletions.
12 changes: 8 additions & 4 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,17 @@ data:
type: snowflake
option:
db_rules:
- column: uid
- columns:
- uid
- id
type: scriptExpr
expr: parseInt($value % 32 / 8)
expr: parseInt(($uid % 32 +$id % 32) / 8)
tbl_rules:
- column: uid
- columns:
- uid
- id
type: scriptExpr
expr: $value % 32
expr: $uid % 32 + $id % 32
step: 32
topology:
db_pattern: employees_${0000..0003}
Expand Down
2 changes: 1 addition & 1 deletion pkg/boot/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ func toSharder(input *config.Rule) (rule.ShardComputer, error) {
case rrule.HashCrc32Shard:
computer = rrule.NewHashCrc32Shard(mod)
case rrule.ScriptExpr:
computer, err = rrule.NewJavascriptShardComputer(input.Expr)
computer, err = rrule.NewJavascriptShardComputer(input.Expr, input.Columns)
default:
panic(fmt.Errorf("error config, unsupport shard type: %s", input.Type))
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/boot/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ func makeVTable(tableName string, table *config.Table) (*rule.VTable, error) {
if dbSteps == nil {
dbSteps = make(map[string]int)
}
dbSharder[it.Column] = shd
keys[it.Column] = struct{}{}
dbSteps[it.Column] = it.Step
columnKey := it.ColumnKey()
dbSharder[columnKey] = shd
keys[columnKey] = struct{}{}
dbSteps[columnKey] = it.Step
}

for _, it := range table.TblRules {
Expand All @@ -94,9 +95,10 @@ func makeVTable(tableName string, table *config.Table) (*rule.VTable, error) {
if tbSteps == nil {
tbSteps = make(map[string]int)
}
tbSharder[it.Column] = shd
keys[it.Column] = struct{}{}
tbSteps[it.Column] = it.Step
columnKey := it.ColumnKey()
tbSharder[columnKey] = shd
keys[columnKey] = struct{}{}
tbSteps[columnKey] = it.Step
}

for k := range keys {
Expand Down
8 changes: 4 additions & 4 deletions pkg/config/equals.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,20 @@ func (r Rules) Equals(o Rules) bool {
oldTmp := map[string]*Rule{}

for i := range r {
newTmp[r[i].Column] = r[i]
newTmp[r[i].ColumnKey()] = r[i]
}
for i := range o {
oldTmp[o[i].Column] = o[i]
oldTmp[o[i].ColumnKey()] = o[i]
}

for i := range r {
if _, ok := oldTmp[o[i].Column]; !ok {
if _, ok := oldTmp[o[i].ColumnKey()]; !ok {
newT = append(newT, o[i])
}
}

for i := range o {
val, ok := newTmp[o[i].Column]
val, ok := newTmp[o[i].ColumnKey()]
if !ok {
deleteT = append(deleteT, o[i])
continue
Expand Down
14 changes: 10 additions & 4 deletions pkg/config/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"io"
"os"
"regexp"
"sort"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -191,10 +192,10 @@ type (
}

Rule struct {
Column string `validate:"required" yaml:"column" json:"column"`
Type string `validate:"required" yaml:"type" json:"type"`
Expr string `validate:"required" yaml:"expr" json:"expr"`
Step int `yaml:"step" json:"step"`
Columns []string `validate:"required" yaml:"columns" json:"columns"`
Type string `validate:"required" yaml:"type" json:"type"`
Expr string `validate:"required" yaml:"expr" json:"expr"`
Step int `yaml:"step" json:"step"`
}

Topology struct {
Expand Down Expand Up @@ -427,3 +428,8 @@ func (l *Listener) String() string {
socketAddr := fmt.Sprintf("%s:%d", l.SocketAddress.Address, l.SocketAddress.Port)
return fmt.Sprintf("Listener protocol_type:%s, socket_address:%s, server_version:%s", l.ProtocolType, socketAddr, l.ServerVersion)
}

func (l *Rule) ColumnKey() string {
sort.Strings(l.Columns)
return strings.Join(l.Columns, ",")
}
29 changes: 29 additions & 0 deletions pkg/proto/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ type VTable struct {
autoIncrement *AutoIncrement
topology *Topology
shards map[string][2]*ShardMetadata // column -> [db shard metadata,table shard metadata]
shardsC []*VShards // todo use shardsC replace shards
}

type VShards struct {
columns []string
shardMetadata [2]*ShardMetadata
key string
}

func (vt *VShards) HasColumns(columns []string) []int {
var bingoList []int
for _, v := range vt.columns {
for i, column := range columns {
if v != column {
return []int{}
}
bingoList = append(bingoList, i)
}
}
return bingoList
}

func (vt *VTable) HasColumn(column string) bool {
Expand Down Expand Up @@ -245,3 +265,12 @@ func (ru *Rule) Range(f func(table string, vt *VTable) bool) {
}
}
}

func (vt *VTable) GetShardColumnIndex(columns []string) (bingoList []int) {
for _, v := range vt.shardsC {
if bingoList = v.HasColumns(columns); len(bingoList) > 0 {
return bingoList
}
}
return []int{}
}
84 changes: 52 additions & 32 deletions pkg/runtime/optimize/dml/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/arana-db/arana/pkg/runtime/ast"
"github.com/arana-db/arana/pkg/runtime/cmp"
rcontext "github.com/arana-db/arana/pkg/runtime/context"
"github.com/arana-db/arana/pkg/runtime/logical"
"github.com/arana-db/arana/pkg/runtime/optimize"
"github.com/arana-db/arana/pkg/runtime/plan/dml"
)
Expand All @@ -57,52 +58,36 @@ func optimizeInsert(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err
return ret, nil
}

// TODO: handle multiple shard keys.
bingoList := vt.GetShardColumnIndex(stmt.Columns)

bingo := -1
// check existing shard columns
for i, col := range stmt.Columns {
if _, _, ok = vt.GetShardMetadata(col); ok {
bingo = i
break
}
}

if bingo < 0 {
if len(bingoList) == 0 {
return nil, errors.Wrap(optimize.ErrNoShardKeyFound, "failed to insert")
}

// check on duplicated key update
for _, upd := range stmt.DuplicatedUpdates {
if upd.Column.Suffix() == stmt.Columns[bingo] {
return nil, errors.New("do not support update sharding key")
for bingo := range bingoList {
if upd.Column.Suffix() == stmt.Columns[bingo] {
return nil, errors.New("do not support update sharding key")
}
}
}

var (
sharder = optimize.NewXSharder(ctx, o.Rule, o.Args)
left = ast.ColumnNameExpressionAtom(make([]string, 1))
filter = &ast.PredicateExpressionNode{
P: &ast.BinaryComparisonPredicateNode{
Left: &ast.AtomPredicateNode{
A: left,
},
Op: cmp.Ceq,
},
}
slots = make(map[string]map[string][]int) // (db,table,valuesIndex)
slots = make(map[string]map[string][]int) // (db,table,valuesIndex)
)

// reset filter
resetFilter := func(column string, value ast.ExpressionNode) {
left[0] = column
filter.P.(*ast.BinaryComparisonPredicateNode).Right = value.(*ast.PredicateExpressionNode).P
}

for i, values := range stmt.Values {
var shards rule.DatabaseTables
value := values[bingo]
resetFilter(stmt.Columns[bingo], value)
var (
shards rule.DatabaseTables
filter ast.ExpressionNode
)
if len(bingoList) == 1 {
filter = buildFilter(stmt.Columns[bingoList[0]], values[bingoList[0]])
} else {
filter = buildLogicalFilter(stmt.Columns, values, bingoList)
}

if len(o.Hints) > 0 {
if shards, err = optimize.Hints(tableName, o.Hints, o.Rule); err != nil {
Expand Down Expand Up @@ -287,3 +272,38 @@ func createSequenceIfAbsent(ctx context.Context, vtab *rule.VTable, metadata *pr
}
return nil
}

func buildFilter(column string, value ast.ExpressionNode) ast.ExpressionNode {
// reset filter
return &ast.PredicateExpressionNode{
P: &ast.BinaryComparisonPredicateNode{
Left: &ast.AtomPredicateNode{
A: ast.ColumnNameExpressionAtom([]string{column}),
},
Op: cmp.Ceq,
Right: value.(*ast.PredicateExpressionNode).P,
},
}
}

func buildLogicalFilter(columns []string, values []ast.ExpressionNode, bingoList []int) ast.ExpressionNode {
filter := &ast.LogicalExpressionNode{
Op: logical.Land,
Left: buildFilter(columns[bingoList[0]], values[bingoList[0]]),
Right: buildFilter(columns[bingoList[1]], values[bingoList[1]]),
}
return appendLogicalFilter(columns, values, bingoList, 2, filter)
}

func appendLogicalFilter(columns []string, values []ast.ExpressionNode, bingoList []int, index int, filter ast.ExpressionNode) ast.ExpressionNode {
if index == len(bingoList) {
return filter
}
newFilter := &ast.LogicalExpressionNode{
Op: logical.Land,
Left: filter,
Right: buildFilter(columns[bingoList[index]], values[bingoList[index]]),
}
appendLogicalFilter(columns, values, bingoList, index, newFilter)
return filter
}
31 changes: 23 additions & 8 deletions pkg/runtime/rule/shard_script.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ import (
var _ rule.ShardComputer = (*jsShardComputer)(nil)

const (
_jsEntrypoint = "__compute__" // shard method name
_jsValueName = "$value" // variable name of column in sharding script
_jsEntrypoint = "__compute__" // shard method name
_jsVariableName = "$" // variable name of column in sharding script
)

type jsShardComputer struct {
Expand All @@ -47,8 +47,8 @@ type jsShardComputer struct {
}

// NewJavascriptShardComputer returns a shard computer which is based on Javascript.
func NewJavascriptShardComputer(script string) (rule.ShardComputer, error) {
script = wrapScript(script)
func NewJavascriptShardComputer(script string, columns []string) (rule.ShardComputer, error) {
script = wrapScript(script, columns)

vm, err := createVM(script)
if err != nil {
Expand All @@ -74,8 +74,17 @@ func (j *jsShardComputer) Compute(value interface{}) (int, error) {
j.putVM(vm)
}()

params, ok := value.([]interface{})
if !ok {
return 0, errors.Wrapf(err, "javascript shard computer params type not is []string")
}
gojaValues := make([]goja.Value, 0, len(params))
for _, v := range params {
gojaValues = append(gojaValues, vm.ToValue(v))
}

fn, _ := goja.AssertFunction(vm.Get(_jsEntrypoint))
res, err := fn(goja.Undefined(), vm.ToValue(value))
res, err := fn(goja.Undefined(), gojaValues...)
if err != nil {
return 0, errors.WithStack(err)
}
Expand All @@ -99,15 +108,21 @@ func (j *jsShardComputer) putVM(vm *goja.Runtime) {
}
}

func wrapScript(script string) string {
func wrapScript(script string, columns []string) string {
params := make([]string, 0, len(columns))
for i := range columns {
params = append(params, _jsVariableName+columns[i])
}
jsFuncParams := strings.Join(params, ",")

var sb strings.Builder

sb.Grow(32 + len(_jsEntrypoint) + len(_jsValueName) + len(script))
sb.Grow(32 + len(_jsEntrypoint) + len(jsFuncParams) + len(script))

sb.WriteString("function ")
sb.WriteString(_jsEntrypoint)
sb.WriteString("(")
sb.WriteString(_jsValueName)
sb.WriteString(jsFuncParams)
sb.WriteString(") {\n")

if strings.Contains(script, "return ") {
Expand Down

0 comments on commit da03da0

Please sign in to comment.