diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py index 441b3795..017f03ac 100644 --- a/torch_utils/ops/grid_sample_gradfix.py +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -22,6 +22,7 @@ enabled = False # Enable the custom op by setting this to true. _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 +_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12 #---------------------------------------------------------------------------- @@ -58,6 +59,8 @@ class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + if _use_pytorch_1_12_api: + op = op[0] if _use_pytorch_1_11_api: output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)