Skip to content

Commit

Permalink
Fix custom ops bug for pytorch 1.12 and onwards
Browse files Browse the repository at this point in the history
Adapt to newer _jit_get_operation API that changed in
pytorch/pytorch#76814

for NVlabs#188, NVlabs#193
  • Loading branch information
jannehellsten authored and phcerdan committed Nov 25, 2023
1 parent 70e105b commit c884633
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torch_utils/ops/grid_sample_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

enabled = False # Enable the custom op by setting this to true.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12

#----------------------------------------------------------------------------

Expand Down Expand Up @@ -58,6 +59,8 @@ class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
if _use_pytorch_1_12_api:
op = op[0]
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
Expand Down

0 comments on commit c884633

Please sign in to comment.