Lightning v2.5
Lightning AI ⚡ is excited to announce the release of Lightning 2.5.
Lightning 2.5 comes with improvements on several fronts, with zero API changes. Our users love it stable, we keep it stable 😄.
Talking about love ❤️, the lightning
, pytorch-lightning
and lightning-fabric
packages are collectively getting more than 10M downloads per month 😮, for a total of over 180M downloads 🤯 since the early days . It's incredible to see PyTorch Lightning getting such a strong adoption across the industry and the sciences.
Release 2.5 embraces PyTorch 2.5, and it marks some of its more recent directions as officially supported, namely tensor subclass-based APIs like Distributed Tensors and TorchAO, in combination with torch.compile
.
Here's a couple of examples:
Distributed FP8 transformer with PyTorch Lightning
Full example here
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.model = None
def configure_model(self):
if self.model is not None:
return
with torch.device("meta"):
model = Transformer(
vocab_size=self.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)
float8_config = Float8LinearConfig(
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa
pad_inner_dim=True,
)
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# we skip the decoder because it typically vocabulary size
# is not divisible by 16 as required by float8
return fqn != "decoder"
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
for module in model.modules():
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
fully_shard(module, mesh=self.device_mesh)
fully_shard(model, mesh=self.device_mesh)
self.model = torch.compile(model)
def training_step(self, batch):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-4)
def train():
L.seed_everything(42)
dataset = WikiText2()
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
model = LanguageModel(vocab_size=dataset.vocab_size)
mp_strategy = ModelParallelStrategy(
data_parallel_size=4,
tensor_parallel_size=1,
)
trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8)
trainer.fit(model, train_dataloader)
trainer.print(torch.cuda.memory_summary())
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
train()
Distributed FP8 transformer with Fabric
Full example here
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.pytorch.demos import Transformer, WikiText2
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
from tqdm import tqdm
def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
float8_config = Float8LinearConfig(
# pip install -U --index-url <https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/> triton-nightly # noqa
pad_inner_dim=True,
)
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# we skip the decoder because it typically vocabulary size
# is not divisible by 16 as required by float8
return fqn != "decoder"
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
for module in model.modules():
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
fully_shard(module, mesh=device_mesh)
fully_shard(model, mesh=device_mesh)
return torch.compile(model)
def train():
L.seed_everything(42)
batch_size = 8
micro_batch_size = 1
max_steps = 100
dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size)
with torch.device("meta"):
model = Transformer(
vocab_size=dataset.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)
strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model)
fabric = L.Fabric(precision="bf16-true", strategy=strategy)
fabric.launch()
model = fabric.setup(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = fabric.setup_optimizers(optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader)
steps = 0
for i, batch in iterable:
input, target = batch
is_accumulating = i % (batch_size // micro_batch_size) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
output = model(input, target)
loss = F.nll_loss(output, target.view(-1))
fabric.backward(loss)
if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
steps += 1
if fabric.is_global_zero:
iterable.set_postfix_str(f"train_loss={loss.item():.2f}")
if steps == max_steps:
break
fabric.print(torch.cuda.memory_summary())
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
train()
As these examples show, it's now easier than ever to take your PyTorch Lightning module and run it with FSDP2 and/or tensor parallelism in FP8 precision, using the ModelParallelStrategy
we introduced in 2.4.
Also note the use of distributed tensor APIs, TorchAO APIs, and torch.compile
directly in the configure_model
hook (or in the parallelize function in Fabric's ModelParallelStrategy
), as opposed to the LightningModule
as a whole. The advantage with this approach is that you can just copy-paste the parallelize functions that come with native PyTorch models directly in configure_model
and get the same effect, no head-scratching involved 🤓.
Talking about head scratching, we also made a pass at the PyTorch Lightning internals and hardened the parts where we keep track of progress counters during training, validation, testing, as well as learning rate scheduling, in relation to resuming from checkpoints. We now made sure there are no (to the best of our knowledge) edge cases where stopping and resuming from checkpoints can change the sequence of loops or other internal states. Fault tolerance for the win 🥳!
Alright! Feel free to take a look at the full changelog below.
And of course: the best way to use PyTorch Lightning and Fabric is through Lightning Studio ⚡. Access GPUs, train models, deploy and more with zero setup. Focus on data and models - not infrastructure.
Changes
PyTorch Lightning
Added
- Added
step
parameter toTensorBoardLogger.log_hyperparams
to visualize changes during training (#20176) - Added
str
method to datamodule (#20301) - Added timeout to DeepSpeedStrategy (#20474)
- Added doc for Truncated Back-Propagation Through Time (#20422)
- Added FP8 + FSDP2 + torch.compile examples for PyTorch Lightning (#20440)
- Added profiling to
Trainer.save_checkpoint
(#20405) - Added after_instantiate_classes hook to CLI (#20401)
Changed
- Updated checkpointing documentation to mark
resume_from_checkpoint
as deprecated (#20477) - Made plugin type checks more flexible (#20186)
- Changed seeding NumPy using
np.random.SeedSequence()
inpl_worker_init_function()
to robustly seed NumPy-dependent dataloader workers (#20369) - Allowed callbacks to be restored not just during training (#20403)
- Changed LightningCLI tests to account for future fix in jsonargparse (#20372)
- Bumped PyTorch to version
2.5
(#20351) - Decoupled checkpoint artifact path from model artifact path (#20325)
- Updated BitsAndBytes version (#20313)
- Changed merging of hparams when logging to ignore parameter names that start with an underscore
_
(#20221) - Re-enabled passing
BytesIO
as path in.to_onnx()
(#20172)
Removed
- Removed
List[int]
as input type for Trainer whenaccelerator="cpu"
(#20399)
Fixed
- Fixed UnboundLocalError when using the predict method with return_predictions=False. (#20484)
- Fixed use of
convert_module
in FSDP to avoid using more memory than necessary during initialization (#20323) - Fixed TypeError in
configure_optimizers
when running withReduceLROnPlateau
(#20471) - Fixed return type in
configure_optimizers
example (#20420) - Fixed in ncorrect URI prefix stripping in MLFlowLogger (#20365)
- Fixed shuffling behavior when using a custom sampler in data module (#20327)
- Ensured restarting from checkpoints leads to consistent internal counters compared to uninterrupted training (#20379)
- Fixed LightningCLI failing when both module and data module save hyperparameters due to conflicting internal
_class_path
parameter (#20221)
Lightning Fabric
Added
- Added
step
parameter toTensorBoardLogger.log_hyperparams
to visualize changes during training (#20176) - Added timeout to DeepSpeedStrategy (#20474)
- Added FP8 + FSDP2 + torch.compile examples for Fabric (#20440)
- Added RTX 4080 super to chips dictionary (#20285)
- Added device property to lazy load functionality (#20183)
- Added
ddp_find_unused_parameters_true
alias in Fabric's DDPStrategy (#20125)
Changed
Removed
- Nothing to see here 😄
Fixed
- Fixed use of
convert_module
in FSDP to avoid using more memory than necessary during initialization (#20323)
Full commit list: 2.4.0 -> 2.5.0
Contributors
We thank all folks who submitted issues, features, fixes and doc changes. It's the only way we can collectively make Lightning ⚡ better for everyone, nice job!
In particular, we would like to thank the authors of the pull-requests above, in no particular order:
@ringohoffman @MrWhatZitToYaa @jedyang97 @chualanagit @lantiga @AlessandroW @kazuar @t-vi @01AbhiSingh @WangYue0000 @amorehead @EricCousineau-TRI @mauvilsa @Borda @pete-mcelroy @ali-alshaar7 @GdoongMathew @farhadrgh @tshu-w @LukasSalchow @awindmann @dadwadw233 @qingquansong
Thank you ❤️ and we hope you'll keep them coming!