Skip to content

Commit

Permalink
parity test
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 20, 2021
1 parent c65938a commit c403835
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def lightning_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10):
seed_everything(idx)

model = cls_model()
dataloader = model.train_dataloader()
# init model parts
trainer = Trainer(
# as the first run is skipped, no need to run it long
Expand All @@ -171,8 +172,7 @@ def lightning_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10):
deterministic=True,
logger=False,
replace_sampler_ddp=False,
num_sanity_val_steps=0,
)
trainer.fit(model)
trainer.fit(model, dataloader)

return trainer.train_loop.running_loss.last().item(), _hook_memory()

0 comments on commit c403835

Please sign in to comment.