diff --git a/testdata/src/default_config/else_if.go b/testdata/src/default_config/else_if.go new file mode 100644 index 0000000..7b10487 --- /dev/null +++ b/testdata/src/default_config/else_if.go @@ -0,0 +1,20 @@ +func fn() { + if true { // want "block should not start with a whitespace" + + fmt.Println("a") + } + + if true { + fmt.Println("a") + } else if false { // want "block should not start with a whitespace" + + fmt.Println("b") + } else { // want "block should not start with a whitespace" + + fmt.Println("c") + fmt.Println("c") + fmt.Println("c") + return // want "return statements should not be cuddled if block has more than two lines" + + } // want "block should not end with a whitespace" +} diff --git a/testdata/src/default_config/else_if.go.golden b/testdata/src/default_config/else_if.go.golden new file mode 100644 index 0000000..44af882 --- /dev/null +++ b/testdata/src/default_config/else_if.go.golden @@ -0,0 +1,16 @@ +func fn() { + if true { // want "block should not start with a whitespace" + fmt.Println("a") + } + + if true { + fmt.Println("a") + } else if false { // want "block should not start with a whitespace" + fmt.Println("b") + } else { // want "block should not start with a whitespace" + fmt.Println("c") + fmt.Println("c") + fmt.Println("c") + return // want "return statements should not be cuddled if block has more than two lines" + } // want "block should not end with a whitespace" +} diff --git a/wsl.go b/wsl.go index ee98f07..4700002 100644 --- a/wsl.go +++ b/wsl.go @@ -766,7 +766,14 @@ func (p *processor) firstBodyStatement(i int, allStmt []ast.Stmt) ast.Node { } } - p.parseBlockBody(nil, statementBodyContent) + // If statement bodies will be parsed already when finding block bodies. + // The reason is because if/else-if/else chains is nested in the AST + // where the else bit is a part of the if statement. Since if statements + // is the only statement that can be chained like this we exclude it + // from parsing it again here. + if _, ok := stmt.(*ast.IfStmt); !ok { + p.parseBlockBody(nil, statementBodyContent) + } case []ast.Stmt: // The Body field for an *ast.CaseClause or *ast.CommClause is of type // []ast.Stmt. We must check leading and trailing whitespaces and then @@ -946,6 +953,8 @@ func (p *processor) findBlockStmt(node ast.Node) []*ast.BlockStmt { var blocks []*ast.BlockStmt switch t := node.(type) { + case *ast.BlockStmt: + return []*ast.BlockStmt{t} case *ast.AssignStmt: for _, x := range t.Rhs { blocks = append(blocks, p.findBlockStmt(x)...) @@ -968,6 +977,8 @@ func (p *processor) findBlockStmt(node ast.Node) []*ast.BlockStmt { blocks = append(blocks, p.findBlockStmt(t.Call)...) case *ast.GoStmt: blocks = append(blocks, p.findBlockStmt(t.Call)...) + case *ast.IfStmt: + blocks = append([]*ast.BlockStmt{t.Body}, p.findBlockStmt(t.Else)...) } return blocks