Skip to content

Commit

Permalink
chore: Rename edit to insert, drop End(), rename Start()
Browse files Browse the repository at this point in the history
The edits are all insertions at specific positions.
There's no replacement of text otherwise -- at least at this time.

Simplify the interface:

- rename it to insert
- rename Start to Pos
- drop the End method
- rename implementations to match
  • Loading branch information
abhinav committed Nov 19, 2023
1 parent fe80c71 commit a58d9f3
Showing 1 changed file with 51 additions and 90 deletions.
141 changes: 51 additions & 90 deletions cmd/errtrace/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"log"
"os"
"sort"
"strings"
)

func main() {
Expand Down Expand Up @@ -97,11 +96,6 @@ func (cmd *mainCmd) Run(args []string) (exitCode int) {
//
// First, it walks the AST to find all the places that need to be modified,
// extracting other information as needed.
// The list of edits is at a higher level than plain text modifications:
// it tracks the kind of edit to make semantically,
// e.g. "wrap an expression" rather than "add these words at this position."
// This provides freedom to gather information
// before committing to specific strings and names.
//
// The collected information is used to pick a package name,
// whether we need an import, etc. and *then* the edits are applied.
Expand Down Expand Up @@ -131,16 +125,16 @@ func (cmd *mainCmd) processFile(write bool, filename string) error {
}
}

var edits []edit
var inserts []insert
w := walker{
fset: fset,
errtrace: errtracePkg,
logger: cmd.log,
edits: &edits,
inserts: &inserts,
}
ast.Walk(&w, f)

if !importsErrtrace && len(edits) > 0 {
if !importsErrtrace && len(inserts) > 0 {
var lastImportDecl *ast.GenDecl
for _, imp := range f.Decls {
decl, ok := imp.(*ast.GenDecl)
Expand All @@ -150,38 +144,25 @@ func (cmd *mainCmd) processFile(write bool, filename string) error {
lastImportDecl = decl
}

var edit appendImportEdit
var i insertImportErrtrace
if lastImportDecl != nil {
// import "foo"
// // becomes
// import "foo"; import "brace.dev/errtrace"
edit.Node = lastImportDecl
i.Node = lastImportDecl
} else {
// package foo
// // becomes
// package foo; import "brace.dev/errtrace"
edit.Node = f.Name
i.Node = f.Name
}
edits = append(edits, &edit)
inserts = append(inserts, &i)
}

sort.Slice(edits, func(i, j int) bool {
return edits[i].Start() < edits[j].Start()
sort.Slice(inserts, func(i, j int) bool {
return inserts[i].Pos() < inserts[j].Pos()
})

// Detect overlapping edits.
// This indicates a bug in the walker.
for i := 1; i < len(edits); i++ {
prev, cur := edits[i-1], edits[i]
if prev.End() > cur.Start() {
var msg strings.Builder
fmt.Fprintf(&msg, "%s:found overlapping edit:\n", filename)
fmt.Fprintf(&msg, "\t%s:%v\n", fset.Position(prev.End()), prev)
fmt.Fprintf(&msg, "\t%s:%v\n", fset.Position(cur.Start()), cur)
panic(msg.String())
}
}

outw := cmd.Stdout
if write {
f, err := os.Create(filename)
Expand All @@ -196,50 +177,49 @@ func (cmd *mainCmd) processFile(write bool, filename string) error {

var lastOffset int
file := fset.File(f.Pos())
for _, edit := range edits {
start, end := file.Offset(edit.Start()), file.Offset(edit.End())
_, _ = out.Write(src[lastOffset:start])
lastOffset = end
for _, it := range inserts {
offset := file.Offset(it.Pos())
_, _ = out.Write(src[lastOffset:offset])
lastOffset = offset

switch edit := edit.(type) {
case *appendImportEdit:
switch it := it.(type) {
case *insertImportErrtrace:
// Add the original node as-is.
_, _ = out.Write(src[start:end])
if errtracePkg == "errtrace" {
// Don't use named imports if we're using the default name.
fmt.Fprintf(out, "; import %q", "braces.dev/errtrace")
} else {
fmt.Fprintf(out, "; import %s %q", errtracePkg, "braces.dev/errtrace")
}

case *wrapOpenEdit:
case *insertWrapOpen:
fmt.Fprintf(out, "%s.Wrap(", errtracePkg)

case *wrapCloseEdit:
case *insertWrapClose:
_, _ = out.WriteString(")")

case *assignWrapEdit:
case *insertWrapAssign:
// Turns this:
// return
// Into this:
// x, y = errtrace.Wrap(x), errtrace.Wrap(y); return
for i, name := range edit.Names {
for i, name := range it.Names {
if i > 0 {
_, _ = out.WriteString(", ")
}
fmt.Fprintf(out, "%s", name)
}
_, _ = out.WriteString(" = ")
for i, name := range edit.Names {
for i, name := range it.Names {
if i > 0 {
_, _ = out.WriteString(", ")
}
fmt.Fprintf(out, "%s.Wrap(%s)", errtracePkg, name)
}
_, _ = out.WriteString("; return")
_, _ = out.WriteString("; ")

default:
cmd.log.Panicf("unhandled edit type %T", edit)
cmd.log.Panicf("unhandled insertion type %T", it)
}
}
_, _ = out.Write(src[lastOffset:]) // flush remaining
Expand All @@ -256,8 +236,8 @@ type walker struct {

// Outputs

// edits is the list of edits to make.
edits *[]edit
// inserts is the list of inserts to make.
inserts *[]insert

// State

Expand Down Expand Up @@ -290,7 +270,7 @@ func (t *walker) Visit(n ast.Node) (w ast.Visitor) {
// Naked return.
// Add assignments to the named return values.
if n.Results == nil {
*t.edits = append(*t.edits, &assignWrapEdit{
*t.inserts = append(*t.inserts, &insertWrapAssign{
Names: t.errorNames,
Stmt: n,
})
Expand Down Expand Up @@ -327,9 +307,9 @@ func (t *walker) Visit(n ast.Node) (w ast.Visitor) {
}
}

*t.edits = append(*t.edits,
&wrapOpenEdit{Expr: expr},
&wrapCloseEdit{Expr: expr},
*t.inserts = append(*t.inserts,
&insertWrapOpen{Expr: expr},
&insertWrapClose{Expr: expr},
)
}
}
Expand Down Expand Up @@ -396,76 +376,61 @@ func (t *walker) funcType(ft *ast.FuncType) ast.Visitor {
return &newT
}

// edit is a request to modify a range of source code.
type edit interface {
Start() token.Pos
End() token.Pos
String() string
// insert is a request to add something to the source code.
type insert interface {
Pos() token.Pos // position to insert at
String() string // description for debugging
}

// appendImportEdit adds an import declaration to the file
// insertImportErrtrace adds an import declaration to the file
// right after the given node.
type appendImportEdit struct {
type insertImportErrtrace struct {
Node ast.Node // the node to insert the import after
}

func (e *appendImportEdit) Start() token.Pos {
return e.Node.Pos()
}

func (e *appendImportEdit) End() token.Pos {
func (e *insertImportErrtrace) Pos() token.Pos {
return e.Node.End()
}

func (e *appendImportEdit) String() string {
func (e *insertImportErrtrace) String() string {
return fmt.Sprintf("append errtrace import after %T", e.Node)
}

// wrapOpenEdit adds a errtrace.Wrap call before an expression.
// insertWrapOpen adds a errtrace.Wrap call before an expression.
//
// foo() -> errtrace.Wrap(foo()
//
// This needs a corresponding wrapCloseEdit to close the call.
type wrapOpenEdit struct {
// This needs a corresponding insertWrapClose to close the call.
type insertWrapOpen struct {
Expr ast.Expr
}

func (e *wrapOpenEdit) Start() token.Pos {
return e.Expr.Pos()
}

// TODO: drop End() from edit interface

func (e *wrapOpenEdit) End() token.Pos {
func (e *insertWrapOpen) Pos() token.Pos {
return e.Expr.Pos()
}

func (e *wrapOpenEdit) String() string {
func (e *insertWrapOpen) String() string {
return fmt.Sprintf("wrap open %T", e.Expr)
}

// wrapCloseEdit closes a errtrace.Wrap call.
// insertWrapClose closes a errtrace.Wrap call.
//
// foo() -> foo())
//
// This needs a corresponding wrapOpenEdit to open the call.
type wrapCloseEdit struct {
// This needs a corresponding insertWrapOpen to open the call.
type insertWrapClose struct {
Expr ast.Expr
}

func (e *wrapCloseEdit) Start() token.Pos {
func (e *insertWrapClose) Pos() token.Pos {
return e.Expr.End()
}

func (e *wrapCloseEdit) End() token.Pos {
return e.Expr.End()
}

func (e *wrapCloseEdit) String() string {
func (e *insertWrapClose) String() string {
return fmt.Sprintf("wrap close %T", e.Expr)
}

// assignWrapEdit wraps a variable in-place with an errtrace.Wrap call.
// insertWrapAssign wraps a variable in-place with an errtrace.Wrap call.
// This is used for naked returns in functions with named return values
//
// For example, it will turn this:
Expand All @@ -481,20 +446,16 @@ func (e *wrapCloseEdit) String() string {
// // ...
// err = errtrace.Wrap(err); return
// }
type assignWrapEdit struct {
Names []string
type insertWrapAssign struct {
Names []string // names of variables to wrap
Stmt *ast.ReturnStmt // Stmt.Results == nil
}

func (e *assignWrapEdit) Start() token.Pos {
func (e *insertWrapAssign) Pos() token.Pos {
return e.Stmt.Pos()
}

func (e *assignWrapEdit) End() token.Pos {
return e.Stmt.End()
}

func (e *assignWrapEdit) String() string {
func (e *insertWrapAssign) String() string {
return fmt.Sprintf("assign errors before %v", e.Names)
}

Expand Down

0 comments on commit a58d9f3

Please sign in to comment.