Skip to content

Commit

Permalink
handle DualOnly ret more reliably (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 authored Apr 12, 2024
1 parent 17c772f commit 9a411dc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
18 changes: 13 additions & 5 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,20 @@ fn gen_enzyme_body(
};
}

let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, exprs);
let ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
if d_sig.decl.output.has_ret() {
// If we return (), we don't have to match the return type.
body.stmts.push(ecx.stmt_expr(ret));
let ret : P<ast::Expr>;
if exprs.len() > 1 {
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, exprs);
ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
} else if exprs.len() == 1 {
let ret_scal = exprs.pop().unwrap();
ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_scal]);
} else {
assert!(!d_sig.decl.output.has_ret());
// We don't have to match the return type.
return body;
}
assert!(d_sig.decl.output.has_ret());
body.stmts.push(ecx.stmt_expr(ret));

body
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2762,6 +2762,7 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<Diff
// We care about safety checks, if an argument get's duplicated and we write into the
// shadow. That's equivalent to Duplicated or DuplicatedOnly.
let safety = if !da.is_empty() {
assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len());
// If we have Activities, we also have spans
assert!(span.is_some());
match da[i] {
Expand Down

0 comments on commit 9a411dc

Please sign in to comment.