Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/double precision #6595

Merged
merged 35 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b224197
Add support for double precision `precision=64`.
ethanwharris Mar 19, 2021
d7a2098
Update CHANGELOG.md
ethanwharris Mar 19, 2021
27c3e72
Minor changes
ethanwharris Mar 19, 2021
55ced9e
Fix typings
ethanwharris Mar 19, 2021
07bdb23
Switch to static methods
ethanwharris Mar 19, 2021
564ad70
Use functools.wraps
ethanwharris Mar 19, 2021
dd79106
Update test
ethanwharris Mar 19, 2021
0522ad8
Add teardown and pickle test
ethanwharris Mar 19, 2021
4103196
Minor doc fix
ethanwharris Mar 19, 2021
3006fab
Add copyright notice to test file
ethanwharris Mar 19, 2021
3b53c81
Update error message in accelerator_connector.py
ethanwharris Mar 20, 2021
b1b8858
Add testfor training_step etc.
ethanwharris Mar 22, 2021
cf12a59
Switch patch logic to seperate class, and patch additional methods
ethanwharris Mar 22, 2021
df6d847
Switch to `.double()`
ethanwharris Mar 22, 2021
72c9be4
Add check for original float32 data
ethanwharris Mar 22, 2021
423302f
Enhance tests for double precision
ethanwharris Mar 22, 2021
b654be2
Update tests/plugins/test_double_plugin.py
ethanwharris Mar 22, 2021
b9c662b
Update tests/plugins/test_double_plugin.py
ethanwharris Mar 22, 2021
dd608b3
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 22, 2021
982767a
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 22, 2021
f92dd2c
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 22, 2021
68dce05
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 22, 2021
e8af281
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 22, 2021
9a8c021
Move `RandomFloatIntDataset`
ethanwharris Mar 22, 2021
6489776
Fix type hint
ethanwharris Mar 22, 2021
f527a41
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 22, 2021
3da2d05
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 22, 2021
2e74cff
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 22, 2021
a7507ad
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 23, 2021
fa323a2
Update pytorch_lightning/plugins/precision/double.py
ethanwharris Mar 23, 2021
23b21c5
Add type hints to args and kwargs
ethanwharris Mar 23, 2021
210fd87
Fix failing tests
ethanwharris Mar 23, 2021
e7b6c7f
Switch `predict` to `predict_step`
ethanwharris Mar 23, 2021
925d109
Merge branch 'master' into feature/double_precision
ethanwharris Mar 24, 2021
59c093b
Remove line from test no longer needed
ethanwharris Mar 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))


- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))


### Changed

- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
Expand Down
5 changes: 4 additions & 1 deletion docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@ precision

|

Full precision (32), half precision (16).
Double precision (64), full precision (32) or half precision (16).
Can be used on CPU, GPU or TPUs.

If used on TPU will use torch.bfloat16 but tensor printing
Expand All @@ -1172,6 +1172,9 @@ will still show torch.float32.
# 16-bit precision
trainer = Trainer(precision=16, gpus=1)

# 64-bit precision
trainer = Trainer(precision=64)

Example::

# one day
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
Expand Down Expand Up @@ -29,6 +30,7 @@
"DDPSpawnPlugin",
"DeepSpeedPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
Expand Down
95 changes: 95 additions & 0 deletions pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import Any, Sequence, Tuple, TYPE_CHECKING, List

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection

if TYPE_CHECKING:
from torch.nn import Module
from torch.optim import Optimizer


class _DoublePrecisionPatch:
"""Class to handle patching of methods in the ``LightningModule`` and subsequent teardown."""

def __init__(self, model: 'Module', method_name: str, old_method: Any) -> None:
self.model = model
self.method_name = method_name
self.old_method = old_method

def teardown(self) -> None:
setattr(self.model, self.method_name, self.old_method)

@staticmethod
def _to_double_precision(data: torch.Tensor) -> torch.Tensor:
if data.is_floating_point():
return data.double()
return data

@staticmethod
def _move_float_tensors_to_double(collection: Any) -> Any:
return apply_to_collection(
collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision
)

@classmethod
def patch(cls, model: 'Module', method_name: str) -> '_DoublePrecisionPatch':
old_method = getattr(model, method_name)

@wraps(old_method)
def new_method(*args: Any, **kwargs: Any) -> Any:
return old_method(
*_DoublePrecisionPatch._move_float_tensors_to_double(args),
**_DoublePrecisionPatch._move_float_tensors_to_double(kwargs)
)

setattr(model, method_name, new_method if callable(old_method) else old_method)
return cls(model, method_name, old_method)


class DoublePrecisionPlugin(PrecisionPlugin):
"""Plugin for training with double (``torch.float64``) precision."""

precision: int = 64

