From 2022e9d93990846b0a31fac6b2f28af926a460fb Mon Sep 17 00:00:00 2001 From: Osama Ahmad Date: Fri, 6 Oct 2023 06:59:27 +0300 Subject: [PATCH 1/2] Change the Program node to a TranslationUnit node, with no injection of the main label --- src/ast.rs | 2 +- src/code_generator.rs | 229 +++++++++++++++++++++++------------------- src/parser.rs | 75 +++++++------- 3 files changed, 162 insertions(+), 144 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index ee88031..dfb8a1f 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -20,7 +20,7 @@ pub enum ASTNode { FunctionDeclaration(Token, Token, Vec), // Type, Identifier, Parameters FunctionDefinition(Token, Token, Vec, Box), // Type, Identifier, Parameters, Body ReturnStatement(Token, Box), // ReturnKeyword, Expression - Program(Vec), + TranslationUnit(Vec), If(Token, Box, Box, Option>), // If, Condition, Body, Else While(Token, Box, Box), // While, Condition, Body DoWhile(Token, Box, Token, Box), // Do, Body, While, Condition diff --git a/src/code_generator.rs b/src/code_generator.rs index c8c5b2c..6ae68b9 100644 --- a/src/code_generator.rs +++ b/src/code_generator.rs @@ -18,7 +18,7 @@ impl CodeGenerator { pub fn generate(&mut self, root: &ASTNode) -> String { match root { - Program(..) => self.generate_program(root), + TranslationUnit(..) => self.generate_translation_unit(root), ReturnStatement(..) => self.generate_return_statement(root), VariableDeclaration(..) => self.generate_variable_declaration(root), VariableDefinition(..) => self.generate_variable_definition(root), @@ -316,31 +316,19 @@ impl CodeGenerator { } } - fn generate_program(&mut self, node: &ASTNode) -> String { + fn generate_translation_unit(&mut self, node: &ASTNode) -> String { let mut result = String::new(); match node { - Program(nodes_vector) => { - self.symbol_table.push_scope(symbol_table::Scope::new(-8)); - self.symbol_table.reset_largest_offset(); + TranslationUnit(nodes_vector) => { + // TODO Support global variables + // FIXME this is a scope for storing global FUNCTION declarations only + // and should not be used for global variables + self.symbol_table.push_scope(symbol_table::Scope::new(0)); for node in nodes_vector { result.push_str(&self.generate(node)); } self.symbol_table.pop_scope(); - - format!( - ".global main\n\ - main:\n\ - push %rbp\n\ - mov %rsp, %rbp\n\ - subq ${}, %rsp\n\ - {}\ - mov %rbp, %rsp\n\ - mov {}, %rax\n\ - pop %rbp", - -self.symbol_table.current_largest_offset(), - result, - CodeGenerator::get_reg1(8) - ) + result } _ => panic!("Internal Error: Expected program node, found {:?}", node), } @@ -608,22 +596,31 @@ mod tests { } #[rstest::rstest] - #[case("return 5 + 3 * 2 + (2 * 19 * 4) / 2 + 9 * 12 / 3 * 3;", 195)] - #[case("int x = 1; int y = 2; return x > y;", 0)] - #[case("int x = 1; int y = 2; return x < y;", 1)] - #[case("int x = 1; int y = 2; return x >= y;", 0)] - #[case("int x = 1; int y = 2; return x <= y;", 1)] - #[case("int x = 1; int y = 2; return x != y;", 1)] - #[case("int x = 1; int y = 2; int z = 1; return x + z == y;", 1)] - #[case("int x = 1; int y = x; return x == y;", 1)] - #[case("return (1 | 2 | 4 | 8 | 16 | 32) & 85;", 21)] - #[case("int x = 12; int y = 423; return x || y;", 1)] - #[case("int x = -7; int y = 15; return x + y;", 8)] - #[case("int x = -----++++----+------12; return x * x;", 144)] - #[case("int x = 1; return -(+(-(+x)));", 1)] - #[case("int false = 0; int true = 123; return true || false;", 1)] #[case( - "int x = 312; int y = 99; int z; z = 2 * x / 3 + y * y; return z - x;", + "int main() { return 5 + 3 * 2 + (2 * 19 * 4) / 2 + 9 * 12 / 3 * 3; }", + 195 + )] + #[case("int main() { int x = 1; int y = 2; return x > y; }", 0)] + #[case("int main() { int x = 1; int y = 2; return x < y; }", 1)] + #[case("int main() { int x = 1; int y = 2; return x >= y; }", 0)] + #[case("int main() { int x = 1; int y = 2; return x <= y; }", 1)] + #[case("int main() { int x = 1; int y = 2; return x != y; }", 1)] + #[case( + "int main() { int x = 1; int y = 2; int z = 1; return x + z == y; }", + 1 + )] + #[case("int main() { int x = 1; int y = x; return x == y; }", 1)] + #[case("int main() { return (1 | 2 | 4 | 8 | 16 | 32) & 85; }", 21)] + #[case("int main() { int x = 12; int y = 423; return x || y; }", 1)] + #[case("int main() { int x = -7; int y = 15; return x + y; }", 8)] + #[case("int main() { int x = -----++++----+------12; return x * x; }", 144)] + #[case("int main() { int x = 1; return -(+(-(+x))); }", 1)] + #[case( + "int main() { int false = 0; int true = 123; return true || false; }", + 1 + )] + #[case( + "int main() { int x = 312; int y = 99; int z; z = 2 * x / 3 + y * y; return z - x; }", 9697 )] fn test_generate_expression_with_precedence( @@ -636,15 +633,15 @@ mod tests { } #[rstest::rstest] - #[case("int x = 55; { int y = 5; x = y; } return x;", 5)] - #[case("int x = 6; { int y = 55; return y; } return x;", 55)] - #[case("int x = 6; { int y = 55; return y; }", 55)] - #[case("int x = 1; { int x = 2; { int x = 3; return x; } }", 3)] - #[case("int x = 1; { int y = 2; { return x; } }", 1)] - #[case("{{{{{{{{{{ return 5; }}}}}}}}}}", 5)] - #[case("{{{{{{ }}}} return 1; }}", 1)] + #[case("int main() { int x = 55; { int y = 5; x = y; } return x; }", 5)] + #[case("int main() { int x = 6; { int y = 55; return y; } return x; }", 55)] + #[case("int main() { int x = 6; { int y = 55; return y; } }", 55)] + #[case("int main() { int x = 1; { int x = 2; { int x = 3; return x; } } }", 3)] + #[case("int main() { int x = 1; { int y = 2; { return x; } } }", 1)] + #[case("int main() { {{{{{{{{{{ return 5; }}}}}}}}}} }", 5)] + #[case("int main() { {{{{{{ }}}} return 1; }} }", 1)] #[case( - "int x = 1; { int y = 5; } { int y = 6; } int y = 7; { int y = 8; return y; }", + "int main() { int x = 1; { int y = 5; } { int y = 6; } int y = 7; { int y = 8; return y; } }", 8 )] fn test_generated_scoped_programs( @@ -657,26 +654,29 @@ mod tests { } #[rstest::rstest] - #[case("{ int x = 4; } return x;")] - #[case("{ int x = 4; } { int y = 5; int z = 6; } return x + z;")] - #[case("{{{{{{ }}}} return 1; }")] - #[case("if (1) { int x = 1; } return x;")] - #[case("if (1) { int x = 1; } else if (0) { int y = 1; } return y;")] - #[case("if (1) int x = 1; return x;")] - #[case("while (0) { int x = 1; } return x;")] - #[case("while (0) int x = 1; return x;")] - #[case("do { int x; } while (1); return x;")] + #[case("int main() { { int x = 4; } return x; }")] + #[case("int main() { { int x = 4; } { int y = 5; int z = 6; } return x + z; }")] + #[case("int main() { {{{{{{ }}}} return 1; } }")] + #[case("int main() { if (1) { int x = 1; } return x; }")] + #[case("int main() { if (1) { int x = 1; } else if (0) { int y = 1; } return y; }")] + #[case("int main() { if (1) int x = 1; return x; }")] + #[case("int main() { while (0) { int x = 1; } return x; }")] + #[case("int main() { while (0) int x = 1; return x; }")] + #[case("int main() { do { int x; } while (1); return x; }")] #[should_panic] fn test_undefined_variables_in_scope(#[case] test_case: String) { let generated = generate_code(test_case); } #[rstest::rstest] - #[case("int x = 55; if (x & 1) { return 16; } return x;", 16)] - #[case("int x = 55; if (x & 0) { return 16; } return x;", 55)] - #[case("int x = 12; if (x & 1) { return 16; } else { return 5 * 12; }", 60)] + #[case("int main() { int x = 55; if (x & 1) { return 16; } return x; }", 16)] + #[case("int main() { int x = 55; if (x & 0) { return 16; } return x; }", 55)] + #[case( + "int main() { int x = 12; if (x & 1) { return 16; } else { return 5 * 12; } }", + 60 + )] #[case( - "if (0) { return 16; } else if (1) { return 5 * 12; } else { return 7; }", + "int main() { if (0) { return 16; } else if (1) { return 5 * 12; } else { return 7; } }", 60 )] fn test_if_statements(#[case] test_case: String, #[case] expected: i32) -> std::io::Result<()> { @@ -687,19 +687,21 @@ mod tests { #[rstest::rstest] #[case( - "int res = 0; for (int i = 0; i <= 5; i = i + 1) { res = res + i * i; } return res;", + "int main() { int res = 0; for (int i = 0; i <= 5; i = i + 1) { res = res + i * i; } return res; }", 55 )] #[case( - "int a = 0; - int b = 1; - int c; - for (int i = 0; i <= 44; i = i + 1) { - c = b + a; - a = b; - b = c; - } - return c;", + "int main() { + int a = 0; + int b = 1; + int c; + for (int i = 0; i <= 44; i = i + 1) { + c = b + a; + a = b; + b = c; + } + return c; + }", 1836311903 )] fn test_for_statements( @@ -714,31 +716,35 @@ mod tests { // Calculating the maximum Fibonacci number that fits in a 32-bit integer. #[rstest::rstest] #[case( - "int x = 44; - int a = 0; - int b = 1; - int c; - while (x >= 0) { - c = a + b; - a = b; - b = c; - x = x - 1; - } - return c;", + "int main() { + int x = 44; + int a = 0; + int b = 1; + int c; + while (x >= 0) { + c = a + b; + a = b; + b = c; + x = x - 1; + } + return c; + }", 1836311903 )] #[case( - "int x = 5; int sum = 0; while (x != 0) { sum = sum + x * x; x = x - 1; } return sum; ", + "int main() { int x = 5; int sum = 0; while (x != 0) { sum = sum + x * x; x = x - 1; } return sum; }", 55 )] #[case( - "int n = 10; - int result = 1; - while (n > 1) { - result = result * n; - n = n - 1; - } - return result;", + "int main() { + int n = 10; + int result = 1; + while (n > 1) { + result = result * n; + n = n - 1; + } + return result; + }", 3628800 )] fn test_while_statements( @@ -752,7 +758,8 @@ mod tests { #[rstest::rstest] #[case( - "int x = 45;\ + "int main() {\ + int x = 45;\ int a = 0;\ int b = 1;\ int c;\ @@ -762,33 +769,36 @@ mod tests { b = c;\ x = x - 1;\ } while (x);\ - return c;", + return c;\ + }", 1836311903 )] #[case( - "int x = 6; - int y = 3; - int temp1 = x * y + 3; - int temp2 = x - y; - int result = 1; - int counter = 0; - - while (temp1 > temp2) { - temp2 = temp2 + x; - temp1 = temp1 - y; - counter = counter + 1; - } + "int main() { + int x = 6; + int y = 3; + int temp1 = x * y + 3; + int temp2 = x - y; + int result = 1; + int counter = 0; + + while (temp1 > temp2) { + temp2 = temp2 + x; + temp1 = temp1 - y; + counter = counter + 1; + } - while (counter > 0) { - result = result * temp1 + temp2; - counter = counter - 1; - } + while (counter > 0) { + result = result * temp1 + temp2; + counter = counter - 1; + } - return result;", + return result; + }", 465 )] #[case( - "int x = 5; int sum = 0; do { sum = sum + x * x; x = x - 1; } while (x); return sum; ", + "int main() { int x = 5; int sum = 0; do { sum = sum + x * x; x = x - 1; } while (x); return sum; }", 55 )] fn test_do_while_statements( @@ -799,4 +809,13 @@ mod tests { expect_exit_code(generated, expected)?; Ok(()) } + + #[rstest::rstest] + #[case("int not_main() { return 1; }")] + #[case("return 1;")] + #[should_panic] + fn test_program_without_main(#[case] test_case: String) { + let generated = generate_code(test_case); + expect_exit_code(generated, 1).unwrap(); + } } diff --git a/src/parser.rs b/src/parser.rs index a2d90fd..61e15e9 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -52,7 +52,7 @@ impl Parser { } result.push(self.parse_unit()); } - ASTNode::Program(result) + TranslationUnit(result) } fn parse_scope(&mut self) -> ASTNode { @@ -64,7 +64,7 @@ impl Parser { result.push(self.parse_unit()); } self.try_consume(TokenType::CloseCurly); - ASTNode::Scope(result) + Scope(result) } fn parse_statement(&mut self) -> ASTNode { @@ -368,10 +368,9 @@ mod tests { use super::*; use crate::ast::ASTNode::*; use crate::lexer::Lexer; - use rstest::rstest; #[rstest::rstest] - #[case("int x = 55;", Program( + #[case("int x = 55;", TranslationUnit( vec![ VariableDefinition( Token{value: "int".to_string(), token_type: TokenType::Type, pos: 0}, @@ -389,7 +388,7 @@ mod tests { } #[rstest::rstest] - #[case("return 123;", Program( + #[case("return 123;", TranslationUnit( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, Box::new(ExpressionNode(Expression::IntegerLiteral( @@ -404,7 +403,7 @@ mod tests { } #[rstest::rstest] - #[case("return 1 ^ 2;", Program( + #[case("return 1 ^ 2;", TranslationUnit( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, Box::new(ExpressionNode( @@ -420,7 +419,7 @@ mod tests { )) )]) )] - #[case("return 1 + 2 * 3;", Program( + #[case("return 1 + 2 * 3;", TranslationUnit( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, Box::new(ExpressionNode( @@ -442,7 +441,7 @@ mod tests { )) )]) )] - #[case("return 1 || x * 3;", Program( + #[case("return 1 || x * 3;", TranslationUnit( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, Box::new(ExpressionNode( @@ -465,7 +464,7 @@ mod tests { ) ]) )] - #[case("{ return 1 && x * 3; }", Program(vec![Scope( + #[case("{ return 1 && x * 3; }", TranslationUnit(vec![Scope( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 2}, Box::new(ExpressionNode( @@ -495,7 +494,7 @@ mod tests { } #[rstest::rstest] - #[case("return 1 != 2;", Program( + #[case("return 1 != 2;", TranslationUnit( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, Box::new(ExpressionNode( @@ -511,7 +510,7 @@ mod tests { )) )]) )] - #[case("return 1 >= 2;", Program( + #[case("return 1 >= 2;", TranslationUnit( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, Box::new(ExpressionNode( @@ -527,7 +526,7 @@ mod tests { )) )]) )] - #[case("return 1 <= 2;", Program( + #[case("return 1 <= 2;", TranslationUnit( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, Box::new(ExpressionNode( @@ -543,7 +542,7 @@ mod tests { )) )]) )] - #[case("return 1 < 2;", Program( + #[case("return 1 < 2;", TranslationUnit( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, Box::new(ExpressionNode( @@ -559,7 +558,7 @@ mod tests { )) )]) )] - #[case("return 1 > 2;", Program( + #[case("return 1 > 2;", TranslationUnit( vec![ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, Box::new(ExpressionNode( @@ -582,7 +581,7 @@ mod tests { } #[rstest::rstest] - #[case("if (true) { return 1; }", Program(vec![If( + #[case("if (true) { return 1; }", TranslationUnit(vec![If( Token { value: "if".to_string(), token_type: TokenType::If, pos: 0 }, Box::new(ExpressionNode(Expression::Variable( Token{value: "true".to_string(), token_type: TokenType::Identifier, pos: 4} @@ -604,7 +603,7 @@ mod tests { #[rstest::rstest] #[ case("return -123;", - Program( + TranslationUnit( vec![ ReturnStatement( Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, @@ -632,7 +631,7 @@ mod tests { } #[rstest::rstest] - #[case("if (ture) { return 1; } else { return 2; }", Program(vec![If( + #[case("if (ture) { return 1; } else { return 2; }", TranslationUnit(vec![If( Token { value: "if".to_string(), token_type: TokenType::If, pos: 0 }, Box::new(ExpressionNode(Expression::Variable( Token{value: "ture".to_string(), token_type: TokenType::Identifier, pos: 4} @@ -657,7 +656,7 @@ mod tests { } #[rstest::rstest] - #[case("if (ture) { return 1; } else if (false) { return 2; }", Program(vec![If( + #[case("if (ture) { return 1; } else if (false) { return 2; }", TranslationUnit(vec![If( Token { value: "if".to_string(), token_type: TokenType::If, pos: 0 }, Box::new(ExpressionNode(Expression::Variable( Token{value: "ture".to_string(), token_type: TokenType::Identifier, pos: 4} @@ -689,7 +688,7 @@ mod tests { } #[rstest::rstest] - #[case("if (ture) { return 1; } else if (false) { return 2; } else { return 3; }", Program(vec![If( + #[case("if (ture) { return 1; } else if (false) { return 2; } else { return 3; }", TranslationUnit(vec![If( Token { value: "if".to_string(), token_type: TokenType::If, pos: 0 }, Box::new(ExpressionNode(Expression::Variable( Token{value: "ture".to_string(), token_type: TokenType::Identifier, pos: 4} @@ -726,7 +725,7 @@ mod tests { } #[rstest::rstest] - #[case("if (ture) return 1; else if (false) return 2; else { return 3; }", Program(vec![If( + #[case("if (ture) return 1; else if (false) return 2; else { return 3; }", TranslationUnit(vec![If( Token { value: "if".to_string(), token_type: TokenType::If, pos: 0 }, Box::new(ExpressionNode(Expression::Variable( Token{value: "ture".to_string(), token_type: TokenType::Identifier, pos: 4} @@ -766,7 +765,7 @@ mod tests { } #[rstest::rstest] - #[case("while (true) { return 1; }", Program(vec![While( + #[case("while (true) { return 1; }", TranslationUnit(vec![While( Token { value: "while".to_string(), token_type: TokenType::While, pos: 0 }, Box::new(ExpressionNode(Expression::Variable( Token{value: "true".to_string(), token_type: TokenType::Identifier, pos: 7} @@ -778,7 +777,7 @@ mod tests { ))) )])) )]))] - #[case("while (true) return 1;", Program(vec![While( + #[case("while (true) return 1;", TranslationUnit(vec![While( Token { value: "while".to_string(), token_type: TokenType::While, pos: 0 }, Box::new(ExpressionNode(Expression::Variable( Token{value: "true".to_string(), token_type: TokenType::Identifier, pos: 7} @@ -797,7 +796,7 @@ mod tests { } #[rstest::rstest] - #[case("int x = 1; do { x = x + 1; } while (x);", Program(vec![ + #[case("int x = 1; do { x = x + 1; } while (x);", TranslationUnit(vec![ VariableDefinition( Token{value: "int".to_string(), token_type: TokenType::Type, pos: 0}, Token{value: "x".to_string(), token_type: TokenType::Identifier, pos: 4}, @@ -826,7 +825,7 @@ mod tests { ) ]))] #[rstest::rstest] - #[case("int x = 1; do x = x + 1; while (x);", Program(vec![ + #[case("int x = 1; do x = x + 1; while (x);", TranslationUnit(vec![ VariableDefinition( Token{value: "int".to_string(), token_type: TokenType::Type, pos: 0}, Token{value: "x".to_string(), token_type: TokenType::Identifier, pos: 4}, @@ -861,7 +860,7 @@ mod tests { } #[rstest::rstest] - #[case("for (;;);", Program(vec![For( + #[case("for (;;);", TranslationUnit(vec![For( Token{value: "for".to_string(), token_type: TokenType::For, pos: 0}, [ Box::new(ExpressionStatement(Expression::Empty)), @@ -870,7 +869,7 @@ mod tests { ], Box::new(Scope(vec![ExpressionStatement(Expression::Empty)])) )]))] - #[case("for (int i;;) { ;; }", Program(vec![For( + #[case("for (int i;;) { ;; }", TranslationUnit(vec![For( Token{value: "for".to_string(), token_type: TokenType::For, pos: 0}, [ Box::new(VariableDeclaration( @@ -882,7 +881,7 @@ mod tests { ], Box::new(Scope(vec![ExpressionStatement(Expression::Empty), ExpressionStatement(Expression::Empty)])) )]))] - #[case("for (int i = 1;;) {}", Program(vec![For( + #[case("for (int i = 1;;) {}", TranslationUnit(vec![For( Token{value: "for".to_string(), token_type: TokenType::For, pos: 0}, [ Box::new(VariableDefinition( @@ -897,7 +896,7 @@ mod tests { ], Box::new(Scope(vec![])) )]))] - #[case("for (; i < 10;) {}", Program(vec![For( + #[case("for (; i < 10;) {}", TranslationUnit(vec![For( Token{value: "for".to_string(), token_type: TokenType::For, pos: 0}, [ Box::new(ExpressionStatement(Expression::Empty)), @@ -914,7 +913,7 @@ mod tests { ], Box::new(Scope(vec![])) )]))] - #[case("for (; i = 1;) {}", Program(vec![For( + #[case("for (; i = 1;) {}", TranslationUnit(vec![For( Token{value: "for".to_string(), token_type: TokenType::For, pos: 0}, [ Box::new(ExpressionStatement(Expression::Empty)), @@ -928,7 +927,7 @@ mod tests { ], Box::new(Scope(vec![])) )]))] - #[case("for (;; i) {}", Program(vec![For( + #[case("for (;; i) {}", TranslationUnit(vec![For( Token{value: "for".to_string(), token_type: TokenType::For, pos: 0}, [ Box::new(ExpressionStatement(Expression::Empty)), @@ -939,7 +938,7 @@ mod tests { ], Box::new(Scope(vec![])) )]))] - #[case("for (;; i = i + 1) {}", Program(vec![For( + #[case("for (;; i = i + 1) {}", TranslationUnit(vec![For( Token{value: "for".to_string(), token_type: TokenType::For, pos: 0}, [ Box::new(ExpressionStatement(Expression::Empty)), @@ -959,7 +958,7 @@ mod tests { ], Box::new(Scope(vec![])) )]))] - #[case("for (int i = 1; i < 10; i = i + 1) {}", Program(vec![For( + #[case("for (int i = 1; i < 10; i = i + 1) {}", TranslationUnit(vec![For( Token{value: "for".to_string(), token_type: TokenType::For, pos: 0}, [ Box::new(VariableDefinition( @@ -1000,14 +999,14 @@ mod tests { } #[rstest::rstest] - #[case("int func();", Program(vec![ + #[case("int func();", TranslationUnit(vec![ FunctionDeclaration( Token { value: "int".to_string(), token_type: TokenType::Type, pos: 0 }, Token { value: "func".to_string(), token_type: TokenType::Identifier, pos: 4 }, vec![] ) ]))] - #[case("int func(int x);", Program(vec![ + #[case("int func(int x);", TranslationUnit(vec![ FunctionDeclaration( Token { value: "int".to_string(), token_type: TokenType::Type, pos: 0 }, Token { value: "func".to_string(), token_type: TokenType::Identifier, pos: 4 }, @@ -1019,7 +1018,7 @@ mod tests { ] ) ]))] - #[case("int func(int x, int y);", Program(vec![ + #[case("int func(int x, int y);", TranslationUnit(vec![ FunctionDeclaration( Token { value: "int".to_string(), token_type: TokenType::Type, pos: 0 }, Token { value: "func".to_string(), token_type: TokenType::Identifier, pos: 4 }, @@ -1042,14 +1041,14 @@ mod tests { } #[rstest::rstest] - #[case("int func();", Program(vec![ + #[case("int func();", TranslationUnit(vec![ FunctionDeclaration( Token { value: "int".to_string(), token_type: TokenType::Type, pos: 0 }, Token { value: "func".to_string(), token_type: TokenType::Identifier, pos: 4 }, vec![] ) ]))] - #[case("int func(int x);", Program(vec![ + #[case("int func(int x);", TranslationUnit(vec![ FunctionDeclaration( Token { value: "int".to_string(), token_type: TokenType::Type, pos: 0 }, Token { value: "func".to_string(), token_type: TokenType::Identifier, pos: 4 }, @@ -1061,7 +1060,7 @@ mod tests { ] ) ]))] - #[case("int func(int x, int y);", Program(vec![ + #[case("int func(int x, int y);", TranslationUnit(vec![ FunctionDeclaration( Token { value: "int".to_string(), token_type: TokenType::Type, pos: 0 }, Token { value: "func".to_string(), token_type: TokenType::Identifier, pos: 4 }, From b603c034a78f42f603d7fdcb076aca51f9c64d0f Mon Sep 17 00:00:00 2001 From: Osama Ahmad Date: Fri, 6 Oct 2023 09:26:44 +0300 Subject: [PATCH 2/2] Generate code for function invocation --- src/ast.rs | 1 + src/code_generator.rs | 78 ++++++++++++++++++++++++++++++++++++++++++- src/parser.rs | 77 +++++++++++++++++++++++++++++++++++++++++- src/symbol_table.rs | 12 +++++-- 4 files changed, 164 insertions(+), 4 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index dfb8a1f..e59ff1d 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -8,6 +8,7 @@ pub enum Expression { Unary(Token, Box), Parenthesized(Box), Assignment(Token, Box), + FunctionCall(Token, Vec), Empty, } diff --git a/src/code_generator.rs b/src/code_generator.rs index 6ae68b9..188ee86 100644 --- a/src/code_generator.rs +++ b/src/code_generator.rs @@ -155,12 +155,43 @@ impl CodeGenerator { Expression::Binary(_, _, _) => self.generate_binary_expression(expression), Expression::Unary(_, _) => self.generate_unary_expression(expression), Expression::Assignment(..) => self.generate_assignment(expression), + Expression::FunctionCall(..) => self.generate_function_call(expression), Expression::Parenthesized(internal_expression) => { self.generate_expression(internal_expression) } } } + fn get_expression_size_in_bytes(exp: &Expression) -> usize { + 4 + } + + fn generate_function_call(&mut self, call: &Expression) -> String { + let mut result = String::new(); + match call { + Expression::FunctionCall(name, parameters) => { + // TODO extract pointer size into a function (and use it in the symbol table too) + let mut push_offset = -16; // return address + rbp + for param in parameters { + let param_size = CodeGenerator::get_expression_size_in_bytes(param); + // Since we're dealing with the stack, subtraction of the offset occurs first + push_offset -= param_size as i32; + result.push_str(&format!( + "{computation}\ + {mov} {result}, {offset}(%rsp)\n", + computation = self.generate_expression(param), + mov = CodeGenerator::mov_mnemonic(param_size), + result = CodeGenerator::get_reg1(param_size), + offset = push_offset + )); + } + result.push_str(&format!("call {}\n", name.value)); + result + } + _ => panic!("Exp"), + } + } + fn generate_unary_expression(&mut self, expression: &Expression) -> String { let mut result = String::new(); match expression { @@ -302,7 +333,7 @@ impl CodeGenerator { format!( "{} {}(%rbp), {}\n", mov_instruction, - stack_offset, + *stack_offset, CodeGenerator::get_reg1(definition.size()) ) } @@ -818,4 +849,49 @@ mod tests { let generated = generate_code(test_case); expect_exit_code(generated, 1).unwrap(); } + + #[rstest::rstest] + #[case( + "int f(int y, int x) { return x + y; } + + int main() { + return f(6, 7); + }", + 13 + )] + #[case( + " + int add(int y, int x) { + return x + y; + } + + int ident(int x) { return x; } + + int fib(int n) { + + int a = 0; + int b = 1; + int c; + while (n >= 0) { + c = add(a, b); + a = ident(b); + b = ident(ident(c)); + n = ident(n - 1); + } + return ident(c); + } + + int main() { + return fib(3) * fib(5) - fib(2); + }", + 62 + )] + fn test_function_invocation( + #[case] test_case: String, + #[case] expected: i32, + ) -> std::io::Result<()> { + let generated = generate_code(test_case); + expect_exit_code(generated, expected)?; + Ok(()) + } } diff --git a/src/parser.rs b/src/parser.rs index 61e15e9..230b19e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -250,9 +250,38 @@ impl Parser { expr } + fn parse_function_arguments(&mut self) -> Vec { + let mut result = Vec::new(); + self.try_consume(TokenType::OpenParen); + while self.current().token_type != TokenType::CloseParen + && self.current().token_type != TokenType::Eof + { + result.push(self.parse_expression()); + if self.current().token_type != TokenType::CloseParen { + self.try_consume(TokenType::Comma); + } + } + self.try_consume(TokenType::CloseParen); + result + } + + fn parse_function_call(&mut self) -> Expression { + let identifier = self.try_consume(TokenType::Identifier); + let arguments = self.parse_function_arguments(); + Expression::FunctionCall(identifier, arguments) + } + + fn parse_primary_expression_starting_with_identifier(&mut self) -> Expression { + if self.peak(1).token_type == TokenType::OpenParen { + self.parse_function_call() + } else { + Expression::Variable(self.consume()) + } + } + fn parse_primary_expression(&mut self) -> Expression { match self.current().token_type { - TokenType::Identifier => Expression::Variable(self.consume()), + TokenType::Identifier => self.parse_primary_expression_starting_with_identifier(), TokenType::IntegerLiteral => Expression::IntegerLiteral(self.consume()), TokenType::OpenParen => self.parse_parenthesized_expression(), _ => panic!("Unexpected token: {:?}", self.current()), @@ -1082,6 +1111,52 @@ mod tests { assert_eq!(expected, result); } + #[rstest::rstest] + #[case("f();", TranslationUnit(vec![ + ExpressionStatement( + Expression::FunctionCall( + Token { value: "f".to_string(), token_type: TokenType::Identifier, pos: 0}, + vec![] + ) + )] + ))] + #[case("f(1);", TranslationUnit(vec![ + ExpressionStatement( + Expression::FunctionCall( + Token { value: "f".to_string(), token_type: TokenType::Identifier, pos: 0}, + vec![ + Expression::IntegerLiteral( + Token { value: "1".to_string(), token_type: TokenType::IntegerLiteral, pos: 2} + ) + ] + ) + )] + ))] + #[case("f(1, x, g());", TranslationUnit(vec![ + ExpressionStatement( + Expression::FunctionCall( + Token { value: "f".to_string(), token_type: TokenType::Identifier, pos: 0}, + vec![ + Expression::IntegerLiteral( + Token { value: "1".to_string(), token_type: TokenType::IntegerLiteral, pos: 2} + ), + Expression::Variable( + Token { value: "x".to_string(), token_type: TokenType::Identifier, pos: 5} + ), + Expression::FunctionCall( + Token { value: "g".to_string(), token_type: TokenType::Identifier, pos: 8}, + vec![] + ) + ] + ) + )] + ))] + fn test_parse_function_call(#[case] test_case: String, #[case] expected: ASTNode) { + let tokens = Lexer::new(test_case.clone()).lex(); + let result = Parser::new(tokens).parse(); + assert_eq!(expected, result); + } + #[rstest::rstest] #[case("return 123")] #[case("int x")] diff --git a/src/symbol_table.rs b/src/symbol_table.rs index 92a039e..e2b6f1a 100644 --- a/src/symbol_table.rs +++ b/src/symbol_table.rs @@ -16,9 +16,16 @@ pub enum Symbol { } impl Symbol { - pub fn size(&self) -> usize { + fn get_type_size_in_bytes(type_str: &str) -> usize { 4 } + + pub fn size(&self) -> usize { + match self { + Variable { variable_type, .. } => Symbol::get_type_size_in_bytes(variable_type), + Symbol::Function { .. } => panic!(), + } + } } pub struct Scope { @@ -52,7 +59,8 @@ impl Scope { symbol_name, &Variable { variable_type: String::from(variable_type), - stack_offset: self.stack_top, + stack_offset: self.stack_top + - (Symbol::get_type_size_in_bytes(variable_type) as isize), }, ) }