diff --git a/v4/export/ir.go b/v4/export/ir.go index 9d5037a5fe72e..b50a6e2060536 100644 --- a/v4/export/ir.go +++ b/v4/export/ir.go @@ -19,6 +19,15 @@ type SQLRowIter interface { HasNext() bool } +type RowReceiverStringer interface { + RowReceiver + Stringer +} + +type Stringer interface { + ToString() string +} + type RowReceiver interface { BindAddress([]interface{}) ReportSize() uint64 diff --git a/v4/export/ir_impl_test.go b/v4/export/ir_impl_test.go index 2910a5df825b9..4a2901a81759e 100644 --- a/v4/export/ir_impl_test.go +++ b/v4/export/ir_impl_test.go @@ -9,6 +9,18 @@ var _ = Suite(&testIRImplSuite{}) type testIRImplSuite struct{} +type simpleRowReceiver struct { + data string +} + +func (s *simpleRowReceiver) BindAddress(arg []interface{}) { + arg[0] = &s.data +} + +func (s *simpleRowReceiver) ReportSize() uint64 { + panic("not implement") +} + func (s *testIRImplSuite) TestRowIter(c *C) { db, mock, err := sqlmock.New() c.Assert(err, IsNil) @@ -26,15 +38,15 @@ func (s *testIRImplSuite) TestRowIter(c *C) { for i := 0; i < 100; i += 1 { c.Assert(iter.HasNext(), IsTrue) } - res := make(dumplingRow, 1) + res := &simpleRowReceiver{} c.Assert(iter.Next(res), IsNil) - c.Assert(res[0].String, Equals, "1") + c.Assert(res.data, Equals, "1") c.Assert(iter.HasNext(), IsTrue) c.Assert(iter.HasNext(), IsTrue) c.Assert(iter.Next(res), IsNil) - c.Assert(res[0].String, Equals, "2") + c.Assert(res.data, Equals, "2") c.Assert(iter.HasNext(), IsTrue) c.Assert(iter.Next(res), IsNil) - c.Assert(res[0].String, Equals, "3") + c.Assert(res.data, Equals, "3") c.Assert(iter.HasNext(), IsFalse) } diff --git a/v4/export/sql_test.go b/v4/export/sql_test.go index 383d49bcb7405..ac2070a76e701 100644 --- a/v4/export/sql_test.go +++ b/v4/export/sql_test.go @@ -57,7 +57,7 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) { mockConf := DefaultConfig() mockConf.SortByPk = true - // Test when the server is TiDB. + // Test TiDB server. mockConf.ServerInfo.ServerType = ServerTypeTiDB // _tidb_rowid is available. @@ -73,7 +73,7 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) { c.Assert(q, Equals, "SELECT * FROM test.t") c.Assert(mock.ExpectationsWereMet(), IsNil) - // Test other server. + // Test other servers. otherServers := []ServerType{ServerTypeUnknown, ServerTypeMySQL, ServerTypeMariaDB} // Test table with primary key. diff --git a/v4/export/sql_type.go b/v4/export/sql_type.go new file mode 100644 index 0000000000000..a383782853227 --- /dev/null +++ b/v4/export/sql_type.go @@ -0,0 +1,147 @@ +package export + +import ( + "database/sql" + "fmt" + "strings" +) + +var colTypeRowReceiverMap = map[string]func() RowReceiverStringer{} + +func init() { + for _, s := range dataTypeString { + colTypeRowReceiverMap[s] = SQLTypeStringMaker + } + for _, s := range dataTypeNum { + colTypeRowReceiverMap[s] = SQLTypeNumberMaker + } + for _, s := range dataTypeBin { + colTypeRowReceiverMap[s] = SQLTypeBytesMaker + } +} + +var dataTypeString = []string{ + "CHAR", "NCHAR", "VARCHAR", "NVARCHAR", "CHARACTER", "VARCHARACTER", + "TIMESTAMP", "DATETIME", "DATE", "TIME", "YEAR", "SQL_TSI_YEAR", + "TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", + "ENUM", "SET", "JSON", +} + +var dataTypeNum = []string{ + "INTEGER", "BIGINT", "TINYINT", "SMALLINT", "MEDIUMINT", + "INT", "INT1", "INT2", "INT3", "INT8", + "FLOAT", "REAL", "DOUBLE", "DOUBLE PRECISION", + "DECIMAL", "NUMERIC", "FIXED", + "BOOL", "BOOLEAN", +} + +var dataTypeBin = []string{ + "BLOB", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", "LONG", + "BINARY", "VARBINARY", + "BIT", +} + +func SQLTypeStringMaker() RowReceiverStringer { + return &SQLTypeString{} +} + +func SQLTypeBytesMaker() RowReceiverStringer { + return &SQLTypeBytes{} +} + +func SQLTypeNumberMaker() RowReceiverStringer { + return &SQLTypeNumber{} +} + +func MakeRowReceiver(colTypes []string) RowReceiverStringer { + rowReceiverArr := make(RowReceiverArr, len(colTypes)) + for i, colTp := range colTypes { + recMaker, ok := colTypeRowReceiverMap[colTp] + if !ok { + recMaker = SQLTypeStringMaker + } + rowReceiverArr[i] = recMaker() + } + return rowReceiverArr +} + +type RowReceiverArr []RowReceiverStringer + +func (r RowReceiverArr) BindAddress(args []interface{}) { + for i := range args { + var singleAddr [1]interface{} + r[i].BindAddress(singleAddr[:]) + args[i] = singleAddr[0] + } +} +func (r RowReceiverArr) ReportSize() uint64 { + var sum uint64 + for _, receiver := range r { + sum += receiver.ReportSize() + } + return sum +} +func (r RowReceiverArr) ToString() string { + var sb strings.Builder + sb.WriteString("(") + for i, receiver := range r { + sb.WriteString(receiver.ToString()) + if i != len(r)-1 { + sb.WriteString(", ") + } + } + sb.WriteString(")") + return sb.String() +} + +type SQLTypeNumber struct { + SQLTypeString +} + +func (s SQLTypeNumber) ToString() string { + if s.Valid { + return s.String + } else { + return "NULL" + } +} + +type SQLTypeString struct { + sql.NullString +} + +func (s *SQLTypeString) BindAddress(arg []interface{}) { + arg[0] = s +} +func (s *SQLTypeString) ReportSize() uint64 { + if s.Valid { + return uint64(len(s.String)) + } + return uint64(len("NULL")) +} +func (s *SQLTypeString) ToString() string { + if s.Valid { + return fmt.Sprintf(`'%s'`, escape(s.String)) + } else { + return "NULL" + } +} + +func escape(src string) string { + src = strings.ReplaceAll(src, "'", "''") + return strings.ReplaceAll(src, `\`, `\\`) +} + +type SQLTypeBytes struct { + bytes []byte +} + +func (s *SQLTypeBytes) BindAddress(arg []interface{}) { + arg[0] = &s.bytes +} +func (s *SQLTypeBytes) ReportSize() uint64 { + return uint64(len(s.bytes)) +} +func (s *SQLTypeBytes) ToString() string { + return fmt.Sprintf("x'%x'", s.bytes) +} diff --git a/v4/export/test_util.go b/v4/export/test_util.go index 7f8cc8b6cc597..385131057266e 100644 --- a/v4/export/test_util.go +++ b/v4/export/test_util.go @@ -2,7 +2,10 @@ package export import ( "database/sql" + "database/sql/driver" "fmt" + + "github.com/DATA-DOG/go-sqlmock" ) type mockStringWriter struct { @@ -71,65 +74,60 @@ func newMockMetaIR(targetName string, meta string, specialComments []string) Met } } -func makeNullString(ss []string) []sql.NullString { - var ns []sql.NullString - for _, s := range ss { - if len(s) != 0 { - ns = append(ns, sql.NullString{String: s, Valid: true}) - } else { - ns = append(ns, sql.NullString{Valid: false}) - } - } - return ns -} - -type mockTableDataIR struct { +type mockTableIR struct { dbName string tblName string - data [][]sql.NullString + data [][]driver.Value specCmt []string colTypes []string } -func (m *mockTableDataIR) ColumnTypes() []string { - return m.colTypes +func (m *mockTableIR) DatabaseName() string { + return m.dbName } -func newMockTableDataIR(databaseName, tableName string, data [][]string, specialComments []string, colTypes []string) TableDataIR { - var nData [][]sql.NullString - for _, ss := range data { - nData = append(nData, makeNullString(ss)) - } - - return &mockTableDataIR{ - dbName: databaseName, - tblName: tableName, - data: nData, - specCmt: specialComments, - colTypes: colTypes, - } +func (m *mockTableIR) TableName() string { + return m.tblName } -func (m *mockTableDataIR) DatabaseName() string { - return m.dbName +func (m *mockTableIR) ColumnCount() uint { + return uint(len(m.colTypes)) } -func (m *mockTableDataIR) TableName() string { - return "employee" +func (m *mockTableIR) ColumnTypes() []string { + return m.colTypes } -func (m *mockTableDataIR) ColumnCount() uint { - return 5 +func (m *mockTableIR) SpecialComments() StringIter { + return newStringIter(m.specCmt...) } -func (m *mockTableDataIR) SpecialComments() StringIter { - return newStringIter(m.specCmt...) +func (m *mockTableIR) Rows() SQLRowIter { + mockRows := sqlmock.NewRows(m.colTypes) + for _, datum := range m.data { + mockRows.AddRow(datum...) + } + db, mock, err := sqlmock.New() + if err != nil { + panic(fmt.Sprintf("sqlmock.New return error: %v", err)) + } + defer db.Close() + mock.ExpectQuery("select 1").WillReturnRows(mockRows) + rows, err := db.Query("select 1") + if err != nil { + panic(fmt.Sprintf("sqlmock.New return error: %v", err)) + } + + return newRowIter(rows, len(m.colTypes)) } -func (m *mockTableDataIR) Rows() SQLRowIter { - return &mockSQLRowIterator{ - idx: 0, - data: m.data, +func newMockTableIR(databaseName, tableName string, data [][]driver.Value, specialComments, colTypes []string) TableDataIR { + return &mockTableIR{ + dbName: databaseName, + tblName: tableName, + data: data, + specCmt: specialComments, + colTypes: colTypes, } } diff --git a/v4/export/writer_util.go b/v4/export/writer_util.go index 5558bff55bf8d..4f55e32be1a92 100644 --- a/v4/export/writer_util.go +++ b/v4/export/writer_util.go @@ -1,32 +1,11 @@ package export import ( - "database/sql" "fmt" "io" "strings" ) -type dumplingRow []sql.NullString - -func (d dumplingRow) BindAddress(args []interface{}) { - for i := range d { - args[i] = &d[i] - } -} - -func (d dumplingRow) ReportSize() uint64 { - var totalSize uint64 - for _, ns := range d { - if ns.Valid { - totalSize += 4 - } else { - totalSize += uint64(len(ns.String)) - } - } - return totalSize -} - func WriteMeta(meta MetaIR, w io.StringWriter, cfg *Config) error { log := cfg.Logger log.Debug("start dumping meta data for target %s", meta.TargetName()) @@ -70,15 +49,13 @@ func WriteInsert(tblIR TableDataIR, w io.StringWriter, cfg *Config) error { } for rowIter.HasNext() { - var dumplingRow = make(dumplingRow, tblIR.ColumnCount()) - if err := rowIter.Next(dumplingRow); err != nil { + row := MakeRowReceiver(tblIR.ColumnTypes()) + if err := rowIter.Next(row); err != nil { log.Error("scanning from sql.Row failed, error: %s", err.Error()) return err } - row := convert(dumplingRow, tblIR.ColumnTypes()) - - if err := write(w, fmt.Sprintf("(%s)", strings.Join(row, ", ")), log); err != nil { + if err := write(w, row.ToString(), log); err != nil { return err } @@ -107,38 +84,3 @@ func write(writer io.StringWriter, str string, logger Logger) error { func wrapStringWith(str string, wrapper string) string { return fmt.Sprintf("%s%s%s", wrapper, str, wrapper) } - -func convert(origin []sql.NullString, colTypes []string) []string { - ret := make([]string, len(origin)) - for i, s := range origin { - if !s.Valid { - ret[i] = "NULL" - continue - } - - if isCharTypes(colTypes[i]) { - ret[i] = wrapStringWith(s.String, "'") - } else { - ret[i] = s.String - } - } - return ret -} - -var charTypes = map[string]struct{}{ - "CHAR": {}, - "NCHAR": {}, - "VARCHAR": {}, - "NVARCHAR": {}, - "BINARY": {}, - "VARBINARY": {}, - "BLOB": {}, - "TEXT": {}, - "ENUM": {}, - "SET": {}, -} - -func isCharTypes(colType string) bool { - _, ok := charTypes[colType] - return ok -} diff --git a/v4/export/writer_util_test.go b/v4/export/writer_util_test.go index 66efb54dc2605..7e19a5465016e 100644 --- a/v4/export/writer_util_test.go +++ b/v4/export/writer_util_test.go @@ -1,8 +1,12 @@ package export import ( - . "github.com/pingcap/check" + "fmt" + "strings" "testing" + + "database/sql/driver" + . "github.com/pingcap/check" ) func TestT(t *testing.T) { @@ -40,8 +44,8 @@ func (s *testUtilSuite) TestWriteMeta(c *C) { } func (s *testUtilSuite) TestWriteInsert(c *C) { - data := [][]string{ - {"1", "male", "bob@mail.com", "020-1234", ""}, + data := [][]driver.Value{ + {"1", "male", "bob@mail.com", "020-1234", nil}, {"2", "female", "sarah@mail.com", "020-1253", "healthy"}, {"3", "male", "john@mail.com", "020-1256", "healthy"}, {"4", "female", "sarah@mail.com", "020-1235", "healthy"}, @@ -51,7 +55,7 @@ func (s *testUtilSuite) TestWriteInsert(c *C) { "/*!40101 SET NAMES binary*/;", "/*!40014 SET FOREIGN_KEY_CHECKS=0*/;", } - tableIR := newMockTableDataIR("test", "employee", data, specCmts, colTypes) + tableIR := newMockTableIR("test", "employee", data, specCmts, colTypes) strCollector := &mockStringCollector{} err := WriteInsert(tableIR, strCollector, s.mockCfg) @@ -66,6 +70,29 @@ func (s *testUtilSuite) TestWriteInsert(c *C) { c.Assert(strCollector.buf, Equals, expected) } +func (s *testUtilSuite) TestSQLDataTypes(c *C) { + data := [][]driver.Value{ + {"CHAR", "char1", `'char1'`}, + {"INT", 12345, `12345`}, + {"BINARY", 1234, "x'31323334'"}, + } + + for _, datum := range data { + sqlType, origin, result := datum[0].(string), datum[1], datum[2].(string) + + tableData := [][]driver.Value{{origin}} + colType := []string{sqlType} + tableIR := newMockTableIR("test", "t", tableData, nil, colType) + strCollector := &mockStringCollector{} + + err := WriteInsert(tableIR, strCollector, s.mockCfg) + c.Assert(err, IsNil) + lines := strings.Split(strCollector.buf, "\n") + c.Assert(len(lines), Equals, 3) + c.Assert(lines[1], Equals, fmt.Sprintf("(%s);", result)) + } +} + func (s *testUtilSuite) TestWrite(c *C) { mocksw := &mockStringWriter{} src := []string{"test", "loooooooooooooooooooong", "poison"} @@ -83,10 +110,3 @@ func (s *testUtilSuite) TestWrite(c *C) { err := write(mocksw, "test", nil) c.Assert(err, IsNil) } - -func (s *testUtilSuite) TestConvert(c *C) { - srcColTypes := []string{"INT", "CHAR", "BIGINT", "VARCHAR", "SET"} - src := makeNullString([]string{"255", "", "25535", "computer_science", "male"}) - exp := []string{"255", "NULL", "25535", "'computer_science'", "'male'"} - c.Assert(convert(src, srcColTypes), DeepEquals, exp) -}