Skip to content

Lightning v2.5

Compare
Choose a tag to compare
@lantiga lantiga released this 20 Dec 14:20
· 4 commits to master since this release
c45c3c9

Lightning AI ⚡ is excited to announce the release of Lightning 2.5.

Lightning 2.5 comes with improvements on several fronts, with zero API changes. Our users love it stable, we keep it stable 😄.

Talking about love ❤️, the lightning, pytorch-lightning and lightning-fabric packages are collectively getting more than 10M downloads per month 😮, for a total of over 180M downloads 🤯 since the early days . It's incredible to see PyTorch Lightning getting such a strong adoption across the industry and the sciences.

Release 2.5 embraces PyTorch 2.5, and it marks some of its more recent directions as officially supported, namely tensor subclass-based APIs like Distributed Tensors and TorchAO, in combination with torch.compile.

Here's a couple of examples:

Distributed FP8 transformer with PyTorch Lightning

Full example here

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training

class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.model = None

    def configure_model(self):
        if self.model is not None:
            return

        with torch.device("meta"):
            model = Transformer(
                vocab_size=self.vocab_size,
                nlayers=16,
                nhid=4096,
                ninp=1024,
                nhead=32,
            )

        float8_config = Float8LinearConfig(
            # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly  # noqa
            pad_inner_dim=True,
        )

        def module_filter_fn(mod: torch.nn.Module, fqn: str):
            # we skip the decoder because it typically vocabulary size
            # is not divisible by 16 as required by float8
            return fqn != "decoder"

        convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)

        for module in model.modules():
            if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
                fully_shard(module, mesh=self.device_mesh)

        fully_shard(model, mesh=self.device_mesh)

        self.model = torch.compile(model)

    def training_step(self, batch):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

def train():
    L.seed_everything(42)

    dataset = WikiText2()
    train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)

    model = LanguageModel(vocab_size=dataset.vocab_size)

    mp_strategy = ModelParallelStrategy(
        data_parallel_size=4,
        tensor_parallel_size=1,
    )

    trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8)

    trainer.fit(model, train_dataloader)

    trainer.print(torch.cuda.memory_summary())

if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")

    train()
Distributed FP8 transformer with Fabric

Full example here

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.pytorch.demos import Transformer, WikiText2
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
from tqdm import tqdm

def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
    float8_config = Float8LinearConfig(
        # pip install -U --index-url <https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/> triton-nightly  # noqa
        pad_inner_dim=True,
    )

    def module_filter_fn(mod: torch.nn.Module, fqn: str):
        # we skip the decoder because it typically vocabulary size
        # is not divisible by 16 as required by float8
        return fqn != "decoder"

    convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)

    for module in model.modules():
        if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
            fully_shard(module, mesh=device_mesh)

    fully_shard(model, mesh=device_mesh)

    return torch.compile(model)

