torch
allows users to implement their own custom backward functions (vjps for now, jvps in the making as of July 2024). The interface looks like:
from typing import Any
import torch
class CustomBackward(torch.autograd.Function):
def forward(ctx: Any):
pass
def backward(ctx: Any):
pass
We require two functions: forward
and backward.
torchuses a context object
ctx` to store interemediate results of the forward pass which
can then accessed in the backward pass to compute derivatives.
Let's look at a fleshed-out pseudo example:
from typing import Any, Tuple
import torch
def custom_logic(intermediate_result: torch.Tensor, some_tensors: torch.Tensor) -> torch.Tensor:
# This function does some custom gradient logic
return intermediate_result * some_tensors
class CustomTorchBackward(torch.autograd.Function):
def forward(ctx: Any, staticarg0: Any, staticarg1: Any, requires_grad_arg: torch.Tensor) -> torch.Tensor:
# 1. compute and store intermediate results needed for backward pass in `ctx`
# NOTE Assume `requires_grad_arg` is a parameter Tensor which `requires_grad`
intermediate_result = staticarg0(requires_grad_arg)
ctx.save_for_backward(requires_grad_arg)
ctx.intermediate_result = intermediate_result
# 2. compute full forward pass
result = staticarg1(intermediate_result)
# 3. Return result
return result
def backward(ctx: Any, grad_out: torch.Tensor) -> Tuple[Any, Any, torch.Tensor]:
# 1. retrieve intermediate result needed for vjp computation from `ctx` object
# 2. compute gradients using `custom_logic` and saved_tensors from forward
grad = grad_out * custom_logic(ctx.intermediate_result, ctx.saved_tensors)
# 3. Return a tuple containing as many `None`s as there are static args which we
# do not intend to compute grads for (so `staticarg0, staticarg1` in our case)
# And lastly the unpacked `grad` objects containing gradients for each of our
# tensors which require grad in our `requires_grad_arg` object
return (None, None, *grad)