Skip to content

Commit

Permalink
Update test_remove_1-5.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed May 4, 2021
1 parent e4bc3ff commit f694c31
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,53 @@ def test_v1_5_0_model_checkpoint_period(tmpdir):
ModelCheckpoint(dirpath=tmpdir, period=1)


def test_v1_5_0_old_on_train_epoch_end(tmpdir):
callback_warning_cache.clear()

class OldSignature(Callback):

def on_train_epoch_end(self, trainer, pl_module, outputs): # noqa
...

class OldSignatureModel(BoringModel):

def on_train_epoch_end(self, outputs): # noqa
...

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)

callback_warning_cache.clear()

model = OldSignatureModel()

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)

trainer.train_loop.warning_cache.clear()

class NewSignature(Callback):

def on_train_epoch_end(self, trainer, pl_module):
...

trainer.callbacks = [NewSignature()]
with no_deprecated_call(match="`Callback.on_train_epoch_end` signature has changed in v1.3."):
trainer.fit(model)

class NewSignatureModel(BoringModel):

def on_train_epoch_end(self):
...

model = NewSignatureModel()
with no_deprecated_call(match="`ModelHooks.on_train_epoch_end` signature has changed in v1.3."):
trainer.fit(model)


@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler))
def test_v1_5_0_profiler_output_filename(tmpdir, cls):
filepath = str(tmpdir / "test.txt")
Expand Down

0 comments on commit f694c31

Please sign in to comment.