-
Notifications
You must be signed in to change notification settings - Fork 7
/
s2s_reverse_mode_ad.jl
1740 lines (1474 loc) · 68.4 KB
/
s2s_reverse_mode_ad.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
SharedDataPairs()
A data structure used to manage the captured data in the `OpaqueClosures` which implement
the bulk of the forwards- and reverse-passes of AD. An entry `(id, data)` at element `n`
of the `pairs` field of this data structure means that `data` will be available at register
`id` during the forwards- and reverse-passes of `AD`.
This is achieved by storing all of the data in the `pairs` field in the captured tuple which
is passed to an `OpaqueClosure`, and extracting this data into registers associated to the
corresponding `ID`s.
"""
struct SharedDataPairs
pairs::Vector{Tuple{ID,Any}}
SharedDataPairs() = new(Tuple{ID,Any}[])
end
"""
add_data!(p::SharedDataPairs, data)::ID
Puts `data` into `p`, and returns the `id` associated to it. This `id` should be assumed to
be available during the forwards- and reverse-passes of AD, and it should further be assumed
that the value associated to this `id` is always `data`.
"""
function add_data!(p::SharedDataPairs, data)::ID
id = ID()
push!(p.pairs, (id, data))
return id
end
"""
shared_data_tuple(p::SharedDataPairs)::Tuple
Create the tuple that will constitute the captured variables in the forwards- and reverse-
pass `OpaqueClosure`s.
For example, if `p.pairs` is
```julia
[(ID(5), 5.0), (ID(3), "hello")]
```
then the output of this function is
```julia
(5.0, "hello")
```
"""
shared_data_tuple(p::SharedDataPairs)::Tuple = tuple(map(last, p.pairs)...)
"""
shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}
Produce a sequence of id-statment pairs which will extract the data from
`shared_data_tuple(p)` such that the correct value is associated to the correct `ID`.
For example, if `p.pairs` is
```julia
[(ID(5), 5.0), (ID(3), "hello")]
```
then the output of this function is
```julia
IDInstPair[
(ID(5), new_inst(:(getfield(_1, 1)))),
(ID(3), new_inst(:(getfield(_1, 2)))),
]
```
"""
function shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}
return map(enumerate(p.pairs)) do (n, p)
return (p[1], new_inst(Expr(:call, get_shared_data_field, Argument(1), n)))
end
end
@inline get_shared_data_field(shared_data, n) = getfield(shared_data, n)
"""
The block stack is the stack used to keep track of which basic blocks are visited on the
forwards pass, and therefore which blocks need to be visited on the reverse pass. There is
one block stack per derived rule.
By using Int32, we assume that there aren't more than `typemax(Int32)` unique basic blocks
in a given function, which ought to be reasonable.
"""
const BlockStack = Stack{Int32}
"""
ADInfo
This data structure is used to hold "global" information associated to a particular call to
`build_rrule`. It is used as a means of communication between `make_ad_stmts!` and the
codegen which produces the forwards- and reverse-passes.
- `interp`: a `MooncakeInterpreter`.
- `block_stack_id`: the ID associated to the block stack -- the stack which keeps track of
which blocks we visited during the forwards-pass, and which is used on the reverse-pass
to determine which blocks to visit.
- `block_stack`: the block stack. Can always be found at `block_stack_id` in the forwards-
and reverse-passes.
- `entry_id`: ID associated to the block inserted at the start of execution in the the
forwards-pass, and the end of execution in the pullback.
- `shared_data_pairs`: the `SharedDataPairs` used to define the captured variables passed
to both the forwards- and reverse-passes.
- `arg_types`: a map from `Argument` to its static type.
- `ssa_insts`: a map from `ID` associated to lines to the primal `NewInstruction`. This
contains the line of code, its static / inferred type, and some other detailss. See
`Core.Compiler.NewInstruction` for a full list of fields.
- `arg_rdata_ref_ids`: the dict mapping from arguments to the `ID` which creates and
initialises the `Ref` which contains the reverse data associated to that argument.
Recall that the heap allocations associated to this `Ref` are always optimised away in
the final programme.
- `ssa_rdata_ref_ids`: the same as `arg_rdata_ref_ids`, but for each `ID` associated to an
ssa rather than each argument.
- `debug_mode`: if `true`, run in "debug mode" -- wraps all rule calls in `DebugRRule`. This
is applied recursively, so that debug mode is also switched on in derived rules.
- `is_used_dict`: for each `ID` associated to a line of code, is `false` if line is not used
anywhere in any other line of code.
- `lazy_zero_rdata_ref_id`: for any arguments whose type doesn't permit the construction of
a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly-
typed field), we need to have a zero-valued rdata available on the reverse-pass so that
this zero-valued rdata can be returned if the argument (or a part of it) is never used
during the forwards-pass and consequently doesn't obtain a value on the reverse-pass.
To achieve this, we construct a `LazyZeroRData` for each of the arguments on the
forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be
associated to this information.
"""
struct ADInfo
interp::MooncakeInterpreter
block_stack_id::ID
block_stack::BlockStack
entry_id::ID
shared_data_pairs::SharedDataPairs
arg_types::Dict{Argument,Any}
ssa_insts::Dict{ID,NewInstruction}
arg_rdata_ref_ids::Dict{Argument,ID}
ssa_rdata_ref_ids::Dict{ID,ID}
debug_mode::Bool
is_used_dict::Dict{ID,Bool}
lazy_zero_rdata_ref_id::ID
end
# The constructor that you should use for ADInfo if you don't have a BBCode lying around.
# See the definition of the ADInfo struct for info on the arguments.
function ADInfo(
interp::MooncakeInterpreter,
arg_types::Dict{Argument,Any},
ssa_insts::Dict{ID,NewInstruction},
is_used_dict::Dict{ID,Bool},
debug_mode::Bool,
zero_lazy_rdata_ref::Ref{<:Tuple},
)
shared_data_pairs = SharedDataPairs()
block_stack = BlockStack()
return ADInfo(
interp,
add_data!(shared_data_pairs, block_stack),
block_stack,
ID(),
shared_data_pairs,
arg_types,
ssa_insts,
Dict((k, ID()) for k in keys(arg_types)),
Dict((k, ID()) for k in keys(ssa_insts)),
debug_mode,
is_used_dict,
add_data!(shared_data_pairs, zero_lazy_rdata_ref),
)
end
# The constructor you should use for ADInfo if you _do_ have a BBCode lying around. See the
# ADInfo struct for information regarding `interp` and `debug_mode`.
function ADInfo(interp::MooncakeInterpreter, ir::BBCode, debug_mode::Bool)
arg_types = Dict{Argument,Any}(
map(((n, t),) -> (Argument(n) => CC.widenconst(t)), enumerate(ir.argtypes))
)
stmts = collect_stmts(ir)
ssa_insts = Dict{ID,NewInstruction}(stmts)
is_used_dict = characterise_used_ids(stmts)
Tlazy_rdata_ref = Tuple{map(lazy_zero_rdata_type ∘ CC.widenconst, ir.argtypes)...}
zero_lazy_rdata_ref = Ref{Tlazy_rdata_ref}()
return ADInfo(
interp, arg_types, ssa_insts, is_used_dict, debug_mode, zero_lazy_rdata_ref
)
end
"""
add_data!(info::ADInfo, data)::ID
Equivalent to `add_data!(info.shared_data_pairs, data)`.
"""
add_data!(info::ADInfo, data)::ID = add_data!(info.shared_data_pairs, data)
"""
add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)
Returns `x` if it is a singleton, or the `ID` of the ssa which will contain it on the
forwards- and reverse-passes. The reason for this is that if something is a singleton, it
can be inserted directly into the IR.
"""
function add_data_if_not_singleton!(p::Union{ADInfo,SharedDataPairs}, x)
return Base.issingletontype(_typeof(x)) ? x : add_data!(p, x)
end
"""
is_used(info::ADInfo, id::ID)::Bool
Returns `true` if `id` is used by any of the lines in the ir, false otherwise.
"""
is_used(info::ADInfo, id::ID)::Bool = info.is_used_dict[id]
"""
get_primal_type(info::ADInfo, x)
Returns the static / inferred type associated to `x`.
"""
get_primal_type(info::ADInfo, x::Argument) = info.arg_types[x]
get_primal_type(info::ADInfo, x::ID) = CC.widenconst(info.ssa_insts[x].type)
get_primal_type(::ADInfo, x::QuoteNode) = _typeof(x.value)
get_primal_type(::ADInfo, x) = _typeof(x)
function get_primal_type(::ADInfo, x::GlobalRef)
return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty
end
function get_primal_type(::ADInfo, x::Expr)
x.head === :boundscheck && return Bool
return error("Unrecognised expression $x found in argument slot.")
end
"""
get_rev_data_id(info::ADInfo, x)
Returns the `ID` associated to the line in the reverse pass which will contain the
reverse data for `x`. If `x` is not an `Argument` or `ID`, then `nothing` is returned.
"""
get_rev_data_id(info::ADInfo, x::Argument) = info.arg_rdata_ref_ids[x]
get_rev_data_id(info::ADInfo, x::ID) = info.ssa_rdata_ref_ids[x]
get_rev_data_id(::ADInfo, ::Any) = nothing
"""
reverse_data_ref_stmts(info::ADInfo)
Create the statements which initialise the reverse-data `Ref`s.
"""
function reverse_data_ref_stmts(info::ADInfo)
return vcat(
map(collect(info.arg_rdata_ref_ids)) do (k, id)
(id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.arg_types[k]))))
end,
map(collect(info.ssa_rdata_ref_ids)) do (k, id)
(id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.ssa_insts[k].type))))
end,
)
end
"""
__make_ref(p::Type{P}) where {P}
Helper for [`reverse_data_ref_stmts`](@ref). Constructs a `Ref` whose element type is the
[`zero_like_rdata_type`](@ref) for `P`, and whose element is the zero-like rdata for `P`.
"""
@inline @generated function __make_ref(p::Type{P}) where {P}
_P = @isdefined(P) ? P : _typeof(p)
R = zero_like_rdata_type(_P)
return :(Ref{$R}(Mooncake.zero_like_rdata_from_type($_P)))
end
# This specialised method is necessary to ensure that `__make_ref` works properly for
# `DataType`s with unbound type parameters. See `TestResources.typevar_tester` for an
# example. The above method requires that `P` be a type in which all parameters are fully-
# bound. Strange errors occur if this property does not hold.
@inline __make_ref(::Type{<:Type}) = Ref{NoRData}(NoRData())
@inline __make_ref(::Type{Union{}}) = nothing
# Returns the number of arguments that the primal function has.
num_args(info::ADInfo) = length(info.arg_types)
"""
RRuleZeroWrapper(rule)
This struct is used to ensure that `ZeroRData`s, which are used as placeholder zero
elements whenever an actual instance of a zero rdata for a particular primal type cannot
be constructed without also having an instance of said type, never reach rules.
On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures
that if it is a `ZeroRData`, we instead get an actual zero of the correct type. If it is
not a zero rdata, the computation _should_ be elided via inlining + constant prop.
"""
struct RRuleZeroWrapper{Trule}
rule::Trule
end
_copy(x::P) where {P<:RRuleZeroWrapper} = P(_copy(x.rule))
struct RRuleWrapperPb{Tpb!!,Tl}
pb!!::Tpb!!
l::Tl
end
(rule::RRuleWrapperPb)(dy) = rule.pb!!(increment!!(dy, instantiate(rule.l)))
@inline function (rule::RRuleZeroWrapper{R})(f::F, args::Vararg{CoDual,N}) where {R,F,N}
y, pb!! = rule.rule(f, args...)
l = lazy_zero_rdata(primal(y))
return y::CoDual, (pb!! isa NoPullback ? pb!! : RRuleWrapperPb(pb!!, l))
end
"""
ADStmtInfo
Data structure which contains the result of `make_ad_stmts!`. Fields are
- `line`: the ID associated to the primal line from which this is derived
- `comms_id`: an `ID` from one of the lines in `fwds`, whose value will be made
available on the reverse-pass in the same `ID`. Nothing is asserted about _how_ this
value is made available on the reverse-pass of AD, so this package is free to do this in
whichever way is most efficient, in particular to group these communication `ID` on a
per-block basis.
- `fwds`: the instructions which run the forwards-pass of AD
- `rvs`: the instructions which run the reverse-pass of AD / the pullback
"""
struct ADStmtInfo
line::ID
comms_id::Union{ID,Nothing}
fwds::Vector{IDInstPair}
rvs::Vector{IDInstPair}
end
"""
ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)
Convenient constructor for `ADStmtInfo`. If either `fwds` or `rvs` is not a vector,
`__vec` promotes it to a single-element `Vector`.
"""
function ad_stmt_info(line::ID, comms_id::Union{ID,Nothing}, fwds, rvs)
if !(comms_id === nothing || in(comms_id, map(first, __vec(line, fwds))))
throw(ArgumentError("comms_id not found in IDs of `fwds` instructions."))
end
return ADStmtInfo(line, comms_id, __vec(line, fwds), __vec(line, rvs))
end
__vec(line::ID, x::Any) = __vec(line, new_inst(x))
__vec(line::ID, x::NewInstruction) = IDInstPair[(line, x)]
__vec(line::ID, x::Vector{Tuple{ID,Any}}) = throw(error("boooo"))
__vec(line::ID, x::Vector{IDInstPair}) = x
"""
comms_channel(info::ADStmtInfo)
Return the element of `fwds` whose `ID` is the communcation `ID`. Returns `Nothing` if
`comms_id` is `nothing`.
"""
function comms_channel(info::ADStmtInfo)
info.comms_id === nothing && return nothing
return only(filter(x -> x[1] == info.comms_id, info.fwds))
end
"""
make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo
Every line in the primal code is associated to one or more lines in the forwards-pass of AD,
and one or more lines in the pullback. This function has method specific to every
node type in the Julia SSAIR.
Translates the instruction `inst`, associated to `line` in the primal, into a specification
of what should happen for this instruction in the forwards- and reverse-passes of AD, and
what data should be shared between the forwards- and reverse-passes. Returns this in the
form of an `ADStmtInfo`.
`info` is a data structure containing various bits of global information that certain types
of nodes need access to.
"""
function make_ad_stmts! end
#=
make_ad_stmts!(::Nothing, line::ID, ::ADInfo)
`nothing` as a statement in Julia IR indicates the presence of a line which will later be
removed. We emit a no-op on both the forwards- and reverse-passes. No shared data.
=#
function make_ad_stmts!(::Nothing, line::ID, ::ADInfo)
return ad_stmt_info(line, nothing, nothing, nothing)
end
#=
make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
`ReturnNode`s have a single field, `val`, for which there are three cases to consider:
1. `val` is undefined: this `ReturnNode` is unreachable. Consequently, we'll never hit the
associated statements on the forwards-pass or pullback. We just return the original
statement on the forwards-pass, and `nothing` on the reverse-pass.
2. `val isa Union{Argument, ID}`: this is an active piece of data. Consequently, we know
that it will be an `CoDual` already, and can just return it. Therefore `stmt`
is returned as the forwards-pass (with any `Argument`s incremented). On the reverse-pass
the associated rdata ref should be incremented with the rdata passed to the pullback,
which lives in argument 2.
3. `val` is defined, but not a `Union{Argument, ID}`: in this case we're returning a
constant -- build a constant CoDual and return that. There is nothing to do on the
reverse pass.
=#
function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
if !is_reachable_return_node(stmt)
return ad_stmt_info(line, nothing, inc_args(stmt), nothing)
end
if is_active(stmt.val)
rdata_id = get_rev_data_id(info, stmt.val)
rvs = new_inst(Expr(:call, increment_ref!, rdata_id, Argument(2)))
return ad_stmt_info(line, nothing, inc_args(stmt), rvs)
else
const_id = ID()
fwds = [
(const_id, new_inst(const_codual_stmt(stmt.val, info))),
(ID(), new_inst(ReturnNode(const_id))),
]
return ad_stmt_info(line, nothing, fwds, nothing)
end
end
# Identity forwards-pass, no-op reverse. No shared data.
function make_ad_stmts!(stmt::IDGotoNode, line::ID, ::ADInfo)
return ad_stmt_info(line, nothing, inc_args(stmt), nothing)
end
# Identity forwards-pass, no-op reverse. No shared data.
function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo)
stmt = inc_args(stmt)
# If cond is not going to be wrapped in a `CoDual`, so just return the stmt.
is_active(stmt.cond) || return ad_stmt_info(line, nothing, stmt, nothing)
# stmt.cond is active, so primal must be extracted from `CoDual`.
cond_id = ID()
fwds = [
(cond_id, new_inst(Expr(:call, primal, stmt.cond))),
(line, new_inst(IDGotoIfNot(cond_id, stmt.dest), Any)),
]
return ad_stmt_info(line, nothing, fwds, nothing)
end
# Identity forwards-pass, no-op reverse. No shared data.
function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo)
vals = stmt.values
new_vals = Vector{Any}(undef, length(vals))
for n in eachindex(vals)
isassigned(vals, n) || continue
new_vals[n] = inc_or_const(vals[n], info)
end
# It turns out to be really very important to do type inference correctly for PhiNodes.
# For some reason, type inference really doesn't like it when you encounter mutually-
# dependent PhiNodes whose types are unknown and for which you set the flag to
# CC.IR_FLAG_REFINED. To avoid this we directly tell the compiler what the type is.
new_type = fcodual_type(get_primal_type(info, line))
_inst = new_inst(IDPhiNode(stmt.edges, new_vals), new_type, info.ssa_insts[line].flag)
return ad_stmt_info(line, nothing, _inst, nothing)
end
function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
# PiNodes of the form `π (nothing, Union{})` have started appearing in 1.11. These nodes
# appear in unreachable sections of code, and appear to serve no purpose. Consequently,
# we mark them for removal (replace them with `nothing`). We do not currently have a
# unit test for this, but integration testing seems to catch it in multiple places.
stmt == PiNode(nothing, Union{}) && return ad_stmt_info(line, nothing, stmt, nothing)
if is_active(stmt.val)
# Get the primal type of this line, and the rdata refs for the `val` of this
# `PiNode` and this line itself.
P = get_primal_type(info, line)
val_rdata_ref_id = get_rev_data_id(info, stmt.val)
output_rdata_ref_id = get_rev_data_id(info, line)
fwds = PiNode(__inc(stmt.val), fcodual_type(CC.widenconst(stmt.typ)))
rvs = Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id)
else
# If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to
# do on the reverse-pass.
const_id = ID()
fwds = [
(const_id, new_inst(const_codual_stmt(stmt.val, info))),
(line, new_inst(PiNode(const_id, fcodual_type(CC.widenconst(stmt.typ))))),
]
rvs = nothing
end
return ad_stmt_info(line, nothing, fwds, rvs)
end
@inline function __pi_rvs!(::Type{P}, val_rdata_ref::Ref, output_rdata_ref::Ref) where {P}
increment_ref!(val_rdata_ref, __deref_and_zero(P, output_rdata_ref))
return nothing
end
# Constant GlobalRefs are handled. See const_codual. Non-constant GlobalRefs are handled by
# assuming that they are constant, and creating a CoDual with the value. We then check at
# run-time that the value has not changed.
function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo)
isconst(stmt) && return const_ad_stmt(stmt, line, info)
const_id, globalref_id = ID(), ID()
fwds = [
(globalref_id, new_inst(stmt)),
(const_id, new_inst(const_codual_stmt(getglobal(stmt.mod, stmt.name), info))),
(line, new_inst(Expr(:call, __verify_const, globalref_id, const_id))),
]
return ad_stmt_info(line, nothing, fwds, nothing)
end
# Helper used by `make_ad_stmts! ` for `GlobalRef`. Noinline to avoid IR bloat.
@noinline function __verify_const(global_ref, stored_value)
@assert global_ref == primal(stored_value)
return uninit_fcodual(global_ref)
end
# QuoteNodes are constant.
make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info)
# Literal constant.
make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info)
"""
const_ad_stmt(stmt, line::ID, info::ADInfo)
Implementation of `make_ad_stmts!` used for constants.
"""
function const_ad_stmt(stmt, line::ID, info::ADInfo)
return ad_stmt_info(line, nothing, const_codual_stmt(stmt, info), nothing)
end
"""
const_codual_stmt(stmt, info::ADInfo)
Returns a `:call` expression which will return a `CoDual` whose primal is `stmt`, and whose
tangent is whatever `uninit_tangent` returns.
"""
function const_codual_stmt(stmt, info::ADInfo)
v = get_const_primal_value(stmt)
if safe_for_literal(v)
return Expr(:call, uninit_fcodual, v)
else
return Expr(:call, identity, add_data!(info, uninit_fcodual(v)))
end
end
"""
const_codual(stmt, info::ADInfo)
Build a `CoDual` from `stmt`, with zero / uninitialised fdata. If the resulting CoDual is
a bits type, then it is returned. If it is not, then the CoDual is put into shared data,
and the ID associated to it in the forwards- and reverse-passes returned.
"""
function const_codual(stmt, info::ADInfo)
v = get_const_primal_value(stmt)
x = uninit_fcodual(v)
return safe_for_literal(v) ? x : add_data!(info, x)
end
function safe_for_literal(v)
v isa Expr && v.head === :boundscheck && return true
v isa String && return true
v isa Type && return true
v isa Tuple && all(safe_for_literal, v) && return true
isbitstype(_typeof(v)) && return true
return false
end
inc_or_const(stmt, info::ADInfo) = is_active(stmt) ? __inc(stmt) : const_codual(stmt, info)
function inc_or_const_stmt(stmt, info::ADInfo)
return if is_active(stmt)
Expr(:call, identity, __inc(stmt))
else
const_codual_stmt(stmt, info)
end
end
"""
get_const_primal_value(x::GlobalRef)
Get the value associated to `x`. For `GlobalRef`s, verify that `x` is indeed a constant.
"""
function get_const_primal_value(x::GlobalRef)
isconst(x) || unhandled_feature("Non-constant GlobalRef not supported: $x")
return getglobal(x.mod, x.name)
end
get_const_primal_value(x::QuoteNode) = x.value
get_const_primal_value(x) = x
# Mooncake does not yet handle `PhiCNode`s. Throw an error if one is encountered.
function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo)
return unhandled_feature("Encountered PhiCNode: $stmt")
end
# Mooncake does not yet handle `UpsilonNode`s. Throw an error if one is encountered.
function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo)
return unhandled_feature(
"Encountered UpsilonNode: $stmt. These are generated as part of some try / catch " *
"/ finally blocks. At the present time, Mooncake.jl cannot differentiate through " *
"these, so they must be avoided. Strategies for resolving this error include " *
"re-writing code such that it avoids generating any UpsilonNodes, or writing a " *
"rule to differentiate the code by hand. If you are in any doubt as to what to " *
"do, please request assistance by opening an issue at " *
"github.com/compintell/Mooncake.jl.",
)
end
# There are quite a number of possible `Expr`s that can be encountered. Each case has its
# own comment, explaining what is going on.
function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
is_invoke = Meta.isexpr(stmt, :invoke)
if Meta.isexpr(stmt, :call) || is_invoke
# Find the types of all arguments to this call / invoke.
args = ((is_invoke ? stmt.args[2:end] : stmt.args)...,)
arg_types = map(arg -> get_primal_type(info, arg), args)
# Special case: if the result of a call to getfield is un-used, then leave the
# primal statment alone (just increment arguments as usual). This was causing
# performance problems in a couple of situations where the field being requested is
# not known at compile time. `getfield` cannot be dead-code eliminated, because it
# can throw an error if the requested field does not exist. Everything _other_ than
# the boundscheck is eliminated in LLVM codegen, so it's important that AD doesn't
# get in the way of this.
#
# This might need to be generalised to more things than just `getfield`, but at the
# time of writing this comment, it's unclear whether or not this is the case.
if !is_used(info, line) && get_const_primal_value(args[1]) == getfield
fwds = new_inst(Expr(:call, __fwds_pass_no_ad!, map(__inc, args)...))
return ad_stmt_info(line, nothing, fwds, nothing)
end
# Construct signature, and determine how the rrule is to be computed.
sig = Tuple{arg_types...}
raw_rule = if is_primitive(context_type(info.interp), sig)
rrule!! # intrinsic / builtin / thing we provably have rule for
elseif is_invoke
mi = stmt.args[1]::Core.MethodInstance
LazyDerivedRule(mi, info.debug_mode) # Static dispatch
else
DynamicDerivedRule(info.debug_mode) # Dynamic dispatch
end
# Wrap the raw rule in a struct which ensures that any `ZeroRData`s are stripped
# away before the raw_rule is called. Only do this if we cannot prove that the
# output of `can_produce_zero_rdata_from_type(P)`, where `P` is the type of the
# value returned by this line.
is_no_pullback = pullback_type(_typeof(raw_rule), arg_types) <: NoPullback
tmp = can_produce_zero_rdata_from_type(get_primal_type(info, line))
zero_wrapped_rule = (tmp || is_no_pullback) ? raw_rule : RRuleZeroWrapper(raw_rule)
# If debug mode has been requested, use a debug rule.
rule = info.debug_mode ? DebugRRule(zero_wrapped_rule) : zero_wrapped_rule
# If the rule is `rrule!!` (i.e. `sig` is primitive), then don't bother putting
# the rule into shared data, because it's safe to put it directly into the code.
rule_ref = add_data_if_not_singleton!(info, rule)
# If the type of the pullback is a singleton type, then there is no need to store it
# in the shared data, it can be interpolated directly into the generated IR.
T_pb!! = pullback_type(_typeof(rule), arg_types)
#
# Write forwards-pass. These statements are written out manually, as writing them
# out in a function would prevent inlining in some (all?) type-unstable situations.
#
# Make arguments to rrule call. Things which are not already CoDual must be made so.
codual_arg_ids = map(_ -> ID(), collect(args))
codual_args = map(args, codual_arg_ids) do arg, id
return (id, new_inst(inc_or_const_stmt(arg, info)))
end
# Make call to rule.
rule_call_id = ID()
rule_call = Expr(:call, rule_ref, codual_arg_ids...)
# Extract the output-codual from the returned tuple.
raw_output_id = ID()
raw_output = Expr(:call, getfield, rule_call_id, 1)
# Extract the pullback from the returned tuple. Specialise on the case that the
# pullback is provably a singleton type.
if Base.issingletontype(T_pb!!)
pb = T_pb!!.instance
pb_stmt = (ID(), new_inst(nothing))
comms_id = nothing
else
pb = ID()
pb_stmt = (pb, new_inst(Expr(:call, getfield, rule_call_id, 2), T_pb!!))
comms_id = pb
end
# Provide a type assertion to help the compiler out. Doing it this way, rather than
# directly changing the inferred type of the instruction associated to raw_output,
# has the advantage of not introducing the possibility of segfaults. It will still
# be optimised away in situations where the compiler is able to successfully infer
# the type, so performance in performance-critical situations is unaffected.
output_id = line
F = fcodual_type(get_primal_type(info, line))
output = Expr(:call, Core.typeassert, raw_output_id, F)
# Create statements associated to forwards-pass.
fwds = vcat(
codual_args,
IDInstPair[
(rule_call_id, new_inst(rule_call)),
(raw_output_id, new_inst(raw_output)),
pb_stmt,
(output_id, new_inst(output)),
],
)
# Make statement associated to reverse-pass. If the reverse-pass is provably a
# NoPullback, then don't bother doing anything at all.
rvs_pass = if T_pb!! <: NoPullback
nothing
else
Expr(
:call,
__run_rvs_pass!,
get_primal_type(info, line),
sig,
pb,
get_rev_data_id(info, line),
map(Base.Fix1(get_rev_data_id, info), args)...,
)
end
return ad_stmt_info(line, comms_id, fwds, new_inst(rvs_pass))
elseif Meta.isexpr(stmt, :boundscheck)
# For some reason the compiler cannot handle boundscheck statements when we run it
# again. Consequently, emit `true` to be safe. Ideally we would handle this in a
# more natural way, but I'm not sure how to do that.
return ad_stmt_info(line, nothing, zero_fcodual(true), nothing)
elseif Meta.isexpr(stmt, :code_coverage_effect)
# Code coverage irrelevant for derived code, and really inflates it in some
# situations. Since code coverage is usually only requrested during CI, including
# these effects also creates differences between the code generated when developing
# and the code generated in CI, which occassionally creates hard-to-debug issues.
return ad_stmt_info(line, nothing, nothing, nothing)
elseif Meta.isexpr(stmt, :copyast)
# Get constant out and shove it in shared storage.
return ad_stmt_info(line, nothing, const_codual_stmt(stmt.args[1], info), nothing)
elseif Meta.isexpr(stmt, :loopinfo)
# Cannot pass loopinfo back through the optimiser for some reason.
# At the time of writing, I am unclear why this is not possible.
return ad_stmt_info(line, nothing, nothing, nothing)
elseif stmt.head in [
:enter,
:gc_preserve_begin,
:gc_preserve_end,
:leave,
:pop_exception,
:throw_undef_if_not,
:meta,
]
# Expressions which do not require any special treatment.
return ad_stmt_info(line, nothing, stmt, nothing)
elseif stmt.head == :(=) && stmt.args[1] isa GlobalRef
msg =
"Encountered assignment to global variable: $(stmt.args[1]). " *
"Cannot differentiate through assignments to globals. " *
"Please refactor your code to avoid assigning to a global, for example by " *
"passing the variable in to the function as an argument."
unhandled_feature(msg)
else
# Encountered an expression that we've not seen before.
throw(error("Unrecognised expression $stmt"))
end
end
is_active(::Union{Argument,ID}) = true
is_active(::Any) = false
"""
pullback_type(Trule, arg_types)
Get a bound on the pullback type, given a rule and associated primal types.
"""
function pullback_type(Trule, arg_types)
T = Core.Compiler.return_type(Tuple{Trule,map(fcodual_type, arg_types)...})
return T <: Tuple ? _pullback_type(T) : Any
end
_pullback_type(::Core.TypeofBottom) = Any
_pullback_type(T::DataType) = T.parameters[2]
_pullback_type(T::Union) = Union{_pullback_type(T.a),_pullback_type(T.b)}
# Used by the getfield special-case in call / invoke statments.
@inline function __fwds_pass_no_ad!(f::F, raw_args::Vararg{Any,N}) where {F,N}
return tuple_splat(__get_primal(f), tuple_map(__get_primal, raw_args))
end
__get_primal(x::CoDual) = primal(x)
__get_primal(x) = x
"""
__run_rvs_pass!(
P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
) where {sig}
Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`.
"""
@inline function __run_rvs_pass!(
P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
) where {sig}
tuple_map(increment_if_ref!, arg_rev_data_refs, pb!!(ret_rev_data_ref[]))
set_ret_ref_to_zero!!(P, ret_rev_data_ref)
return nothing
end
@inline increment_if_ref!(ref::Ref, rvs_data) = increment_ref!(ref, rvs_data)
@inline increment_if_ref!(::Ref, ::ZeroRData) = nothing
@inline increment_if_ref!(::Nothing, ::Any) = nothing
@inline increment_ref!(x::Ref, t) = setindex!(x, increment!!(x[], t))
@inline increment_ref!(::Base.RefValue{NoRData}, t) = nothing
@inline function set_ret_ref_to_zero!!(::Type{P}, r::Ref{R}) where {P,R}
return r[] = zero_like_rdata_from_type(P)
end
@inline set_ret_ref_to_zero!!(::Type{P}, r::Base.RefValue{NoRData}) where {P} = nothing
#
# Runners for generated code. The main job of these functions is to handle the translation
# between differing varargs conventions.
#
struct Pullback{Tprimal,Tpb_oc,Tisva<:Val,Tnvargs<:Val}
pb_oc::Tpb_oc
isva::Tisva
nvargs::Tnvargs
end
function Pullback(
Tprimal, pb_oc::Tpb_oc, isva::Tisva, nvargs::Tnvargs
) where {Tpb_oc,Tisva,Tnvargs}
return Pullback{Tprimal,Tpb_oc,Tisva,Tnvargs}(pb_oc, isva, nvargs)
end
@inline (pb::Pullback)(dy) = __flatten_varargs(pb.isva, pb.pb_oc[].oc(dy), pb.nvargs)
struct DerivedRule{Tprimal,Tfwds_oc,Tpb,Tisva<:Val,Tnargs<:Val}
fwds_oc::Tfwds_oc
pb::Tpb
isva::Tisva
nargs::Tnargs
end
function DerivedRule(Tprimal, fwds_oc::T, pb::U, isva::V, nargs::W) where {T,U,V,W}
return DerivedRule{Tprimal,T,U,V,W}(fwds_oc, pb, isva, nargs)
end
# Extends functionality defined for debug_mode.
function verify_args(r::DerivedRule{sig}, x) where {sig}
Tx = Tuple{map(_typeof ∘ primal, __unflatten_codual_varargs(r.isva, x, r.nargs))...}
Tx <: sig && return nothing
throw(ArgumentError("Arguments with sig $Tx do not subtype rule signature, $sig"))
end
_copy(::Nothing) = nothing
function _copy(x::P) where {P<:DerivedRule}
new_captures = _copy(x.fwds_oc.oc.captures)
new_fwds_oc = replace_captures(x.fwds_oc, new_captures)
new_pb_oc_ref = Ref(replace_captures(x.pb.pb_oc[], new_captures))
new_pb = typeof(x.pb)(new_pb_oc_ref, x.isva, x.pb.nvargs)
return P(new_fwds_oc, new_pb, x.isva, x.nargs)
end
_copy(x::Symbol) = x
_copy(x::Tuple) = map(_copy, x)
_copy(x::NamedTuple) = map(_copy, x)
_copy(x::Ref{T}) where {T} = isassigned(x) ? Ref{T}(_copy(x[])) : Ref{T}()
_copy(x::Type) = x
_copy(x) = copy(x)
@inline function (fwds::DerivedRule{P,Q,S})(args::Vararg{CoDual,N}) where {P,Q,S,N}
uf_args = __unflatten_codual_varargs(fwds.isva, args, fwds.nargs)
return fwds.fwds_oc.oc(uf_args...)::CoDual, fwds.pb
end
"""
__flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs}
If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).
"""
function __flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva,nvargs}
isva || return args
last_el = isa(args[end], NoRData) ? ntuple(n -> NoRData(), nvargs) : args[end]
return (args[1:(end - 1)]..., last_el...)
end
"""
__unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs}
If isva and nargs=2, then inputs `(CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))`
are transformed into `(CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))`.
"""
function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva,nargs}
isva || return args
group_primal = map(primal, args[nargs:end])
if fdata_type(tangent_type(_typeof(group_primal))) == NoFData
grouped_args = zero_fcodual(group_primal)
else
grouped_args = CoDual(group_primal, map(tangent, args[nargs:end]))
end
return (args[1:(nargs - 1)]..., grouped_args)
end
#
# Rule derivation.
#
_is_primitive(C::Type, mi::Core.MethodInstance) = is_primitive(C, mi.specTypes)
_is_primitive(C::Type, sig::Type) = is_primitive(C, sig)
const RuleMC{A,R} = MistyClosure{OpaqueClosure{A,R}}
"""
rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}
Compute the concrete type of the rule that will be returned from `build_rrule`. This is
important for performance in dynamic dispatch, and to ensure that recursion works
properly.
"""
function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}
if _is_primitive(C, sig_or_mi)
return debug_mode ? DebugRRule{typeof(rrule!!)} : typeof(rrule!!)
end
ir, _ = lookup_ir(interp, sig_or_mi)
Treturn = Base.Experimental.compute_ir_rettype(ir)
isva, _ = is_vararg_and_sparam_names(sig_or_mi)
arg_types = map(CC.widenconst, ir.argtypes)
sig = Tuple{arg_types...}
arg_fwds_types = Tuple{map(fcodual_type, arg_types)...}
arg_rvs_types = Tuple{map(rdata_type ∘ tangent_type, arg_types)...}
rvs_return_type = rdata_type(tangent_type(Treturn))
pb_oc_type = MistyClosure{OpaqueClosure{Tuple{rvs_return_type},arg_rvs_types}}
pb_type = Pullback{sig,Base.RefValue{pb_oc_type},Val{isva},nvargs(isva, sig)}
nargs = Val{length(ir.argtypes)}
Tderived_rule = DerivedRule{
sig,RuleMC{arg_fwds_types,fcodual_type(Treturn)},pb_type,Val{isva},nargs
}
return debug_mode ? DebugRRule{Tderived_rule} : Tderived_rule
end
nvargs(isva, sig) = Val{isva ? length(sig.parameters[end].parameters) : 0}
struct MooncakeRuleCompilationError <: Exception
interp::MooncakeInterpreter
sig
debug_mode::Bool
end
function Base.showerror(io::IO, err::MooncakeRuleCompilationError)
msg =
"MooncakeRuleCompilationError: an error occured while Mooncake was compiling a " *
"rule to differentiate something. If the `caused by` error " *
"message below does not make it clear to you how the problem can be fixed, " *
"please open an issue at github.com/compintell/Mooncake.jl describing your " *
"problem.\n" *
"To replicate this error run the following:\n"
println(io, msg)
println(
io,
"Mooncake.build_rrule(Mooncake.$(err.interp), $(err.sig); debug_mode=$(err.debug_mode))",
)
return println(
io,
"\nNote that you may need to `using` some additional packages if not all of the " *
"names printed in the above signature are available currently in your environment.",
)
end
"""
build_rrule(args...; debug_mode=false)
Helper method. Only uses static information from `args`.
"""
function build_rrule(args...; debug_mode=false)
interp = get_interpreter()
return build_rrule(interp, _typeof(TestUtils.__get_primals(args)); debug_mode)
end
"""
build_rrule(sig::Type{<:Tuple})
Equivalent to `build_rrule(Mooncake.get_interpreter(), sig)`.
"""
build_rrule(sig::Type{<:Tuple}) = build_rrule(get_interpreter(), sig)
const MOONCAKE_INFERENCE_LOCK = ReentrantLock()