Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom autograd func memory refinement #8993

Merged
merged 11 commits into from
Sep 9, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def wrap_all_outputs(result, training_mode_flag):
def register_context(result):
# Search for context among all outputs.
ctx = None
# All forward outputs of torch.autograd.Function shared a same gradient function pointer,
# so here we just get the first tensor having grad_fn attribute.
# (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267)
first_tensor_output = None
for arg in result:
if not isinstance(arg, torch.Tensor) or not hasattr(arg, 'grad_fn'):
Expand All @@ -93,6 +96,22 @@ def register_context(result):
if training_mode_flag:
# Must extract one valid context from result tensors.
assert ctx is not None

# FORWARD BACKWARD FUNCTION CONNECTIONS
# input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function
pengwa marked this conversation as resolved.
Show resolved Hide resolved
# ↓ ↑
# autograd.Function apply() ------------> autograd.Function backward()
# ↓ | ↑
# output_1, output_2 --- shared_ptr<PyNode> --- ↑
# ↓ previous gradient function

# We remove the edges starting between current autograd.Function's gradient function and
# it's input's gradient function (e.g. AccumulateGrad gradient function), then
# AccumulateGrad gradient function will be destroyed, releasing the reference to input_1
# (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21).
# The next edges are stored in Node, with which we can get next gradient function.
# https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527
torch_interop_utils.clear_grad_fns_for_next_edges(first_tensor_output, ctx.saved_tensors)
torch_interop_utils.register_grad_fn(id(ctx), first_tensor_output)
else:
# Context must not present under non-training mode.
Expand Down Expand Up @@ -158,36 +177,37 @@ def call_python_backward_function(
inplace: indicates if args can be modified inside the custom function.
args: inputs to "backward_function".
'''
def wrap_all_outputs(result):
if isinstance(result, torch.Tensor):
return [to_dlpack(result)]
elif isinstance(result, tuple) or isinstance(result, list):
return [to_dlpack(value) if value is not None else None for value in result]
else:
raise wrap_exception(ORTModuleIOError,
TypeError(f'ORTModule does not support the following model output type {type(result)}.'))
with torch.no_grad():
def wrap_all_outputs(result):
if isinstance(result, torch.Tensor):
return [to_dlpack(result)]
elif isinstance(result, tuple) or isinstance(result, list):
return [to_dlpack(value) if value is not None else None for value in result]
else:
raise wrap_exception(ORTModuleIOError,
TypeError(f'ORTModule does not support the following model output type {type(result)}.'))

try:
# Backward inputs should not require gradients.
assert all(grad_flag == 0 for grad_flag in requires_grad_flags)
try:
# Backward inputs should not require gradients.
assert all(grad_flag == 0 for grad_flag in requires_grad_flags)

# Prepare inputs for calling Python function.
wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg)
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args))
# Prepare inputs for calling Python function.
wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg)
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args))

# Call Python function.
result = backward_function(*wrapped_args)
# Call Python function.
result = backward_function(*wrapped_args)

# Extract results as DLPack tensor list.
wrapped_returned_args = wrap_all_outputs(result)
# Extract results as DLPack tensor list.
wrapped_returned_args = wrap_all_outputs(result)

ctx = wrapped_args[0]
torch_interop_utils.unregister_grad_fn(id(ctx))
ctx = wrapped_args[0]
torch_interop_utils.unregister_grad_fn(id(ctx))

return tuple(wrapped_returned_args)
except Exception as e:
# Flush buffers. Otherwise, calling this from C++ may lose them.
print('Exception happens when running ', backward_function)
sys.stdout.flush()
sys.stderr.flush()
raise wrap_exception(ORTModuleFallbackException, e)
return tuple(wrapped_returned_args)
except Exception as e:
# Flush buffers. Otherwise, calling this from C++ may lose them.
print('Exception happens when running ', backward_function)
sys.stdout.flush()
sys.stderr.flush()
raise wrap_exception(ORTModuleFallbackException, e)
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@
#include <torch/extension.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>

// In Torch forward run (e.g. THPVariable_apply), ctx of type THPFunction* (which is also a PyObject*)
// is created. The ctx is used to run user-defined forward function and backward function as the first
// parameter. The same time, a cdata of type std::shared_ptr<PyNode> is created, cdata is owned by:
// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor own
// In Torch forward run (e.g. THPFunction_apply), ctx of type THPFunction* (which is also a PyObject*)
// is created (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L673).
// The ctx is used to run user-defined forward function and backward function as the first
// parameter. The same time, a cdata of type std::shared_ptr<PyNode> is created
// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L677),
// cdata is owned by:
// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor owns
// shared_pointer<TensorImpl>; TensorImpl owns std::unique_ptr<AutogradMeta>; AutogradMeta
// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr<PyNode>,
// the so called gradient function.)
// b). the consumer operator of forward run outputs, will let its own PyNode/Node own the grad_fn_
// (of type std::shared_ptr<PyNode>) of all inputs that require grad.
// e.g, the so called gradient function.)
// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/variable.h#L194
// b). the consumer operator of forward run outputs, will let its own PyNode/Node (gradident function)
// owns the grad_fn_ (of type std::shared_ptr<PyNode>) of all inputs that require grad.
// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L263
// BUT, if we run torch computation within PythonOp, b) is lost. SO, for some cases, where forward outputs
// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr<PyNode>) references
// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0;
Expand Down Expand Up @@ -55,6 +61,45 @@ class PyNodeSharedPointerPool {
};


void clear_grad_fns_for_next_edges(at::Tensor target, std::vector<at::Tensor> saved_tensors) {
// For leaf tensor, there will be a AccumulateGrad (gradident function) created, which owns a
// reference to the tensor.
// For any user saved tensors (with save_for_backward), if the tensor is leaf, we put the map
// {AccumulateGrad*, Tensor*} into grad_fn_to_tensor_map.
std::unordered_map<torch::autograd::Node*, at::Tensor*> grad_fn_to_tensor_map;
for (auto& t: saved_tensors) {
auto grad_fn = t.grad_fn();
if (!grad_fn) {
grad_fn = torch::autograd::impl::try_get_grad_accumulator(t);
if (grad_fn) {
TORCH_CHECK(grad_fn_to_tensor_map.find(grad_fn.get()) == grad_fn_to_tensor_map.end(),
"found AccumulateGrad* is used by more than one tensors.");
grad_fn_to_tensor_map.insert({grad_fn.get(), &t});
}
}
}

const auto& gradient_func_sptr = target.grad_fn();
for (auto& edge : gradient_func_sptr->next_edges()) {
torch::autograd::Node* node_func = edge.function.get();
// If we find the next gradient function is AccumulateGrad, we will check whether its owned
// tensors is in ctx.save_tensors or not. If yes, we skip it; otherwise, we clean the edge, which
// will release the AccumulateGrad function.
if (dynamic_cast<torch::autograd::AccumulateGrad*>(node_func)) {
if (grad_fn_to_tensor_map.find(node_func) != grad_fn_to_tensor_map.end()) {
// skip the edges that connect to saved_tensors. Because when unpack ctx.saved_tensors (using input, = ctx.saved_tensors) in backward,
// there is such a check : if the saved tensor is a leaf and requires grad, it it should have grad accumulator.
// If we clean the edge, then an exception "RuntimeError: No grad accumulator for a saved leaf!" will be thrown
TORCH_WARN("Find a AccumulateGrad node, but skip it because the owned tensor is in saved_tensors.");
continue;
} else {
TORCH_WARN("Find a AccumulateGrad node, and planned to clean the edge to it.");
edge.function.reset();
}
}
}
}

void register_grad_fn(size_t ctx_address, at::Tensor target)
{
torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target);
Expand All @@ -69,4 +114,5 @@ void unregister_grad_fn(size_t ctx_address)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("register_grad_fn", &register_grad_fn, "increase grad_fn shared pointer reference.");
m.def("unregister_grad_fn", &unregister_grad_fn, "release grad_fn shared pointer referece.");
m.def("clear_grad_fns_for_next_edges", &clear_grad_fns_for_next_edges, "remove reference on next edges' gradident funtions.");
}