Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLVM: If statements, while loops, binary expressions, and expression statements #17

Merged
merged 7 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 257 additions & 16 deletions src/code_generation/llvm/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub struct LLVMGenerator<'ctx> {
builder: Builder<'ctx>,
module: Module<'ctx>,
symbol_table: SymbolTable<'ctx>,
current_function: Option<FunctionValue<'ctx>>, // The function currently being generated
counter: i32, // For naming labels of blocks
}

impl<'ctx> LLVMGenerator<'ctx> {
Expand All @@ -30,6 +32,8 @@ impl<'ctx> LLVMGenerator<'ctx> {
builder,
module,
symbol_table,
current_function: None,
counter: 0,
}
}
}
Expand All @@ -51,11 +55,11 @@ impl<'ctx> LLVMGenerator<'ctx> {
FunctionDeclaration(..) => self.generate_function_declaration(node).as_any_value_enum(),
FunctionDefinition(..) => self.generate_function_definition(node).as_any_value_enum(),
ExpressionNode(expression) => self.generate_expression(expression).as_any_value_enum(),
// Scope(..) => self.generate_scope(node),
// If(..) => self.generate_if_statement(node),
// While(..) => self.generate_while(node),
Scope(..) => self.generate_scope(node).as_any_value_enum(),
If(..) => self.generate_if_statement(node).as_any_value_enum(),
While(..) => self.generate_while(node).as_any_value_enum(),
// DoWhile(..) => self.generate_do_while(node),
// ExpressionStatement(..) => self.generate_expression_statement(node),
ExpressionStatement(..) => self.generate_expression_statement(node).as_any_value_enum(),
// For(..) => self.generate_for(node),
_ => panic!(),
}
Expand Down Expand Up @@ -280,7 +284,6 @@ impl<'ctx> LLVMGenerator<'ctx> {
),
}
}
// TODO add the function to the symbol table.
function
}
_ => panic!(
Expand All @@ -299,6 +302,8 @@ impl<'ctx> LLVMGenerator<'ctx> {
params.clone(),
));

self.current_function = Some(declaration);

self.symbol_table.push_scope();

let basic_block = self.context.append_basic_block(declaration, "entry");
Expand Down Expand Up @@ -386,10 +391,97 @@ impl<'ctx> LLVMGenerator<'ctx> {
match expression {
Expression::IntegerLiteral { .. } => self.generate_integer_literal(expression),
Expression::Variable(name) => self.generate_variable_expression(name),
Expression::Binary(op, lhs, rhs) => self.generate_binary_expression(op, lhs, rhs),
Expression::Assignment(lhs, rhs) => self.generate_assignment(lhs, rhs),
_ => todo!(),
}
}

fn generate_assignment(&mut self, lhs: &Token, rhs: &Expression) -> BasicValueEnum<'ctx> {
if self.is_in_global_scope() {
panic!("Assignment to variable {} not allowed in global scope", lhs.value);
}
if self.symbol_table.find_hierarchically(&lhs.value).is_none() {
panic!("Reference to undefined variable `{}`", lhs.value);
}

let rhs = self.generate_expression(&rhs);

// FIXME make sure that assignments to global variables are correct
self.builder
.build_store(
self.symbol_table
.find_hierarchically(&lhs.value)
.unwrap()
.pointer,
rhs,
)
.unwrap();

rhs.clone()
}

