Skip to content

Commit

Permalink
Rust flags (#159)
Browse files Browse the repository at this point in the history
* add RUSTFLAGS version of enzyme flags

* Remove old env arg checks, use flags now

* small fixups
  • Loading branch information
ZuseZ4 authored Aug 11, 2024
1 parent 0bd1b5d commit 1f83693
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 55 deletions.
108 changes: 54 additions & 54 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ use rustc_data_structures::small_c_str::SmallCStr;
use rustc_errors::{DiagCtxt, FatalError, Level};
use rustc_fs_util::{link_or_copy, path_to_c_string};
use rustc_middle::ty::TyCtxt;
use rustc_session::config::{self, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath};
use rustc_session::config::{self, AutoDiff, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath};
use rustc_session::Session;
use rustc_span::symbol::sym;
use rustc_span::InnerSpan;
Expand Down Expand Up @@ -707,7 +707,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {


unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize]) {
llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize], ad: &[AutoDiff]) {

// first, remove all calls from fnc
let bb = LLVMGetFirstBasicBlock(tgt);
Expand All @@ -729,12 +729,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
let last_inst = LLVMRustGetLastInstruction(bb).unwrap();
LLVMPositionBuilderAtEnd(builder, bb);

let safety_run_checks;
if std::env::var("ENZYME_NO_SAFETY_CHECKS").is_ok() {
safety_run_checks = false;
} else {
safety_run_checks = true;
}
let safety_run_checks = !ad.contains(&AutoDiff::NoSafetyChecks);

if inner_param_num == outer_param_num {
call_args = outer_args;
Expand Down Expand Up @@ -951,6 +946,7 @@ pub(crate) unsafe fn enzyme_ad(
diag_handler: &DiagCtxt,
item: AutoDiffItem,
logic_ref: EnzymeLogicRef,
ad: &[AutoDiff],
) -> Result<(), FatalError> {
let autodiff_mode = item.attrs.mode;
let rust_name = item.source;
Expand Down Expand Up @@ -1010,16 +1006,16 @@ pub(crate) unsafe fn enzyme_ad(

llvm::set_strict_aliasing(false);

if std::env::var("ENZYME_PRINT_TA").is_ok() {
if ad.contains(&AutoDiff::PrintTA) {
llvm::set_print_type(true);
}
if std::env::var("ENZYME_PRINT_AA").is_ok() {
llvm::set_print_activity(true);
if ad.contains(&AutoDiff::PrintTA) {
llvm::set_print_type(true);
}
if std::env::var("ENZYME_PRINT_PERF").is_ok() {
if ad.contains(&AutoDiff::PrintPerf) {
llvm::set_print_perf(true);
}
if std::env::var("ENZYME_PRINT").is_ok() {
if ad.contains(&AutoDiff::Print) {
llvm::set_print(true);
}

Expand Down Expand Up @@ -1062,7 +1058,7 @@ pub(crate) unsafe fn enzyme_ad(
let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res));

let rev_mode = item.attrs.mode == DiffMode::Reverse;
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions);
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions, ad);
// TODO: implement drop for wrapper type?
FreeTypeAnalysis(type_analysis);

Expand All @@ -1087,7 +1083,9 @@ pub(crate) unsafe fn differentiate(

llvm::set_strict_aliasing(false);

if std::env::var("ENZYME_LOOSE_TYPES").is_ok() {
let ad = &config.autodiff;

if ad.contains(&AutoDiff::LooseTypes) {
dbg!("Setting loose types to true");
llvm::set_loose_types(true);
}
Expand All @@ -1110,41 +1108,42 @@ pub(crate) unsafe fn differentiate(
// trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary.
// This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in
// Enzyme's compiler explorer. TODO: Can we run llvm-extract on the module to remove all other functions?
if std::env::var("ENZYME_OPT").is_ok() {
if ad.contains(&AutoDiff::OPT) {
dbg!("Enable extra debug helper to debug Enzyme through the opt plugin");
crate::builder::add_opt_dbg_helper(llmod, llcx, fn_def, item.attrs.clone(), i);
}
}

if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() || std::env::var("ENZYME_OPT").is_ok(){
if ad.contains(&AutoDiff::PrintModBefore) || ad.contains(&AutoDiff::OPT) {
unsafe {
LLVMDumpModule(llmod);
}
}

if std::env::var("ENZYME_INLINE").is_ok() {
if ad.contains(&AutoDiff::Inline) {
dbg!("Setting inline to true");
llvm::set_inline(true);
}

if std::env::var("ENZYME_TT_DEPTH").is_ok() {
let depth = std::env::var("ENZYME_TT_DEPTH").unwrap();
let depth = depth.parse::<u64>().unwrap();
assert!(depth >= 1);
llvm::set_max_int_offset(depth);
}
if std::env::var("ENZYME_TT_WIDTH").is_ok() {
let width = std::env::var("ENZYME_TT_WIDTH").unwrap();
let width = width.parse::<u64>().unwrap();
assert!(width >= 1);
llvm::set_max_type_offset(width);
}

if std::env::var("ENZYME_RUNTIME_ACTIVITY").is_ok() {
if ad.contains(&AutoDiff::RuntimeActivity) {
dbg!("Setting runtime activity check to true");
llvm::set_runtime_activity_check(true);
}

for val in ad {
match &val {
AutoDiff::TTDepth(depth) => {
assert!(*depth >= 1);
llvm::set_max_int_offset(*depth);
}
AutoDiff::TTWidth(width) => {
assert!(*width >= 1);
llvm::set_max_type_offset(*width);
}
_ => {},
}
};

let differentiate = !diff_items.is_empty();
let mut first_order_items: Vec<AutoDiffItem> = vec![];
let mut higher_order_items: Vec<AutoDiffItem> = vec![];
Expand All @@ -1157,29 +1156,29 @@ pub(crate) unsafe fn differentiate(
}
}

let mut fnc_opt = false;
if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() {
dbg!("Enable extra optimizations for Enzyme");
fnc_opt = true;
}

let fnc_opt = ad.contains(&AutoDiff::EnableFncOpt);

// If a function is a base for some higher order ad, always optimize
let fnc_opt_base = true;
let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt_base as u8);

for item in first_order_items {
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt);
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt, ad);
assert!(res.is_ok());
}

// For the rest, follow the user choice on debug vs release.
// Reuse the opt one if possible for better compile time (Enzyme internal caching).
let logic_ref = match fnc_opt {
true => logic_ref_opt,
true => {
dbg!("Enable extra optimizations for Enzyme");
logic_ref_opt
}
false => CreateEnzymeLogic(fnc_opt as u8),
};
for item in higher_order_items {
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref);
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref, ad);
assert!(res.is_ok());
}

Expand Down Expand Up @@ -1212,14 +1211,14 @@ pub(crate) unsafe fn differentiate(
break;
}
}
if std::env::var("ENZYME_PRINT_MOD_AFTER_ENZYME").is_ok() {
if ad.contains(&AutoDiff::PrintModAfterEnzyme) {
unsafe {
LLVMDumpModule(llmod);
}
}


if std::env::var("ENZYME_NO_MOD_OPT_AFTER").is_ok() || !differentiate {
if ad.contains(&AutoDiff::NoModOptAfter) || !differentiate {
trace!("Skipping module optimization after automatic differentiation");
} else {
if let Some(opt_level) = config.opt_level {
Expand All @@ -1231,18 +1230,18 @@ pub(crate) unsafe fn differentiate(
};
let mut first_run = false;
dbg!("Running Module Optimization after differentiation");
if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() {
if ad.contains(&AutoDiff::NoVecUnroll) {
// disables vectorization and loop unrolling
first_run = true;
}
if std::env::var("ENZYME_ALT_PIPELINE").is_ok() {
if ad.contains(&AutoDiff::AltPipeline) {
dbg!("Running first postAD optimization");
first_run = true;
}
let noop = false;
llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?;
}
if std::env::var("ENZYME_ALT_PIPELINE").is_ok() {
if ad.contains(&AutoDiff::AltPipeline) {
dbg!("Running Second postAD optimization");
if let Some(opt_level) = config.opt_level {
let opt_stage = match cgcx.lto {
Expand All @@ -1253,7 +1252,7 @@ pub(crate) unsafe fn differentiate(
};
let mut first_run = false;
dbg!("Running Module Optimization after differentiation");
if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() {
if ad.contains(&AutoDiff::NoVecUnroll) {
// enables vectorization and loop unrolling
first_run = false;
}
Expand All @@ -1263,7 +1262,7 @@ pub(crate) unsafe fn differentiate(
}
}

if std::env::var("ENZYME_PRINT_MOD_AFTER_OPTS").is_ok() {
if ad.contains(&AutoDiff::PrintModAfterOpts) {
unsafe {
LLVMDumpModule(llmod);
}
Expand Down Expand Up @@ -1341,15 +1340,16 @@ pub(crate) unsafe fn optimize(
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
_ => llvm::OptStage::PreLinkNoLTO,
};

// Second run only relevant for AD
let first_run = true;
let noop;
if std::env::var("ENZYME_ALT_PIPELINE").is_ok() {
noop = true;
dbg!("Skipping PreAD optimization");
} else {
noop = false;
}
let noop = false;
//if ad.contains(&AutoDiff::AltPipeline) {
// noop = true;
// dbg!("Skipping PreAD optimization");
//} else {
// noop = false;
//}
return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop);
}
Ok(())
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ pub struct ModuleConfig {
pub inline_threshold: Option<u32>,
pub emit_lifetime_markers: bool,
pub llvm_plugins: Vec<String>,
pub autodiff: Vec<config::AutoDiff>,
}

impl ModuleConfig {
Expand Down Expand Up @@ -259,6 +260,7 @@ impl ModuleConfig {
inline_threshold: sess.opts.cg.inline_threshold,
emit_lifetime_markers: sess.emit_lifetime_markers(),
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
}
}

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_interface/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ fn test_unstable_options_tracking_hash() {

// Make sure that changing a [TRACKED] option changes the hash.
// tidy-alphabetical-start
tracked!(autodiff, vec![String::from("ad_flags")]);
tracked!(allow_features, Some(vec![String::from("lang_items")]));
tracked!(always_encode_mir, true);
tracked!(asm_comments, true);
Expand Down
51 changes: 50 additions & 1 deletion compiler/rustc_session/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,53 @@ pub enum InstrumentCoverage {
Off,
}

/// The different settings that the `-Z ad` flag can have.
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
pub enum AutoDiff {
/// Print TypeAnalysis information
PrintTA,
/// Print ActivityAnalysis Information
PrintAA,
/// Print Performance Warnings from Enzyme
PrintPerf,
/// Combines the three print flags above.
Print,
/// Print the whole module, before running opts.
PrintModBefore,
/// Print the whole module just before we pass it to Enzyme.
/// For Debug purpose, prefer the OPT flag below
PrintModAfterOpts,
/// Print the module after Enzyme differentiated everything.
PrintModAfterEnzyme,

/// Enzyme's loose type debug helper (can cause incorrect gradients)
LooseTypes,
/// Output a Module using __enzyme calls to prepare it for opt + enzyme pass usage
OPT,

/// TypeTree options
/// TODO: Figure out how to let users construct these,
/// or whether we want to leave this option in the first place.
TTWidth(u64),
TTDepth(u64),

/// More flags
NoModOptAfter,
/// Tell Enzyme to run LLVM Opts on each function it generated. By default off,
/// since we already optimize the whole module after Enzyme is done.
EnableFncOpt,
NoVecUnroll,
/// Obviously unsafe, disable the length checks that we have for shadow args.
NoSafetyChecks,
RuntimeActivity,
/// Runs Enzyme specific Inlining
Inline,
/// Runs Optimization twice after AD, and zero times after.
/// This is mainly for Benchmarking purpose to show that
/// compiler based AD has a performance benefit. TODO: fix
AltPipeline,
}

/// Settings for `-Z instrument-xray` flag.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub struct InstrumentXRay {
Expand Down Expand Up @@ -3229,8 +3276,9 @@ pub(crate) mod dep_tracking {
LinkerPluginLto, LocationDetail, LtoCli, NextSolverConfig, OomStrategy, OptLevel,
OutFileName, OutputType, OutputTypes, Polonius, RemapPathScopeComponents, ResolveDocLinks,
SourceFileHashAlgorithm, SplitDwarfKind, SwitchWithOptPath, SymbolManglingVersion,
TrimmedDefPaths, WasiExecModel,
TrimmedDefPaths, WasiExecModel, AutoDiff,
};
//use crate::config::AutoDiff;
use crate::lint;
use crate::utils::NativeLib;
use rustc_data_structures::fx::FxIndexMap;
Expand Down Expand Up @@ -3285,6 +3333,7 @@ pub(crate) mod dep_tracking {
}

impl_dep_tracking_hash_via_hash!(
AutoDiff,
bool,
usize,
NonZeroUsize,
Expand Down
Loading

0 comments on commit 1f83693

Please sign in to comment.