diff --git a/Sources/SwiftSyntax/RawSyntax.swift b/Sources/SwiftSyntax/RawSyntax.swift index 15ee54f558136..5b05146845c3c 100644 --- a/Sources/SwiftSyntax/RawSyntax.swift +++ b/Sources/SwiftSyntax/RawSyntax.swift @@ -478,7 +478,7 @@ extension RawSyntax { if shouldVisit { // Visit this node realizes a syntax node. visitor.visitPre() - visitChildren = visitor.visit() + visitChildren = visitor.visit() == .visitChildren } if visitChildren { for (offset, element) in layout.enumerated() { diff --git a/Sources/SwiftSyntax/SyntaxClassifier.swift.gyb b/Sources/SwiftSyntax/SyntaxClassifier.swift.gyb index 30854b0dd67a0..c4d7ca5c0f18e 100644 --- a/Sources/SwiftSyntax/SyntaxClassifier.swift.gyb +++ b/Sources/SwiftSyntax/SyntaxClassifier.swift.gyb @@ -70,7 +70,7 @@ fileprivate class _SyntaxClassifier: SyntaxVisitor { } } - override func visit(_ token: TokenSyntax) { + override func visit(_ token: TokenSyntax) -> SyntaxVisitorContinueKind { assert(token.isPresent) // FIXME: We need to come up with some way in which the SyntaxClassifier can // classify trivia (i.e. comments). In particular we need to be able to @@ -91,16 +91,17 @@ fileprivate class _SyntaxClassifier: SyntaxVisitor { assert(classifications[token] == nil, "\(token) has already been classified") classifications[token] = classification + return .skipChildren } % for node in SYNTAX_NODES: % if is_visitable(node): - override func visit(_ node: ${node.name}) -> Bool { + override func visit(_ node: ${node.name}) -> SyntaxVisitorContinueKind { if skipNodeIds.contains(node.raw.id) { - return false + return .skipChildren } % if node.is_unknown() or node.is_syntax_collection(): - return true + return .visitChildren % else: % for child in node.children: % if child.is_optional: @@ -123,7 +124,7 @@ fileprivate class _SyntaxClassifier: SyntaxVisitor { % end % end % end - return false + return .skipChildren % end } % end diff --git a/Sources/SwiftSyntax/SyntaxRewriter.swift.gyb b/Sources/SwiftSyntax/SyntaxRewriter.swift.gyb index bebe308840f3b..9ded221844dfd 100644 --- a/Sources/SwiftSyntax/SyntaxRewriter.swift.gyb +++ b/Sources/SwiftSyntax/SyntaxRewriter.swift.gyb @@ -94,24 +94,35 @@ open class SyntaxRewriter { } } +/// The enum describes how the SyntaxVistor should continue after visiting +/// the current node. +public enum SyntaxVisitorContinueKind { + + /// The visitor should visit the descendents of the current node. + case visitChildren + + /// The visitor should avoid visiting the descendents of the current node. + case skipChildren +} + open class SyntaxVisitor { public init() {} % for node in SYNTAX_NODES: % if is_visitable(node): /// Visting ${node.name} specifically. /// - Parameter node: the node we are visiting. - /// - Returns: whether we should visit the descendents of node. - open func visit(_ node: ${node.name}) -> Bool { - return true + /// - Returns: how should we continue visiting. + open func visit(_ node: ${node.name}) -> SyntaxVisitorContinueKind { + return .visitChildren } % end % end /// Visting UnknownSyntax specifically. /// - Parameter node: the node we are visiting. - /// - Returns: whether we should visit the descendents of node. - open func visit(_ node: UnknownSyntax) -> Bool { - return true + /// - Returns: how should we continue visiting. + open func visit(_ node: UnknownSyntax) -> SyntaxVisitorContinueKind { + return .visitChildren } /// Whether we should ever visit a given syntax kind. @@ -128,7 +139,9 @@ open class SyntaxVisitor { return true } - open func visit(_ token: TokenSyntax) {} + open func visit(_ token: TokenSyntax) -> SyntaxVisitorContinueKind { + return .skipChildren + } /// The function called before visiting the node and its descendents. /// - node: the node we are about to visit. @@ -138,18 +151,17 @@ open class SyntaxVisitor { /// - node: the node we just finished visiting. open func visitPost(_ node: Syntax) {} - public func visit(_ node: Syntax) -> Bool { + public func visit(_ node: Syntax) -> SyntaxVisitorContinueKind { switch node.raw.kind { - case .token: visit(node as! TokenSyntax) + case .token: return visit(node as! TokenSyntax) % for node in SYNTAX_NODES: % if is_visitable(node): case .${node.swift_syntax_kind}: return visit(node as! ${node.name}) % end % end case .unknown: return visit(node as! UnknownSyntax) - default: break + default: return .skipChildren } - return false } } @@ -158,7 +170,7 @@ open class SyntaxVisitor { /// otherwise the node is represented as a child index list from a realized /// ancestor. class PendingSyntaxNode { - let parent: PendingSyntaxNode? + let parent: PendingSyntaxNode! private var kind: PendingSyntaxNodeKind private enum PendingSyntaxNodeKind { @@ -174,7 +186,7 @@ class PendingSyntaxNode { case .realized(let node): return node case .virtual(let index): - let _node = parent!.node.child(at: index)! + let _node = parent.node.child(at: index)! kind = .realized(node: _node) return _node } @@ -199,7 +211,7 @@ class PendingSyntaxNode { /// not interesting to users' SyntaxVisitor. class RawSyntaxVisitor { private let visitor: SyntaxVisitor - private var currentNode: PendingSyntaxNode? + private var currentNode: PendingSyntaxNode! required init(_ visitor: SyntaxVisitor, _ root: Syntax) { self.visitor = visitor @@ -215,25 +227,25 @@ class RawSyntaxVisitor { } func addChildIdx(_ idx: Int) { - currentNode = PendingSyntaxNode(currentNode!, idx) + currentNode = PendingSyntaxNode(currentNode, idx) } func moveUp() { - currentNode = currentNode!.parent + currentNode = currentNode.parent } func visitPre() { - visitor.visitPre(currentNode!.node) + visitor.visitPre(currentNode.node) } func visitPost() { - visitor.visitPost(currentNode!.node) + visitor.visitPost(currentNode.node) } // The current raw syntax node is interesting for the user, so realize a // correponding syntax node and feed it into the visitor. - func visit() -> Bool { - return visitor.visit(currentNode!.node) + func visit() -> SyntaxVisitorContinueKind { + return visitor.visit(currentNode.node) } } diff --git a/Sources/lit-test-helper/ClassifiedSyntaxTreePrinter.swift b/Sources/lit-test-helper/ClassifiedSyntaxTreePrinter.swift index 86fcb0836e24f..4a67c518bcf4c 100644 --- a/Sources/lit-test-helper/ClassifiedSyntaxTreePrinter.swift +++ b/Sources/lit-test-helper/ClassifiedSyntaxTreePrinter.swift @@ -115,11 +115,12 @@ class ClassifiedSyntaxTreePrinter: SyntaxVisitor { } } - override func visit(_ node: TokenSyntax) { + override func visit(_ node: TokenSyntax) -> SyntaxVisitorContinueKind { visit(node.leadingTrivia) let classification = classifications[node] ?? SyntaxClassification.none recordCurrentClassification(classification) result += node.text visit(node.trailingTrivia) + return .skipChildren } } diff --git a/Tests/SwiftSyntaxTest/AbsolutePosition.swift b/Tests/SwiftSyntaxTest/AbsolutePosition.swift index 5521385f9650f..82dce37e0257a 100644 --- a/Tests/SwiftSyntaxTest/AbsolutePosition.swift +++ b/Tests/SwiftSyntaxTest/AbsolutePosition.swift @@ -64,9 +64,10 @@ public class AbsolutePositionTestCase: XCTestCase { _ = node.byteSize _ = node.positionAfterSkippingLeadingTrivia } - override func visit(_ node: TokenSyntax) { + override func visit(_ node: TokenSyntax) -> SyntaxVisitorContinueKind { XCTAssertEqual(node.positionAfterSkippingLeadingTrivia.utf8Offset, node.position.utf8Offset + node.leadingTrivia.byteSize) + return .skipChildren } } parsed.walk(Visitor()) diff --git a/Tests/SwiftSyntaxTest/DiagnosticTest.swift b/Tests/SwiftSyntaxTest/DiagnosticTest.swift index 17206ea03203c..0c2c6d2af553e 100644 --- a/Tests/SwiftSyntaxTest/DiagnosticTest.swift +++ b/Tests/SwiftSyntaxTest/DiagnosticTest.swift @@ -76,14 +76,14 @@ public class DiagnosticTestCase: XCTestCase { self.url = url self.engine = engine } - override func visit(_ function: FunctionDeclSyntax) -> Bool { + override func visit(_ function: FunctionDeclSyntax) -> SyntaxVisitorContinueKind { let startLoc = function.identifier.startLocation(in: url) let endLoc = function.endLocation(in: url) engine.diagnose(.badFunction(function.identifier), location: startLoc) { $0.highlight(function.identifier.sourceRange(in: self.url)) } engine.diagnose(.endOfFunction(function.identifier), location: endLoc) - return true + return .visitChildren } } diff --git a/Tests/SwiftSyntaxTest/VisitorTest.swift b/Tests/SwiftSyntaxTest/VisitorTest.swift index d74bf0f66f978..f30871d43987d 100644 --- a/Tests/SwiftSyntaxTest/VisitorTest.swift +++ b/Tests/SwiftSyntaxTest/VisitorTest.swift @@ -13,9 +13,9 @@ public class SyntaxVisitorTestCase: XCTestCase { public func testBasic() { class FuncCounter: SyntaxVisitor { var funcCount = 0 - override func visit(_ node: FunctionDeclSyntax) -> Bool { + override func visit(_ node: FunctionDeclSyntax) -> SyntaxVisitorContinueKind { funcCount += 1 - return true + return .visitChildren } } XCTAssertNoThrow(try { @@ -70,9 +70,9 @@ public class SyntaxVisitorTestCase: XCTestCase { class VisitCollections: SyntaxVisitor { var numberOfCodeBlockItems = 0 - override func visit(_ items: CodeBlockItemListSyntax) -> Bool { + override func visit(_ items: CodeBlockItemListSyntax) -> SyntaxVisitorContinueKind { numberOfCodeBlockItems += items.count - return true + return .visitChildren } }