Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes need to unsqueeze from dp #1319

Merged
merged 11 commits into from
Apr 2, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319))
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))

Expand Down
9 changes: 0 additions & 9 deletions pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@ def training_step(self, batch, batch_idx):
# calculate loss
loss_val = self.loss(y, y_hat)

# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)

tqdm_dict = {'train_loss': loss_val}
output = OrderedDict({
'loss': loss_val,
Expand Down Expand Up @@ -145,11 +141,6 @@ def validation_step(self, batch, batch_idx):
if self.on_gpu:
val_acc = val_acc.cuda(loss_val.device.index)

# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
val_acc = val_acc.unsqueeze(0)

output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
Expand Down
18 changes: 5 additions & 13 deletions pl_examples/domain_templates/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 0:
# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)

# match gpu device (or keep as cpu)
if self.on_gpu:
z = z.cuda(imgs.device.index)
z = z.type_as(imgs)

# generate images
self.generated_imgs = self(z)
Expand All @@ -115,8 +112,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
if self.on_gpu:
valid = valid.cuda(imgs.device.index)
valid = valid.type_as(imgs)

# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
Expand All @@ -134,15 +130,13 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
if self.on_gpu:
valid = valid.cuda(imgs.device.index)
valid = valid.type_as(imgs)

real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
if self.on_gpu:
fake = fake.cuda(imgs.device.index)
fake = fake.type_as(fake)

fake_loss = self.adversarial_loss(
self.discriminator(self.generated_imgs.detach()), fake)
Expand Down Expand Up @@ -174,9 +168,7 @@ def train_dataloader(self):

def on_epoch_end(self):
z = torch.randn(8, self.hparams.latent_dim)
# match gpu device (or keep as cpu)
if self.on_gpu:
z = z.cuda(self.last_imgs.device.index)
z = z.type_as(self.last_imgs)

# log sampled images
sample_imgs = self(z)
Expand Down
3 changes: 0 additions & 3 deletions pl_examples/domain_templates/reinforse_learn_Qnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,6 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O
# calculates training loss
loss = self.dqn_mse_loss(batch)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)

if done:
self.total_reward = self.episode_reward
self.episode_reward = 0
Expand Down
12 changes: 0 additions & 12 deletions pl_examples/full_examples/imagenet/imagenet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ def training_step(self, batch, batch_idx):
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))

# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
acc1 = acc1.unsqueeze(0)
acc5 = acc5.unsqueeze(0)

tqdm_dict = {'train_loss': loss_val}
output = OrderedDict({
'loss': loss_val,
Expand All @@ -69,12 +63,6 @@ def validation_step(self, batch, batch_idx):
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))

# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
acc1 = acc1.unsqueeze(0)
acc5 = acc5.unsqueeze(0)

output = OrderedDict({
'val_loss': loss_val,
'val_acc1': acc1,
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def _worker(i, module, input, kwargs, device=None):

else:
output = module.validation_step(*input, **kwargs)

if module.use_dp or module.use_ddp2:
auto_squeeze_dim_zeros(output)
# ---------------

with lock:
Expand Down Expand Up @@ -199,3 +202,15 @@ def _worker(i, module, input, kwargs, device=None):
raise output
outputs.append(output)
return outputs


def auto_squeeze_dim_zeros(output):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just "unsqueeze_scalars"?

"""
In DP or DDP2 we need to unsqueeze dim 0
:param output:
:return:
"""
for k, v in output.items():
is_scalar = len(v.size()) == 0
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
if is_scalar:
output[k] = output[k].unsqueeze(0)