Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: fix csv parser #9005

Merged
merged 11 commits into from
Jan 15, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ func (s *testExecSuite) TestGetFieldsFromLine(c *C) {
`"\0\b\n\r\t\Z\\\ \c\'\""`,
[]string{string([]byte{0, '\b', '\n', '\r', '\t', 26, '\\', ' ', ' ', 'c', '\'', '"'})},
},
// Test mixed.
{
`"123",456,"\t7890",abcd`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to add more test cases to guarantee the behavior we support.

[]string{"123", "456", "\t7890", "abcd"},
},
}

ldInfo := LoadDataInfo{
Expand All @@ -214,7 +219,7 @@ func (s *testExecSuite) TestGetFieldsFromLine(c *C) {
}

_, err := ldInfo.getFieldsFromLine([]byte(`1,a string,100.20`))
c.Assert(err, NotNil)
c.Assert(err, IsNil)
}

func assertEqualStrings(c *C, got []field, expect []string) {
Expand Down
185 changes: 164 additions & 21 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package executor

import (
"bytes"
"context"
"fmt"
"strings"
Expand Down Expand Up @@ -209,7 +208,6 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error
if len(prevData) == 0 && len(curData) == 0 {
return nil, false, nil
}

var line []byte
var isEOF, hasStarting, reachLimit bool
if len(prevData) > 0 && len(curData) == 0 {
Expand All @@ -220,7 +218,6 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error
for len(curData) > 0 {
line, curData, hasStarting = e.getLine(prevData, curData)
prevData = nil

// If it doesn't find the terminated symbol and this data isn't the last data,
// the data can't be inserted.
if line == nil && !isEOF {
Expand Down Expand Up @@ -313,28 +310,174 @@ func (e *LoadDataInfo) addRecordLD(row []types.Datum) (int64, error) {
type field struct {
str []byte
maybeNull bool
enclosed bool
}

type fieldWriter struct {
pos int
enclosedChar byte
fieldTermChar byte
term *string
isEnclosed bool
isLineStart bool
isFieldStart bool
ReadBuf *[]byte
OutputBuf []byte
}

func (w *fieldWriter) Init(enclosedChar byte, fieldTermChar byte, readBuf *[]byte, term *string) {
w.isEnclosed = false
w.isLineStart = true
w.isFieldStart = true
w.ReadBuf = readBuf
w.enclosedChar = enclosedChar
w.fieldTermChar = fieldTermChar
w.term = term
}

func (w *fieldWriter) putback() {
w.pos--
}

func (w *fieldWriter) getChar() (bool, byte) {
if w.pos < len(*w.ReadBuf) {
ret := (*w.ReadBuf)[w.pos]
w.pos++
return true, ret
}
return false, 0
}

func (w *fieldWriter) isTerminator() bool {
chkpt, isterm := w.pos, true
for i := 1; i < len(*w.term); i++ {
flag, ch := w.getChar()
if !flag || ch != (*w.term)[i] {
isterm = false
break
}
}
if !isterm {
w.pos = chkpt
return false
}
return true
}

func (w *fieldWriter) outputField(enclosed bool) field {
var fild []byte
start := 0
if enclosed {
start = 1
}
for i := start; i < len(w.OutputBuf); i++ {
fild = append(fild, w.OutputBuf[i])
}
if len(fild) == 0 {
fild = []byte("")
}
w.OutputBuf = w.OutputBuf[0:0]
w.isEnclosed = false
w.isFieldStart = true
return field{fild, false, enclosed}
}

func (w *fieldWriter) GetField() (bool, field) {
// The first return value implies whether fieldWriter read the last character of line.
if w.isLineStart {
_, ch := w.getChar()
if ch == w.enclosedChar {
w.isEnclosed = true
w.isFieldStart, w.isLineStart = false, false
w.OutputBuf = append(w.OutputBuf, ch)
} else {
w.putback()
}
}
for {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it's complex enough and more difficult to maintain.
If we meet some error next time, I'll consider use some more general method instead of hard written those things.

flag, ch := w.getChar()
if !flag {
ret := w.outputField(false)
return true, ret
}
if ch == w.enclosedChar && w.isFieldStart {
// If read enclosed char at field start.
w.isEnclosed = true
w.OutputBuf = append(w.OutputBuf, ch)
w.isLineStart, w.isFieldStart = false, false
continue
}
w.isLineStart, w.isFieldStart = false, false
if ch == w.fieldTermChar && !w.isEnclosed {
// If read filed terminate char.
if w.isTerminator() {
ret := w.outputField(false)
return false, ret
}
w.OutputBuf = append(w.OutputBuf, ch)
} else if ch == w.enclosedChar && w.isEnclosed {
// If read enclosed char, look ahead.
flag, ch = w.getChar()
if !flag {
ret := w.outputField(true)
return true, ret
} else if ch == w.enclosedChar {
w.OutputBuf = append(w.OutputBuf, ch)
continue
} else if ch == w.fieldTermChar {
// If the next char is fieldTermChar, look ahead.
if w.isTerminator() {
ret := w.outputField(true)
return false, ret
}
w.OutputBuf = append(w.OutputBuf, ch)
} else {
// If there is no terminator behind enclosedChar, put the char back.
w.OutputBuf = append(w.OutputBuf, w.enclosedChar)
w.putback()
}
} else if ch == '\\' {
// TODO: escape only support '\'
w.OutputBuf = append(w.OutputBuf, ch)
flag, ch = w.getChar()
if flag {
if ch == w.enclosedChar {
w.OutputBuf = append(w.OutputBuf, ch)
} else {
w.putback()
}
}
} else {
w.OutputBuf = append(w.OutputBuf, ch)
}
}
}

// getFieldsFromLine splits line according to fieldsInfo.
func (e *LoadDataInfo) getFieldsFromLine(line []byte) ([]field, error) {
var sep []byte
if e.FieldsInfo.Enclosed != 0 {
if line[0] != e.FieldsInfo.Enclosed || line[len(line)-1] != e.FieldsInfo.Enclosed {
return nil, errors.Errorf("line %s should begin and end with %c", string(line), e.FieldsInfo.Enclosed)
var (
reader fieldWriter
fields []field
)

if len(line) == 0 {
str := []byte("")
fields = append(fields, field{str, false, false})
return fields, nil
}

reader.Init(e.FieldsInfo.Enclosed, e.FieldsInfo.Terminated[0], &line, &e.FieldsInfo.Terminated)
for {
eol, f := reader.GetField()
f = f.escape()
if string(f.str) == "NULL" && !f.enclosed {
f.str = []byte{'N'}
f.maybeNull = true
}
fields = append(fields, f)
if eol {
break
}
line = line[1 : len(line)-1]
sep = make([]byte, 0, len(e.FieldsInfo.Terminated)+2)
sep = append(sep, e.FieldsInfo.Enclosed)
sep = append(sep, e.FieldsInfo.Terminated...)
sep = append(sep, e.FieldsInfo.Enclosed)
} else {
sep = []byte(e.FieldsInfo.Terminated)
}
rawCols := bytes.Split(line, sep)
fields := make([]field, 0, len(rawCols))
for _, v := range rawCols {
f := field{v, false}
fields = append(fields, f.escape())
}
return fields, nil
}
Expand All @@ -354,7 +497,7 @@ func (f *field) escape() field {
f.str[pos] = c
pos++
}
return field{f.str[:pos], f.maybeNull}
return field{f.str[:pos], f.maybeNull, f.enclosed}
}

func (f *field) escapeChar(c byte) byte {
Expand Down
143 changes: 143 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,149 @@ func runTestLoadData(c *C, server *Server) {
dbt.Assert(err, NotNil)
})

err = fp.Close()
c.Assert(err, IsNil)
err = os.Remove(path)
c.Assert(err, IsNil)

fp, err = os.Create(path)
c.Assert(err, IsNil)
c.Assert(fp, NotNil)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can os.Remove(path) here immediately.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next test case still use this file.

// Test mixed unenclosed and enclosed fields.
_, err = fp.WriteString(
"\"abc\",123\n" +
"def,456,\n" +
"hig,\"789\",")
c.Assert(err, IsNil)

runTestsOnNewDB(c, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Strict = false
}, "LoadData", func(dbt *DBTest) {
dbt.mustExec("create table test (str varchar(10) default null, i int default null)")
_, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`)
dbt.Assert(err1, IsNil)
var (
str string
id int
)
rows := dbt.mustQuery("select * from test")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
err = rows.Scan(&str, &id)
dbt.Check(err, IsNil)
dbt.Check(str, DeepEquals, "abc")
dbt.Check(id, DeepEquals, 123)
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&str, &id)
dbt.Check(str, DeepEquals, "def")
dbt.Check(id, DeepEquals, 456)
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&str, &id)
dbt.Check(str, DeepEquals, "hig")
dbt.Check(id, DeepEquals, 789)
dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data"))
dbt.mustExec("delete from test")
})

err = fp.Close()
c.Assert(err, IsNil)
err = os.Remove(path)
c.Assert(err, IsNil)

fp, err = os.Create(path)
c.Assert(err, IsNil)
c.Assert(fp, NotNil)

// Test irregular csv file.
_, err = fp.WriteString(
`,\N,NULL,,` + "\n" +
"00,0,000000,,\n" +
`2003-03-03, 20030303,030303,\N` + "\n")
c.Assert(err, IsNil)

runTestsOnNewDB(c, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Strict = false
}, "LoadData", func(dbt *DBTest) {
dbt.mustExec("create table test (a date, b date, c date not null, d date)")
_, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ','`)
dbt.Assert(err1, IsNil)
var (
a sql.NullString
b sql.NullString
d sql.NullString
c sql.NullString
)
rows := dbt.mustQuery("select * from test")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
err = rows.Scan(&a, &b, &c, &d)
dbt.Check(err, IsNil)
dbt.Check(a.String, Equals, "0000-00-00")
dbt.Check(b.String, Equals, "")
dbt.Check(c.String, Equals, "0000-00-00")
dbt.Check(d.String, Equals, "0000-00-00")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&a, &b, &c, &d)
dbt.Check(a.String, Equals, "0000-00-00")
dbt.Check(b.String, Equals, "0000-00-00")
dbt.Check(c.String, Equals, "0000-00-00")
dbt.Check(d.String, Equals, "0000-00-00")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&a, &b, &c, &d)
dbt.Check(a.String, Equals, "2003-03-03")
dbt.Check(b.String, Equals, "2003-03-03")
dbt.Check(c.String, Equals, "2003-03-03")
dbt.Check(d.String, Equals, "")
dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data"))
dbt.mustExec("delete from test")
})

err = fp.Close()
c.Assert(err, IsNil)
err = os.Remove(path)
c.Assert(err, IsNil)

fp, err = os.Create(path)
c.Assert(err, IsNil)
c.Assert(fp, NotNil)

// Test double enclosed.
_, err = fp.WriteString(
`"field1","field2"` + "\n" +
`"a""b","cd""ef"` + "\n" +
`"a"b",c"d"e` + "\n")
c.Assert(err, IsNil)

runTestsOnNewDB(c, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Strict = false
}, "LoadData", func(dbt *DBTest) {
dbt.mustExec("create table test (a varchar(20), b varchar(20))")
_, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`)
dbt.Assert(err1, IsNil)
var (
a sql.NullString
b sql.NullString
)
rows := dbt.mustQuery("select * from test")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
err = rows.Scan(&a, &b)
dbt.Check(err, IsNil)
dbt.Check(a.String, Equals, "field1")
dbt.Check(b.String, Equals, "field2")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&a, &b)
dbt.Check(a.String, Equals, `a"b`)
dbt.Check(b.String, Equals, `cd"ef`)
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&a, &b)
dbt.Check(a.String, Equals, `a"b`)
dbt.Check(b.String, Equals, `c"d"e`)
dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data"))
dbt.mustExec("delete from test")
})

// unsupport ClientLocalFiles capability
server.capability ^= tmysql.ClientLocalFiles
runTestsOnNewDB(c, func(config *mysql.Config) {
Expand Down