saved_symbols |=new_symbols
- # Update saved_sym_nodes that are now reordered to have all bindings
- # at front
- saved_sym_nodes=saved_sym_nodes_binding+saved_sym_nodes_derived
+ # Update saved_sym_nodes that are now reordered to have all bindings at
+ # front. This can also be used later on to figure out the position of saved
+ # sym nodes in the output of fwd graph.
+ saved_sym_nodes.clear()
+ saved_sym_nodes.extend(saved_sym_nodes_binding+saved_sym_nodes_derived)# Now, we re-generate the fwd/bwd graphs.# NB: This might increase compilation time, but I doubt it matters
@@ -816,7 +818,7 @@
Source code for torch._functorch.partitioners
return new_gm
-deffunctionalize_rng_ops(joint_module,fw_module,bw_module):
+deffunctionalize_rng_ops(joint_module,fw_module,bw_module,num_sym_nodes):# During user-driven activation checkpointing, we have to ensure that a rng# op in fwd yields the same output as the recomputed rng op in the bwd. To# do this, we use functionalize wrappers to wrap the random ops and share
@@ -827,7 +829,9 @@
Source code for torch._functorch.partitioners
# Step 2 - Modify the fwd pass such that
# 1) Replace rand with run_and_save_rng_state wrapper# 2) Replace the users of the original op with the output[1] of this op.
- # 3) Collect all the rng_state - output[0] of each op, and make them output nodes.
+ # 3) Collect all the rng_state - output[0] of each op, and make them
+ # output nodes. Special care needs to be taken here because fwd outputs
+ # has symints at the very end.# Step 3 - Modify the bwd pass such that# 1) Add the input nodes just before the tangents for the stashed rng states# 2) Replace rand with run_with_save_rng_state wrappers
@@ -910,11 +914,15 @@
Source code for torch._functorch.partitioners
bw_graph.erase_node(bw_node)
- # Add the rng states in the output of the fwd graph
- fw_output=[nodefornodeinfw_module.graph.nodesifnode.op=="output"][0]
- outputs=fw_output.args[0]+fw_rng_state_outputs
+ # Add the rng states in the output of the fwd graph. AOT Autograd assumes
+ # that symints are at the end of forward graph outputs. So, insert the new
+ # rng states accordingly.
+ fw_output_node=[nodefornodeinfw_module.graph.nodesifnode.op=="output"][0]
+ fw_outputs=fw_output_node.args[0]
+ sym_node_start_idx=len(fw_outputs)-num_sym_nodes
+ outputs=fw_outputs[:sym_node_start_idx]+fw_rng_state_outputs+fw_outputs[sym_node_start_idx:]fw_module.graph.output(outputs)
- fw_module.graph.erase_node(fw_output)
+ fw_module.graph.erase_node(fw_output_node)fw_module.recompile()bw_module.recompile()returnfw_module,bw_module
@@ -1202,13 +1210,14 @@
Source code for torch._functorch.partitioners
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes=list(filter(lambdan:is_sym_node(n),saved_values))saved_values=list(filter(lambdan:notis_sym_node(n),saved_values))
+ # NB: saved_sym_nodes will be mutated to reflect the actual saved symbolsfw_module,bw_module=_extract_fwd_bwd_modules(joint_module,saved_values,saved_sym_nodes=saved_sym_nodes,num_fwd_outputs=num_fwd_outputs)ifgraph_has_recomputable_ops:ifgraph_has_recomputable_rng_ops:fw_module,bw_module=functionalize_rng_ops(
- joint_module,fw_module,bw_module
+ joint_module,fw_module,bw_module,len(saved_sym_nodes))bw_module=reordering_to_mimic_autograd_engine(bw_module)
diff --git a/nightly/aot_autograd.html b/nightly/aot_autograd.html
index eb2eaee08..4058004ec 100644
--- a/nightly/aot_autograd.html
+++ b/nightly/aot_autograd.html
@@ -217,7 +217,7 @@