-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training fails at the end of the epoch when returning None in the training step #7544
Comments
Thanks for reporting this. Can you simulate it with our bug report model please? Would help me alot thanks! |
Sure, this reproduce the bug import os
import random
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
if batch_idx == 2:
loss = None
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=5,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=10,
weights_summary=None,
)
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
trainer.test(model, test_dataloaders=test_data)
if __name__ == '__main__':
run() |
I think its because of this if batch_idx == 2:
loss = None
self.log("train_loss", loss)
if batch_idx == 2:
loss = None
else:
self.log("train_loss", loss) or lightning should handle this internally? |
Ahh, I see, it makes sense. When averaging the loss across multiple batches, how does lightning handles the fact that a batch was skipped due to the loss being None? Does it simply not include it in the average? |
Perfect thank you. |
Sorry, had to delete my answer and double check but yes, it averages only over the metrics logged, not over all training_steps. |
to be specific it does weighted average by default using batch_size. In your case, it hasn't reached up till that point because this error is thrown while converting the logs list to PyTorch tensor and since it contains NaN values, it is throwing the error. Ideally, if a batch is skipped then it shouldn't contribute while aggregating the results so you can have an else statement there which will just work fine. |
🐛 Bug
Sometimes my training loss in a batch is nan. Hence, I return None as loss so that the model will not backpropagate through it as suggested here: #4956. It works fine during the epoch; however, the code fails at the end of the epoch in the function reduce_across_time (line 532).
In case of None, value will be equal to [None] and torch cannot create a proper tensor out of it (*** RuntimeError: Could not infer dtype of NoneType)
Is it me doing something wrong, or is it a bug in Lightning? Is there any workaround?
Pytorch Version
pytorch-lightning-1.3.1
torch 1.8.1+cu11
python 3.7.9
The text was updated successfully, but these errors were encountered: