diff --git a/dm/config/subtask.go b/dm/config/subtask.go index a8a6e9a8b2f..a98429d7d30 100644 --- a/dm/config/subtask.go +++ b/dm/config/subtask.go @@ -28,7 +28,6 @@ import ( "time" "github.com/BurntSushi/toml" - "github.com/pingcap/tidb-tools/pkg/column-mapping" extstorage "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/pkg/util/dbutil" "github.com/pingcap/tidb/pkg/util/filter" @@ -41,6 +40,7 @@ import ( "github.com/pingcap/tiflow/dm/pkg/utils" "github.com/pingcap/tiflow/engine/pkg/promutil" bf "github.com/pingcap/tiflow/pkg/binlog-filter" + "github.com/pingcap/tiflow/pkg/column-mapping" "github.com/pingcap/tiflow/pkg/version" "go.uber.org/atomic" "go.uber.org/zap" diff --git a/dm/config/task.go b/dm/config/task.go index 4d4e0fb49d1..031fece8457 100644 --- a/dm/config/task.go +++ b/dm/config/task.go @@ -28,7 +28,6 @@ import ( "github.com/coreos/go-semver/semver" "github.com/docker/go-units" "github.com/dustin/go-humanize" - "github.com/pingcap/tidb-tools/pkg/column-mapping" "github.com/pingcap/tidb/pkg/lightning/config" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/util/filter" @@ -38,6 +37,7 @@ import ( "github.com/pingcap/tiflow/dm/pkg/terror" "github.com/pingcap/tiflow/dm/pkg/utils" bf "github.com/pingcap/tiflow/pkg/binlog-filter" + "github.com/pingcap/tiflow/pkg/column-mapping" "go.uber.org/zap" "gopkg.in/yaml.v2" ) diff --git a/dm/config/task_converters.go b/dm/config/task_converters.go index 168da1f3207..306037a285c 100644 --- a/dm/config/task_converters.go +++ b/dm/config/task_converters.go @@ -17,7 +17,6 @@ import ( "fmt" "strings" - "github.com/pingcap/tidb-tools/pkg/column-mapping" "github.com/pingcap/tidb/pkg/util/filter" router "github.com/pingcap/tidb/pkg/util/table-router" "github.com/pingcap/tiflow/dm/config/dbconfig" @@ -27,6 +26,7 @@ import ( "github.com/pingcap/tiflow/dm/pkg/storage" "github.com/pingcap/tiflow/dm/pkg/terror" bf "github.com/pingcap/tiflow/pkg/binlog-filter" + "github.com/pingcap/tiflow/pkg/column-mapping" "go.uber.org/zap" ) diff --git a/dm/syncer/syncer_test.go b/dm/syncer/syncer_test.go index ce93d38e871..a592ad11550 100644 --- a/dm/syncer/syncer_test.go +++ b/dm/syncer/syncer_test.go @@ -31,7 +31,6 @@ import ( _ "github.com/go-sql-driver/mysql" . "github.com/pingcap/check" "github.com/pingcap/failpoint" - cm "github.com/pingcap/tidb-tools/pkg/column-mapping" "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" @@ -58,6 +57,7 @@ import ( "github.com/pingcap/tiflow/dm/syncer/dbconn" "github.com/pingcap/tiflow/dm/syncer/metrics" bf "github.com/pingcap/tiflow/pkg/binlog-filter" + cm "github.com/pingcap/tiflow/pkg/column-mapping" "github.com/pingcap/tiflow/pkg/errorutil" "github.com/pingcap/tiflow/pkg/sqlmodel" "github.com/stretchr/testify/require" diff --git a/engine/jobmaster/dm/config/config.go b/engine/jobmaster/dm/config/config.go index 3d38df3c133..2cd4f3105f6 100644 --- a/engine/jobmaster/dm/config/config.go +++ b/engine/jobmaster/dm/config/config.go @@ -20,7 +20,6 @@ import ( "github.com/dustin/go-humanize" "github.com/google/uuid" - "github.com/pingcap/tidb-tools/pkg/column-mapping" "github.com/pingcap/tidb/pkg/util/filter" router "github.com/pingcap/tidb/pkg/util/table-router" "github.com/pingcap/tiflow/dm/config" @@ -28,6 +27,7 @@ import ( "github.com/pingcap/tiflow/dm/config/dbconfig" "github.com/pingcap/tiflow/dm/master" bf "github.com/pingcap/tiflow/pkg/binlog-filter" + "github.com/pingcap/tiflow/pkg/column-mapping" "github.com/pingcap/tiflow/pkg/errors" "go.uber.org/atomic" "gopkg.in/yaml.v2" diff --git a/pkg/column-mapping/column.go b/pkg/column-mapping/column.go new file mode 100644 index 00000000000..77e1a4fd282 --- /dev/null +++ b/pkg/column-mapping/column.go @@ -0,0 +1,538 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package column + +import ( + "fmt" + "strconv" + "strings" + "sync" + + "github.com/pingcap/errors" + selector "github.com/pingcap/tidb/pkg/util/table-rule-selector" +) + +var ( + // for partition ID, ref definition of partitionID + instanceIDBitSize = 4 + schemaIDBitSize = 7 + tableIDBitSize = 8 + maxOriginID int64 = 17592186044416 +) + +// SetPartitionRule sets bit size of schema ID and table ID +func SetPartitionRule(instanceIDSize, schemaIDSize, tableIDSize int) { + instanceIDBitSize = instanceIDSize + schemaIDBitSize = schemaIDSize + tableIDBitSize = tableIDSize + maxOriginID = 1 << uint(64-instanceIDSize-schemaIDSize-tableIDSize-1) +} + +// Expr indicates how to handle column mapping +type Expr string + +// poor Expr +const ( + AddPrefix Expr = "add prefix" + AddSuffix Expr = "add suffix" + PartitionID Expr = "partition id" +) + +// Exprs is some built-in expression for column mapping +// only supports some poor expressions now, +// we would unify tableInfo later and support more +var Exprs = map[Expr]func(*mappingInfo, []interface{}) ([]interface{}, error){ + AddPrefix: addPrefix, // arguments contains prefix + AddSuffix: addSuffix, // arguments contains suffix + // arguments contains [instance_id, prefix of schema, prefix of table] + // we would compute a ID like + // [1:1 bit][2:9 bits][3:10 bits][4:44 bits] int64 (using default bits length) + // # 1 useless, no reason + // # 2 schema ID (schema suffix) + // # 3 table ID (table suffix) + // # 4 origin ID (>= 0, <= 17592186044415) + // + // others: schema = arguments[1] or arguments[1] + arguments[3] + schema suffix + // table = arguments[2] or arguments[2] + arguments[3] + table suffix + // example: schema = schema_1 table = t_1 + // => arguments[1] = "schema", arguments[2] = "t", arguments[3] = "_" + // if arguments[1]/arguments[2] == "", it means we don't use schemaID/tableID to compute partition ID + // if length of arguments is < 4, arguments[3] is set to "" (empty string) + PartitionID: partitionID, +} + +// Rule is a rule to map column +// TODO: we will do it later, if we need to implement a real column mapping, we need table structure of source and target system +type Rule struct { + PatternSchema string `yaml:"schema-pattern" json:"schema-pattern" toml:"schema-pattern"` + PatternTable string `yaml:"table-pattern" json:"table-pattern" toml:"table-pattern"` + SourceColumn string `yaml:"source-column" json:"source-column" toml:"source-column"` // modify, add refer column, ignore + TargetColumn string `yaml:"target-column" json:"target-column" toml:"target-column"` // add column, modify + Expression Expr `yaml:"expression" json:"expression" toml:"expression"` + Arguments []string `yaml:"arguments" json:"arguments" toml:"arguments"` + CreateTableQuery string `yaml:"create-table-query" json:"create-table-query" toml:"create-table-query"` +} + +// ToLower covert schema/table parttern to lower case +func (r *Rule) ToLower() { + r.PatternSchema = strings.ToLower(r.PatternSchema) + r.PatternTable = strings.ToLower(r.PatternTable) +} + +// Valid checks validity of rule. +// add prefix/suffix: it should have target column and one argument +// partition id: it should have 3 to 4 arguments +func (r *Rule) Valid() error { + if _, ok := Exprs[r.Expression]; !ok { + return errors.NotFoundf("expression %s", r.Expression) + } + + if r.TargetColumn == "" { + return errors.NotValidf("rule need to be applied a target column") + } + + if r.Expression == AddPrefix || r.Expression == AddSuffix { + if len(r.Arguments) != 1 { + return errors.NotValidf("arguments %v for add prefix/suffix", r.Arguments) + } + } + + if r.Expression == PartitionID { + switch len(r.Arguments) { + case 3, 4: + break + default: + return errors.NotValidf("arguments %v for patition id", r.Arguments) + } + } + + return nil +} + +// Adjust normalizes the rule into an easier-to-process form, e.g. filling in +// optional arguments with the default values. +func (r *Rule) Adjust() { + if r.Expression == PartitionID && len(r.Arguments) == 3 { + r.Arguments = append(r.Arguments, "") + } +} + +// check source and target position +func (r *Rule) adjustColumnPosition(source, target int) (int, int, error) { + // if not found target, ignore it + if target == -1 { + return source, target, errors.NotFoundf("target column %s", r.TargetColumn) + } + + return source, target, nil +} + +type mappingInfo struct { + ignore bool + sourcePosition int + targetPosition int + rule *Rule + + instanceID int64 + schemaID int64 + tableID int64 +} + +// Mapping maps column to something by rules +type Mapping struct { + selector.Selector + + caseSensitive bool + + cache struct { + sync.RWMutex + infos map[string]*mappingInfo + } +} + +// NewMapping returns a column mapping +func NewMapping(caseSensitive bool, rules []*Rule) (*Mapping, error) { + m := &Mapping{ + Selector: selector.NewTrieSelector(), + caseSensitive: caseSensitive, + } + m.resetCache() + + for _, rule := range rules { + if err := m.AddRule(rule); err != nil { + return nil, errors.Annotatef(err, "initial rule %+v in mapping", rule) + } + } + + return m, nil +} + +func (m *Mapping) addOrUpdateRule(rule *Rule, isUpdate bool) error { + if m == nil || rule == nil { + return nil + } + + err := rule.Valid() + if err != nil { + return errors.Trace(err) + } + if !m.caseSensitive { + rule.ToLower() + } + rule.Adjust() + + m.resetCache() + if isUpdate { + err = m.Insert(rule.PatternSchema, rule.PatternTable, rule, selector.Replace) + } else { + err = m.Insert(rule.PatternSchema, rule.PatternTable, rule, selector.Insert) + } + if err != nil { + var method string + if isUpdate { + method = "update" + } else { + method = "add" + } + return errors.Annotatef(err, "%s rule %+v into mapping", method, rule) + } + + return nil +} + +// AddRule adds a rule into mapping +func (m *Mapping) AddRule(rule *Rule) error { + return m.addOrUpdateRule(rule, false) +} + +// UpdateRule updates mapping rule +func (m *Mapping) UpdateRule(rule *Rule) error { + return m.addOrUpdateRule(rule, true) +} + +// RemoveRule removes a rule from mapping +func (m *Mapping) RemoveRule(rule *Rule) error { + if m == nil || rule == nil { + return nil + } + if !m.caseSensitive { + rule.ToLower() + } + + m.resetCache() + err := m.Remove(rule.PatternSchema, rule.PatternTable) + if err != nil { + return errors.Annotatef(err, "remove rule %+v from mapping", rule) + } + + return nil +} + +// HandleRowValue handles row value +func (m *Mapping) HandleRowValue(schema, table string, columns []string, vals []interface{}) ([]interface{}, []int, error) { + if m == nil { + return vals, nil, nil + } + + schemaL, tableL := schema, table + if !m.caseSensitive { + schemaL, tableL = strings.ToLower(schema), strings.ToLower(table) + } + + info, err := m.queryColumnInfo(schemaL, tableL, columns) + if err != nil { + return nil, nil, errors.Trace(err) + } + if info.ignore { + return vals, nil, nil + } + + exp, ok := Exprs[info.rule.Expression] + if !ok { + return nil, nil, errors.NotFoundf("column mapping expression %s", info.rule.Expression) + } + + vals, err = exp(info, vals) + if err != nil { + return nil, nil, errors.Trace(err) + } + + return vals, []int{info.sourcePosition, info.targetPosition}, nil +} + +// HandleDDL handles ddl +func (m *Mapping) HandleDDL(schema, table string, columns []string, statement string) (string, []int, error) { + if m == nil { + return statement, nil, nil + } + + schemaL, tableL := schema, table + if !m.caseSensitive { + schemaL, tableL = strings.ToLower(schema), strings.ToLower(table) + } + + info, err := m.queryColumnInfo(schemaL, tableL, columns) + if err != nil { + return statement, nil, errors.Trace(err) + } + + if info.ignore { + return statement, nil, nil + } + + m.resetCache() + // only output erro now, wait fix it manually + return statement, nil, errors.Errorf("ddl %s @ column mapping rule %s/%s:%+v not implemented", statement, schema, table, info.rule) +} + +func (m *Mapping) queryColumnInfo(schema, table string, columns []string) (*mappingInfo, error) { + m.cache.RLock() + ci, ok := m.cache.infos[tableName(schema, table)] + m.cache.RUnlock() + if ok { + return ci, nil + } + + info := &mappingInfo{ + ignore: true, + } + rules := m.Match(schema, table) + if len(rules) == 0 { + m.cache.Lock() + m.cache.infos[tableName(schema, table)] = info + m.cache.Unlock() + + return info, nil + } + + var ( + schemaRules []*Rule + tableRules = make([]*Rule, 0, 1) + ) + // classify rules into schema level rules and table level + // table level rules have highest priority + for i := range rules { + rule, ok := rules[i].(*Rule) + if !ok { + return nil, errors.NotValidf("column mapping rule %+v", rules[i]) + } + + if len(rule.PatternTable) == 0 { + schemaRules = append(schemaRules, rule) + } else { + tableRules = append(tableRules, rule) + } + } + + // only support one expression for one table now, refine it later + var rule *Rule + if len(table) == 0 || len(tableRules) == 0 { + if len(schemaRules) != 1 { + return nil, errors.NotSupportedf("`%s`.`%s` matches %d schema column mapping rules which should be one. It's", schema, table, len(schemaRules)) + } + + rule = schemaRules[0] + } else { + if len(tableRules) != 1 { + return nil, errors.NotSupportedf("`%s`.`%s` matches %d table column mapping rules which should be one. It's", schema, table, len(tableRules)) + } + + rule = tableRules[0] + } + if rule == nil { + m.cache.Lock() + m.cache.infos[tableName(schema, table)] = info + m.cache.Unlock() + + return info, nil + } + + // compute source and target column position + sourcePosition := findColumnPosition(columns, rule.SourceColumn) + targetPosition := findColumnPosition(columns, rule.TargetColumn) + + sourcePosition, targetPosition, err := rule.adjustColumnPosition(sourcePosition, targetPosition) + if err != nil { + return nil, errors.Trace(err) + } + + info = &mappingInfo{ + sourcePosition: sourcePosition, + targetPosition: targetPosition, + rule: rule, + } + + // if expr is partition ID, compute schema and table ID + if rule.Expression == PartitionID { + info.instanceID, info.schemaID, info.tableID, err = computePartitionID(schema, table, rule) + if err != nil { + return nil, errors.Trace(err) + } + } + + m.cache.Lock() + m.cache.infos[tableName(schema, table)] = info + m.cache.Unlock() + + return info, nil +} + +func (m *Mapping) resetCache() { + m.cache.Lock() + m.cache.infos = make(map[string]*mappingInfo) + m.cache.Unlock() +} + +func findColumnPosition(cols []string, col string) int { + for i := range cols { + if cols[i] == col { + return i + } + } + + return -1 +} + +func tableName(schema, table string) string { + return fmt.Sprintf("`%s`.`%s`", schema, table) +} + +func addPrefix(info *mappingInfo, vals []interface{}) ([]interface{}, error) { + prefix := info.rule.Arguments[0] + originStr, ok := vals[info.targetPosition].(string) + if !ok { + return nil, errors.NotValidf("column %d value is not string, but %v, which is", info.targetPosition, vals[info.targetPosition]) + } + + // fast to concatenated string + rawByte := make([]byte, 0, len(prefix)+len(originStr)) + rawByte = append(rawByte, prefix...) + rawByte = append(rawByte, originStr...) + + vals[info.targetPosition] = string(rawByte) + return vals, nil +} + +func addSuffix(info *mappingInfo, vals []interface{}) ([]interface{}, error) { + suffix := info.rule.Arguments[0] + originStr, ok := vals[info.targetPosition].(string) + if !ok { + return nil, errors.NotValidf("column %d value is not string, but %v, which is", info.targetPosition, vals[info.targetPosition]) + } + + rawByte := make([]byte, 0, len(suffix)+len(originStr)) + rawByte = append(rawByte, originStr...) + rawByte = append(rawByte, suffix...) + + vals[info.targetPosition] = string(rawByte) + return vals, nil +} + +func partitionID(info *mappingInfo, vals []interface{}) ([]interface{}, error) { + // only int64 now + var ( + originID int64 + err error + isChars bool + ) + + switch rawID := vals[info.targetPosition].(type) { + case int: + originID = int64(rawID) + case int8: + originID = int64(rawID) + case int32: + originID = int64(rawID) + case int64: + originID = rawID + case uint: + originID = int64(rawID) + case uint16: + originID = int64(rawID) + case uint32: + originID = int64(rawID) + case uint64: + originID = int64(rawID) + case string: + originID, err = strconv.ParseInt(rawID, 10, 64) + if err != nil { + return nil, errors.NotValidf("column %d value is not int, but %v, which is", info.targetPosition, vals[info.targetPosition]) + } + isChars = true + default: + return nil, errors.NotValidf("type %T(%v)", vals[info.targetPosition], vals[info.targetPosition]) + } + + if originID >= maxOriginID || originID < 0 { + return nil, errors.NotValidf("id must less than %d, greater than or equal to 0, but got %d, which is", maxOriginID, originID) + } + + originID = info.instanceID | info.schemaID | info.tableID | originID + if isChars { + vals[info.targetPosition] = strconv.FormatInt(originID, 10) + } else { + vals[info.targetPosition] = originID + } + + return vals, nil +} + +func computePartitionID(schema, table string, rule *Rule) (instanceID int64, schemaID int64, tableID int64, err error) { + shiftCnt := uint(63) + if instanceIDBitSize > 0 && len(rule.Arguments[0]) > 0 { + var instanceIDUnsign uint64 + shiftCnt = shiftCnt - uint(instanceIDBitSize) + instanceIDUnsign, err = strconv.ParseUint(rule.Arguments[0], 10, instanceIDBitSize) + if err != nil { + return + } + instanceID = int64(instanceIDUnsign << shiftCnt) + } + + sep := rule.Arguments[3] + + if schemaIDBitSize > 0 && len(rule.Arguments[1]) > 0 { + shiftCnt = shiftCnt - uint(schemaIDBitSize) + schemaID, err = computeID(schema, rule.Arguments[1], sep, schemaIDBitSize, shiftCnt) + if err != nil { + return + } + } + + if tableIDBitSize > 0 && len(rule.Arguments[2]) > 0 { + shiftCnt = shiftCnt - uint(tableIDBitSize) + tableID, err = computeID(table, rule.Arguments[2], sep, tableIDBitSize, shiftCnt) + } + + return +} + +func computeID(name string, prefix, sep string, bitSize int, shiftCount uint) (int64, error) { + if name == prefix { + return 0, nil + } + + prefix += sep + if len(prefix) >= len(name) || prefix != name[:len(prefix)] { + return 0, errors.NotValidf("%s is not the prefix of %s", prefix, name) + } + + idStr := name[len(prefix):] + id, err := strconv.ParseUint(idStr, 10, bitSize) + if err != nil { + return 0, errors.NotValidf("the suffix of %s can't be converted to int64", idStr) + } + + return int64(id << shiftCount), nil +} diff --git a/pkg/column-mapping/column_test.go b/pkg/column-mapping/column_test.go new file mode 100644 index 00000000000..6a8ddb98e6a --- /dev/null +++ b/pkg/column-mapping/column_test.go @@ -0,0 +1,278 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package column + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRule(t *testing.T) { + // test invalid rules + inValidRule := &Rule{"test*", "abc*", "id", "id", "Error", nil, "xxx"} + require.Error(t, inValidRule.Valid()) + + inValidRule.TargetColumn = "" + require.Error(t, inValidRule.Valid()) + + inValidRule.Expression = AddPrefix + inValidRule.TargetColumn = "id" + require.Error(t, inValidRule.Valid()) + + inValidRule.Arguments = []string{"1"} + require.NoError(t, inValidRule.Valid()) + + inValidRule.Expression = PartitionID + require.Error(t, inValidRule.Valid()) + + inValidRule.Arguments = []string{"1", "test_", "t_"} + require.NoError(t, inValidRule.Valid()) +} + +func TestHandle(t *testing.T) { + rules := []*Rule{ + {"Test*", "xxx*", "", "id", AddPrefix, []string{"instance_id:"}, "xx"}, + } + + // initial column mapping + m, err := NewMapping(false, rules) + require.NoError(t, err) + require.Len(t, m.cache.infos, 0) + + // test add prefix, add suffix is similar + vals, poss, err := m.HandleRowValue("test", "xxx", []string{"age", "id"}, []interface{}{1, "1"}) + require.NoError(t, err) + require.Equal(t, []interface{}{1, "instance_id:1"}, vals) + require.Equal(t, []int{-1, 1}, poss) + + // test cache + vals, poss, err = m.HandleRowValue("test", "xxx", []string{"name"}, []interface{}{1, "1"}) + require.NoError(t, err) + require.Equal(t, []interface{}{1, "instance_id:1"}, vals) + require.Equal(t, []int{-1, 1}, poss) + + // test resetCache + m.resetCache() + _, _, err = m.HandleRowValue("test", "xxx", []string{"name"}, []interface{}{"1"}) + + require.Error(t, err) + + // test DDL + _, _, err = m.HandleDDL("test", "xxx", []string{"id", "age"}, "create table xxx") + require.Error(t, err) + + statement, poss, err := m.HandleDDL("abc", "xxx", []string{"id", "age"}, "create table xxx") + require.NoError(t, err) + require.Equal(t, "create table xxx", statement) + require.Nil(t, poss) +} + +func TestQueryColumnInfo(t *testing.T) { + SetPartitionRule(4, 7, 8) + rules := []*Rule{ + {"test*", "xxx*", "", "id", PartitionID, []string{"8", "test_", "xxx_"}, "xx"}, + } + + // initial column mapping + m, err := NewMapping(false, rules) + require.NoError(t, err) + + // test mismatch + info, err := m.queryColumnInfo("test_2", "t_1", []string{"id", "name"}) + require.NoError(t, err) + require.True(t, info.ignore) + + // test matched + info, err = m.queryColumnInfo("test_2", "xxx_1", []string{"id", "name"}) + require.NoError(t, err) + require.Equal(t, &mappingInfo{ + sourcePosition: -1, + targetPosition: 0, + rule: rules[0], + instanceID: int64(8 << 59), + schemaID: int64(2 << 52), + tableID: int64(1 << 44), + }, info) + + m.resetCache() + SetPartitionRule(0, 0, 3) + info, err = m.queryColumnInfo("test_2", "xxx_1", []string{"id", "name"}) + require.NoError(t, err) + require.Equal(t, &mappingInfo{ + sourcePosition: -1, + targetPosition: 0, + rule: rules[0], + instanceID: int64(0), + schemaID: int64(0), + tableID: int64(1 << 60), + }, info) +} + +func TestSetPartitionRule(t *testing.T) { + SetPartitionRule(4, 7, 8) + require.Equal(t, 4, instanceIDBitSize) + require.Equal(t, 7, schemaIDBitSize) + require.Equal(t, 8, tableIDBitSize) + require.Equal(t, int64(1<<44), maxOriginID) + + SetPartitionRule(0, 3, 4) + require.Equal(t, 0, instanceIDBitSize) + require.Equal(t, 3, schemaIDBitSize) + require.Equal(t, 4, tableIDBitSize) + require.Equal(t, int64(1<<56), maxOriginID) +} + +func TestComputePartitionID(t *testing.T) { + SetPartitionRule(4, 7, 8) + + rule := &Rule{ + Arguments: []string{"test", "t"}, + } + _, _, _, err := computePartitionID("test_1", "t_1", rule) + require.Error(t, err) + _, _, _, err = computePartitionID("test", "t", rule) + require.Error(t, err) + + rule = &Rule{ + Arguments: []string{"2", "test", "t", "_"}, + } + instanceID, schemaID, tableID, err := computePartitionID("test_1", "t_1", rule) + require.NoError(t, err) + require.Equal(t, int64(2<<59), instanceID) + require.Equal(t, int64(1<<52), schemaID) + require.Equal(t, int64(1<<44), tableID) + + // test default partition ID to zero + instanceID, schemaID, tableID, err = computePartitionID("test", "t_3", rule) + require.NoError(t, err) + require.Equal(t, int64(2<<59), instanceID) + require.Equal(t, int64(0), schemaID) + require.Equal(t, int64(3<<44), tableID) + + instanceID, schemaID, tableID, err = computePartitionID("test_5", "t", rule) + require.NoError(t, err) + require.Equal(t, int64(2<<59), instanceID) + require.Equal(t, int64(5<<52), schemaID) + require.Equal(t, int64(0), tableID) + + _, _, _, err = computePartitionID("unrelated", "t_6", rule) + require.ErrorContains(t, err, "test_ is not the prefix of unrelated") + + _, _, _, err = computePartitionID("test", "x", rule) + require.ErrorContains(t, err, "t_ is not the prefix of x") + + _, _, _, err = computePartitionID("test_0", "t_0xa", rule) + require.ErrorContains(t, err, "the suffix of 0xa can't be converted to int64") + + _, _, _, err = computePartitionID("test_0", "t_", rule) + require.ErrorContains(t, err, "t_ is not the prefix of t_") // needs a better error message + + _, _, _, err = computePartitionID("testx", "t_3", rule) + require.ErrorContains(t, err, "test_ is not the prefix of testx") + + SetPartitionRule(4, 0, 8) + rule = &Rule{ + Arguments: []string{"2", "test_", "t_", ""}, + } + instanceID, schemaID, tableID, err = computePartitionID("test_1", "t_1", rule) + require.NoError(t, err) + require.Equal(t, int64(2<<59), instanceID) + require.Equal(t, int64(0), schemaID) + require.Equal(t, int64(1<<51), tableID) + + instanceID, schemaID, tableID, err = computePartitionID("test_", "t_", rule) + require.NoError(t, err) + require.Equal(t, int64(2<<59), instanceID) + require.Equal(t, int64(0), schemaID) + require.Equal(t, int64(0), tableID) + + // test ignore instance ID + SetPartitionRule(4, 7, 8) + rule = &Rule{ + Arguments: []string{"", "test_", "t_", ""}, + } + instanceID, schemaID, tableID, err = computePartitionID("test_1", "t_1", rule) + require.NoError(t, err) + require.Equal(t, int64(0), instanceID) + require.Equal(t, int64(1<<56), schemaID) + require.Equal(t, int64(1<<48), tableID) + + // test ignore schema ID + rule = &Rule{ + Arguments: []string{"2", "", "t_", ""}, + } + instanceID, schemaID, tableID, err = computePartitionID("test_1", "t_1", rule) + require.NoError(t, err) + require.Equal(t, int64(2<<59), instanceID) + require.Equal(t, int64(0), schemaID) + require.Equal(t, int64(1<<51), tableID) + + // test ignore schema ID + rule = &Rule{ + Arguments: []string{"2", "test_", "", ""}, + } + instanceID, schemaID, tableID, err = computePartitionID("test_1", "t_1", rule) + require.NoError(t, err) + require.Equal(t, int64(2<<59), instanceID) + require.Equal(t, int64(1<<52), schemaID) + require.Equal(t, int64(0), tableID) +} + +func TestPartitionID(t *testing.T) { + SetPartitionRule(4, 7, 8) + info := &mappingInfo{ + instanceID: int64(2 << 59), + schemaID: int64(1 << 52), + tableID: int64(1 << 44), + targetPosition: 1, + } + + // test wrong type + _, err := partitionID(info, []interface{}{1, "ha"}) + require.Error(t, err) + + // test exceed maxOriginID + _, err = partitionID(info, []interface{}{"ha", 1 << 44}) + require.Error(t, err) + + vals, err := partitionID(info, []interface{}{"ha", 1}) + require.NoError(t, err) + require.Equal(t, []interface{}{"ha", int64(2<<59 | 1<<52 | 1<<44 | 1)}, vals) + + info.instanceID = 0 + vals, err = partitionID(info, []interface{}{"ha", "123"}) + require.NoError(t, err) + require.Equal(t, []interface{}{"ha", fmt.Sprintf("%d", int64(1<<52|1<<44|123))}, vals) +} + +func TestCaseSensitive(t *testing.T) { + // we test case insensitive in TestHandle + rules := []*Rule{ + {"Test*", "xxx*", "", "id", AddPrefix, []string{"instance_id:"}, "xx"}, + } + + // case sensitive + // initial column mapping + m, err := NewMapping(true, rules) + require.NoError(t, err) + require.Len(t, m.cache.infos, 0) + + // test add prefix, add suffix is similar + vals, poss, err := m.HandleRowValue("test", "xxx", []string{"age", "id"}, []interface{}{1, "1"}) + require.NoError(t, err) + require.Equal(t, []interface{}{1, "1"}, vals) + require.Nil(t, poss) +}