Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support binary expressions for pointers (#416) #422

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
@@ -1134,6 +1134,7 @@ impl Display for Operator {
Operator::Multiplication => "*",
Operator::Division => "/",
Operator::Equal => "=",
Operator::Modulo => "MOD",
_ => unimplemented!(),
};
f.write_str(symbol)
141 changes: 141 additions & 0 deletions src/codegen/generators/expression_generator.rs
Original file line number Diff line number Diff line change
@@ -213,6 +213,13 @@ impl<'a, 'b> ExpressionCodeGenerator<'a, 'b> {
self.generate_expression(left)?,
self.generate_expression(right)?,
))
} else if (ltype.is_pointer() && rtype.is_int())
|| (ltype.is_int() && rtype.is_pointer())
|| (ltype.is_pointer() && rtype.is_pointer())
{
self.create_llvm_binary_expression_for_pointer(
operator, left, ltype, right, rtype, expression,
)
} else {
self.create_llvm_generic_binary_expression(operator, left, right, expression)
}
@@ -1008,6 +1015,140 @@ impl<'a, 'b> ExpressionCodeGenerator<'a, 'b> {
})
}

/// generates the result of an pointer binary-expression
///
/// - `operator` the binary operator
/// - `left` the left side of the binary expression, needs to be an pointer/int-value
/// - `left_type` DataTypeInformation of the left side
/// - `right` the right side of the binary expression, needs to be an pointer/int-value
/// - `right_type` DataTypeInformation of the right side
/// - `expression` the binary expression
pub fn create_llvm_binary_expression_for_pointer(
&self,
operator: &Operator,
left: &AstStatement,
left_type: &DataTypeInformation,
right: &AstStatement,
right_type: &DataTypeInformation,
expression: &AstStatement,
) -> Result<BasicValueEnum<'a>, Diagnostic> {
let left_expr = self.generate_expression(left)?;
let right_expr = self.generate_expression(right)?;

let result = match operator {
Operator::Plus | Operator::Minus => {
let (ptr, index, name) = if left_type.is_pointer() && right_type.is_int() {
let ptr = left_expr.into_pointer_value();
let index = right_expr.into_int_value();
let name = format!("access_{}", left_type.get_name());
(Some(ptr), Some(index), Some(name))
} else if left_type.is_int() && right_type.is_pointer() {
let ptr = right_expr.into_pointer_value();
let index = left_expr.into_int_value();
let name = format!("access_{}", right_type.get_name());
(Some(ptr), Some(index), Some(name))
} else {
// if left and right are both pointers we can not perform plus/minus
(None, None, None)
};

if let (Some(ptr), Some(mut index), Some(name)) = (ptr, index, name) {
// if operator is minus we need to negate the index
if let Operator::Minus = operator {
index = index.const_neg();
}

Ok(self
.llvm
.load_array_element(ptr, &[index], name.as_str())?
.as_basic_value_enum())
} else {
Err(Diagnostic::codegen_error(
format!("'{}' operation must contain one int type", operator).as_str(),
expression.get_location(),
))
}
}
Operator::Equal => Ok(self
.llvm
.builder
.build_int_compare(
IntPredicate::EQ,
self.convert_to_int_value_if_pointer(left_expr),
self.convert_to_int_value_if_pointer(right_expr),
"tmpVar",
)
.as_basic_value_enum()),
Operator::NotEqual => Ok(self
.llvm
.builder
.build_int_compare(
IntPredicate::NE,
self.convert_to_int_value_if_pointer(left_expr),
self.convert_to_int_value_if_pointer(right_expr),
"tmpVar",
)
.as_basic_value_enum()),
Operator::Less => Ok(self
.llvm
.builder
.build_int_compare(
IntPredicate::SLT,
self.convert_to_int_value_if_pointer(left_expr),
self.convert_to_int_value_if_pointer(right_expr),
"tmpVar",
)
.as_basic_value_enum()),
Operator::Greater => Ok(self
.llvm
.builder
.build_int_compare(
IntPredicate::SGT,
self.convert_to_int_value_if_pointer(left_expr),
self.convert_to_int_value_if_pointer(right_expr),
"tmpVar",
)
.as_basic_value_enum()),
Operator::LessOrEqual => Ok(self
.llvm
.builder
.build_int_compare(
IntPredicate::SLE,
self.convert_to_int_value_if_pointer(left_expr),
self.convert_to_int_value_if_pointer(right_expr),
"tmpVar",
)
.as_basic_value_enum()),
Operator::GreaterOrEqual => Ok(self
.llvm
.builder
.build_int_compare(
IntPredicate::SGE,
self.convert_to_int_value_if_pointer(left_expr),
self.convert_to_int_value_if_pointer(right_expr),
"tmpVar",
)
.as_basic_value_enum()),
_ => Err(Diagnostic::codegen_error(
format!("Operator '{}' unimplemented for pointers", operator).as_str(),
expression.get_location(),
)),
};

result
}

