-
Notifications
You must be signed in to change notification settings - Fork 501
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
batch: add function to split batch functions
Fixes #186
- Loading branch information
Showing
3 changed files
with
419 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
// Copyright 2017 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
// bach splits a single script containing multiple batches separated by | ||
// a keyword into multiple scripts. | ||
package batch | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"strconv" | ||
"strings" | ||
"unicode" | ||
) | ||
|
||
// Split the provided SQL into multiple sql scripts based on a given | ||
// separator, often "GO". It also allows escaping newlines with a | ||
// backslash. | ||
func Split(sql, separator string) []string { | ||
if len(separator) == 0 || len(sql) < len(separator) { | ||
return []string{sql} | ||
} | ||
l := &lexer{ | ||
Sql: sql, | ||
Sep: separator, | ||
At: 0, | ||
} | ||
state := stateWhitespace | ||
for state != nil { | ||
state = state(l) | ||
} | ||
l.AddCurrent(1) | ||
return l.Batch | ||
} | ||
|
||
const debugPrintStateName = false | ||
|
||
func printStateName(name string, l *lexer) { | ||
if debugPrintStateName { | ||
fmt.Printf("state %s At=%d\n", name, l.At) | ||
} | ||
} | ||
|
||
func hasPrefixFold(s, sep string) bool { | ||
if len(s) < len(sep) { | ||
return false | ||
} | ||
return strings.EqualFold(s[:len(sep)], sep) | ||
} | ||
|
||
type lexer struct { | ||
Sql string | ||
Sep string | ||
At int | ||
Start int | ||
|
||
Skip []int | ||
|
||
Batch []string | ||
} | ||
|
||
func (l *lexer) Add(b string) { | ||
if len(b) == 0 { | ||
return | ||
} | ||
l.Batch = append(l.Batch, b) | ||
} | ||
|
||
func (l *lexer) Next() bool { | ||
l.At++ | ||
return l.At < len(l.Sql) | ||
} | ||
|
||
func (l *lexer) AddCurrent(count int64) bool { | ||
if count < 0 { | ||
count = 0 | ||
} | ||
if l.At >= len(l.Sql) { | ||
l.At = len(l.Sql) | ||
} | ||
text := l.Sql[l.Start:l.At] | ||
if len(l.Skip) > 0 { | ||
buf := &bytes.Buffer{} | ||
nextSkipIndex := 0 | ||
nextSkip := l.Skip[nextSkipIndex] | ||
for i, r := range text { | ||
if i == nextSkip { | ||
nextSkipIndex++ | ||
if nextSkipIndex < len(l.Skip) { | ||
nextSkip = l.Skip[nextSkipIndex] | ||
} | ||
continue | ||
} | ||
buf.WriteRune(r) | ||
} | ||
text = buf.String() | ||
l.Skip = nil | ||
} | ||
// Limit the number of counts for sanity. | ||
if count > 1000 { | ||
count = 1000 | ||
} | ||
for i := int64(0); i < count; i++ { | ||
l.Add(text) | ||
} | ||
l.At += len(l.Sep) | ||
l.Start = l.At | ||
return (l.At < len(l.Sql)) | ||
} | ||
|
||
type stateFn func(*lexer) stateFn | ||
|
||
const ( | ||
lineComment = "--" | ||
leftComment = "/*" | ||
rightComment = "*/" | ||
) | ||
|
||
func stateSep(l *lexer) stateFn { | ||
printStateName("sep", l) | ||
if l.At+len(l.Sep) >= len(l.Sql) { | ||
return nil | ||
} | ||
s := l.Sql[l.At+len(l.Sep):] | ||
|
||
parseNumberStart := -1 | ||
loop: | ||
for i, r := range s { | ||
switch { | ||
case r == '\n', r == '\r': | ||
l.AddCurrent(1) | ||
return stateWhitespace | ||
case unicode.IsSpace(r): | ||
case unicode.IsNumber(r): | ||
parseNumberStart = i | ||
break loop | ||
} | ||
} | ||
if parseNumberStart < 0 { | ||
return nil | ||
} | ||
|
||
parseNumberCount := 0 | ||
numLoop: | ||
for i, r := range s[parseNumberStart:] { | ||
switch { | ||
case unicode.IsNumber(r): | ||
parseNumberCount = i | ||
default: | ||
break numLoop | ||
} | ||
} | ||
parseNumberEnd := parseNumberStart + parseNumberCount + 1 | ||
|
||
count, err := strconv.ParseInt(s[parseNumberStart:parseNumberEnd], 10, 64) | ||
if err != nil { | ||
return stateText | ||
} | ||
for _, r := range s[parseNumberEnd:] { | ||
switch { | ||
case r == '\n', r == '\r': | ||
l.AddCurrent(count) | ||
l.At += parseNumberEnd | ||
l.Start = l.At | ||
return stateWhitespace | ||
case unicode.IsSpace(r): | ||
default: | ||
return stateText | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func stateText(l *lexer) stateFn { | ||
printStateName("text", l) | ||
for { | ||
ch := l.Sql[l.At] | ||
|
||
switch { | ||
case strings.HasPrefix(l.Sql[l.At:], lineComment): | ||
l.At += len(lineComment) | ||
return stateLineComment | ||
case strings.HasPrefix(l.Sql[l.At:], leftComment): | ||
l.At += len(leftComment) | ||
return stateMultiComment | ||
case ch == '\'': | ||
l.At += 1 | ||
return stateString | ||
case ch == '\r', ch == '\n': | ||
l.At += 1 | ||
return stateWhitespace | ||
default: | ||
if l.Next() == false { | ||
return nil | ||
} | ||
} | ||
} | ||
} | ||
|
||
func stateWhitespace(l *lexer) stateFn { | ||
printStateName("whitespace", l) | ||
if l.At >= len(l.Sql) { | ||
return nil | ||
} | ||
ch := l.Sql[l.At] | ||
|
||
switch { | ||
case unicode.IsSpace(rune(ch)): | ||
l.At += 1 | ||
return stateWhitespace | ||
case hasPrefixFold(l.Sql[l.At:], l.Sep): | ||
return stateSep | ||
default: | ||
return stateText | ||
} | ||
} | ||
|
||
func stateLineComment(l *lexer) stateFn { | ||
printStateName("line-comment", l) | ||
for { | ||
if l.At >= len(l.Sql) { | ||
return nil | ||
} | ||
ch := l.Sql[l.At] | ||
|
||
switch { | ||
case ch == '\r', ch == '\n': | ||
l.At += 1 | ||
return stateWhitespace | ||
default: | ||
if l.Next() == false { | ||
return nil | ||
} | ||
} | ||
} | ||
} | ||
|
||
func stateMultiComment(l *lexer) stateFn { | ||
printStateName("multi-line-comment", l) | ||
for { | ||
switch { | ||
case strings.HasPrefix(l.Sql[l.At:], rightComment): | ||
l.At += len(leftComment) | ||
return stateWhitespace | ||
default: | ||
if l.Next() == false { | ||
return nil | ||
} | ||
} | ||
} | ||
} | ||
|
||
func stateString(l *lexer) stateFn { | ||
printStateName("string", l) | ||
for { | ||
if l.At >= len(l.Sql) { | ||
return nil | ||
} | ||
ch := l.Sql[l.At] | ||
chNext := rune(-1) | ||
if l.At+1 < len(l.Sql) { | ||
chNext = rune(l.Sql[l.At+1]) | ||
} | ||
|
||
switch { | ||
case ch == '\\' && (chNext == '\r' || chNext == '\n'): | ||
next := 2 | ||
l.Skip = append(l.Skip, l.At, l.At+1) | ||
if chNext == '\r' && l.At+2 < len(l.Sql) && l.Sql[l.At+2] == '\n' { | ||
l.Skip = append(l.Skip, l.At+2) | ||
next = 3 | ||
} | ||
l.At += next | ||
case ch == '\'' && chNext == '\'': | ||
l.At += 2 | ||
case ch == '\'' && chNext != '\'': | ||
l.At += 1 | ||
return stateWhitespace | ||
default: | ||
if l.Next() == false { | ||
return nil | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// Copyright 2017 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
// +build gofuzz | ||
|
||
package batch | ||
|
||
func Fuzz(data []byte) int { | ||
Split(string(data), "GO") | ||
return 0 | ||
} |
Oops, something went wrong.