Skip to content

Commit

Permalink
fix shadow pointer, and handle some more cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Nov 5, 2024
1 parent 686d328 commit 1f57d0f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
15 changes: 11 additions & 4 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1112,16 +1112,24 @@ pub(crate) fn differentiate(
}

// Before dumping the module, we want all the tt to become part of the module.
for (i, item) in diff_items.iter().enumerate() {
for item in diff_items.iter() {
let tt: FncTree = FncTree { args: item.inputs.clone(), ret: item.output.clone() };
let name = CString::new(item.source.clone()).unwrap();
dbg!("Source name: {:?}", &name);
let fn_def: &llvm::Value =
unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap() };
let tgt_name = CString::new(item.target.clone()).unwrap();
dbg!("Target name: {:?}", &tgt_name);
let fn_target: &llvm::Value =
unsafe { llvm::LLVMGetNamedFunction(llmod, tgt_name.as_ptr()).unwrap() };
let fn_target: Option<&llvm::Value> =
unsafe { llvm::LLVMGetNamedFunction(llmod, tgt_name.as_ptr()) };
let fn_target = match fn_target {
Some(x) => x,
None => return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
src: item.source.clone(),
target: item.target.clone(),
error: "could not find target function".to_owned(),
})),
};
if !ad.contains(&AutoDiff::NoTypeTrees) {
crate::builder::add_tt2(llmod, llcx, fn_def, tt);
}
Expand All @@ -1138,7 +1146,6 @@ pub(crate) fn differentiate(
fn_def,
fn_target,
item.attrs.clone(),
i,
);
}
}
Expand Down
31 changes: 20 additions & 11 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ pub(crate) fn add_opt_dbg_helper2<'ll>(
val: &'ll Value,
tgt: &'ll Value,
attrs: AutoDiffAttrs,
i: usize,
) {
let inputs = attrs.input_activity;
let output = attrs.ret_activity;
Expand Down Expand Up @@ -159,18 +158,26 @@ pub(crate) fn add_opt_dbg_helper2<'ll>(
llvm::LLVMMDStringInContext2(llcx, "enzyme_primal_return".as_ptr() as *const c_char, 20);
final_num_args = num_args * 2 + 1;

//match output {
// DiffActivity::Duplicated => {
// args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret));
// },
// DiffActivity::Dual => {
// args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret));
// },
// _ => {},
//}
match output {
DiffActivity::Duplicated => {
args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret));
final_num_args += 1;
},
DiffActivity::Dual => {
args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret));
final_num_args += 1;
},
DiffActivity::Active => {
args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret));
final_num_args += 1;
},
_ => {},
}

let mut pos = 0;
for i in 0..num_args {
let arg = llvm::LLVMGetParam(tgt, i);
let arg = llvm::LLVMGetParam(tgt, pos);
pos += 1;
let activity = inputs[i as usize];
let (activity, duplicated): (&Metadata, bool) = match activity {
DiffActivity::None => panic!(),
Expand All @@ -186,6 +193,8 @@ pub(crate) fn add_opt_dbg_helper2<'ll>(
args.push(llvm::LLVMMetadataAsValue(llcx, activity));
args.push(arg);
if duplicated {
let arg = llvm::LLVMGetParam(tgt, pos);
pos += 1;
final_num_args += 1;
args.push(arg);
}
Expand Down

0 comments on commit 1f57d0f

Please sign in to comment.