pub fn convert_to_int_value_if_pointer(&self, value: BasicValueEnum<'a>) -> IntValue<'a> {
match value {
BasicValueEnum::PointerValue(v) => {
let int_type = v.get_type().size_of().get_type();
v.const_to_int(int_type)
}
BasicValueEnum::IntValue(v) => v,
_ => unimplemented!(),
}
}

/// generates the result of an int/bool binary-expression (+, -, *, /, %, ==)
///
/// - `operator` the binary operator
92 changes: 92 additions & 0 deletions src/codegen/tests/codegen_error_messages_tests.rs
Original file line number Diff line number Diff line change
@@ -437,3 +437,95 @@ fn assigning_empty_string_literal_to_wide_char_results_in_error() {
panic!("expected code-gen error but got none")
}
}

#[test]
fn pointer_binary_expression_adding_two_pointers() {
let result = codegen_without_unwrap(
r#"
PROGRAM mainProg
VAR
x : INT;
ptr : REF_TO INT;
END_VAR
ptr := &(x);
ptr := ptr + ptr;
END_PROGRAM"#,
);
if let Err(msg) = result {
assert_eq!(
Diagnostic::codegen_error("'+' operation must contain one int type", (88..97).into()),
msg
)
} else {
panic!("expected code-gen error but got none")
}
}

#[test]
fn pointer_binary_expression_multiplication() {
let result = codegen_without_unwrap(
r#"
PROGRAM mainProg
VAR
x : INT;
ptr : REF_TO INT;
END_VAR
ptr := &(x);
ptr := ptr * ptr;
END_PROGRAM"#,
);
if let Err(msg) = result {
assert_eq!(
Diagnostic::codegen_error("Operator '*' unimplemented for pointers", (88..97).into()),
msg
)
} else {
panic!("expected code-gen error but got none")
}
}

#[test]
fn pointer_binary_expression_division() {
let result = codegen_without_unwrap(
r#"
PROGRAM mainProg
VAR
x : INT;
ptr : REF_TO INT;
END_VAR
ptr := &(x);
ptr := ptr / ptr;
END_PROGRAM"#,
);
if let Err(msg) = result {
assert_eq!(
Diagnostic::codegen_error("Operator '/' unimplemented for pointers", (88..97).into()),
msg
)
} else {
panic!("expected code-gen error but got none")
}
}

#[test]
fn pointer_binary_expression_modulo() {
let result = codegen_without_unwrap(
r#"
PROGRAM mainProg
VAR
x : INT;
ptr : REF_TO INT;
END_VAR
ptr := &(x);
ptr := ptr MOD ptr;
END_PROGRAM"#,
);
if let Err(msg) = result {
assert_eq!(
Diagnostic::codegen_error("Operator 'MOD' unimplemented for pointers", (88..99).into()),
msg
)
} else {
panic!("expected code-gen error but got none")
}
}
32 changes: 32 additions & 0 deletions src/codegen/tests/expression_tests.rs
Original file line number Diff line number Diff line change
@@ -301,3 +301,35 @@ fn cast_lword_to_pointer() {
//should result in normal number-comparisons
insta::assert_snapshot!(result);
}

#[test]
fn pointer_arithmetics() {
// codegen should be successful for binary expression for pointer<->int / int<->pointer / pointer<->pointer
let result = codegen(
"
PROGRAM main
VAR
x : INT := 10;
y : INT := 20;
pt : REF_TO INT;
comp : BOOL;
END_VAR
pt := &(x);

(* +/- *)
pt := pt + 1;
pt := 1 + pt;
pt := pt - y;

(* compare pointer-pointer / pointer-int *)
comp := pt = pt;
comp := pt <> y;
comp := pt < pt;
comp := pt > y;
comp := pt <= pt;
comp := y >= pt;
END_PROGRAM
",
);
insta::assert_snapshot!(result);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
---
source: src/codegen/tests/expression_tests.rs
expression: result

---
; ModuleID = 'main'
source_filename = "main"

%main_interface = type { i16, i16, i16*, i1 }

@main_instance = global %main_interface { i16 10, i16 20, i16* null, i1 false }

define void @main(%main_interface* %0) {
entry:
%x = getelementptr inbounds %main_interface, %main_interface* %0, i32 0, i32 0
%y = getelementptr inbounds %main_interface, %main_interface* %0, i32 0, i32 1
%pt = getelementptr inbounds %main_interface, %main_interface* %0, i32 0, i32 2
%comp = getelementptr inbounds %main_interface, %main_interface* %0, i32 0, i32 3
store i16* %x, i16** %pt, align 8
%load_pt = load i16*, i16** %pt, align 8
%access___main_pt = getelementptr inbounds i16, i16* %load_pt, i32 1
store i16* %access___main_pt, i16** %pt, align 8
%load_pt1 = load i16*, i16** %pt, align 8
%access___main_pt2 = getelementptr inbounds i16, i16* %load_pt1, i32 1
store i16* %access___main_pt2, i16** %pt, align 8
%load_pt3 = load i16*, i16** %pt, align 8
%load_y = load i16, i16* %y, align 2
%access___main_pt4 = getelementptr inbounds i16, i16* %load_pt3, i16 sub (i16 0, i16 %load_y)
store i16* %access___main_pt4, i16** %pt, align 8
%load_pt5 = load i16*, i16** %pt, align 8
%load_pt6 = load i16*, i16** %pt, align 8
store i1 icmp eq (i64 ptrtoint (i16* %load_pt5 to i64), i64 ptrtoint (i16* %load_pt6 to i64)), i1* %comp, align 1
%load_pt7 = load i16*, i16** %pt, align 8
%load_y8 = load i16, i16* %y, align 2
%tmpVar = icmp ne i64 ptrtoint (i16* %load_pt7 to i64), i16 %load_y8
store i1 %tmpVar, i1* %comp, align 1
%load_pt9 = load i16*, i16** %pt, align 8
%load_pt10 = load i16*, i16** %pt, align 8
store i1 icmp slt (i64 ptrtoint (i16* %load_pt9 to i64), i64 ptrtoint (i16* %load_pt10 to i64)), i1* %comp, align 1
%load_pt11 = load i16*, i16** %pt, align 8
%load_y12 = load i16, i16* %y, align 2
%tmpVar13 = icmp sgt i64 ptrtoint (i16* %load_pt11 to i64), i16 %load_y12
store i1 %tmpVar13, i1* %comp, align 1
%load_pt14 = load i16*, i16** %pt, align 8
%load_pt15 = load i16*, i16** %pt, align 8
store i1 icmp sle (i64 ptrtoint (i16* %load_pt14 to i64), i64 ptrtoint (i16* %load_pt15 to i64)), i1* %comp, align 1
%load_y16 = load i16, i16* %y, align 2
%load_pt17 = load i16*, i16** %pt, align 8
%tmpVar18 = icmp sge i16 %load_y16, i64 ptrtoint (i16* %load_pt17 to i64)
store i1 %tmpVar18, i1* %comp, align 1
ret void
}

4 changes: 4 additions & 0 deletions src/typesystem.rs
Original file line number Diff line number Diff line change
@@ -259,6 +259,10 @@ impl DataTypeInformation {
)
}

