diff --git a/src/code_generation/llvm/generator.rs b/src/code_generation/llvm/generator.rs index 6e72bc8..f7298b6 100644 --- a/src/code_generation/llvm/generator.rs +++ b/src/code_generation/llvm/generator.rs @@ -398,8 +398,34 @@ impl<'ctx> LLVMGenerator<'ctx> { self.generate_unary_expression(operator, operand) } Expression::Assignment(lhs, rhs) => self.generate_assignment(lhs, rhs), - _ => todo!(), + Expression::Parenthesized(expression) => self.generate_expression(expression), + Expression::FunctionCall(func_name, args) => { + self.generate_function_call(&func_name.value, args) + } + Expression::Empty => self + .context + .i32_type() + .const_int(0, false) + .as_basic_value_enum(), + } + } + + fn generate_function_call( + &mut self, + name: &str, + args: &Vec, + ) -> BasicValueEnum<'ctx> { + let function = self.module.get_function(name).unwrap(); + let mut arguments = Vec::new(); + for arg in args { + arguments.push(self.generate_expression(arg).into()); } + self.builder + .build_call(function, &arguments, "call") + .unwrap() + .try_as_basic_value() + .left() + .unwrap() } fn generate_unary_expression( @@ -834,16 +860,28 @@ mod tests { let ast = syntax_analysis::parser::Parser::new(tokens).parse(); let generated_ir = code_generation::llvm::generator::LLVMGenerator::new(&mut context).generate(&ast); - let exit_code = interpret_llvm_ir(&generated_ir); - assert_eq!( - test_case.expected % 256, - exit_code % 256, - "Test case: {} -- Expected: {}, found: {}\nGenerated IR:\n{}", - test_case.name, - test_case.expected, - exit_code, - generated_ir - ); + let (exit_code, stdout_str) = interpret_llvm_ir(&generated_ir); + if test_case.expected_exit_code.is_some() { + let expected_exit_code = test_case.expected_exit_code.unwrap(); + assert_eq!( + (expected_exit_code + 256) % 256, + (exit_code + 256) % 256, + "Test case: {} -- Expected exit code: {}, found: {}\nGenerated IR:\n{}", + test_case.name, + expected_exit_code, + exit_code, + generated_ir + ); + } + + if test_case.expected_output.is_some() { + let expected_output = test_case.expected_output.unwrap(); + assert_eq!( + expected_output, stdout_str, + "Test case: {} -- Expected stdout: {}, found: {}\nGenerated IR:\n{}", + test_case.name, expected_output, stdout_str, generated_ir + ); + } } } @@ -887,4 +925,9 @@ mod tests { fn test_for() { run_tests_from_file("./src/tests/for.c"); } + + #[test] + fn test_function_calls() { + run_tests_from_file("./src/tests/function_calls.c"); + } } diff --git a/src/syntax_analysis/parser.rs b/src/syntax_analysis/parser.rs index c672376..63a90e1 100644 --- a/src/syntax_analysis/parser.rs +++ b/src/syntax_analysis/parser.rs @@ -28,8 +28,11 @@ fn binary_operator_precedence(token_type: TokenType) -> u8 { } } -fn is_unary_operator(token_type: &TokenType) -> bool { - *token_type == TokenType::Plus || *token_type == TokenType::Minus +fn unary_operator_precedence(token_type: &TokenType) -> u8 { + match token_type { + TokenType::Plus | TokenType::Minus | TokenType::Bang => 10, + _ => 0, + } } impl Parser { @@ -240,10 +243,16 @@ impl Parser { if self.is_assignment() { return self.parse_assignment_expression(); } - if is_unary_operator(&self.current().token_type) { - return Expression::Unary(self.consume().clone(), Box::new(self.parse_expression())); + let mut left = None; + let unary_op_precedence = unary_operator_precedence(&self.current().token_type); + if unary_op_precedence != 0 && unary_op_precedence >= parent_precedence { + left = Some(Expression::Unary( + self.consume().clone(), + Box::new(self.parse_expression_internal(unary_op_precedence)), + )); + } else { + left = Some(self.parse_primary_expression()); } - let mut left = self.parse_primary_expression(); loop { let operator_token = self.current().clone(); let operator_precedence = binary_operator_precedence(operator_token.token_type.clone()); @@ -255,9 +264,13 @@ impl Parser { } self.advance(); let right = self.parse_expression_internal(operator_precedence); - left = Expression::Binary(operator_token, Box::new(left), Box::new(right)) + left = Some(Expression::Binary( + operator_token, + Box::new(left.unwrap()), + Box::new(right), + )) } - left + left.unwrap() } fn parse_parenthesized_expression(&mut self) -> Expression { @@ -690,6 +703,37 @@ mod tests { ) ) ] + #[case("return -2 + 1;", + TranslationUnit( + vec![ + ReturnStatement( + Token{value: "return".to_string(), token_type: TokenType::Return, pos: 0}, + Box::new( + ExpressionNode( + Expression::Binary( + Token{value: "+".to_string(), token_type: TokenType::Plus, pos: 10}, + Box::new( + Expression::Unary( + Token{value: "-".to_string(), token_type: TokenType::Minus, pos: 7}, + Box::new( + Expression::IntegerLiteral( + Token{value: "2".to_string(), token_type: TokenType::IntegerLiteral, pos: 8} + ) + ) + ) + ), + Box::new( + Expression::IntegerLiteral( + Token{value: "1".to_string(), token_type: TokenType::IntegerLiteral, pos: 12} + ) + ) + ) + ) + ) + ) + ] + ) + )] fn test_parse_unary_expression(#[case] test_case: String, #[case] expected: ASTNode) { let tokens = Lexer::new(test_case).lex(); let result = Parser::new(tokens).parse(); diff --git a/src/tests/function_calls.c b/src/tests/function_calls.c new file mode 100644 index 0000000..f8faa03 --- /dev/null +++ b/src/tests/function_calls.c @@ -0,0 +1,37 @@ +// CASE Basic function call +// RETURNS 1 + +int f(int x) { + return x - 1; +} + +int main() { + return f(2); +} + +// CASE Recursive Fibonacci +// RETURNS 3 + +int fib(int n) { + if (n <= 1) return n; + return fib(n-1) + fib(n-2); +} + +int main() { + return fib(4); +} + +// CASE Print to stdout +// RETURNS 0 +// Outputs Hello + +int putchar(int c); + +int main() { + putchar(72); + putchar(101); + putchar(108); + putchar(108); + putchar(111); + return 0; +} diff --git a/src/tests/operators.c b/src/tests/operators.c index f03a4a8..6011f2c 100644 --- a/src/tests/operators.c +++ b/src/tests/operators.c @@ -1,33 +1,33 @@ -// Case Binary operator + -// Returns 3 +// CASE Binary operator + +// RETURNS 3 int main() { return 1 + 2; } -// Case Binary operator - -// Returns -1 +// CASE Binary operator - +// RETURNS -1 int main() { return 1 - 2; } -// Case Binary operator * -// Returns 10 +// CASE Binary operator * +// RETURNS 10 int main() { return 5 * 2; } -// Case Binary operator * -// Returns 4 +// CASE Binary operator * +// RETURNS 4 int main() { return 9 / 2; } -// Case Binary arithmetic operators combined -// Returns 9697 +// CASE Binary arithmetic operators combined +// RETURNS 9697 int main() { int x = 312; @@ -37,43 +37,56 @@ int main() { return z - x; } -// Case Binary operator || -// Returns 1 +// CASE Binary operator || +// RETURNS 1 int main() { - int false = 0; int true = 123; return true || false; + int false = 0; int true = 123; int y = true || false; return y; } -// Case Binary operator && -// Returns 0 +// CASE Binary operator && +// RETURNS 0 int main() { - int false = 0; int true = 123; return true && false; + int false = 0; int true = 123; int y = true && false; return y; } -// Case Binary operator ^ -// Returns 123 +// CASE Binary operator ^ +// RETURNS 123 int main() { int false = 0; int true = 123; return true ^ false; } -// Case Unary operator ! 1 -// Returns 1 +// CASE Unary operator ! 1 +// RETURNS 1 int main() { - return !0; + int x = !0; + return x; } -// Case Unary operator ! 2 -// Returns 0 +// CASE Unary operator ! 2 +// RETURNS 0 int main() { - return !1; + int x = !1; + return x; } -// Case Unary +- -// Returns 144 +// CASE Unary +- +// RETURNS 144 + int main() { int x = -----++++----+------12; return x * x; -} \ No newline at end of file +} + +// CASE Binary, unary and parenthesized expressions +// RETURNS -1 + +int main() { + int x = 3; + int y = 4; + int z = 5; + return (-x + y) * (z - x) / (-y + z) - x; +} diff --git a/src/utils/test_utils.rs b/src/utils/test_utils.rs index cd80e1c..489856c 100644 --- a/src/utils/test_utils.rs +++ b/src/utils/test_utils.rs @@ -46,7 +46,8 @@ pub fn expect_exit_code(source: String, expected: i32) -> std::io::Result<()> { pub struct TestCase { pub name: String, pub source: String, - pub expected: i32, + pub expected_exit_code: Option, + pub expected_output: Option, } pub fn parse_test_file(path: &str) -> Vec { @@ -56,26 +57,44 @@ pub fn parse_test_file(path: &str) -> Vec { let test_strings = contents.split("// CASE ").skip(1).collect::>(); let mut result = vec![]; for test_string in test_strings { - let mut lines = test_string.lines(); - let name = lines.next().unwrap().clone().to_string(); - let expected = lines - .next() - .unwrap() - .strip_prefix("// RETURNS ") - .unwrap() - .parse::() - .unwrap(); - let source = lines.collect::>().join("\n"); - result.push(TestCase { - name, - source, - expected, - }); + let mut lines = test_string.lines().peekable(); + + let mut test_case = TestCase { + name: lines.next().unwrap().to_string(), + source: "".to_string(), + expected_exit_code: None, + expected_output: None, + }; + + while lines.peek().is_some() { + let line = lines.peek().unwrap().to_string(); + + if line.is_empty() { + lines.next().unwrap(); + continue; + } + + let line = line.split_ascii_whitespace().collect::>(); + if line[0] != "//" { + break; + } + + match line[1].to_ascii_lowercase().as_str() { + "returns" => test_case.expected_exit_code = Some(line[2].parse::().unwrap()), + "outputs" => test_case.expected_output = Some(line[2..].join(" ")), + _ => panic!("Invalid test case metadata: {}", line.join(" ")), + } + + lines.next().unwrap(); + } + + test_case.source = lines.collect::>().join("\n"); + result.push(test_case); } result } -pub fn interpret_llvm_ir(ir: &str) -> i32 { +pub fn interpret_llvm_ir(ir: &str) -> (i32, String) { let id = Uuid::new_v4(); let ir_path = format!("./{}.ll", id); let mut ir_file = File::create(&ir_path).unwrap(); @@ -87,6 +106,7 @@ pub fn interpret_llvm_ir(ir: &str) -> i32 { .output() .expect("Failed to compile generated code"); let exit_code = output.status; + let stdout_str = String::from_utf8(output.stdout).unwrap(); remove_file(&ir_path).unwrap(); - return exit_code.code().unwrap(); + return (exit_code.code().unwrap(), stdout_str); }