def __init__(self) -> None:
self.patches: List[_DoublePrecisionPatch] = []

def connect(
self,
model: 'Module',
optimizers: Sequence['Optimizer'],
lr_schedulers: Sequence[Any],
) -> Tuple['Module', Sequence['Optimizer'], Sequence[Any]]:
"""Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`,
`predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter
`optimizers` or `lr_schedulers`."""
model = model.to(dtype=torch.float64)
if isinstance(model, LightningModule):
self.patches.append(_DoublePrecisionPatch.patch(model, 'training_step'))
self.patches.append(_DoublePrecisionPatch.patch(model, 'validation_step'))
self.patches.append(_DoublePrecisionPatch.patch(model, 'test_step'))
self.patches.append(_DoublePrecisionPatch.patch(model, 'predict_step'))
self.patches.append(_DoublePrecisionPatch.patch(model, 'forward'))

return super().connect(model, optimizers, lr_schedulers)

def post_dispatch(self) -> None:
while len(self.patches) > 0:
self.patches.pop().teardown()
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
DDPSpawnShardedPlugin,
DeepSpeedPlugin,
DeepSpeedPrecisionPlugin,
DoublePrecisionPlugin,
HorovodPlugin,
NativeMixedPrecisionPlugin,
PrecisionPlugin,
Expand Down Expand Up @@ -319,7 +320,8 @@ def select_precision_plugin(self) -> PrecisionPlugin:

if self.precision == 32:
return PrecisionPlugin()

elif self.precision == 64:
return DoublePrecisionPlugin()
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
elif self.precision == 16:
if self.on_tpu:
return TPUHalfPrecisionPlugin()
Expand Down Expand Up @@ -358,7 +360,7 @@ def select_precision_plugin(self) -> PrecisionPlugin:
log.info("Using APEX 16bit precision.")
return ApexMixedPrecisionPlugin(self.amp_level)

raise NotImplementedError("We only support precisions 32 and 16!")
raise NotImplementedError("We only support precisions 64, 32 and 16!")

def select_training_type_plugin(self) -> TrainingTypePlugin:
if self.use_ddp2:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def __init__(

plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.

precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.
precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or
TPUs.

max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.
Expand Down
129 changes: 129 additions & 0 deletions tests/plugins/test_double_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel, RandomDataset


class RandomFloatIntDataset(Dataset):

def __init__(self, size, length):
self.len = length
self.float_data = torch.randn(length, size)
self.int_data = torch.randint(10, (length, 1))

def __getitem__(self, index):
return self.float_data[index], self.int_data[index]

def __len__(self):
return self.len


class DoublePrecisionBoringModel(BoringModel):

def training_step(self, batch, batch_idx):
float_data, int_data = batch
assert float_data.dtype == torch.float64
output = self(float_data)
loss = self.loss(batch, output)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
assert batch.dtype == torch.float64
output = self(batch)
loss = self.loss(batch, output)
return {"x": loss}

def test_step(self, batch, batch_idx):
assert batch.dtype == torch.float64
output = self(batch)
loss = self.loss(batch, output)
return {"y": loss}

def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert batch.dtype == torch.float64
return self(batch)

def on_fit_start(self):
assert self.layer.weight.dtype == torch.float64

def on_after_backward(self):
assert self.layer.weight.grad.dtype == torch.float64

def train_dataloader(self):
dataset = RandomFloatIntDataset(32, 64)
assert dataset.float_data.dtype == torch.float32 # Don't start with double data
return DataLoader(dataset)

def predict_dataloader(self):
return DataLoader(RandomDataset(32, 64))


class DoublePrecisionBoringModelNoForward(BoringModel):

def training_step(self, batch, batch_idx):
assert batch.dtype == torch.float64
output = self.layer(batch)
assert output.dtype == torch.float64
loss = self.loss(batch, output)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
assert batch.dtype == torch.float64
output = self.layer(batch)
assert output.dtype == torch.float64
loss = self.loss(batch, output)
return {"x": loss}

def test_step(self, batch, batch_idx):
assert batch.dtype == torch.float64
output = self.layer(batch)
assert output.dtype == torch.float64
loss = self.loss(batch, output)
return {"y": loss}

def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert batch.dtype == torch.float64
output = self.layer(batch)
assert output.dtype == torch.float64
return output

def predict_dataloader(self):
return DataLoader(RandomDataset(32, 64))


@pytest.mark.parametrize(
'boring_model',
(DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward)
)
def test_double_precision(tmpdir, boring_model):
model = boring_model()
original_training_step = model.training_step

trainer = Trainer(
max_epochs=2,
default_root_dir=tmpdir,
fast_dev_run=2,
precision=64,
log_every_n_steps=1,
)
trainer.fit(model)
trainer.test(model)
trainer.predict(model)

assert model.training_step == original_training_step