diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 3796164df176b..da68af20fecc8 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -415,7 +415,7 @@ def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: return self.lightning_module.test_step(*args, **kwargs) def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: - with self.precision_plugin.test_step_context(): + with self.precision_plugin.predict_step_context(): return self.lightning_module.predict_step(*args, **kwargs) def post_training_step(self):