diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 435cddd238778a..3d3274f0f68cc7 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -30,7 +30,7 @@ def test_lightning_distributed_module_methods(): pl_module.validation_step.assert_called_with(batch, batch_idx) -def test_lightning_distributed_module_warn_none_output(wrapper_class): +def test_lightning_distributed_module_warn_none_output(): """ Test that the LightningModuleWrapper warns about forgotten return statement. """ pl_module = MagicMock() wrapped_module = LightningModuleWrapperBase(pl_module)