pub fn is_pointer(&self) -> bool {
matches!(self, DataTypeInformation::Pointer { .. })
}

pub fn is_unsigned_int(&self) -> bool {
matches!(self, DataTypeInformation::Integer { signed: false, .. })
}
2 changes: 1 addition & 1 deletion src/validation/stmt_validator.rs
Original file line number Diff line number Diff line change
@@ -439,7 +439,7 @@ impl StatementValidator {
.get_type_information();

if std::mem::discriminant(left_type) == std::mem::discriminant(right_type)
&& !left_type.is_numerical()
&& !(left_type.is_numerical() || left_type.is_pointer())
{
//see if we have the right compare-function (non-numbers are compared using user-defined callback-functions)
if operator.is_comparison_operator()
73 changes: 73 additions & 0 deletions tests/correctness/expressions.rs
Original file line number Diff line number Diff line change
@@ -222,3 +222,76 @@ fn enums_can_be_compared() {
let _: i32 = compile_and_run(function, &mut main);
assert_eq!([true, true, true], [main.a, main.b, main.c]);
}

#[test]
fn binary_expressions_for_pointers() {
#[derive(Default)]
struct Main {
a: u8,
b: u8,
c: u8,
d: u8,
e: u8,
equal: bool,
not_equal: bool,
less: bool,
greater: bool,
less_or_equal: bool,
greater_or_equal: bool,
}

let function = "
PROGRAM main
VAR
a : CHAR;
b : CHAR;
c : CHAR;
d : CHAR;
e : CHAR;
equal : BOOL;
not_equal : BOOL;
less : BOOL;
greater : BOOL;
less_or_equal : BOOL;
greater_or_equal : BOOL;
END_VAR
VAR_TEMP
arr : ARRAY[0..3] OF CHAR := ['a','b','c','d'];
ptr : REF_TO CHAR;
negative : INT := -1;
END_VAR
ptr := &(arr);
ptr := ptr + 2;
a := ptr^;
ptr := ptr + 1;
b := ptr^;
ptr := ptr - 1;
c := ptr^;
ptr := ptr + negative;
d := ptr^;
ptr := ptr - negative;
e := ptr^;
equal := ptr = ptr;
not_equal := ptr <> ptr;
less := ptr < ptr;
greater := ptr > ptr;
less_or_equal := ptr <= ptr;
greater_or_equal := ptr >= ptr;
END_PROGRAM
";
let mut main = Main::default();
let _: i32 = compile_and_run(function, &mut main);
assert_eq!(main.a, "c".as_bytes()[0]);
assert_eq!(main.b, "d".as_bytes()[0]);
assert_eq!(main.c, "c".as_bytes()[0]);
assert_eq!(main.d, "b".as_bytes()[0]);
assert_eq!(main.e, "c".as_bytes()[0]);
assert_eq!(main.equal, true);
assert_eq!(main.not_equal, false);
assert_eq!(main.less, false);
assert_eq!(main.greater, false);
assert_eq!(main.less_or_equal, true);
assert_eq!(main.greater_or_equal, true);
}