fn generate_binary_expression(
&mut self,
token: &Token,
lhs: &Expression,
rhs: &Expression,
) -> BasicValueEnum<'ctx> {
let lhs = self.generate_expression(lhs).into_int_value();
let rhs = self.generate_expression(rhs).into_int_value();
// TODO implement the rest of the binary expressions
match token.token_type {
TokenType::EqualsEquals => {
self.builder
.build_int_compare(inkwell::IntPredicate::EQ, lhs, rhs, "bool_value")
}
TokenType::NotEquals => {
self.builder
.build_int_compare(inkwell::IntPredicate::NE, lhs, rhs, "bool_value")
}
TokenType::GreaterThan => {
self.builder
.build_int_compare(inkwell::IntPredicate::SGT, lhs, rhs, "bool_value")
}
TokenType::GreaterThanEquals => {
self.builder
.build_int_compare(inkwell::IntPredicate::SGE, lhs, rhs, "bool_value")
}
TokenType::LessThan => {
self.builder
.build_int_compare(inkwell::IntPredicate::SLT, lhs, rhs, "bool_value")
}
TokenType::LessThanEquals => {
self.builder
.build_int_compare(inkwell::IntPredicate::SLE, lhs, rhs, "bool_value")
}
TokenType::Plus => self.builder.build_int_add(lhs, rhs, "temp_add"),
TokenType::Minus => self.builder.build_int_sub(lhs, rhs, "temp_sub"),
TokenType::Star => self.builder.build_int_mul(lhs, rhs, "temp_mul"),
TokenType::Slash => self.builder.build_int_signed_div(lhs, rhs, "temp_div"),
TokenType::And => self.builder.build_and(lhs, rhs, "temp_and"),
TokenType::Bar => self.builder.build_or(lhs, rhs, "temp_or"),
TokenType::AndAnd | TokenType::BarBar => {
let lhs = self
.builder
.build_int_cast(lhs, self.context.bool_type(), "bool_lhs")
.unwrap();
let rhs = self
.builder
.build_int_cast(rhs, self.context.bool_type(), "bool_rhs")
.unwrap();
if token.token_type == TokenType::AndAnd {
self.builder.build_and(lhs, rhs, "temp_logical_and")
} else {
self.builder.build_or(lhs, rhs, "temp_logical_or")
}
}
_ => panic!(),
}
.unwrap()
.as_basic_value_enum()
}

fn generate_variable_expression(&mut self, name: &Token) -> BasicValueEnum<'ctx> {
if let Some(variable) = self.symbol_table.find(&name.value) {
self.builder
Expand Down Expand Up @@ -422,25 +514,159 @@ impl<'ctx> LLVMGenerator<'ctx> {
_ => panic!(),
}
}
fn generate_scope(&mut self, node: &ASTNode) -> IntValue<'ctx> {
self.symbol_table.push_scope();

fn generate_scope(&mut self, scope: &ASTNode) -> String {
todo!()
let statements = match node {
Scope(statements) => statements,
_ => panic!(
"Internal error: expected translation unit, found: {:?}",
node
),
};

for statement in statements {
self.generate(statement);
}

self.symbol_table.pop_scope();

self.context.i32_type().const_int(0, false)
}

fn generate_if_statement(&mut self, node: &ASTNode) -> String {
todo!()
fn generate_if_statement(&mut self, node: &ASTNode) -> IntValue<'ctx> {
let (condition_node, then_node, else_node) = match node {
If(_, condition_node, then_node, else_node) => (condition_node, then_node, else_node),
_ => panic!(),
};

let counter = self.counter;

let then_block = self
.context
.append_basic_block(self.current_function.unwrap(), &format!("then_{counter}"));
let end_block = self
.context
.append_basic_block(self.current_function.unwrap(), &format!("if_end_{counter}"));
let else_block = if else_node.is_some() {
self.context
.append_basic_block(self.current_function.unwrap(), &format!("else_{counter}"))
} else {
end_block
};

self.counter += 1;

let condition = match condition_node.as_ref() {
ASTNode::ExpressionNode(expression) => expression,
_ => panic!(),
};

let cond_result = self.generate_expression(condition);

let zero = self.context.i32_type().const_int(0, false);
let i32_value = self
.builder
.build_int_z_extend(
cond_result.into_int_value(),
self.context.i32_type(),
"extended_condition",
)
.unwrap();
let bool_value = self
.builder
.build_int_compare(inkwell::IntPredicate::NE, i32_value, zero, "bool_value")
.unwrap();
self.builder
.build_conditional_branch(bool_value, then_block, else_block)
.unwrap();

self.builder.position_at_end(then_block);

self.generate(then_node);
self.builder.build_unconditional_branch(end_block).unwrap();

if else_node.is_some() {
self.builder.position_at_end(else_block);
self.generate(else_node.as_ref().unwrap());
self.builder.build_unconditional_branch(end_block).unwrap();
}

self.builder.position_at_end(end_block);

zero
}