def train():
    L.seed_everything(42)

    batch_size = 8
    micro_batch_size = 1

    max_steps = 100

    dataset = WikiText2()
    dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size)

    with torch.device("meta"):
        model = Transformer(
            vocab_size=dataset.vocab_size,
            nlayers=16,
            nhid=4096,
            ninp=1024,
            nhead=32,
        )

    strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model)

    fabric = L.Fabric(precision="bf16-true", strategy=strategy)
    fabric.launch()

    model = fabric.setup(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    optimizer = fabric.setup_optimizers(optimizer)

    dataloader = fabric.setup_dataloaders(dataloader)

    iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader)

    steps = 0

    for i, batch in iterable:
        input, target = batch

        is_accumulating = i % (batch_size // micro_batch_size) != 0

        with fabric.no_backward_sync(model, enabled=is_accumulating):
            output = model(input, target)
            loss = F.nll_loss(output, target.view(-1))
            fabric.backward(loss)

        if not is_accumulating:
            fabric.clip_gradients(model, optimizer, max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
            steps += 1

        if fabric.is_global_zero:
            iterable.set_postfix_str(f"train_loss={loss.item():.2f}")

        if steps == max_steps:
            break

    fabric.print(torch.cuda.memory_summary())

if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")

    train()

As these examples show, it's now easier than ever to take your PyTorch Lightning module and run it with FSDP2 and/or tensor parallelism in FP8 precision, using the ModelParallelStrategy we introduced in 2.4.

Also note the use of distributed tensor APIs, TorchAO APIs, and torch.compile directly in the configure_model hook (or in the parallelize function in Fabric's ModelParallelStrategy), as opposed to the LightningModule as a whole. The advantage with this approach is that you can just copy-paste the parallelize functions that come with native PyTorch models directly in configure_model and get the same effect, no head-scratching involved 🤓.

Talking about head scratching, we also made a pass at the PyTorch Lightning internals and hardened the parts where we keep track of progress counters during training, validation, testing, as well as learning rate scheduling, in relation to resuming from checkpoints. We now made sure there are no (to the best of our knowledge) edge cases where stopping and resuming from checkpoints can change the sequence of loops or other internal states. Fault tolerance for the win 🥳!

Alright! Feel free to take a look at the full changelog below.

And of course: the best way to use PyTorch Lightning and Fabric is through Lightning Studio ⚡. Access GPUs, train models, deploy and more with zero setup. Focus on data and models - not infrastructure.

Changes

PyTorch Lightning

Added
  • Added step parameter to TensorBoardLogger.log_hyperparams to visualize changes during training (#20176)
  • Added str method to datamodule (#20301)
  • Added timeout to DeepSpeedStrategy (#20474)
  • Added doc for Truncated Back-Propagation Through Time (#20422)
  • Added FP8 + FSDP2 + torch.compile examples for PyTorch Lightning (#20440)
  • Added profiling to Trainer.save_checkpoint (#20405)
  • Added after_instantiate_classes hook to CLI (#20401)
Changed
  • Updated checkpointing documentation to mark resume_from_checkpoint as deprecated (#20477)
  • Made plugin type checks more flexible (#20186)
  • Changed seeding NumPy using np.random.SeedSequence() in pl_worker_init_function() to robustly seed NumPy-dependent dataloader workers (#20369)
  • Allowed callbacks to be restored not just during training (#20403)
  • Changed LightningCLI tests to account for future fix in jsonargparse (#20372)
  • Bumped PyTorch to version 2.5 (#20351)
  • Decoupled checkpoint artifact path from model artifact path (#20325)
  • Updated BitsAndBytes version (#20313)
  • Changed merging of hparams when logging to ignore parameter names that start with an underscore _ (#20221)
  • Re-enabled passing BytesIO as path in .to_onnx() (#20172)
Removed
  • Removed List[int] as input type for Trainer when accelerator="cpu" (#20399)
Fixed
  • Fixed UnboundLocalError when using the predict method with return_predictions=False. (#20484)
  • Fixed use of convert_module in FSDP to avoid using more memory than necessary during initialization (#20323)
  • Fixed TypeError in configure_optimizers when running with ReduceLROnPlateau (#20471)
  • Fixed return type in configure_optimizers example (#20420)
  • Fixed in ncorrect URI prefix stripping in MLFlowLogger (#20365)
  • Fixed shuffling behavior when using a custom sampler in data module (#20327)
  • Ensured restarting from checkpoints leads to consistent internal counters compared to uninterrupted training (#20379)
  • Fixed LightningCLI failing when both module and data module save hyperparameters due to conflicting internal _class_path parameter (#20221)

Lightning Fabric

Added
  • Added step parameter to TensorBoardLogger.log_hyperparams to visualize changes during training (#20176)
  • Added timeout to DeepSpeedStrategy (#20474)
  • Added FP8 + FSDP2 + torch.compile examples for Fabric (#20440)
  • Added RTX 4080 super to chips dictionary (#20285)
  • Added device property to lazy load functionality (#20183)
  • Added ddp_find_unused_parameters_true alias in Fabric's DDPStrategy (#20125)
Changed
  • Changed seeding NumPy using np.random.SeedSequence() in pl_worker_init_function() to robustly seed NumPy-dependent dataloader workers (#20369)
  • Bumped PyTorch to version 2.5 (#20351)
  • Update BitsAndBytes version (#20313)
Removed
  • Nothing to see here 😄
Fixed
  • Fixed use of convert_module in FSDP to avoid using more memory than necessary during initialization (#20323)

Full commit list: 2.4.0 -> 2.5.0

Contributors

We thank all folks who submitted issues, features, fixes and doc changes. It's the only way we can collectively make Lightning ⚡ better for everyone, nice job!

In particular, we would like to thank the authors of the pull-requests above, in no particular order:

@ringohoffman @MrWhatZitToYaa @jedyang97 @chualanagit @lantiga @AlessandroW @kazuar @t-vi @01AbhiSingh @WangYue0000 @amorehead @EricCousineau-TRI @mauvilsa @Borda @pete-mcelroy @ali-alshaar7 @GdoongMathew @farhadrgh @tshu-w @LukasSalchow @awindmann @dadwadw233 @qingquansong

Thank you ❤️ and we hope you'll keep them coming!