@@ -8,6 +8,7 @@ use libc::{c_char, c_int, c_void, size_t};
8
8
use llvm:: {
9
9
LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
10
10
} ;
11
+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
11
12
use rustc_codegen_ssa:: back:: link:: ensure_removed;
12
13
use rustc_codegen_ssa:: back:: versioned_llvm_target;
13
14
use rustc_codegen_ssa:: back:: write:: {
@@ -28,7 +29,7 @@ use rustc_session::config::{
28
29
use rustc_span:: symbol:: sym;
29
30
use rustc_span:: { BytePos , InnerSpan , Pos , SpanData , SyntaxContext } ;
30
31
use rustc_target:: spec:: { CodeModel , RelocModel , SanitizerSet , SplitDebuginfo , TlsModel } ;
31
- use tracing:: debug;
32
+ use tracing:: { debug, trace } ;
32
33
33
34
use crate :: back:: lto:: ThinBuffer ;
34
35
use crate :: back:: owned_target_machine:: OwnedTargetMachine ;
@@ -530,9 +531,38 @@ pub(crate) unsafe fn llvm_optimize(
530
531
config : & ModuleConfig ,
531
532
opt_level : config:: OptLevel ,
532
533
opt_stage : llvm:: OptStage ,
534
+ skip_size_increasing_opts : bool ,
533
535
) -> Result < ( ) , FatalError > {
534
- let unroll_loops =
535
- opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
536
+ // Enzyme:
537
+ // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
538
+ // source code. However, benchmarks show that optimizations increasing the code size
539
+ // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
540
+ // and finally re-optimize the module, now with all optimizations available.
541
+ // FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
542
+ // differentiated.
543
+
544
+ let unroll_loops;
545
+ let vectorize_slp;
546
+ let vectorize_loop;
547
+
548
+ // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
549
+ // optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
550
+ // we should make this more granular, or at least check that the user has at least one autodiff
551
+ // call in their code, to justify altering the compilation pipeline.
552
+ if skip_size_increasing_opts && cfg ! ( llvm_enzyme) {
553
+ unroll_loops = false ;
554
+ vectorize_slp = false ;
555
+ vectorize_loop = false ;
556
+ } else {
557
+ unroll_loops =
558
+ opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
559
+ vectorize_slp = config. vectorize_slp ;
560
+ vectorize_loop = config. vectorize_loop ;
561
+ }
562
+ trace ! (
563
+ "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}" ,
564
+ unroll_loops, vectorize_slp, vectorize_loop
565
+ ) ;
536
566
let using_thin_buffers = opt_stage == llvm:: OptStage :: PreLinkThinLTO || config. bitcode_needed ( ) ;
537
567
let pgo_gen_path = get_pgo_gen_path ( config) ;
538
568
let pgo_use_path = get_pgo_use_path ( config) ;
@@ -596,8 +626,8 @@ pub(crate) unsafe fn llvm_optimize(
596
626
using_thin_buffers,
597
627
config. merge_functions ,
598
628
unroll_loops,
599
- config . vectorize_slp ,
600
- config . vectorize_loop ,
629
+ vectorize_slp,
630
+ vectorize_loop,
601
631
config. no_builtins ,
602
632
config. emit_lifetime_markers ,
603
633
sanitizer_options. as_ref ( ) ,
@@ -619,6 +649,83 @@ pub(crate) unsafe fn llvm_optimize(
619
649
result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
620
650
}
621
651
652
+ pub ( crate ) fn differentiate (
653
+ module : & ModuleCodegen < ModuleLlvm > ,
654
+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
655
+ diff_items : Vec < AutoDiffItem > ,
656
+ config : & ModuleConfig ,
657
+ ) -> Result < ( ) , FatalError > {
658
+ for item in & diff_items {
659
+ trace ! ( "{}" , item) ;
660
+ }
661
+
662
+ let llmod = module. module_llvm . llmod ( ) ;
663
+ let llcx = & module. module_llvm . llcx ;
664
+ let diag_handler = cgcx. create_dcx ( ) ;
665
+
666
+ // Before dumping the module, we want all the tt to become part of the module.
667
+ for item in diff_items. iter ( ) {
668
+ let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
669
+ let fn_def: Option < & llvm:: Value > =
670
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) } ;
671
+ let fn_def = match fn_def {
672
+ Some ( x) => x,
673
+ None => {
674
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
675
+ src : item. source . clone ( ) ,
676
+ target : item. target . clone ( ) ,
677
+ error : "could not find source function" . to_owned ( ) ,
678
+ } ) ) ;
679
+ }
680
+ } ;
681
+ let target_name = CString :: new ( item. target . clone ( ) ) . unwrap ( ) ;
682
+ debug ! ( "target name: {:?}" , & target_name) ;
683
+ let fn_target: Option < & llvm:: Value > =
684
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, target_name. as_ptr ( ) ) } ;
685
+ let fn_target = match fn_target {
686
+ Some ( x) => x,
687
+ None => {
688
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
689
+ src : item. source . clone ( ) ,
690
+ target : item. target . clone ( ) ,
691
+ error : "could not find target function" . to_owned ( ) ,
692
+ } ) ) ;
693
+ }
694
+ } ;
695
+
696
+ crate :: builder:: generate_enzyme_call ( llmod, llcx, fn_def, fn_target, item. attrs . clone ( ) ) ;
697
+ }
698
+
699
+ // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
700
+
701
+ if let Some ( opt_level) = config. opt_level {
702
+ let opt_stage = match cgcx. lto {
703
+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
704
+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
705
+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
706
+ _ => llvm:: OptStage :: PreLinkNoLTO ,
707
+ } ;
708
+ // This is our second opt call, so now we run all opts,
709
+ // to make sure we get the best performance.
710
+ let skip_size_increasing_opts = false ;
711
+ trace ! ( "running Module Optimization after differentiation" ) ;
712
+ unsafe {
713
+ llvm_optimize (
714
+ cgcx,
715
+ diag_handler. handle ( ) ,
716
+ module,
717
+ config,
718
+ opt_level,
719
+ opt_stage,
720
+ skip_size_increasing_opts,
721
+ ) ?
722
+ } ;
723
+ }
724
+ trace ! ( "done with differentiate()" ) ;
725
+
726
+ Ok ( ( ) )
727
+ }
728
+
622
729
// Unsafe due to LLVM calls.
623
730
pub ( crate ) unsafe fn optimize (
624
731
cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -641,14 +748,29 @@ pub(crate) unsafe fn optimize(
641
748
unsafe { llvm:: LLVMWriteBitcodeToFile ( llmod, out. as_ptr ( ) ) } ;
642
749
}
643
750
751
+ // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
752
+
644
753
if let Some ( opt_level) = config. opt_level {
645
754
let opt_stage = match cgcx. lto {
646
755
Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
647
756
Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
648
757
_ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
649
758
_ => llvm:: OptStage :: PreLinkNoLTO ,
650
759
} ;
651
- return unsafe { llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage) } ;
760
+
761
+ // If we know that we will later run AD, then we disable vectorization and loop unrolling
762
+ let skip_size_increasing_opts = cfg ! ( llvm_enzyme) ;
763
+ return unsafe {
764
+ llvm_optimize (
765
+ cgcx,
766
+ dcx,
767
+ module,
768
+ config,
769
+ opt_level,
770
+ opt_stage,
771
+ skip_size_increasing_opts,
772
+ )
773
+ } ;
652
774
}
653
775
Ok ( ( ) )
654
776
}
0 commit comments