From e2b420f4e9a6e6dd07ffac14bc02c4692aaff423 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Wed, 7 Oct 2020 09:48:30 +1100 Subject: [PATCH] Codegen for stateful lexer. This uses a fairly naive recursive solution but even with that caveat gives significant gains. Roughly 10x performance with constant GC overhead per tokenisation call. The latter is possible because the tokens return slices into the input substring. --- lexer/lexer.go | 2 +- lexer/stateful/codegen/codegen.go | 413 ++++++++++++++++++ lexer/stateful/codegen/codegen_test.go | 105 +++++ lexer/stateful/codegen/generatedlexer_test.go | 393 +++++++++++++++++ lexer/stateful/stateful.go | 72 +-- 5 files changed, 958 insertions(+), 27 deletions(-) create mode 100644 lexer/stateful/codegen/codegen.go create mode 100644 lexer/stateful/codegen/codegen_test.go create mode 100644 lexer/stateful/codegen/generatedlexer_test.go diff --git a/lexer/lexer.go b/lexer/lexer.go index 741bc65d..fb60a410 100644 --- a/lexer/lexer.go +++ b/lexer/lexer.go @@ -63,7 +63,7 @@ func Must(def Definition, err error) Definition { // ConsumeAll reads all tokens from a Lexer. func ConsumeAll(lexer Lexer) ([]Token, error) { - tokens := []Token{} + tokens := make([]Token, 0, 1024) for { token, err := lexer.Next() if err != nil { diff --git a/lexer/stateful/codegen/codegen.go b/lexer/stateful/codegen/codegen.go new file mode 100644 index 00000000..eea70da3 --- /dev/null +++ b/lexer/stateful/codegen/codegen.go @@ -0,0 +1,413 @@ +// Package codegen generates Go code for stateful lexers. +package codegen + +import ( + "fmt" + "io" + "regexp" + "regexp/syntax" + "text/template" + "unicode/utf8" + + "github.com/alecthomas/participle/lexer/stateful" +) + +var backrefRe = regexp.MustCompile(`(\\+)(\d)`) + +var tmpl = template.Must(template.New("lexgen").Funcs(template.FuncMap{ + "IsPush": func(r stateful.Rule) string { + if p, ok := r.Action.(stateful.ActionPush); ok { + return p.State + } + return "" + }, + "IsPop": func(r stateful.Rule) bool { + _, ok := r.Action.(stateful.ActionPop) + return ok + }, + "HaveBackrefs": func(def *stateful.Definition, state string) bool { + for _, rule := range def.Rules()[state] { + if backrefRe.MatchString(rule.Pattern) { + return true + } + } + return false + }, +}).Parse(` +// Code generated by Participle. DO NOT EDIT. +package {{.Package}} + +import ( + "strings" + "unicode/utf8" + + "github.com/alecthomas/participle" + "github.com/alecthomas/participle/lexer" +) + +var Lexer lexer.Definition = definitionImpl{} + +type definitionImpl struct {} + +func (definitionImpl) Symbols() map[string]rune { + return map[string]rune{ +{{- range $sym, $rn := .Def.Symbols}} + "{{$sym}}": {{$rn}}, +{{- end}} + } +} + +func (definitionImpl) LexString(filename string, s string) (lexer.Lexer, error) { + return &lexerImpl{ + s: s, + pos: lexer.Position{ + Filename: filename, + Line: 1, + Column: 1, + }, + states: []lexerState{lexerState{name: "Root"}}, + }, nil +} + +func (d definitionImpl) LexBytes(filename string, b []byte) (lexer.Lexer, error) { + return d.LexString(filename, string(b)) +} + +func (d definitionImpl) Lex(filename string, r io.Reader) (lexer.Lexer, error) { + s := &strings.Builder{} + _, err := io.Copy(s, r) + if err != nil { + return nil, err + } + return d.LexString(filename, s.String()) +} + +type lexerState struct { + name string + groups []string +} + +type lexerImpl struct { + s string + p int + pos lexer.Position + states []lexerState +} + +func (l *lexerImpl) Next() (lexer.Token, error) { + if l.p == len(l.s) { + return lexer.EOFToken(l.pos), nil + } + var ( + state = l.states[len(l.states)-1] + groups []int + sym rune + ) + switch state.name { +{{- range $state, $rules := .Def.Rules}} + case "{{$state}}": +{{- range $i, $rule := $rules}} + {{- if $i}} else {{end -}} +{{- if .Pattern -}} + if match := match{{.Name}}(l.s, l.p); match[1] != 0 { + sym = {{index $.Def.Symbols .Name}} + groups = match[:] +{{- else}} + if true { +{{- end}} +{{- if .|IsPush}} + l.states = append(l.states, lexerState{name: "{{.|IsPush}}"{{if HaveBackrefs $.Def $state}}, groups: l.sgroups(groups){{end}}}) +{{- else if .|IsPop}} + l.states = l.states[:len(l.states)-1] +{{- else if not .Action}} +{{- else}} + Unsupported action {{.Action}} +{{- end}} + } +{{- end}} +{{- end}} + } + if groups == nil { + return lexer.Token{}, participle.Errorf(l.pos, "no lexer rules in state %q matched input text", l.states[len(l.states)-1]) + } + pos := l.pos + span := l.s[groups[0]:groups[1]] + l.p = groups[1] + l.pos.Offset = groups[1] + lines := strings.Count(span, "\n") + l.pos.Line += lines + // Update column. + if lines == 0 { + l.pos.Column += utf8.RuneCountInString(span) + } else { + l.pos.Column = utf8.RuneCountInString(span[strings.LastIndex(span, "\n"):]) + } + return lexer.Token{ + Type: sym, + Value: span, + Pos: pos, + }, nil +} + +func (l *lexerImpl) sgroups(match []int) []string { + sgroups := make([]string, len(match)/2) + for i := 0; i < len(match)-1; i += 2 { + sgroups[i/2] = l.s[l.p+match[i]:l.p+match[i+1]] + } + return sgroups +} + +`)) + +// Generate Go code for the given stateful lexer. +// +// The generated code should in general by around 10x faster and produce zero garbage per token. +func Generate(w io.Writer, pkg string, def *stateful.Definition) error { + type ctx struct { + Package string + Def *stateful.Definition + } + rules := def.Rules() + err := tmpl.Execute(w, ctx{pkg, def}) + if err != nil { + return err + } + seen := map[string]bool{} // Rules can be duplicated by Include(). + for _, rules := range rules { + for _, rule := range rules { + if rule.Name == "" { + panic(rule) + } + if seen[rule.Name] { + continue + } + seen[rule.Name] = true + fmt.Fprintf(w, "\n") + err := generateRegexMatch(w, rule.Name, rule.Pattern) + if err != nil { + return err + } + } + } + return nil +} + +func generateRegexMatch(w io.Writer, name, pattern string) error { + re, err := syntax.Parse(pattern, syntax.Perl) + if err != nil { + return err + } + ids := map[string]int{} + idn := 0 + reid := func(re *syntax.Regexp) int { + key := re.Op.String() + ":" + re.String() + id, ok := ids[key] + if ok { + return id + } + id = idn + idn++ + ids[key] = id + return id + } + exists := func(re *syntax.Regexp) bool { + key := re.Op.String() + ":" + re.String() + _, ok := ids[key] + return ok + } + re = re.Simplify() + fmt.Fprintf(w, "// %s\n", re) + fmt.Fprintf(w, "func match%s(s string, p int) (groups [%d]int) {\n", name, 2*re.MaxCap()+2) + flattened := flatten(re) + + // Fast-path a single literal. + if len(flattened) == 1 && re.Op == syntax.OpLiteral { + n := utf8.RuneCountInString(string(re.Rune)) + if n == 1 { + fmt.Fprintf(w, "if p < len(s) && s[p] == %q {\n", re.Rune[0]) + } else { + fmt.Fprintf(w, "if p+%d < len(s) && s[p:p+%d] == %q {\n", n, n, string(re.Rune)) + } + fmt.Fprintf(w, "groups[0] = p\n") + fmt.Fprintf(w, "groups[1] = p + %d\n", n) + fmt.Fprintf(w, "}\n") + fmt.Fprintf(w, "return\n") + fmt.Fprintf(w, "}\n") + return nil + } + for _, re := range flattened { + if exists(re) { + continue + } + fmt.Fprintf(w, "// %s (%s)\n", re, re.Op) + fmt.Fprintf(w, "l%d := func(s string, p int) int {\n", reid(re)) + if re.Flags&syntax.NonGreedy != 0 { + panic("non-greedy match not supported: " + re.String()) + } + switch re.Op { + case syntax.OpNoMatch: // matches no strings + fmt.Fprintf(w, "return p\n") + + case syntax.OpEmptyMatch: // matches empty string + fmt.Fprintf(w, "if len(s) == 0 { return p }\n") + fmt.Fprintf(w, "return -1\n") + + case syntax.OpLiteral: // matches Runes sequence + n := utf8.RuneCountInString(string(re.Rune)) + if n == 1 { + fmt.Fprintf(w, "if p < len(s) && s[p] == %q { return p+1 }\n", re.Rune[0]) + } else { + fmt.Fprintf(w, "if p+%d < len(s) && s[p:p+%d] == %q { return p+%d }\n", n, n, string(re.Rune), n) + } + fmt.Fprintf(w, "return -1\n") + + case syntax.OpCharClass: // matches Runes interpreted as range pair list + fmt.Fprintf(w, "if len(s) <= p { return -1 }\n") + needDecode := false + for i := 0; i < len(re.Rune); i += 2 { + l, r := re.Rune[i], re.Rune[i+1] + ln, rn := utf8.RuneLen(l), utf8.RuneLen(r) + if ln != 1 || rn != 1 { + needDecode = true + break + } + } + if needDecode { + fmt.Fprintf(w, "var (rn rune; n int)\n") + decodeRune(w, "p", "rn", "n") + } else { + fmt.Fprintf(w, "rn := s[p]\n") + } + fmt.Fprintf(w, "switch {\n") + for i := 0; i < len(re.Rune); i += 2 { + l, r := re.Rune[i], re.Rune[i+1] + ln, rn := utf8.RuneLen(l), utf8.RuneLen(r) + if ln == 1 && rn == 1 { + if l == r { + fmt.Fprintf(w, "case rn == %q: return p+1\n", l) + } else { + fmt.Fprintf(w, "case rn >= %q && rn <= %q: return p+1\n", l, r) + } + } else { + if l == r { + fmt.Fprintf(w, "case rn == %q: return p+n\n", l) + } else { + fmt.Fprintf(w, "case rn >= %q && rn <= %q: return p+n\n", l, r) + } + } + } + fmt.Fprintf(w, "}\n") + fmt.Fprintf(w, "return -1\n") + + case syntax.OpAnyCharNotNL: // matches any character except newline + fmt.Fprintf(w, "var (rn rune; n int)\n") + decodeRune(w, "p", "rn", "n") + fmt.Fprintf(w, "if len(s) <= p+n || rn == '\\n' { return -1 }\n") + fmt.Fprintf(w, "return p+n\n") + + case syntax.OpAnyChar: // matches any character + fmt.Fprintf(w, "var n int\n") + fmt.Fprintf(w, "if s[p] < utf8.RuneSelf {\n") + fmt.Fprintf(w, " n = 1\n") + fmt.Fprintf(w, "} else {\n") + fmt.Fprintf(w, " _, n = utf8.DecodeRuneInString(s[p:])\n") + fmt.Fprintf(w, "}\n") + fmt.Fprintf(w, "if len(s) <= p+n { return -1 }\n") + fmt.Fprintf(w, "return p+n\n") + + case syntax.OpWordBoundary, syntax.OpNoWordBoundary, + syntax.OpBeginText, syntax.OpEndText, + syntax.OpBeginLine, syntax.OpEndLine: + fmt.Fprintf(w, "var l, u rune = -1, -1\n") + fmt.Fprintf(w, "if p == 0 {\n") + decodeRune(w, "0", "u", "_") + fmt.Fprintf(w, "} else if p == len(s) {\n") + fmt.Fprintf(w, " l, _ = utf8.DecodeLastRuneInString(s)\n") + fmt.Fprintf(w, "} else {\n") + fmt.Fprintf(w, " var ln int\n") + decodeRune(w, "p", "l", "ln") + fmt.Fprintf(w, " if p+ln <= len(s) {\n") + decodeRune(w, "p+ln", "u", "_") + fmt.Fprintf(w, " }\n") + fmt.Fprintf(w, "}\n") + fmt.Fprintf(w, "op := syntax.EmptyOpContext(l, u)\n") + lut := map[syntax.Op]string{ + syntax.OpWordBoundary: "EmptyWordBoundary", + syntax.OpNoWordBoundary: "EmptyNoWordBoundary", + syntax.OpBeginText: "EmptyBeginText", + syntax.OpEndText: "EmptyEndText", + syntax.OpBeginLine: "EmptyBeginLine", + syntax.OpEndLine: "EmptyEndLine", + } + fmt.Fprintf(w, "if op & syntax.%s != 0 { return p }\n", lut[re.Op]) + fmt.Fprintf(w, "return -1\n") + + case syntax.OpCapture: // capturing subexpression with index Cap, optional name Name + fmt.Fprintf(w, "np := l%d(s, p)\n", reid(re.Sub0[0])) + fmt.Fprintf(w, "if np != -1 {\n") + fmt.Fprintf(w, " groups[%d] = p\n", re.Cap*2) + fmt.Fprintf(w, " groups[%d] = np\n", re.Cap*2+1) + fmt.Fprintf(w, "}\n") + fmt.Fprintf(w, "return np") + + case syntax.OpStar: // matches Sub[0] zero or more times + fmt.Fprintf(w, "for len(s) > p {\n") + fmt.Fprintf(w, "if np := l%d(s, p); np == -1 { return p } else { p = np }\n", reid(re.Sub0[0])) + fmt.Fprintf(w, "}\n") + fmt.Fprintf(w, "return p\n") + + case syntax.OpPlus: // matches Sub[0] one or more times + fmt.Fprintf(w, "if p = l%d(s, p); p == -1 { return -1 }\n", reid(re.Sub0[0])) + fmt.Fprintf(w, "for len(s) > p {\n") + fmt.Fprintf(w, "if np := l%d(s, p); np == -1 { return p } else { p = np }\n", reid(re.Sub0[0])) + fmt.Fprintf(w, "}\n") + fmt.Fprintf(w, "return p\n") + + case syntax.OpQuest: // matches Sub[0] zero or one times + fmt.Fprintf(w, "if np := l%d(s, p); np != -1 { return np }\n", reid(re.Sub0[0])) + fmt.Fprintf(w, "return p\n") + + case syntax.OpRepeat: // matches Sub[0] at least Min times, at most Max (Max == -1 is no limit) + panic("??") + + case syntax.OpConcat: // matches concatenation of Subs + for _, sub := range re.Sub { + fmt.Fprintf(w, "if p = l%d(s, p); p == -1 { return -1 }\n", reid(sub)) + } + fmt.Fprintf(w, "return p\n") + + case syntax.OpAlternate: // matches alternation of Subs + for _, sub := range re.Sub { + fmt.Fprintf(w, "if np := l%d(s, p); np != -1 { return np }\n", reid(sub)) + } + fmt.Fprintf(w, "return -1\n") + } + fmt.Fprintf(w, "}\n") + } + fmt.Fprintf(w, "np := l%d(s, p)\n", reid(re)) + fmt.Fprintf(w, "if np == -1 {\n") + fmt.Fprintf(w, " return\n") + fmt.Fprintf(w, "}\n") + fmt.Fprintf(w, "groups[0] = p\n") + fmt.Fprintf(w, "groups[1] = np\n") + fmt.Fprintf(w, "return\n") + fmt.Fprintf(w, "}\n") + return nil +} + +// This exists because of https://github.com/golang/go/issues/31666 +func decodeRune(w io.Writer, offset string, rn string, n string) { + fmt.Fprintf(w, "if s[%s] < utf8.RuneSelf {\n", offset) + fmt.Fprintf(w, " %s, %s = rune(s[%s]), 1\n", rn, n, offset) + fmt.Fprintf(w, "} else {\n") + fmt.Fprintf(w, " %s, %s = utf8.DecodeRuneInString(s[%s:])\n", rn, n, offset) + fmt.Fprintf(w, "}\n") +} + +func flatten(re *syntax.Regexp) (out []*syntax.Regexp) { + for _, sub := range re.Sub { + out = append(out, flatten(sub)...) + } + out = append(out, re) + return +} diff --git a/lexer/stateful/codegen/codegen_test.go b/lexer/stateful/codegen/codegen_test.go new file mode 100644 index 00000000..c0cb2637 --- /dev/null +++ b/lexer/stateful/codegen/codegen_test.go @@ -0,0 +1,105 @@ +package codegen_test + +import ( + "bytes" + "os" + "os/exec" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/alecthomas/participle/lexer" + "github.com/alecthomas/participle/lexer/stateful" + "github.com/alecthomas/participle/lexer/stateful/codegen" +) + +var ( + benchmarkInput = `"` + strings.Repeat(`hello ${name} world what's the song that you're singing, come on get ${emotion}`, 1000) + `"` + exprLexer = stateful.Must(stateful.Rules{ + "Root": { + {`String`, `"`, stateful.Push("String")}, + }, + "String": { + {"Escaped", `\\.`, nil}, + {"StringEnd", `"`, stateful.Pop()}, + {"Expr", `\${`, stateful.Push("Expr")}, + {"Char", `[^$"\\]+`, nil}, + }, + "Expr": { + stateful.Include("Root"), + {`Whitespace`, `\s+`, nil}, + {`Oper`, `[-+/*%]`, nil}, + {"Ident", `\w+`, nil}, + {"ExprEnd", `}`, stateful.Pop()}, + }, + }) +) + +func TestGenerate(t *testing.T) { + w := &bytes.Buffer{} + err := codegen.Generate(w, "codegen_test", exprLexer) + require.NoError(t, err) + source := w.String() + cmd := exec.Command("pbcopy") + cmd.Stdin = strings.NewReader(source) + err = cmd.Run() + require.NoError(t, err) + + formatted := &bytes.Buffer{} + cmd = exec.Command("goimports") + cmd.Stdin = strings.NewReader(source) + cmd.Stdout = formatted + cmd.Stderr = os.Stderr + err = cmd.Run() + require.NoError(t, err, source) + + cmd = exec.Command("pbcopy") + cmd.Stdin = formatted + err = cmd.Run() + require.NoError(t, err) +} + +func BenchmarkStatefulGenerated(b *testing.B) { + b.ReportAllocs() + b.ReportMetric(float64(len(benchmarkInput)), "B") + slex := Lexer.(interface { + LexString(string, string) (lexer.Lexer, error) + }) + for i := 0; i < b.N; i++ { + lex, err := slex.LexString("", benchmarkInput) + if err != nil { + b.Fatal(err) + } + for { + t, err := lex.Next() + if err != nil { + b.Fatal(err) + } + if t.EOF() { + break + } + } + } +} + +func BenchmarkStatefulRegex(b *testing.B) { + b.ReportAllocs() + b.ReportMetric(float64(len(benchmarkInput)), "B") + input := []byte(benchmarkInput) + for i := 0; i < b.N; i++ { + lex, err := exprLexer.LexBytes("", input) + if err != nil { + b.Fatal(err) + } + for { + t, err := lex.Next() + if err != nil { + b.Fatal(err) + } + if t.EOF() { + break + } + } + } +} diff --git a/lexer/stateful/codegen/generatedlexer_test.go b/lexer/stateful/codegen/generatedlexer_test.go new file mode 100644 index 00000000..55710ee5 --- /dev/null +++ b/lexer/stateful/codegen/generatedlexer_test.go @@ -0,0 +1,393 @@ +// Code generated by Participle. DO NOT EDIT. +package codegen_test + +import ( + "io" + "strings" + "unicode/utf8" + + "github.com/alecthomas/participle" + "github.com/alecthomas/participle/lexer" +) + +var Lexer lexer.Definition = definitionImpl{} + +type definitionImpl struct{} + +func (definitionImpl) Symbols() map[string]rune { + return map[string]rune{ + "Char": -11, + "EOF": -1, + "Escaped": -8, + "Expr": -10, + "ExprEnd": -6, + "Ident": -5, + "Oper": -4, + "String": -7, + "StringEnd": -9, + "Whitespace": -3, + } +} + +func (definitionImpl) LexString(filename string, s string) (lexer.Lexer, error) { + return &lexerImpl{ + s: s, + pos: lexer.Position{ + Filename: filename, + Line: 1, + Column: 1, + }, + states: []lexerState{lexerState{name: "Root"}}, + }, nil +} + +func (d definitionImpl) LexBytes(filename string, b []byte) (lexer.Lexer, error) { + return d.LexString(filename, string(b)) +} + +func (d definitionImpl) Lex(filename string, r io.Reader) (lexer.Lexer, error) { + s := &strings.Builder{} + _, err := io.Copy(s, r) + if err != nil { + return nil, err + } + return d.LexString(filename, s.String()) +} + +type lexerState struct { + name string + groups []string +} + +type lexerImpl struct { + s string + p int + pos lexer.Position + states []lexerState +} + +func (l *lexerImpl) Next() (lexer.Token, error) { + if l.p == len(l.s) { + return lexer.EOFToken(l.pos), nil + } + var ( + state = l.states[len(l.states)-1] + groups []int + sym rune + ) + switch state.name { + case "Expr": + if match := matchString(l.s, l.p); match[1] != 0 { + sym = -7 + groups = match[:] + l.states = append(l.states, lexerState{name: "String"}) + } else if match := matchWhitespace(l.s, l.p); match[1] != 0 { + sym = -3 + groups = match[:] + } else if match := matchOper(l.s, l.p); match[1] != 0 { + sym = -4 + groups = match[:] + } else if match := matchIdent(l.s, l.p); match[1] != 0 { + sym = -5 + groups = match[:] + } else if match := matchExprEnd(l.s, l.p); match[1] != 0 { + sym = -6 + groups = match[:] + l.states = l.states[:len(l.states)-1] + } + case "Root": + if match := matchString(l.s, l.p); match[1] != 0 { + sym = -7 + groups = match[:] + l.states = append(l.states, lexerState{name: "String"}) + } + case "String": + if match := matchEscaped(l.s, l.p); match[1] != 0 { + sym = -8 + groups = match[:] + } else if match := matchStringEnd(l.s, l.p); match[1] != 0 { + sym = -9 + groups = match[:] + l.states = l.states[:len(l.states)-1] + } else if match := matchExpr(l.s, l.p); match[1] != 0 { + sym = -10 + groups = match[:] + l.states = append(l.states, lexerState{name: "Expr"}) + } else if match := matchChar(l.s, l.p); match[1] != 0 { + sym = -11 + groups = match[:] + } + } + if groups == nil { + return lexer.Token{}, participle.Errorf(l.pos, "no lexer rules in state %q matched input text", l.states[len(l.states)-1]) + } + pos := l.pos + span := l.s[groups[0]:groups[1]] + l.p = groups[1] + l.pos.Offset = groups[1] + lines := strings.Count(span, "\n") + l.pos.Line += lines + // Update column. + if lines == 0 { + l.pos.Column += utf8.RuneCountInString(span) + } else { + l.pos.Column = utf8.RuneCountInString(span[strings.LastIndex(span, "\n"):]) + } + return lexer.Token{ + Type: sym, + Value: span, + Pos: pos, + }, nil +} + +func (l *lexerImpl) sgroups(match []int) []string { + sgroups := make([]string, len(match)/2) + for i := 0; i < len(match)-1; i += 2 { + sgroups[i/2] = l.s[l.p+match[i] : l.p+match[i+1]] + } + return sgroups +} + +// " +func matchString(s string, p int) (groups [2]int) { + if p < len(s) && s[p] == '"' { + groups[0] = p + groups[1] = p + 1 + } + return +} + +// \\(?-s:.) +func matchEscaped(s string, p int) (groups [2]int) { + // \\ (Literal) + l0 := func(s string, p int) int { + if p < len(s) && s[p] == '\\' { + return p + 1 + } + return -1 + } + // (?-s:.) (AnyCharNotNL) + l1 := func(s string, p int) int { + var ( + rn rune + n int + ) + if s[p] < utf8.RuneSelf { + rn, n = rune(s[p]), 1 + } else { + rn, n = utf8.DecodeRuneInString(s[p:]) + } + if len(s) <= p+n || rn == '\n' { + return -1 + } + return p + n + } + // \\(?-s:.) (Concat) + l2 := func(s string, p int) int { + if p = l0(s, p); p == -1 { + return -1 + } + if p = l1(s, p); p == -1 { + return -1 + } + return p + } + np := l2(s, p) + if np == -1 { + return + } + groups[0] = p + groups[1] = np + return +} + +// " +func matchStringEnd(s string, p int) (groups [2]int) { + if p < len(s) && s[p] == '"' { + groups[0] = p + groups[1] = p + 1 + } + return +} + +// \$\{ +func matchExpr(s string, p int) (groups [2]int) { + if p+2 < len(s) && s[p:p+2] == "${" { + groups[0] = p + groups[1] = p + 2 + } + return +} + +// [^"\$\\]+ +func matchChar(s string, p int) (groups [2]int) { + // [^"\$\\] (CharClass) + l0 := func(s string, p int) int { + if len(s) <= p { + return -1 + } + var ( + rn rune + n int + ) + if s[p] < utf8.RuneSelf { + rn, n = rune(s[p]), 1 + } else { + rn, n = utf8.DecodeRuneInString(s[p:]) + } + switch { + case rn >= '\x00' && rn <= '!': + return p + 1 + case rn == '#': + return p + 1 + case rn >= '%' && rn <= '[': + return p + 1 + case rn >= ']' && rn <= '\U0010ffff': + return p + n + } + return -1 + } + // [^"\$\\]+ (Plus) + l1 := func(s string, p int) int { + if p = l0(s, p); p == -1 { + return -1 + } + for len(s) > p { + if np := l0(s, p); np == -1 { + return p + } else { + p = np + } + } + return p + } + np := l1(s, p) + if np == -1 { + return + } + groups[0] = p + groups[1] = np + return +} + +// [\t-\n\f-\r ]+ +func matchWhitespace(s string, p int) (groups [2]int) { + // [\t-\n\f-\r ] (CharClass) + l0 := func(s string, p int) int { + if len(s) <= p { + return -1 + } + rn := s[p] + switch { + case rn >= '\t' && rn <= '\n': + return p + 1 + case rn >= '\f' && rn <= '\r': + return p + 1 + case rn == ' ': + return p + 1 + } + return -1 + } + // [\t-\n\f-\r ]+ (Plus) + l1 := func(s string, p int) int { + if p = l0(s, p); p == -1 { + return -1 + } + for len(s) > p { + if np := l0(s, p); np == -1 { + return p + } else { + p = np + } + } + return p + } + np := l1(s, p) + if np == -1 { + return + } + groups[0] = p + groups[1] = np + return +} + +// [%\*-\+\-/] +func matchOper(s string, p int) (groups [2]int) { + // [%\*-\+\-/] (CharClass) + l0 := func(s string, p int) int { + if len(s) <= p { + return -1 + } + rn := s[p] + switch { + case rn == '%': + return p + 1 + case rn >= '*' && rn <= '+': + return p + 1 + case rn == '-': + return p + 1 + case rn == '/': + return p + 1 + } + return -1 + } + np := l0(s, p) + if np == -1 { + return + } + groups[0] = p + groups[1] = np + return +} + +// [0-9A-Z_a-z]+ +func matchIdent(s string, p int) (groups [2]int) { + // [0-9A-Z_a-z] (CharClass) + l0 := func(s string, p int) int { + if len(s) <= p { + return -1 + } + rn := s[p] + switch { + case rn >= '0' && rn <= '9': + return p + 1 + case rn >= 'A' && rn <= 'Z': + return p + 1 + case rn == '_': + return p + 1 + case rn >= 'a' && rn <= 'z': + return p + 1 + } + return -1 + } + // [0-9A-Z_a-z]+ (Plus) + l1 := func(s string, p int) int { + if p = l0(s, p); p == -1 { + return -1 + } + for len(s) > p { + if np := l0(s, p); np == -1 { + return p + } else { + p = np + } + } + return p + } + np := l1(s, p) + if np == -1 { + return + } + groups[0] = p + groups[1] = np + return +} + +// \} +func matchExprEnd(s string, p int) (groups [2]int) { + if p < len(s) && s[p] == '}' { + groups[0] = p + groups[1] = p + 1 + } + return +} diff --git a/lexer/stateful/stateful.go b/lexer/stateful/stateful.go index 598a5feb..1c445f1b 100644 --- a/lexer/stateful/stateful.go +++ b/lexer/stateful/stateful.go @@ -219,41 +219,46 @@ type RulesAction interface { applyRules(state string, rule int, rules compiledRules) error } -// ActionFunc is a function that is also a Action. -type ActionFunc func(*Lexer, []string) error +// ActionPop pops to the previous state when the Rule matches. +type ActionPop struct{} -func (m ActionFunc) applyAction(lexer *Lexer, groups []string) error { return m(lexer, groups) } // nolint: golint +func (p ActionPop) applyAction(lexer *Lexer, groups []string) error { + if groups[0] == "" { + return errors.New("did not consume any input") + } + lexer.stack = lexer.stack[:len(lexer.stack)-1] + return nil +} // Pop to the previous state. func Pop() Action { - return ActionFunc(func(lexer *Lexer, groups []string) error { - if groups[0] == "" { - return errors.New("did not consume any input") - } - lexer.stack = lexer.stack[:len(lexer.stack)-1] - return nil - }) + return ActionPop{} } -var returnToParent = Rule{"popIfEmpty", "", nil} +var returnToParent = Rule{"returnToParent", "", nil} // Return to the parent state. // // Useful as the last rule in a sub-state. func Return() Rule { return returnToParent } +// ActionPush pushes the current state and switches to "State" when the Rule matches. +type ActionPush struct{ State string } + +func (p ActionPush) applyAction(lexer *Lexer, groups []string) error { + if groups[0] == "" { + return errors.New("did not consume any input") + } + lexer.stack = append(lexer.stack, lexerState{name: p.State, groups: groups}) + return nil +} + // Push to the given state. // // The target state will then be the set of rules used for matching // until another Push or Pop is encountered. func Push(state string) Action { - return ActionFunc(func(lexer *Lexer, groups []string) error { - if groups[0] == "" { - return errors.New("did not consume any input") - } - lexer.stack = append(lexer.stack, lexerState{name: state, groups: groups}) - return nil - }) + return ActionPush{state} } type include struct{ state string } @@ -286,7 +291,7 @@ type Definition struct { // MustSimple creates a new lexer definition based on a single state described by `rules`. // panics if the rules trigger an error -func MustSimple(rules []Rule) lexer.Definition { +func MustSimple(rules []Rule) *Definition { def, err := NewSimple(rules) if err != nil { panic(err) @@ -295,12 +300,12 @@ func MustSimple(rules []Rule) lexer.Definition { } // NewSimple creates a new stateful lexer with a single "Root" state. -func NewSimple(rules []Rule) (lexer.Definition, error) { +func NewSimple(rules []Rule) (*Definition, error) { return New(Rules{"Root": rules}) } // Must creates a new stateful lexer and panics if it is incorrect. -func Must(rules Rules) lexer.Definition { +func Must(rules Rules) *Definition { def, err := New(rules) if err != nil { panic(err) @@ -309,7 +314,7 @@ func Must(rules Rules) lexer.Definition { } // New constructs a new stateful lexer from rules. -func New(rules Rules) (lexer.Definition, error) { +func New(rules Rules) (*Definition, error) { compiled := compiledRules{} for key, set := range rules { if _, ok := compiled[key]; !ok { @@ -387,11 +392,18 @@ restart: }, nil } -func (d *Definition) Lex(filename string, r io.Reader) (lexer.Lexer, error) { // nolint: golint - data, err := ioutil.ReadAll(r) - if err != nil { - return nil, err +// Rules returns the user-provided Rules used to construct the lexer. +func (d *Definition) Rules() Rules { + out := Rules{} + for state, rules := range d.rules { + for _, rule := range rules.rules { + out[state] = append(out[state], rule.Rule) + } } + return out +} + +func (d *Definition) LexBytes(filename string, data []byte) (lexer.Lexer, error) { // nolint: golint return &Lexer{ def: d, data: data, @@ -404,6 +416,14 @@ func (d *Definition) Lex(filename string, r io.Reader) (lexer.Lexer, error) { // }, nil } +func (d *Definition) Lex(filename string, r io.Reader) (lexer.Lexer, error) { // nolint: golint + data, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + return d.LexBytes(filename, data) +} + func (d *Definition) Symbols() map[string]rune { // nolint: golint return d.symbols }