diff --git a/rust/yarp/build.rs b/rust/yarp/build.rs index ab41d64e275..9735909e308 100644 --- a/rust/yarp/build.rs +++ b/rust/yarp/build.rs @@ -154,6 +154,12 @@ fn write_node(file: &mut File, node: &Node) -> Result<(), Box {}<'pr> {{", node.name)?; + writeln!(file, " /// Converts this node to a generic node.")?; + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn as_node(&self) -> Node<'pr> {{")?; + writeln!(file, " Node::{} {{ pointer: self.pointer, marker: PhantomData }}", node.name)?; + writeln!(file, " }}")?; + writeln!(file)?; writeln!(file, " /// Returns the location of this node.")?; writeln!(file, " #[must_use]")?; writeln!(file, " pub fn location(&self) -> Location<'pr> {{")?; diff --git a/rust/yarp/src/lib.rs b/rust/yarp/src/lib.rs index 9c3e19ed9b7..1df92e8f4d8 100644 --- a/rust/yarp/src/lib.rs +++ b/rust/yarp/src/lib.rs @@ -308,4 +308,20 @@ mod tests { assert_eq!(visitor.count, 2); } + + #[test] + fn node_upcast_test() { + use super::Node; + + let source = "module Foo; end"; + let result = parse(source.as_ref()); + + let node = result.node(); + let upcast_node = node.as_program_node().unwrap().as_node(); + assert!(matches!(upcast_node, Node::ProgramNode { .. })); + + let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap(); + let upcast_node = node.as_module_node().unwrap().as_node(); + assert!(matches!(upcast_node, Node::ModuleNode { .. })); + } }