Skip to content

Commit 68c5486

Browse files
authoredJan 26, 2024
fix: Apply trait constraints from method calls (#4152)
# Description ## Problem\* Resolves #4124 Resolves #4095 ## Summary\* We were never applying trait constraints from method calls before. These have been handled for other identifiers since #4000, but not for method calls which desugar to a function identifier that is called, then type checked with its own special function. I've fixed this by removing the special function and recursively type checking the function call they desugar to instead. This way we have less code duplication and only need to fix things in one spot in the future. ## Additional Context It is a good day when you get to fix a bug by removing code. This is a draft currently because I still need: - [x] To add `&mut` implicitly where applicable to the function calls that are now checked recursively - [x] To add the test case I'm using locally ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[Exceptional Case]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings.
1 parent 16becb8 commit 68c5486

File tree

8 files changed

+149
-169
lines changed

8 files changed

+149
-169
lines changed
 

‎compiler/noirc_frontend/src/hir/resolution/resolver.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,8 @@ impl<'a> Resolver<'a> {
14501450
HirExpression::MemberAccess(HirMemberAccess {
14511451
lhs: self.resolve_expression(access.lhs),
14521452
rhs: access.rhs,
1453+
// This is only used when lhs is a reference and we want to return a reference to rhs
1454+
is_offset: false,
14531455
})
14541456
}
14551457
ExpressionKind::Error => HirExpression::Error,

‎compiler/noirc_frontend/src/hir/type_check/expr.rs

+68-144
Large diffs are not rendered by default.

‎compiler/noirc_frontend/src/hir/type_check/stmt.rs

+17-20
Original file line numberDiff line numberDiff line change
@@ -206,24 +206,22 @@ impl<'interner> TypeChecker<'interner> {
206206
let object_ref = &mut object;
207207
let mutable_ref = &mut mutable;
208208

209+
let dereference_lhs = move |_: &mut Self, _, element_type| {
210+
// We must create a temporary value first to move out of object_ref before
211+
// we eventually reassign to it.
212+
let id = DefinitionId::dummy_id();
213+
let location = Location::new(span, fm::FileId::dummy());
214+
let ident = HirIdent::non_trait_method(id, location);
215+
let tmp_value = HirLValue::Ident(ident, Type::Error);
216+
217+
let lvalue = std::mem::replace(object_ref, Box::new(tmp_value));
218+
*object_ref = Box::new(HirLValue::Dereference { lvalue, element_type });
219+
*mutable_ref = true;
220+
};
221+
222+
let name = &field_name.0.contents;
209223
let (object_type, field_index) = self
210-
.check_field_access(
211-
&lhs_type,
212-
&field_name.0.contents,
213-
span,
214-
move |_, _, element_type| {
215-
// We must create a temporary value first to move out of object_ref before
216-
// we eventually reassign to it.
217-
let id = DefinitionId::dummy_id();
218-
let location = Location::new(span, fm::FileId::dummy());
219-
let ident = HirIdent::non_trait_method(id, location);
220-
let tmp_value = HirLValue::Ident(ident, Type::Error);
221-
222-
let lvalue = std::mem::replace(object_ref, Box::new(tmp_value));
223-
*object_ref = Box::new(HirLValue::Dereference { lvalue, element_type });
224-
*mutable_ref = true;
225-
},
226-
)
224+
.check_field_access(&lhs_type, name, span, Some(dereference_lhs))
227225
.unwrap_or((Type::Error, 0));
228226

229227
let field_index = Some(field_index);
@@ -325,6 +323,7 @@ impl<'interner> TypeChecker<'interner> {
325323
// Now check if LHS is the same type as the RHS
326324
// Importantly, we do not coerce any types implicitly
327325
let expr_span = self.interner.expr_span(&rhs_expr);
326+
328327
self.unify_with_coercions(&expr_type, &annotated_type, rhs_expr, || {
329328
TypeCheckError::TypeMismatch {
330329
expected_typ: annotated_type.to_string(),
@@ -335,10 +334,8 @@ impl<'interner> TypeChecker<'interner> {
335334
if annotated_type.is_unsigned() {
336335
self.lint_overflowing_uint(&rhs_expr, &annotated_type);
337336
}
338-
annotated_type
339-
} else {
340-
expr_type
341337
}
338+
expr_type
342339
}
343340

344341
/// Check if an assignment is overflowing with respect to `annotated_type`

‎compiler/noirc_frontend/src/hir_def/expr.rs

+12-4
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ pub struct HirMemberAccess {
152152
// This field is not an IdentId since the rhs of a field
153153
// access has no corresponding definition
154154
pub rhs: Ident,
155+
156+
/// True if we should return an offset of the field rather than the field itself.
157+
/// For most cases this is false, corresponding to `foo.bar` in source code.
158+
/// This is true when calling methods or when we have an lvalue we want to preserve such
159+
/// that if `foo : &mut Foo` has a field `bar : Bar`, we can return an `&mut Bar`.
160+
pub is_offset: bool,
155161
}
156162

157163
#[derive(Debug, Clone)]
@@ -201,13 +207,14 @@ pub enum HirMethodReference {
201207
}
202208

203209
impl HirMethodCallExpression {
210+
/// Converts a method call into a function call
204211
pub fn into_function_call(
205212
mut self,
206213
method: &HirMethodReference,
207214
object_type: Type,
208215
location: Location,
209216
interner: &mut NodeInterner,
210-
) -> (ExprId, HirExpression) {
217+
) -> HirExpression {
211218
let mut arguments = vec![self.object];
212219
arguments.append(&mut self.arguments);
213220

@@ -225,9 +232,10 @@ impl HirMethodCallExpression {
225232
(id, ImplKind::TraitMethod(*method_id, constraint, false))
226233
}
227234
};
228-
let expr = HirExpression::Ident(HirIdent { location, id, impl_kind });
229-
let func = interner.push_expr(expr);
230-
(func, HirExpression::Call(HirCallExpression { func, arguments, location }))
235+
let func = HirExpression::Ident(HirIdent { location, id, impl_kind });
236+
let func = interner.push_expr(func);
237+
interner.push_expr_location(func, location.span, location.file);
238+
HirExpression::Call(HirCallExpression { func, arguments, location })
231239
}
232240
}
233241

‎compiler/noirc_frontend/src/monomorphization/printer.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ impl AstPrinter {
3030

3131
pub fn print_expr(&mut self, expr: &Expression, f: &mut Formatter) -> std::fmt::Result {
3232
match expr {
33-
Expression::Ident(ident) => write!(f, "{}${}", ident.name, ident.definition),
33+
Expression::Ident(ident) => {
34+
write!(f, "{}${}", ident.name, ident.definition)
35+
}
3436
Expression::Literal(literal) => self.print_literal(literal, f),
3537
Expression::Block(exprs) => self.print_block(exprs, f),
3638
Expression::Unary(unary) => self.print_unary(unary, f),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[package]
2+
name = "regression_4124"
3+
type = "bin"
4+
authors = [""]
5+
compiler_version = ">=0.22.0"
6+
7+
[dependencies]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
value = 0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
use dep::std::option::Option;
2+
3+
trait MyDeserialize<N> {
4+
fn deserialize(fields: [Field; N]) -> Self;
5+
}
6+
7+
impl MyDeserialize<1> for Field {
8+
fn deserialize(fields: [Field; 1]) -> Self {
9+
fields[0]
10+
}
11+
}
12+
13+
pub fn storage_read<N>() -> [Field; N] {
14+
dep::std::unsafe::zeroed()
15+
}
16+
17+
struct PublicState<T> {
18+
storage_slot: Field,
19+
}
20+
21+
impl<T> PublicState<T> {
22+
pub fn new(storage_slot: Field) -> Self {
23+
assert(storage_slot != 0, "Storage slot 0 not allowed. Storage slots must start from 1.");
24+
PublicState { storage_slot }
25+
}
26+
27+
pub fn read<T_SERIALIZED_LEN>(_self: Self) -> T where T: MyDeserialize<T_SERIALIZED_LEN> {
28+
// storage_read returns slice here
29+
let fields: [Field; T_SERIALIZED_LEN] = storage_read();
30+
T::deserialize(fields)
31+
}
32+
}
33+
34+
fn main(value: Field) {
35+
let ps: PublicState<Field> = PublicState::new(27);
36+
37+
// error here
38+
assert(ps.read() == value);
39+
}

0 commit comments

Comments
 (0)