Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tt to mem calls #82

Merged
merged 2 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(self.layout.size.bytes()),
MemFlags::empty(),
None,
);

bx.lifetime_end(llscratch, scratch_size);
Expand Down
56 changes: 51 additions & 5 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,10 +703,56 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
let f_ty = LLVMRustGetFunctionType(src);

let inner_param_num = LLVMCountParams(src);
let mut outer_args: Vec<&Value> = get_params(tgt);
let outer_param_num = LLVMCountParams(tgt);
let outer_args: Vec<&Value> = get_params(tgt);
let inner_args: Vec<&Value> = get_params(src);
let mut call_args: Vec<&Value> = vec![];

if inner_param_num as usize != outer_args.len() {
panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, outer_args.len());
if inner_param_num == outer_param_num {
call_args = outer_args;
} else {
dbg!("Different number of args, adjusting");
let mut outer_pos: usize = 0;
let mut inner_pos: usize = 0;
// copy over if they are identical.
// If not, skip the outer arg (and assert it's int).
while outer_pos < outer_param_num as usize {
let inner_arg = inner_args[inner_pos];
let outer_arg = outer_args[outer_pos];
let inner_arg_ty = llvm::LLVMTypeOf(inner_arg);
let outer_arg_ty = llvm::LLVMTypeOf(outer_arg);
if inner_arg_ty == outer_arg_ty {
call_args.push(outer_arg);
inner_pos += 1;
outer_pos += 1;
} else {
// out: (ptr, <>int1, ptr, int2)
// inner: (ptr, <>ptr, int)
// goal: (ptr, ptr, int1), skipping int2
// we are here: <>
assert!(llvm::LLVMRustGetTypeKind(outer_arg_ty) == llvm::TypeKind::Integer);
assert!(llvm::LLVMRustGetTypeKind(inner_arg_ty) == llvm::TypeKind::Pointer);
let next_outer_arg = outer_args[outer_pos + 1];
let next_inner_arg = inner_args[inner_pos + 1];
let next_outer_arg_ty = llvm::LLVMTypeOf(next_outer_arg);
let next_inner_arg_ty = llvm::LLVMTypeOf(next_inner_arg);
assert!(llvm::LLVMRustGetTypeKind(next_outer_arg_ty) == llvm::TypeKind::Pointer);
assert!(llvm::LLVMRustGetTypeKind(next_inner_arg_ty) == llvm::TypeKind::Integer);
let next2_outer_arg = outer_args[outer_pos + 2];
let next2_outer_arg_ty = llvm::LLVMTypeOf(next2_outer_arg);
assert!(llvm::LLVMRustGetTypeKind(next2_outer_arg_ty) == llvm::TypeKind::Integer);
call_args.push(next_outer_arg);
call_args.push(outer_arg);

outer_pos += 3;
inner_pos += 2;
}
}
}


if inner_param_num as usize != call_args.len() {
panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, call_args.len());
}

let inner_fnc_name = llvm::get_value_name(src);
Expand All @@ -719,8 +765,8 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
builder,
f_ty,
src,
outer_args.as_mut_ptr(),
outer_args.len(),
call_args.as_mut_ptr(),
call_args.len(),
c_inner_fnc_name.as_ptr(),
);

Expand Down
13 changes: 12 additions & 1 deletion compiler/rustc_codegen_llvm/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ use rustc_data_structures::small_c_str::SmallCStr;
use rustc_middle::dep_graph;
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs;
use rustc_middle::mir::mono::{Linkage, Visibility};
use rustc_middle::ty::TyCtxt;
use rustc_session::config::DebugInfo;
use rustc_span::symbol::Symbol;
use rustc_target::spec::SanitizerSet;

use rustc_middle::mir::mono::MonoItem;
use rustc_middle::ty::{ParamEnv, TyCtxt, fnc_typetrees};

use std::time::Instant;

pub struct ValueIter<'ll> {
Expand Down Expand Up @@ -86,6 +88,15 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen
let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx);
for &(mono_item, data) in &mono_items {
mono_item.predefine::<Builder<'_, '_, '_>>(&cx, data.linkage, data.visibility);
let inst = match mono_item {
MonoItem::Fn(instance) => instance,
_ => continue,
};
let fn_ty = inst.ty(tcx, ParamEnv::empty());
let _fnc_tree = fnc_typetrees(tcx, fn_ty, &mut vec![]);
//trace!("codegen_module: predefine fn {}", inst);
//trace!("{} \n {:?} \n {:?}", inst, fn_ty, _fnc_tree);
// Manuel: TODO
}

