Skip to content

Commit

Permalink
[docs] Add step to ensure sync_dist is adding to logging when multi-g…
Browse files Browse the repository at this point in the history
…pu enabled (#4817)

* Add additional check to ensure validation/test step are updated accordingly

* Update docs/source/multi_gpu.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update docs/source/multi_gpu.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update docs/source/multi_gpu.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update docs/source/multi_gpu.rst

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people authored Nov 23, 2020
1 parent ccf38ce commit 9186abe
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions docs/source/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,33 @@ Lightning adds the correct samplers when needed, so no need to explicitly add sa

.. note:: For iterable datasets, we don't do this automatically.


Synchronize validation and test logging
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes.
This is done by adding `sync_dist=True` to all `self.log` calls in the validation and test step.
This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers.

Note if you use any built in metrics or custom metrics that use the :ref:`Metrics API <metrics>`, these do not need to be updated and are automatically handled for you.

.. testcode::

def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss(logits, y)
# Add sync_dist=True to sync logging across all GPU workers
self.log('validation_loss', loss, on_step=True, on_epoch=True, sync_dist=True)

def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss(logits, y)
# Add sync_dist=True to sync logging across all GPU workers
self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True)


Make models pickleable
^^^^^^^^^^^^^^^^^^^^^^
It's very likely your code is already `pickleable <https://docs.python.org/3/library/pickle.html>`_,
Expand Down

0 comments on commit 9186abe

Please sign in to comment.