diff --git a/batch/batch.go b/batch/batch.go new file mode 100644 index 00000000..5b793dcb --- /dev/null +++ b/batch/batch.go @@ -0,0 +1,287 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// bach splits a single script containing multiple batches separated by +// a keyword into multiple scripts. +package batch + +import ( + "bytes" + "fmt" + "strconv" + "strings" + "unicode" +) + +// Split the provided SQL into multiple sql scripts based on a given +// separator, often "GO". It also allows escaping newlines with a +// backslash. +func Split(sql, separator string) []string { + if len(separator) == 0 || len(sql) < len(separator) { + return []string{sql} + } + l := &lexer{ + Sql: sql, + Sep: separator, + At: 0, + } + state := stateWhitespace + for state != nil { + state = state(l) + } + l.AddCurrent(1) + return l.Batch +} + +const debugPrintStateName = false + +func printStateName(name string, l *lexer) { + if debugPrintStateName { + fmt.Printf("state %s At=%d\n", name, l.At) + } +} + +func hasPrefixFold(s, sep string) bool { + if len(s) < len(sep) { + return false + } + return strings.EqualFold(s[:len(sep)], sep) +} + +type lexer struct { + Sql string + Sep string + At int + Start int + + Skip []int + + Batch []string +} + +func (l *lexer) Add(b string) { + if len(b) == 0 { + return + } + l.Batch = append(l.Batch, b) +} + +func (l *lexer) Next() bool { + l.At++ + return l.At < len(l.Sql) +} + +func (l *lexer) AddCurrent(count int64) bool { + if count < 0 { + count = 0 + } + if l.At >= len(l.Sql) { + l.At = len(l.Sql) + } + text := l.Sql[l.Start:l.At] + if len(l.Skip) > 0 { + buf := &bytes.Buffer{} + nextSkipIndex := 0 + nextSkip := l.Skip[nextSkipIndex] + for i, r := range text { + if i == nextSkip { + nextSkipIndex++ + if nextSkipIndex < len(l.Skip) { + nextSkip = l.Skip[nextSkipIndex] + } + continue + } + buf.WriteRune(r) + } + text = buf.String() + l.Skip = nil + } + // Limit the number of counts for sanity. + if count > 1000 { + count = 1000 + } + for i := int64(0); i < count; i++ { + l.Add(text) + } + l.At += len(l.Sep) + l.Start = l.At + return (l.At < len(l.Sql)) +} + +type stateFn func(*lexer) stateFn + +const ( + lineComment = "--" + leftComment = "/*" + rightComment = "*/" +) + +func stateSep(l *lexer) stateFn { + printStateName("sep", l) + if l.At+len(l.Sep) >= len(l.Sql) { + return nil + } + s := l.Sql[l.At+len(l.Sep):] + + parseNumberStart := -1 +loop: + for i, r := range s { + switch { + case r == '\n', r == '\r': + l.AddCurrent(1) + return stateWhitespace + case unicode.IsSpace(r): + case unicode.IsNumber(r): + parseNumberStart = i + break loop + } + } + if parseNumberStart < 0 { + return nil + } + + parseNumberCount := 0 +numLoop: + for i, r := range s[parseNumberStart:] { + switch { + case unicode.IsNumber(r): + parseNumberCount = i + default: + break numLoop + } + } + parseNumberEnd := parseNumberStart + parseNumberCount + 1 + + count, err := strconv.ParseInt(s[parseNumberStart:parseNumberEnd], 10, 64) + if err != nil { + return stateText + } + for _, r := range s[parseNumberEnd:] { + switch { + case r == '\n', r == '\r': + l.AddCurrent(count) + l.At += parseNumberEnd + l.Start = l.At + return stateWhitespace + case unicode.IsSpace(r): + default: + return stateText + } + } + + return nil +} + +func stateText(l *lexer) stateFn { + printStateName("text", l) + for { + ch := l.Sql[l.At] + + switch { + case strings.HasPrefix(l.Sql[l.At:], lineComment): + l.At += len(lineComment) + return stateLineComment + case strings.HasPrefix(l.Sql[l.At:], leftComment): + l.At += len(leftComment) + return stateMultiComment + case ch == '\'': + l.At += 1 + return stateString + case ch == '\r', ch == '\n': + l.At += 1 + return stateWhitespace + default: + if l.Next() == false { + return nil + } + } + } +} + +func stateWhitespace(l *lexer) stateFn { + printStateName("whitespace", l) + if l.At >= len(l.Sql) { + return nil + } + ch := l.Sql[l.At] + + switch { + case unicode.IsSpace(rune(ch)): + l.At += 1 + return stateWhitespace + case hasPrefixFold(l.Sql[l.At:], l.Sep): + return stateSep + default: + return stateText + } +} + +func stateLineComment(l *lexer) stateFn { + printStateName("line-comment", l) + for { + if l.At >= len(l.Sql) { + return nil + } + ch := l.Sql[l.At] + + switch { + case ch == '\r', ch == '\n': + l.At += 1 + return stateWhitespace + default: + if l.Next() == false { + return nil + } + } + } +} + +func stateMultiComment(l *lexer) stateFn { + printStateName("multi-line-comment", l) + for { + switch { + case strings.HasPrefix(l.Sql[l.At:], rightComment): + l.At += len(leftComment) + return stateWhitespace + default: + if l.Next() == false { + return nil + } + } + } +} + +func stateString(l *lexer) stateFn { + printStateName("string", l) + for { + if l.At >= len(l.Sql) { + return nil + } + ch := l.Sql[l.At] + chNext := rune(-1) + if l.At+1 < len(l.Sql) { + chNext = rune(l.Sql[l.At+1]) + } + + switch { + case ch == '\\' && (chNext == '\r' || chNext == '\n'): + next := 2 + l.Skip = append(l.Skip, l.At, l.At+1) + if chNext == '\r' && l.At+2 < len(l.Sql) && l.Sql[l.At+2] == '\n' { + l.Skip = append(l.Skip, l.At+2) + next = 3 + } + l.At += next + case ch == '\'' && chNext == '\'': + l.At += 2 + case ch == '\'' && chNext != '\'': + l.At += 1 + return stateWhitespace + default: + if l.Next() == false { + return nil + } + } + } +} diff --git a/batch/batch_fuzz.go b/batch/batch_fuzz.go new file mode 100644 index 00000000..150818db --- /dev/null +++ b/batch/batch_fuzz.go @@ -0,0 +1,12 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build gofuzz + +package batch + +func Fuzz(data []byte) int { + Split(string(data), "GO") + return 0 +} diff --git a/batch/batch_test.go b/batch/batch_test.go new file mode 100644 index 00000000..b151f2e6 --- /dev/null +++ b/batch/batch_test.go @@ -0,0 +1,120 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package batch + +import ( + "fmt" + "testing" +) + +func TestBatchSplit(t *testing.T) { + type testItem struct { + Sql string + Expect []string + } + + list := []testItem{ + testItem{ + Sql: `use DB +go +select 1 +go +select 2 +`, + Expect: []string{`use DB +`, ` +select 1 +`, ` +select 2 +`, + }, + }, + testItem{ + Sql: `go +use DB go +`, + Expect: []string{` +use DB go +`, + }, + }, + testItem{ + Sql: `select 'It''s go time' +go +select top 1 1`, + Expect: []string{`select 'It''s go time' +`, ` +select top 1 1`, + }, + }, + testItem{ + Sql: `select 1 /* go */ +go +select top 1 1`, + Expect: []string{`select 1 /* go */ +`, ` +select top 1 1`, + }, + }, + testItem{ + Sql: `select 1 -- go +go +select top 1 1`, + Expect: []string{`select 1 -- go +`, ` +select top 1 1`, + }, + }, + testItem{Sql: `"0'"`, Expect: []string{`"0'"`}}, + testItem{Sql: "0'", Expect: []string{"0'"}}, + testItem{Sql: "--", Expect: []string{"--"}}, + testItem{Sql: "GO", Expect: nil}, + testItem{Sql: "/*", Expect: []string{"/*"}}, + testItem{Sql: "gO\x01\x00O550655490663051008\n", Expect: []string{"\n"}}, + testItem{Sql: "select 1;\nGO 2\nselect 2;", Expect: []string{"select 1;\n", "select 1;\n", "\nselect 2;"}}, + testItem{Sql: "select 'hi\\\n-hello';", Expect: []string{"select 'hi-hello';"}}, + testItem{Sql: "select 'hi\\\r\n-hello';", Expect: []string{"select 'hi-hello';"}}, + testItem{Sql: "select 'hi\\\r-hello';", Expect: []string{"select 'hi-hello';"}}, + testItem{Sql: "select 'hi\\\n\nhello';", Expect: []string{"select 'hi\nhello';"}}, + } + + index := -1 + + for i := range list { + if index >= 0 && index != i { + continue + } + sqltext := list[i].Sql + t.Run(fmt.Sprintf("index-%d", i), func(t *testing.T) { + ss := Split(sqltext, "go") + if len(ss) != len(list[i].Expect) { + t.Errorf("Test Item index %d; expect %d items, got %d %q", i, len(list[i].Expect), len(ss), ss) + return + } + for j := 0; j < len(ss); j++ { + if ss[j] != list[i].Expect[j] { + t.Errorf("Test Item index %d, batch index %d; expect <%s>, got <%s>", i, j, list[i].Expect[j], ss[j]) + } + } + }) + } +} + +func TestHasPrefixFold(t *testing.T) { + list := []struct { + s, pre string + is bool + }{ + {"h", "H", true}, + {"h", "K", false}, + {"go 5\n", "go", true}, + } + for _, item := range list { + is := hasPrefixFold(item.s, item.pre) + if is != item.is { + t.Error("want (%q, %q)=%t got %t", item.s, item.pre, item.is, is) + } + } +}