diff --git a/test/test_functorch.py b/test/test_functorch.py index 68c13be0e..8f8396dc5 100644 --- a/test/test_functorch.py +++ b/test/test_functorch.py @@ -362,7 +362,7 @@ def zero_grad(p): for p in params.flatten_keys().values() ) assert params.requires_grad - params.apply_(zero_grad) + params.apply_(zero_grad, filter_empty=True) assert params.requires_grad def test_repopulate(self):