// ... and now that we have everything pre-defined, fill out those definitions.
Expand Down
64 changes: 58 additions & 6 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ use std::iter;
use std::ops::Deref;
use std::ptr;

use crate::typetree::to_enzyme_typetree;
use rustc_ast::expand::typetree::{TypeTree, FncTree};

// All Builders must have an llfn associated with them
#[must_use]
pub struct Builder<'a, 'll, 'tcx> {
Expand Down Expand Up @@ -134,6 +137,35 @@ macro_rules! builder_methods_for_value_instructions {
}
}

fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) {
let inputs = tt.args;
let _ret: TypeTree = tt.ret;
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");
let attr_name = "enzyme_type";
let c_attr_name = std::ffi::CString::new(attr_name).unwrap();
for (i, &ref input) in inputs.iter().enumerate() {
let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) };
let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
unsafe {
let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);
llvm::LLVMRustAddParamAttr(val, i as u32, attr);
}
unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) };
}
dbg!(&val);
}


impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
fn build(cx: &'a CodegenCx<'ll, 'tcx>, llbb: &'ll BasicBlock) -> Self {
let bx = Builder::with_cx(cx);
Expand Down Expand Up @@ -874,11 +906,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let val = unsafe {
llvm::LLVMRustBuildMemCpy(
self.llbuilder,
dst,
Expand All @@ -887,7 +920,14 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};
if let Some(tt) = tt {
let llmod = self.cx.llmod;
let llcx = self.cx.llcx;
add_tt(llmod, llcx, val, tt);
} else {
trace!("builder: no tt");
}
}

Expand All @@ -899,11 +939,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let val = unsafe {
llvm::LLVMRustBuildMemMove(
self.llbuilder,
dst,
Expand All @@ -912,7 +953,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};
if let Some(tt) = tt {
let llmod = self.cx.llmod;
let llcx = self.cx.llcx;
add_tt(llmod, llcx, val, tt);
}
}

Expand All @@ -923,17 +969,23 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
size: &'ll Value,
align: Align,
flags: MemFlags,
tt: Option<FncTree>,
) {
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let val = unsafe {
llvm::LLVMRustBuildMemSet(
self.llbuilder,
ptr,
align.bytes() as c_uint,
fill_byte,
size,
is_volatile,
);
)
};
if let Some(tt) = tt {
let llmod = self.cx.llmod;
let llcx = self.cx.llcx;
add_tt(llmod, llcx, val, tt);
}
}

Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,10 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
// We don't support volatile / extern / (global?) values.
// Just because I didn't had time to test them, and it seems less urgent.
let args_uncacheable = vec![0; input_tts.len()];
if args_uncacheable.len() != input_activity.len() {
dbg!("args_uncacheable.len(): {}", args_uncacheable.len());
dbg!("input_activity.len(): {}", input_activity.len());
}
assert!(args_uncacheable.len() == input_activity.len());
let num_fnc_args = LLVMCountParams(fnc);
println!("num_fnc_args: {}", num_fnc_args);
Expand Down
16 changes: 15 additions & 1 deletion compiler/rustc_codegen_ssa/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ use std::time::{Duration, Instant};

use itertools::Itertools;

use rustc_middle::ty::typetree_from;
use rustc_ast::expand::typetree::{TypeTree, FncTree};