fn generate_while(&mut self, while_node: &ASTNode) -> String {
todo!()
fn generate_while(&mut self, node: &ASTNode) -> IntValue<'ctx> {
let counter = self.counter;
let cond_block = self.context.append_basic_block(
self.current_function.unwrap(),
&format!("while_condition_{counter}"),
);
let body_block = self.context.append_basic_block(
self.current_function.unwrap(),
&format!("while_body_{counter}"),
);
let end_block = self.context.append_basic_block(
self.current_function.unwrap(),
&format!("while_end_{counter}"),
);

self.counter += 1;

self.builder.build_unconditional_branch(cond_block).unwrap();

self.builder.position_at_end(cond_block);

let (condition_node, body_node) = match node {
While(_, condition_node, body_node) => (condition_node, body_node),
_ => panic!(),
};

let condition = match condition_node.as_ref() {
ASTNode::ExpressionNode(expression) => expression,
_ => panic!(),
};

let cond_result = self.generate_expression(condition);
let i32_value = self
.builder
.build_int_z_extend(
cond_result.into_int_value(),
self.context.i32_type(),
"extended_condition",
)
.unwrap();
let zero = self.context.i32_type().const_int(0, false);
let bool_value = self
.builder
.build_int_compare(inkwell::IntPredicate::NE, i32_value, zero, "bool_value")
.unwrap();

self.builder
.build_conditional_branch(bool_value, body_block, end_block)
.unwrap();

self.builder.position_at_end(body_block);

self.generate(body_node);
self.builder.build_unconditional_branch(cond_block).unwrap();

self.builder.position_at_end(end_block);

self.context.i32_type().const_int(0, false)
}

fn generate_do_while(&mut self, node: &ASTNode) -> String {
todo!()
}

fn generate_expression_statement(&mut self, node: &ASTNode) -> String {
todo!()
fn generate_expression_statement(&mut self, node: &ASTNode) -> BasicValueEnum<'ctx> {
let expression = match node {
ExpressionStatement(expression) => expression,
_ => panic!(),
};
return self.generate_expression(expression);
}

fn generate_for(&mut self, node: &ASTNode) -> String {
Expand Down Expand Up @@ -472,9 +698,9 @@ mod tests {
code_generation::llvm::generator::LLVMGenerator::new(&mut context).generate(&ast);
let exit_code = interpret_llvm_ir(&generated_ir);
assert_eq!(
test_case.expected, exit_code,
"Test case: {} -- Expected: {}, found: {}",
test_case.name, test_case.expected, exit_code
test_case.expected % 256, exit_code % 256,
"Test case: {} -- Expected: {}, found: {}\nGenerated IR:\n{}",
test_case.name, test_case.expected, exit_code, generated_ir
);
}
}
Expand All @@ -489,4 +715,19 @@ mod tests {
fn test_erroneous_variable_declarations_and_definitions() {
run_tests_from_file("./src/tests/variables_error.c");
}

#[test]
fn test_basic_if() {
run_tests_from_file("./src/tests/if.c");
}

#[test]
fn test_while() {
run_tests_from_file("./src/tests/while.c");
}

#[test]
fn test_assignment() {
run_tests_from_file("./src/tests/assignment.c");
}
}
8 changes: 8 additions & 0 deletions src/code_generation/llvm/symbol_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ impl<'ctx> SymbolTable<'ctx> {
None
}

pub fn find_hierarchically(&self, name: &str) -> Option<&Variable<'ctx>> {
let mut result = self.find(&name);
if result.is_none() {
result = self.find_in_global_scope(&name);
}
result
}

pub fn find_in_current_scope(&self, name: &str) -> Option<&Variable<'ctx>> {
if let Some(scope) = self.scopes.last() {
return scope.get(name);
Expand Down
Loading