diff --git a/executor/executor_pkg_test.go b/executor/executor_pkg_test.go index b454b6c8ed4b2..d6287d36d52ab 100644 --- a/executor/executor_pkg_test.go +++ b/executor/executor_pkg_test.go @@ -198,6 +198,11 @@ func (s *testExecSuite) TestGetFieldsFromLine(c *C) { `"\0\b\n\r\t\Z\\\ \c\'\""`, []string{string([]byte{0, '\b', '\n', '\r', '\t', 26, '\\', ' ', ' ', 'c', '\'', '"'})}, }, + // Test mixed. + { + `"123",456,"\t7890",abcd`, + []string{"123", "456", "\t7890", "abcd"}, + }, } ldInfo := LoadDataInfo{ @@ -214,7 +219,7 @@ func (s *testExecSuite) TestGetFieldsFromLine(c *C) { } _, err := ldInfo.getFieldsFromLine([]byte(`1,a string,100.20`)) - c.Assert(err, NotNil) + c.Assert(err, IsNil) } func assertEqualStrings(c *C, got []field, expect []string) { diff --git a/executor/load_data.go b/executor/load_data.go index 8246212d80804..4ad67808e0dfc 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -14,7 +14,6 @@ package executor import ( - "bytes" "context" "fmt" "strings" @@ -209,7 +208,6 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error if len(prevData) == 0 && len(curData) == 0 { return nil, false, nil } - var line []byte var isEOF, hasStarting, reachLimit bool if len(prevData) > 0 && len(curData) == 0 { @@ -220,7 +218,6 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error for len(curData) > 0 { line, curData, hasStarting = e.getLine(prevData, curData) prevData = nil - // If it doesn't find the terminated symbol and this data isn't the last data, // the data can't be inserted. if line == nil && !isEOF { @@ -313,28 +310,174 @@ func (e *LoadDataInfo) addRecordLD(row []types.Datum) (int64, error) { type field struct { str []byte maybeNull bool + enclosed bool +} + +type fieldWriter struct { + pos int + enclosedChar byte + fieldTermChar byte + term *string + isEnclosed bool + isLineStart bool + isFieldStart bool + ReadBuf *[]byte + OutputBuf []byte +} + +func (w *fieldWriter) Init(enclosedChar byte, fieldTermChar byte, readBuf *[]byte, term *string) { + w.isEnclosed = false + w.isLineStart = true + w.isFieldStart = true + w.ReadBuf = readBuf + w.enclosedChar = enclosedChar + w.fieldTermChar = fieldTermChar + w.term = term +} + +func (w *fieldWriter) putback() { + w.pos-- +} + +func (w *fieldWriter) getChar() (bool, byte) { + if w.pos < len(*w.ReadBuf) { + ret := (*w.ReadBuf)[w.pos] + w.pos++ + return true, ret + } + return false, 0 +} + +func (w *fieldWriter) isTerminator() bool { + chkpt, isterm := w.pos, true + for i := 1; i < len(*w.term); i++ { + flag, ch := w.getChar() + if !flag || ch != (*w.term)[i] { + isterm = false + break + } + } + if !isterm { + w.pos = chkpt + return false + } + return true +} + +func (w *fieldWriter) outputField(enclosed bool) field { + var fild []byte + start := 0 + if enclosed { + start = 1 + } + for i := start; i < len(w.OutputBuf); i++ { + fild = append(fild, w.OutputBuf[i]) + } + if len(fild) == 0 { + fild = []byte("") + } + w.OutputBuf = w.OutputBuf[0:0] + w.isEnclosed = false + w.isFieldStart = true + return field{fild, false, enclosed} +} + +func (w *fieldWriter) GetField() (bool, field) { + // The first return value implies whether fieldWriter read the last character of line. + if w.isLineStart { + _, ch := w.getChar() + if ch == w.enclosedChar { + w.isEnclosed = true + w.isFieldStart, w.isLineStart = false, false + w.OutputBuf = append(w.OutputBuf, ch) + } else { + w.putback() + } + } + for { + flag, ch := w.getChar() + if !flag { + ret := w.outputField(false) + return true, ret + } + if ch == w.enclosedChar && w.isFieldStart { + // If read enclosed char at field start. + w.isEnclosed = true + w.OutputBuf = append(w.OutputBuf, ch) + w.isLineStart, w.isFieldStart = false, false + continue + } + w.isLineStart, w.isFieldStart = false, false + if ch == w.fieldTermChar && !w.isEnclosed { + // If read filed terminate char. + if w.isTerminator() { + ret := w.outputField(false) + return false, ret + } + w.OutputBuf = append(w.OutputBuf, ch) + } else if ch == w.enclosedChar && w.isEnclosed { + // If read enclosed char, look ahead. + flag, ch = w.getChar() + if !flag { + ret := w.outputField(true) + return true, ret + } else if ch == w.enclosedChar { + w.OutputBuf = append(w.OutputBuf, ch) + continue + } else if ch == w.fieldTermChar { + // If the next char is fieldTermChar, look ahead. + if w.isTerminator() { + ret := w.outputField(true) + return false, ret + } + w.OutputBuf = append(w.OutputBuf, ch) + } else { + // If there is no terminator behind enclosedChar, put the char back. + w.OutputBuf = append(w.OutputBuf, w.enclosedChar) + w.putback() + } + } else if ch == '\\' { + // TODO: escape only support '\' + w.OutputBuf = append(w.OutputBuf, ch) + flag, ch = w.getChar() + if flag { + if ch == w.enclosedChar { + w.OutputBuf = append(w.OutputBuf, ch) + } else { + w.putback() + } + } + } else { + w.OutputBuf = append(w.OutputBuf, ch) + } + } } // getFieldsFromLine splits line according to fieldsInfo. func (e *LoadDataInfo) getFieldsFromLine(line []byte) ([]field, error) { - var sep []byte - if e.FieldsInfo.Enclosed != 0 { - if line[0] != e.FieldsInfo.Enclosed || line[len(line)-1] != e.FieldsInfo.Enclosed { - return nil, errors.Errorf("line %s should begin and end with %c", string(line), e.FieldsInfo.Enclosed) + var ( + reader fieldWriter + fields []field + ) + + if len(line) == 0 { + str := []byte("") + fields = append(fields, field{str, false, false}) + return fields, nil + } + + reader.Init(e.FieldsInfo.Enclosed, e.FieldsInfo.Terminated[0], &line, &e.FieldsInfo.Terminated) + for { + eol, f := reader.GetField() + f = f.escape() + if string(f.str) == "NULL" && !f.enclosed { + f.str = []byte{'N'} + f.maybeNull = true + } + fields = append(fields, f) + if eol { + break } - line = line[1 : len(line)-1] - sep = make([]byte, 0, len(e.FieldsInfo.Terminated)+2) - sep = append(sep, e.FieldsInfo.Enclosed) - sep = append(sep, e.FieldsInfo.Terminated...) - sep = append(sep, e.FieldsInfo.Enclosed) - } else { - sep = []byte(e.FieldsInfo.Terminated) - } - rawCols := bytes.Split(line, sep) - fields := make([]field, 0, len(rawCols)) - for _, v := range rawCols { - f := field{v, false} - fields = append(fields, f.escape()) } return fields, nil } @@ -354,7 +497,7 @@ func (f *field) escape() field { f.str[pos] = c pos++ } - return field{f.str[:pos], f.maybeNull} + return field{f.str[:pos], f.maybeNull, f.enclosed} } func (f *field) escapeChar(c byte) byte { diff --git a/server/server_test.go b/server/server_test.go index e7eed294b7ee2..a0d9323314c56 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -483,6 +483,149 @@ func runTestLoadData(c *C, server *Server) { dbt.Assert(err, NotNil) }) + err = fp.Close() + c.Assert(err, IsNil) + err = os.Remove(path) + c.Assert(err, IsNil) + + fp, err = os.Create(path) + c.Assert(err, IsNil) + c.Assert(fp, NotNil) + + // Test mixed unenclosed and enclosed fields. + _, err = fp.WriteString( + "\"abc\",123\n" + + "def,456,\n" + + "hig,\"789\",") + c.Assert(err, IsNil) + + runTestsOnNewDB(c, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Strict = false + }, "LoadData", func(dbt *DBTest) { + dbt.mustExec("create table test (str varchar(10) default null, i int default null)") + _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`) + dbt.Assert(err1, IsNil) + var ( + str string + id int + ) + rows := dbt.mustQuery("select * from test") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + err = rows.Scan(&str, &id) + dbt.Check(err, IsNil) + dbt.Check(str, DeepEquals, "abc") + dbt.Check(id, DeepEquals, 123) + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows.Scan(&str, &id) + dbt.Check(str, DeepEquals, "def") + dbt.Check(id, DeepEquals, 456) + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows.Scan(&str, &id) + dbt.Check(str, DeepEquals, "hig") + dbt.Check(id, DeepEquals, 789) + dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + dbt.mustExec("delete from test") + }) + + err = fp.Close() + c.Assert(err, IsNil) + err = os.Remove(path) + c.Assert(err, IsNil) + + fp, err = os.Create(path) + c.Assert(err, IsNil) + c.Assert(fp, NotNil) + + // Test irregular csv file. + _, err = fp.WriteString( + `,\N,NULL,,` + "\n" + + "00,0,000000,,\n" + + `2003-03-03, 20030303,030303,\N` + "\n") + c.Assert(err, IsNil) + + runTestsOnNewDB(c, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Strict = false + }, "LoadData", func(dbt *DBTest) { + dbt.mustExec("create table test (a date, b date, c date not null, d date)") + _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ','`) + dbt.Assert(err1, IsNil) + var ( + a sql.NullString + b sql.NullString + d sql.NullString + c sql.NullString + ) + rows := dbt.mustQuery("select * from test") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + err = rows.Scan(&a, &b, &c, &d) + dbt.Check(err, IsNil) + dbt.Check(a.String, Equals, "0000-00-00") + dbt.Check(b.String, Equals, "") + dbt.Check(c.String, Equals, "0000-00-00") + dbt.Check(d.String, Equals, "0000-00-00") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows.Scan(&a, &b, &c, &d) + dbt.Check(a.String, Equals, "0000-00-00") + dbt.Check(b.String, Equals, "0000-00-00") + dbt.Check(c.String, Equals, "0000-00-00") + dbt.Check(d.String, Equals, "0000-00-00") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows.Scan(&a, &b, &c, &d) + dbt.Check(a.String, Equals, "2003-03-03") + dbt.Check(b.String, Equals, "2003-03-03") + dbt.Check(c.String, Equals, "2003-03-03") + dbt.Check(d.String, Equals, "") + dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + dbt.mustExec("delete from test") + }) + + err = fp.Close() + c.Assert(err, IsNil) + err = os.Remove(path) + c.Assert(err, IsNil) + + fp, err = os.Create(path) + c.Assert(err, IsNil) + c.Assert(fp, NotNil) + + // Test double enclosed. + _, err = fp.WriteString( + `"field1","field2"` + "\n" + + `"a""b","cd""ef"` + "\n" + + `"a"b",c"d"e` + "\n") + c.Assert(err, IsNil) + + runTestsOnNewDB(c, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Strict = false + }, "LoadData", func(dbt *DBTest) { + dbt.mustExec("create table test (a varchar(20), b varchar(20))") + _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`) + dbt.Assert(err1, IsNil) + var ( + a sql.NullString + b sql.NullString + ) + rows := dbt.mustQuery("select * from test") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + err = rows.Scan(&a, &b) + dbt.Check(err, IsNil) + dbt.Check(a.String, Equals, "field1") + dbt.Check(b.String, Equals, "field2") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows.Scan(&a, &b) + dbt.Check(a.String, Equals, `a"b`) + dbt.Check(b.String, Equals, `cd"ef`) + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows.Scan(&a, &b) + dbt.Check(a.String, Equals, `a"b`) + dbt.Check(b.String, Equals, `c"d"e`) + dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + dbt.mustExec("delete from test") + }) + // unsupport ClientLocalFiles capability server.capability ^= tmysql.ClientLocalFiles runTestsOnNewDB(c, func(config *mysql.Config) {