pub fn bin_op_to_icmp_predicate(op: hir::BinOpKind, signed: bool) -> IntPredicate {
match op {
hir::BinOpKind::Eq => IntPredicate::IntEQ,
Expand Down Expand Up @@ -357,6 +360,7 @@ pub fn wants_new_eh_instructions(sess: &Session) -> bool {
wants_wasm_eh(sess) || wants_msvc_seh(sess)
}

// Manuel TODO
pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
bx: &mut Bx,
dst: Bx::Value,
Expand All @@ -370,15 +374,25 @@ pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
if size == 0 {
return;
}
let my_ty = layout.ty;
let tcx: TyCtxt<'_> = bx.cx().tcx();
let fnc_tree: TypeTree = typetree_from(tcx, my_ty);
let fnc_tree: FncTree = FncTree {
args: vec![fnc_tree.clone(), fnc_tree.clone()],
ret: TypeTree::new(),
};

if flags == MemFlags::empty()
&& let Some(bty) = bx.cx().scalar_copy_backend_type(layout)
{
let temp = bx.load(bty, src, src_align);
bx.store(temp, dst, dst_align);
} else {
bx.memcpy(dst, dst_align, src, src_align, bx.cx().const_usize(size), flags);
trace!("my_ty: {:?}, enzyme tt: {:?}", my_ty, fnc_tree);
trace!("memcpy_ty: {:?} -> {:?} (size={}, align={:?})", src, dst, size, dst_align);
bx.memcpy(dst, dst_align, src, src_align, bx.cx().const_usize(size), flags, Some(fnc_tree));
}
//let (_args, _ret): (Vec<TypeTree>, TypeTree) = (fnc_tree.args, fnc_tree.ret);
}

pub fn codegen_instance<'a, 'tcx: 'a, Bx: BuilderMethods<'a, 'tcx>>(
Expand Down
25 changes: 22 additions & 3 deletions compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ use rustc_target::abi::{
WrappingRange,
};

use rustc_middle::ty::typetree_from;
use rustc_ast::expand::typetree::{TypeTree, FncTree};
use crate::rustc_middle::ty::layout::HasTyCtxt;

fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
bx: &mut Bx,
allow_overlap: bool,
Expand All @@ -25,15 +29,23 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
src: Bx::Value,
count: Bx::Value,
) {
let tcx: TyCtxt<'_> = bx.cx().tcx();
let fnc_tree: TypeTree = typetree_from(tcx, ty);
let fnc_tree: FncTree = FncTree {
args: vec![fnc_tree.clone(), fnc_tree.clone()],
ret: TypeTree::new(),
};

let layout = bx.layout_of(ty);
let size = layout.size;
let align = layout.align.abi;
let size = bx.mul(bx.const_usize(size.bytes()), count);
let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() };
trace!("copy: mir ty: {:?}, enzyme tt: {:?}", ty, fnc_tree);
if allow_overlap {
bx.memmove(dst, align, src, align, size, flags);
bx.memmove(dst, align, src, align, size, flags, Some(fnc_tree));
} else {
bx.memcpy(dst, align, src, align, size, flags);
bx.memcpy(dst, align, src, align, size, flags, Some(fnc_tree));
}
}

Expand All @@ -45,12 +57,19 @@ fn memset_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
val: Bx::Value,
count: Bx::Value,
) {
let tcx: TyCtxt<'_> = bx.cx().tcx();
let fnc_tree: TypeTree = typetree_from(tcx, ty);
let fnc_tree: FncTree = FncTree {
args: vec![fnc_tree.clone(), fnc_tree.clone()],
ret: TypeTree::new(),
};

let layout = bx.layout_of(ty);
let size = layout.size;
let align = layout.align.abi;
let size = bx.mul(bx.const_usize(size.bytes()), count);
let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() };
bx.memset(dst, val, size, align, flags);
bx.memset(dst, val, size, align, flags, Some(fnc_tree));
}

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/mir/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {
let neg_address = bx.neg(address);
let offset = bx.and(neg_address, align_minus_1);
let dst = bx.inbounds_gep(bx.type_i8(), alloca, &[offset]);
bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty());
bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty(), None);

// Store the allocated region and the extra to the indirect place.
let indirect_operand = OperandValue::Pair(dst, llextra);
Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,17 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {

// Use llvm.memset.p0i8.* to initialize all zero arrays
if bx.cx().const_to_opt_u128(v, false) == Some(0) {
//let ty = bx.cx().val_ty(v);
let fill = bx.cx().const_u8(0);
bx.memset(start, fill, size, dest.align, MemFlags::empty());
bx.memset(start, fill, size, dest.align, MemFlags::empty(), None);
return;
}

// Use llvm.memset.p0i8.* to initialize byte arrays
let v = bx.from_immediate(v);
if bx.cx().val_ty(v) == bx.cx().type_i8() {
bx.memset(start, v, size, dest.align, MemFlags::empty());
//let ty = bx.cx().type_i8();
bx.memset(start, v, size, dest.align, MemFlags::empty(), None);
return;
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/mir/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let align = pointee_layout.align;
let dst = dst_val.immediate();
let src = src_val.immediate();
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
}
mir::StatementKind::FakeRead(..)
| mir::StatementKind::Retag { .. }
Expand Down
Loading
Loading