From 6cd754af958792e0947303517cf7b662d7072b04 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 23 Sep 2021 19:28:27 +0000 Subject: [PATCH 1/3] Added rot90 batch rule Description: - Added rot90 batch rule - Enabled associated tests Note: tests seem not be reliable to output bdim. Tests are passing even if batching rule outputs wrong bdim --- functorch/csrc/BatchRulesViews.cpp | 15 +++++++++++++++ test/test_ops.py | 2 +- test/test_vmap.py | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/functorch/csrc/BatchRulesViews.cpp b/functorch/csrc/BatchRulesViews.cpp index 7fb2bdbbc..a1fe6a67a 100644 --- a/functorch/csrc/BatchRulesViews.cpp +++ b/functorch/csrc/BatchRulesViews.cpp @@ -283,6 +283,20 @@ std::tuple> _reshape_alias_batch_rule(const Tensor& se return std::make_tuple(at::reshape(self, new_shape), bdim); } +std::tuple> rot90_batch_rule( + const Tensor& self, + optional bdim, + int64_t k, + const IntArrayRef dims) { + + auto self_ = moveBatchDimToFront(self, bdim); + VmapDimVector new_dims; + for (auto i: dims) { + new_dims.push_back(getPhysicalDim(self_, true, i)); + } + return std::make_tuple(at::rot90(self_, k, new_dims), 0); +} + std::tuple> diagonal_backward_batch_rule( const Tensor& grad_input, optional grad_input_bdim, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { @@ -339,6 +353,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT("squeeze", squeeze_batch_rule); VMAP_SUPPORT("squeeze.dim", squeeze_dim_batch_rule); VMAP_SUPPORT("_reshape_alias", _reshape_alias_batch_rule); + VMAP_SUPPORT("rot90", rot90_batch_rule); VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule); VMAP_SUPPORT("select_backward", select_backward_batch_rule); VMAP_SUPPORT("slice_backward", slice_backward_batch_rule); diff --git a/test/test_ops.py b/test/test_ops.py index 94f85fe84..f21d156cd 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -456,7 +456,7 @@ def test_vmapvjp(self, device, dtype, op): xfail('quantile'), xfail('renorm'), xfail('roll'), - xfail('rot90'), + # xfail('rot90'), xfail('scatter_add'), xfail('solve'), xfail('sort'), diff --git a/test/test_vmap.py b/test/test_vmap.py index 8842a2607..099ee038e 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3027,7 +3027,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('resolve_conj'), xfail('resolve_neg'), xfail('roll'), - xfail('rot90'), + # xfail('rot90'), xfail('scatter'), xfail('scatter_add'), xfail('take'), From be083653d8b25a352a908bd373fdd522a74ed895 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 24 Sep 2021 09:35:23 +0000 Subject: [PATCH 2/3] Replaced manual batching rule with REDUCTION_BOXED_ARGS --- functorch/csrc/BatchRulesReduceOps.cpp | 1 + functorch/csrc/BatchRulesViews.cpp | 15 --------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/functorch/csrc/BatchRulesReduceOps.cpp b/functorch/csrc/BatchRulesReduceOps.cpp index 1f23bf2d6..a7f9a110f 100644 --- a/functorch/csrc/BatchRulesReduceOps.cpp +++ b/functorch/csrc/BatchRulesReduceOps.cpp @@ -316,6 +316,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { REDUCTION_BOXED_ARGS(topk, 2); REDUCTION_BOXED(var.correction); REDUCTION_BOXED(var_mean.correction); + REDUCTION_BOXED_ARGS(rot90, 2); VMAP_SUPPORT("_log_softmax_backward_data", _log_softmax_backward_batch_rule); VMAP_SUPPORT("_softmax_backward_data", _softmax_backward_batch_rule); diff --git a/functorch/csrc/BatchRulesViews.cpp b/functorch/csrc/BatchRulesViews.cpp index a1fe6a67a..7fb2bdbbc 100644 --- a/functorch/csrc/BatchRulesViews.cpp +++ b/functorch/csrc/BatchRulesViews.cpp @@ -283,20 +283,6 @@ std::tuple> _reshape_alias_batch_rule(const Tensor& se return std::make_tuple(at::reshape(self, new_shape), bdim); } -std::tuple> rot90_batch_rule( - const Tensor& self, - optional bdim, - int64_t k, - const IntArrayRef dims) { - - auto self_ = moveBatchDimToFront(self, bdim); - VmapDimVector new_dims; - for (auto i: dims) { - new_dims.push_back(getPhysicalDim(self_, true, i)); - } - return std::make_tuple(at::rot90(self_, k, new_dims), 0); -} - std::tuple> diagonal_backward_batch_rule( const Tensor& grad_input, optional grad_input_bdim, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { @@ -353,7 +339,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT("squeeze", squeeze_batch_rule); VMAP_SUPPORT("squeeze.dim", squeeze_dim_batch_rule); VMAP_SUPPORT("_reshape_alias", _reshape_alias_batch_rule); - VMAP_SUPPORT("rot90", rot90_batch_rule); VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule); VMAP_SUPPORT("select_backward", select_backward_batch_rule); VMAP_SUPPORT("slice_backward", slice_backward_batch_rule); From fbca77c5d20be1ffaf70f61a64ae0a7c026c3ae3 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sat, 25 Sep 2021 20:18:01 +0000 Subject: [PATCH 3/3] Removed commented xfail('rot90') --- test/test_ops.py | 1 - test/test_vmap.py | 1 - 2 files changed, 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 32294a94d..987319611 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -455,7 +455,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('quantile'), xfail('renorm'), xfail('roll'), - # xfail('rot90'), xfail('scatter_add'), xfail('solve'), xfail('sort'), diff --git a/test/test_vmap.py b/test/test_vmap.py index bac685bb0..82b9b91c3 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3030,7 +3030,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('resolve_conj'), xfail('resolve_neg'), xfail('roll'), - # xfail('rot90'), xfail('scatter'), xfail('scatter_add'), xfail('take'),