Skip to content

Commit

Permalink
executor: fix load data panic if the data is broken at escape charact…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Mar 2, 2022
1 parent d3a3830 commit 459917c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 109 deletions.
136 changes: 27 additions & 109 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,65 +363,20 @@ func (e *LoadDataInfo) SetMaxRowsInBatch(limit uint64) {
e.curBatchCnt = 0
}

// getValidData returns prevData and curData that starts from starting symbol.
// If the data doesn't have starting symbol, prevData is nil and curData is curData[len(curData)-startingLen+1:].
// If curData size less than startingLen, curData is returned directly.
func (e *LoadDataInfo) getValidData(prevData, curData []byte) ([]byte, []byte) {
startingLen := len(e.LinesInfo.Starting)
if startingLen == 0 {
return prevData, curData
}

prevLen := len(prevData)
if prevLen > 0 {
// starting symbol in the prevData
idx := strings.Index(string(hack.String(prevData)), e.LinesInfo.Starting)
if idx != -1 {
return prevData[idx:], curData
}

// starting symbol in the middle of prevData and curData
restStart := curData
if len(curData) >= startingLen {
restStart = curData[:startingLen-1]
}
prevData = append(prevData, restStart...)
idx = strings.Index(string(hack.String(prevData)), e.LinesInfo.Starting)
if idx != -1 {
return prevData[idx:prevLen], curData
}
}

// starting symbol in the curData
// getValidData returns curData that starts from starting symbol.
// If the data doesn't have starting symbol, return curData[len(curData)-startingLen+1:] and false.
func (e *LoadDataInfo) getValidData(curData []byte) ([]byte, bool) {
idx := strings.Index(string(hack.String(curData)), e.LinesInfo.Starting)
if idx != -1 {
return nil, curData[idx:]
if idx == -1 {
return curData[len(curData)-len(e.LinesInfo.Starting)+1:], false
}

// no starting symbol
if len(curData) >= startingLen {
curData = curData[len(curData)-startingLen+1:]
}
return nil, curData
}

func (e *LoadDataInfo) isInQuoter(bs []byte) bool {
inQuoter := false
for i := 0; i < len(bs); i++ {
switch bs[i] {
case e.FieldsInfo.Enclosed:
inQuoter = !inQuoter
case e.FieldsInfo.Escaped:
i++
default:
}
}
return inQuoter
return curData[idx:], true
}

// indexOfTerminator return index of terminator, if not, return -1.
// normally, the field terminator and line terminator is short, so we just use brute force algorithm.
func (e *LoadDataInfo) indexOfTerminator(bs []byte, isInQuoter bool) int {
func (e *LoadDataInfo) indexOfTerminator(bs []byte) int {
fieldTerm := []byte(e.FieldsInfo.Terminated)
fieldTermLen := len(fieldTerm)
lineTerm := []byte(e.LinesInfo.Terminated)
Expand Down Expand Up @@ -462,15 +417,13 @@ func (e *LoadDataInfo) indexOfTerminator(bs []byte, isInQuoter bool) int {
inQuoter := false
loop:
for i := 0; i < len(bs); i++ {
if atFieldStart && bs[i] == e.FieldsInfo.Enclosed {
if !isInQuoter {
inQuoter = true
}
if atFieldStart && e.FieldsInfo.Enclosed != byte(0) && bs[i] == e.FieldsInfo.Enclosed {
inQuoter = !inQuoter
atFieldStart = false
continue
}
restLen := len(bs) - i - 1
if inQuoter && bs[i] == e.FieldsInfo.Enclosed {
if inQuoter && e.FieldsInfo.Enclosed != byte(0) && bs[i] == e.FieldsInfo.Enclosed {
// look ahead to see if it is end of line or field.
switch cmpTerm(restLen, bs[i+1:]) {
case lineTermType:
Expand Down Expand Up @@ -508,67 +461,32 @@ loop:
// getLine returns a line, curData, the next data start index and a bool value.
// If it has starting symbol the bool is true, otherwise is false.
func (e *LoadDataInfo) getLine(prevData, curData []byte, ignore bool) ([]byte, []byte, bool) {
startingLen := len(e.LinesInfo.Starting)
prevData, curData = e.getValidData(prevData, curData)
if prevData == nil && len(curData) < startingLen {
return nil, curData, false
}
inquotor := e.isInQuoter(prevData)
prevLen := len(prevData)
terminatedLen := len(e.LinesInfo.Terminated)
curStartIdx := 0
if prevLen < startingLen {
curStartIdx = startingLen - prevLen
}
endIdx := -1
if len(curData) >= curStartIdx {
if ignore {
endIdx = strings.Index(string(hack.String(curData[curStartIdx:])), e.LinesInfo.Terminated)
} else {
endIdx = e.indexOfTerminator(curData[curStartIdx:], inquotor)
}
}
if endIdx == -1 {
// no terminated symbol
if len(prevData) == 0 {
return nil, curData, true
}

// terminated symbol in the middle of prevData and curData
if prevData != nil {
curData = append(prevData, curData...)
if ignore {
endIdx = strings.Index(string(hack.String(curData[startingLen:])), e.LinesInfo.Terminated)
} else {
endIdx = e.indexOfTerminator(curData[startingLen:], inquotor)
}
startLen := len(e.LinesInfo.Starting)
if startLen != 0 {
if len(curData) < startLen {
return nil, curData, false
}
if endIdx != -1 {
nextDataIdx := startingLen + endIdx + terminatedLen
return curData[startingLen : startingLen+endIdx], curData[nextDataIdx:], true
var ok bool
curData, ok = e.getValidData(curData)
if !ok {
return nil, curData, false
}
// no terminated symbol
return nil, curData, true
}

// terminated symbol in the curData
nextDataIdx := curStartIdx + endIdx + terminatedLen
if len(prevData) == 0 {
return curData[curStartIdx : curStartIdx+endIdx], curData[nextDataIdx:], true
}

// terminated symbol in the curData
prevData = append(prevData, curData[:nextDataIdx]...)
var endIdx int
if ignore {
endIdx = strings.Index(string(hack.String(prevData[startingLen:])), e.LinesInfo.Terminated)
endIdx = strings.Index(string(hack.String(curData[startLen:])), e.LinesInfo.Terminated)
} else {
endIdx = e.indexOfTerminator(prevData[startingLen:], inquotor)
endIdx = e.indexOfTerminator(curData[startLen:])
}
if endIdx >= prevLen {
return prevData[startingLen : startingLen+endIdx], curData[nextDataIdx:], true

if endIdx == -1 {
return nil, curData, true
}

// terminated symbol in the middle of prevData and curData
lineLen := startingLen + endIdx + terminatedLen
return prevData[startingLen : startingLen+endIdx], curData[lineLen-prevLen:], true
return curData[startLen : startLen+endIdx], curData[startLen+endIdx+len(e.LinesInfo.Terminated):], true
}

// InsertData inserts data into specified table according to the specified format.
Expand Down
2 changes: 2 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2326,6 +2326,8 @@ func (s *testSuite4) TestLoadDataEscape(c *C) {
{nil, []byte("7\trtn0ZbN\n"), []string{"7|" + string([]byte{'r', 't', 'n', '0', 'Z', 'b', 'N'})}, nil, trivialMsg},
{nil, []byte("8\trtn0Zb\\N\n"), []string{"8|" + string([]byte{'r', 't', 'n', '0', 'Z', 'b', 'N'})}, nil, trivialMsg},
{nil, []byte("9\ttab\\ tab\n"), []string{"9|tab tab"}, nil, trivialMsg},
// data broken at escape character.
{[]byte("1\ta string\\"), []byte("\n1\n"), []string{"1|a string\n1"}, nil, trivialMsg},
}
deleteSQL := "delete from load_data_test"
selectSQL := "select * from load_data_test;"
Expand Down

0 comments on commit 459917c

Please sign in to comment.