Skip to content

Commit

Permalink
Removes need to unsqueeze from dp (#1319)
Browse files Browse the repository at this point in the history
* removes need to unsqueeze from dp

* removes need to unsqueeze from dp

* fixed examples

* added auto unsqueeze

* added auto unsqueeze

* added auto unsqueeze

* added auto unsqueeze

* Update pytorch_lightning/overrides/data_parallel.py

Co-Authored-By: Adrian Wälchli <adrian.waelchli@students.unibe.ch>

* fixed dp parse

* fixed dp parse

Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch>
  • Loading branch information
williamFalcon and Adrian Wälchli authored Apr 2, 2020
1 parent 6b41b5c commit 3cb149f
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 37 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319))
- Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))
Expand Down
9 changes: 0 additions & 9 deletions pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@ def training_step(self, batch, batch_idx):
# calculate loss
loss_val = self.loss(y, y_hat)

# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)

tqdm_dict = {'train_loss': loss_val}
output = OrderedDict({
'loss': loss_val,
Expand Down Expand Up @@ -145,11 +141,6 @@ def validation_step(self, batch, batch_idx):
if self.on_gpu:
val_acc = val_acc.cuda(loss_val.device.index)

# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
val_acc = val_acc.unsqueeze(0)

output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
Expand Down
18 changes: 5 additions & 13 deletions pl_examples/domain_templates/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 0:
# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)

# match gpu device (or keep as cpu)
if self.on_gpu:
z = z.cuda(imgs.device.index)
z = z.type_as(imgs)

# generate images
self.generated_imgs = self(z)
Expand All @@ -115,8 +112,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
if self.on_gpu:
valid = valid.cuda(imgs.device.index)
valid = valid.type_as(imgs)

# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
Expand All @@ -134,15 +130,13 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
if self.on_gpu:
valid = valid.cuda(imgs.device.index)
valid = valid.type_as(imgs)

real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
if self.on_gpu:
fake = fake.cuda(imgs.device.index)
fake = fake.type_as(fake)

fake_loss = self.adversarial_loss(
self.discriminator(self.generated_imgs.detach()), fake)
Expand Down Expand Up @@ -174,9 +168,7 @@ def train_dataloader(self):

def on_epoch_end(self):
z = torch.randn(8, self.hparams.latent_dim)
# match gpu device (or keep as cpu)
if self.on_gpu:
z = z.cuda(self.last_imgs.device.index)
z = z.type_as(self.last_imgs)

# log sampled images
sample_imgs = self(z)
Expand Down
3 changes: 0 additions & 3 deletions pl_examples/domain_templates/reinforse_learn_Qnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,6 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O
# calculates training loss
loss = self.dqn_mse_loss(batch)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)

if done:
self.total_reward = self.episode_reward
self.episode_reward = 0
Expand Down
12 changes: 0 additions & 12 deletions pl_examples/full_examples/imagenet/imagenet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ def training_step(self, batch, batch_idx):
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))

# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
acc1 = acc1.unsqueeze(0)
acc5 = acc5.unsqueeze(0)

tqdm_dict = {'train_loss': loss_val}
output = OrderedDict({
'loss': loss_val,
Expand All @@ -69,12 +63,6 @@ def validation_step(self, batch, batch_idx):
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))

# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
acc1 = acc1.unsqueeze(0)
acc5 = acc5.unsqueeze(0)

output = OrderedDict({
'val_loss': loss_val,
'val_acc1': acc1,
Expand Down
18 changes: 18 additions & 0 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def _worker(i, module, input, kwargs, device=None):

else:
output = module.validation_step(*input, **kwargs)

if module.use_dp or module.use_ddp2:
auto_squeeze_dim_zeros(output)
# ---------------

with lock:
Expand Down Expand Up @@ -199,3 +202,18 @@ def _worker(i, module, input, kwargs, device=None):
raise output
outputs.append(output)
return outputs


def auto_squeeze_dim_zeros(output):
"""
In DP or DDP2 we need to unsqueeze dim 0
:param output:
:return:
"""
for k, v in output.items():
if not isinstance(v, torch.Tensor):
continue

is_scalar = v.dim() == 0
if is_scalar:
output[k] = output[k].unsqueeze(0)

0 comments on commit 3cb149f

Please sign in to comment.