-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP integration #6152
FSDP integration #6152
Conversation
To test cpu offload and to help me fix: from argparse import ArgumentParser
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule
from pytorch_lightning.plugins import FullShardedPlugin
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.trainer.model.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
def cli_main():
pl.seed_everything(1234)
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
parser = MNISTDataModule.add_argparse_args(parser)
args = parser.parse_args()
dm = MNISTDataModule.from_argparse_args(args)
model = LitClassifier(args.hidden_dim, args.learning_rate)
trainer = pl.Trainer.from_argparse_args(args, plugins=FullShardedPlugin(cpu_offload=True), precision=16, gpus=1,
max_epochs=1)
trainer.fit(model, datamodule=dm)
if __name__ == '__main__':
cli_main() |
Codecov Report
@@ Coverage Diff @@
## master #6152 +/- ##
=======================================
- Coverage 87% 86% -1%
=======================================
Files 200 202 +2
Lines 12857 13049 +192
=======================================
+ Hits 11224 11273 +49
- Misses 1633 1776 +143 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some shallow comments
Hello @SeanNaren! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-05-04 16:27:51 UTC |
@@ -47,7 +47,7 @@ def test_invalid_apex_sharded(tmpdir): | |||
""" | |||
|
|||
model = BoringModel() | |||
with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'): | |||
with pytest.raises(MisconfigurationException, match='Sharded Plugins are not supported with Apex AMP'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also extend the testcase here? So that we test with all sharded plugins?
Thanks for the extensive reviews guys it's much appreciated :) to give a status update IssuesCPU Offload doesn't work currently with PyTorch Native AMP (no idea about APEX) when using the AMP Grad Scaler. This can be tracked by this issue: facebookresearch/fairscale#421. Whilst we discuss long term upstream solution, we can use @tchaton's fix of delaying the grad moves (performance to be confirmed) 59dbb83 or we can disable this functionality for now. I think @tchaton solution may be worth while, but we might be able to move the logic into the FSDP class. Need to POC this out. Flatten parameters removes all parameters within the lightning module, and moves them to a contiguous tensor. This means that when you're building your optimizer, you should refer to the wrapped model, not individual layers or the lightning module via Help Needed!To get the best benefits (of scaling to ridiculous parameter sizes), we need to recommend wrapping child modules in the FSDP wrapper, like suggested in the main issue facebookresearch/fairscale#413. Because the FSDP wrapper requires torch distributed to be created, we need to delay the wrapping. I've gotten past this in the future by doing something like below: from pytorch_lightning.plugins.fully_sharded import ShardedModule
class MyModel(pl.LightningModule):
def __init__(self):
...
self.linear = ShardedModule(torch.nn.Linear(5, 5)) ShardedModule will then do something like this: class ShardedModule(nn.Module):
def __init__(self, module):
...
self.module = module
def init_module(self):
self.module = FullyShardedDataParallel(self.module)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs) Any thoughts here? |
Note that this will be needed for large models (20B+) even after we add automatic wrapping, since initializing models of this size on CPU can cause users to run out of system RAM. Thus you want to be wrapping layers with FSDP as you initialize the layers, so that they get sharded in place and free CPU memory. |
Ah thanks for this, I did not see this! Brings up a great point that lazy init as I suggested won't solve in this case so we'll need to re-think further |
For some reason when the typing initiative was going on the model was removed from the precision plugin f29ecbf#diff-3facc0e73962d7c559c4257f0845ee7de30191a51017643c0f8f83bb0edb8a12L79 cc @carmocca in case there was another reason it was removed. I've added this back in this PR and specified that it can be a torch.nn.Module as well, since the model could be wrapped. @shuyingsunshine21 glad it fixed the issue! We'll definitely highlight in the docs. Let me know how your XLM experiments go! |
you might tag the wrong person. will let you know (testing now) for the |
has another question, as FSDP wrap the whole lightning module, when we setup metrics (where the metric class has tensor fields) in the module. When using FP16, that would be converted to float16 also, this would cause some problem where the metric computation make use of it and some other metrics computing on the fly which are of type float32. is there a workaround for not casting those to float16? import torch
import torch.nn.functional as F
import torchmetrics as metrics
# from fairscale.nn import wrap
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
from pytorch_lightning.plugins import FullyShardedPlugin
from torch.nn import Linear, ReLU, Sequential, BCEWithLogitsLoss
from torch.utils.data import DataLoader, Dataset
class RandomDataTensor:
def __init__(self, size, length, num_classes):
self.data = torch.randn(length, size)
# multi-class label
self.label = torch.zeros([length, num_classes])
for i in range(length):
self.label[i][torch.randint(num_classes, (3,))] = 1
class RandomDataset(Dataset):
def __init__(self, random_data_tensor):
self.len = random_data_tensor.data.shape[0]
self.data = random_data_tensor.data
self.label = random_data_tensor.label
def __getitem__(self, index):
return {"sample": self.data[index], "label": self.label[index]}
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.num_classes = 10
self.model = Sequential(
Linear(32, 32),
ReLU(),
Linear(32, self.num_classes),
)
self.loss = BCEWithLogitsLoss()
self.model_unrelated_parameter = torch.ones(3, 5, dtype=torch.float32) # <--- here is similar as a Metric class owning this
def training_step(self, batch, batch_idx):
logits = self(batch)
loss = self.loss(logits, batch["label"])
return {"loss": loss}
def configure_optimizers(self):
self.optimizer = torch.optim.SGD(
self.trainer.model.parameters(),
lr=0.1,
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1)
return [self.optimizer], [lr_scheduler]
def on_train_start(self):
assert self.model_unrelated_parameter.dtype == torch.float32. # <--- here the model_unrelated_parameter type is float16 for ddp_fully_sharded
model = BoringModel()
data_module = LightningDataModule.from_datasets(
train_dataset=RandomDataset(dataset_tensors["train"]),
val_dataset=RandomDataset(dataset_tensors["val"]),
test_dataset=RandomDataset(dataset_tensors["test"]),
batch_size=16,
)
trainer = Trainer(
gpus=1, max_epochs=1, precision=16, accelerator="ddp_fully_sharded"
)
trainer.fit(model, datamodule=data_module) as pointed above, the assertion in |
@shuyingsunshine21 I'm not able to reproduce this, there were a few missing definitions in your example so I made a new one: import torch
from pytorch_lightning import Trainer, LightningModule
from torch.utils.data import DataLoader, Dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Sequential(
torch.nn.Linear(32, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, 2)
)
self.model_unrelated_parameter = torch.ones(3, 5, dtype=torch.float32)
def forward(self, x):
return self.layer(x)
def loss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def training_step(self, batch, batch_idx):
output = self(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
output = self(batch)
loss = self.loss(batch, output)
return {"x": loss}
def test_step(self, batch, batch_idx):
output = self(batch)
loss = self.loss(batch, output)
return {"y": loss}
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
return DataLoader(RandomDataset(32, 64))
def test_dataloader(self):
return DataLoader(RandomDataset(32, 64))
def on_train_start(self):
assert self.model_unrelated_parameter.dtype == torch.float32
if __name__ == '__main__':
model = BoringModel()
trainer = Trainer(
max_epochs=1,
gpus=1,
precision=16,
plugins='ddp_fully_sharded'
)
trainer.fit(model)
trainer.test(model) This runs fine, and the assertion passes, is there something missing from the above that I omitted? EDIT: I also see that sync batch norm is supported in FSDP, so we should remove the guard, is that right @min-xu-ai |
@SeanNaren , my bad, you are right, the above example is not correct. Let me re-paste, just add from torchmetrics import Metric
import torch
from pytorch_lightning import Trainer, LightningModule
from torch.utils.data import DataLoader, Dataset
class TestMetric(Metric):
thresholds: torch.Tensor
def __init__(
self,
num_thresholds: int = 100,
compute_on_step: bool = False,
**kwargs
) -> None:
super().__init__(compute_on_step=compute_on_step, **kwargs)
self.num_thresholds = num_thresholds
thresholds = torch.arange(num_thresholds) / num_thresholds
self.register_buffer("thresholds", thresholds)
assert self.thresholds.dtype == torch.float32 # <- this part is fine
def update(self, output: torch.Tensor) -> None:
assert self.thresholds.dtype == torch.float32 # <- this breaks for fully sharded
self.predictions = torch.rand((1 , self.num_thresholds), device=output.device)
def compute(self) -> torch.Tensor:
assert self.thresholds.dtype == torch.float32 # <- this breaks for fully sharded
condition = self.predictions >= 0.5 # <- this is float32
thresholds_at_p = (
torch.where(
condition, self.thresholds, torch.scalar_tensor(1e6, device=condition.device)
)
.min(dim=1)
.values
) # <- as a result, this computation would fail for fully sharded
return thresholds_at_p
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Sequential(
torch.nn.Linear(32, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, 2)
)
self.val_test_metric = TestMetric(num_thresholds=100)
def forward(self, x):
return self.layer(x)
def loss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def training_step(self, batch, batch_idx):
output = self(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
output = self(batch)
loss = self.loss(batch, output)
self.val_test_metric(output)
return {"x": loss}
def on_validation_epoch_end(self):
self.val_test_metric.compute()
def test_step(self, batch, batch_idx):
output = self(batch)
loss = self.loss(batch, output)
return {"y": loss}
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
return DataLoader(RandomDataset(32, 64))
def test_dataloader(self):
return DataLoader(RandomDataset(32, 64))
if __name__ == '__main__':
model = BoringModel()
trainer = Trainer(
max_epochs=1,
gpus=1,
precision=16,
plugins='ddp_fully_sharded'
)
trainer.fit(model)
trainer.test(model) Note: tested the above for |
Thanks @shuyingsunshine21 we can iterate on this! Another bug I've run into which will block this PR heavily is that the parameters are not kept when going from This is because with I think the cleanest fix would be at teardown or after training has finished, we unflatten the parameters back to the original model. Is this doable @min-xu-ai via EDIT: a high level reprod for this issue: import os
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
gpus=1,
plugins='ddp_fully_sharded',
weights_summary=None,
)
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
# Passes
assert len(trainer.accelerator.training_type_plugin.model.state_dict()) > 0
trainer.test(model, test_dataloaders=test_data)
# Fails
assert len(trainer.accelerator.training_type_plugin.model.state_dict()) > 0
if __name__ == '__main__':
run() There are no parameters within the model, since we've assumed the parameters are stored in the model after |
if self.automatic_module_wrap and not self._model_has_nested_fsdp(): | ||
self.model = auto_wrap(LightningFullyShardedModule(self.model)) | ||
if not isinstance(self.model, FullyShardedDataParallel): | ||
self.model = wrap(self.model) | ||
else: | ||
self.model = wrap(LightningFullyShardedModule(self.model)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if manually wrapping contents inside the lightning module, is this final outer layer wrap needed? or could we defer this to the user in the lightning module too?
then we could not wrap model in the dummy LightningFullyShardedModule
to map forward to one of the step functions. would it also mean users don't have to refer to self.trainer.model inside of the lightning module?
would this avoid the parameter flattening issue across stages?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I understand! Theoretically I think so, since then we're just using the LM as a wrapper. So the cases I see:
- User wraps nothing, expects module to be wrapped by Lightning, and potentially
auto_wrap
to handle recursive wrapping - User wraps some of the layers in
configure_sharded_model
but then expects all other layers to be included in a higher wrapper class (wrap the entire LM) - User wraps all of the layers in
configure_sharded_model
, doesn't require any high level wrapping
Solutions
- This should be default behaviour, i.e
plugins=fsdp
orplugins=fsdp_auto_wrap
- This should be the same as 1., i.e
plugins=fsdp
orplugins=fsdp_auto_wrap
- This could be
plugins=fsdp_manual
where we do not wrap the highest level module, allowing the user to do whatever they'd like inconfigure_optimizers
.
In either case, it's important to fix the flattening issue for 1. and 2. which for most users trying out will be the first step I think. Thoughts @ananthsub?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly those 3 cases!
To unblock the initial integration, I was wondering if we should start with option #3 to unblock power users in release candidates with the caveat that they are responsible for the full wrapping. Maybe this could can be option on the plugin as to whether the outer wrap on lightning module needs to be applied in order to distinguish between cases 2 and 3.
Completely agreed with you that most users will opt for cases 1 and 2, so we'll need to figure out the parameter flattening, whether in lightning or fairscale, but wanted to offer this as one way we could sequence these cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started making changes to see if any issues arise with case 3 and a few observation:
The user still has to define a single model, may it be a Module containing modules in a sequential wrapper, or just defining their own model structure defining a forward function. This means self.model
will still probably be required in every case for FSDP to work in configure_optimizers
.
I also ran into an issue where clipping grad norms which in manual mode cannot be handled automatically, as we do not wrap the model:
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
"""Mixed Precision for Full Sharded Training"""
def clip_gradients(`
self,
optimizer: 'Optimizer',
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
model: Optional[Module] = None
) -> None:
# Model manages clipping of gradients
model = cast(FullyShardedDataParallel, model)
# todo: expose norm type once precision plugin supports this.
model.clip_grad_norm_(clip_val, norm_type=2.0) # This breaks
A potential solution albeit not as elegant as I'd like, would be to go through the immediate children of the LightningModule, find the root FSDP module and call clip_grad_norm_
on it. I assume this will be a negligible cost added on top of the training loop but what are your thoughts @ananthsub?
self.model.to(self.root_device) | ||
# ensure we update the device type in the lightning module | ||
self.lightning_module.to(self.root_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might need to be cautious about this, as fsdp_module.to(device)
will summon full parameters first: https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L348-L367
and when we perform teardown
for GPU memory cleanup, we have self.lightning_module.cpu()
return unwrap_lightning_module_fully_sharded(self.model) | ||
|
||
def on_save(self, checkpoint: dict) -> dict: | ||
state_dict = self.collate_state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SeanNaren , after getting detailed memory usage, I finally figured out why originally the full model fits in one GPU, but when checkpointing, it OOM
because in checkpoint_connector
(https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L270-L277), we have
model = self.trainer.lightning_module
checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': model.state_dict(),
}
here, we try to collect again, this would double the size.
One easy workaround now, is to add
del checkpoint['state_dict']
but this is not ideal, we summon the full parameters twice which is unnecessary.
I feel, we should modify that file to let training type plugin to control, something like trainer.accelerator.training_type_plugin.state_dict()
especially when we would like to collect only sharded state dict in the future.
cc @ananthsub
@min-xu-ai , I think this is the root cause for OOM, facebookresearch/fairscale#658 should not be problem (for setting state_dict_device=torch.device("cpu")
, CPU OOM should be similar problem as we also double the model storage in CPU)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @shuyingsunshine21 for your help here! This makes sense since we're allocating memory new memory.
I agree with allowing the training type plugin to return the state dict, we already rely on the accelerator to dump the optimizer dicts. I'm happy to make the change!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SeanNaren , thanks, no worry, if you have not already made the change, I could help send a small PR for that.
I am closing this in favour of #7487 Remaining is the ability to auto wrap the model so users do not have to manually annotate layers. This will come in followup PRs once we figure out the case :) |
What does this PR do?
Integrates fully sharded (ZeRO Stage 3) parallelism as seen in https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html
This also deprecates pipe + stops CI running the tests by updating the fairscale installation, as we move towards a full replacement primarily due to elegance and a long term future using FSDP.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