Skip to content

Releases: Lightning-AI/pytorch-lightning

Lightning v2.5 post0

21 Dec 01:35
9177ec0
Compare
Choose a tag to compare

Lightning v2.5

20 Dec 14:20
c45c3c9
Compare
Choose a tag to compare

Lightning AI ⚡ is excited to announce the release of Lightning 2.5.

Lightning 2.5 comes with improvements on several fronts, with zero API changes. Our users love it stable, we keep it stable 😄.

Talking about love ❤️, the lightning, pytorch-lightning and lightning-fabric packages are collectively getting more than 10M downloads per month 😮, for a total of over 180M downloads 🤯 since the early days . It's incredible to see PyTorch Lightning getting such a strong adoption across the industry and the sciences.

Release 2.5 embraces PyTorch 2.5, and it marks some of its more recent directions as officially supported, namely tensor subclass-based APIs like Distributed Tensors and TorchAO, in combination with torch.compile.

Here's a couple of examples:

Distributed FP8 transformer with PyTorch Lightning

Full example here

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training

class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.model = None

    def configure_model(self):
        if self.model is not None:
            return

        with torch.device("meta"):
            model = Transformer(
                vocab_size=self.vocab_size,
                nlayers=16,
                nhid=4096,
                ninp=1024,
                nhead=32,
            )

        float8_config = Float8LinearConfig(
            # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly  # noqa
            pad_inner_dim=True,
        )

        def module_filter_fn(mod: torch.nn.Module, fqn: str):
            # we skip the decoder because it typically vocabulary size
            # is not divisible by 16 as required by float8
            return fqn != "decoder"

        convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)

        for module in model.modules():
            if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
                fully_shard(module, mesh=self.device_mesh)

        fully_shard(model, mesh=self.device_mesh)

        self.model = torch.compile(model)

    def training_step(self, batch):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

def train():
    L.seed_everything(42)

    dataset = WikiText2()
    train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)

    model = LanguageModel(vocab_size=dataset.vocab_size)

    mp_strategy = ModelParallelStrategy(
        data_parallel_size=4,
        tensor_parallel_size=1,
    )

    trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8)

    trainer.fit(model, train_dataloader)

    trainer.print(torch.cuda.memory_summary())

if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")

    train()
Distributed FP8 transformer with Fabric

Full example here

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.pytorch.demos import Transformer, WikiText2
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
from tqdm import tqdm

def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
    float8_config = Float8LinearConfig(
        # pip install -U --index-url <https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/> triton-nightly  # noqa
        pad_inner_dim=True,
    )

    def module_filter_fn(mod: torch.nn.Module, fqn: str):
        # we skip the decoder because it typically vocabulary size
        # is not divisible by 16 as required by float8
        return fqn != "decoder"

    convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)

    for module in model.modules():
        if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
            fully_shard(module, mesh=device_mesh)

    fully_shard(model, mesh=device_mesh)

    return torch.compile(model)

def train():
    L.seed_everything(42)

    batch_size = 8
    micro_batch_size = 1

    max_steps = 100

    dataset = WikiText2()
    dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size)

    with torch.device("meta"):
        model = Transformer(
            vocab_size=dataset.vocab_size,
            nlayers=16,
            nhid=4096,
            ninp=1024,
            nhead=32,
        )

    strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model)

    fabric = L.Fabric(precision="bf16-true", strategy=strategy)
    fabric.launch()

    model = fabric.setup(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    optimizer = fabric.setup_optimizers(optimizer)

    dataloader = fabric.setup_dataloaders(dataloader)

    iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader)

    steps = 0

    for i, batch in iterable:
        input, target = batch

        is_accumulating = i % (batch_size // micro_batch_size) != 0

        with fabric.no_backward_sync(model, enabled=is_accumulating):
            output = model(input, target)
            loss = F.nll_loss(output, target.view(-1))
            fabric.backward(loss)

        if not is_accumulating:
            fabric.clip_gradients(model, optimizer, max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
            steps += 1

        if fabric.is_global_zero:
            iterable.set_postfix_str(f"train_loss={loss.item():.2f}")

        if steps == max_steps:
            break

    fabric.print(torch.cuda.memory_summary())

if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")

    train()

As these examples show, it's now easier than ever to take your PyTorch Lightning module and run it with FSDP2 and/or tensor parallelism in FP8 precision, using the ModelParallelStrategy we introduced in 2.4.

Also note the use of distributed tensor APIs, TorchAO APIs, and torch.compile directly in the configure_model hook (or in the parallelize function in Fabric's ModelParallelStrategy), as opposed to the LightningModule as a whole. The advantage with this approach is that you can just copy-paste the parallelize functions that come with native PyTorch models directly in configure_model and get the same effect, no head-scratching involved 🤓.

Talking about head scratching, we also made a pass at the PyTorch Lightning internals and hardened the parts where we keep track of progress counters during training, validation, testing, as well as learning rate scheduling, in relation to resuming from checkpoints. We now made sure there are no (to the best of our knowledge) edge cases where stopping and resuming from checkpoints can change the sequence of loops or other internal states. Fault tolerance for the win 🥳!

Alright! Feel free to take a look at the full changelog below.

And of course: the best way to use PyTorch Lightning and Fabric is through Lightning Studio ⚡. Access GPUs, train models, deploy and more with zero setup. Focus on data and models - not infrastructure.

Changes

PyTorch Lightning

Added
  • Added step parameter to TensorBoardLogger.log_hyperparams to visualize changes during training (#20176)
  • Added str method to datamodule (#20301)
  • Added timeout to DeepSpeedStrategy (#20474)
  • Added doc for Truncated Back-Propagation Through Time (#20422)
  • Added FP8 + FSDP2 + torch.compile examples for PyTorch Lightning (#20440)
  • Added profiling to Trainer.save_checkpoint (#20405)
  • Added after_instantiate_classes hook to CLI (#20401)

<details...

Read more

Lightning 2.5 RC

12 Dec 15:22
110d621
Compare
Choose a tag to compare
Lightning 2.5 RC Pre-release
Pre-release
2.5.0rc0

Bump to 2.5.0rc0 (#20493)

Lightning v2.4

07 Aug 09:44
2129fdf
Compare
Choose a tag to compare

Lightning AI ⚡ is excited to announce the release of Lightning 2.4. This is mainly a compatibility upgrade for PyTorch 2.4 and Python 3.12, with a sprinkle of a few features and bug fixes.

Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.

Changes

PyTorch Lightning

Added
  • Made saving non-distributed checkpoints fully atomic (#20011)
  • Added dump_stats flag to AdvancedProfiler (#19703)
  • Added a flag verbose to the seed_everything() function (#20108)
  • Added support for PyTorch 2.4 (#20010)
  • Added support for Python 3.12 (20078)
  • The TQDMProgressBar now provides an option to retain prior training epoch bars (#19578)
  • Added the count of modules in train and eval mode to the printed ModelSummary table (#20159)
Changed
  • Triggering KeyboardInterrupt (Ctrl+C) during .fit(), .evaluate(), .test() or .predict() now terminates all processes launched by the Trainer and exits the program (#19976)
  • Changed the implementation of how seeds are chosen for dataloader workers when using seed_everything(..., workers=True) (#20055)
  • NumPy is no longer a required dependency (#20090)
Removed
  • Removed support for PyTorch 2.1 (#20009)
  • Removed support for Python 3.8 (#20071)
Fixed
  • Avoid LightningCLI saving hyperparameters with class_path and init_args since this would be a breaking change (#20068)
  • Fixed an issue that would cause too many printouts of the seed info when using seed_everything() (#20108)
  • Fixed _LoggerConnector's _ResultMetric to move all registered keys to the device of the logged value if needed (#19814)
  • Fixed _optimizer_to_device logic for special 'step' key in optimizer state causing performance regression (#20019)
  • Fixed parameter counts in ModelSummary when model has distributed parameters (DTensor) (#20163)

Lightning Fabric

Added
  • Made saving non-distributed checkpoints fully atomic (#20011)
  • Added a flag verbose to the seed_everything() function (#20108)
  • Added support for PyTorch 2.4 (#20028)
  • Added support for Python 3.12 (20078)
Changed
  • Changed the implementation of how seeds are chosen for dataloader workers when using seed_everything(..., workers=True) (#20055)
  • NumPy is no longer a required dependency (#20090)
Removed
  • Removed support for PyTorch 2.1 (#20009)
  • Removed support for Python 3.8 (#20071)
Fixed
  • Fixed an attribute error when loading a checkpoint into a quantized model using the _lazy_load() function (#20121)
  • Fixed _optimizer_to_device logic for special 'step' key in optimizer state causing performance regression (#20019)

Full commit list: 2.3.0 -> 2.4.0

Contributors

We thank all our contributors who submitted pull requests for features, bug fixes and documentation updates.

New Contributors

Did you know?

Chuck Norris can solve NP-hard problems in polynomial time. In fact, any problem is easy when Chuck Norris solves it.

Patch release v2.3.3

08 Jul 20:42
Compare
Choose a tag to compare

This release removes the code from the main lightning package that was reported in CVE-2024-5980.

Patch release v2.3.2

04 Jul 09:06
056bb08
Compare
Choose a tag to compare

Includes a minor bugfix that avoids a conflict with the entrypoint command with another package #20041.

Patch release v2.3.1

27 Jun 17:47
8b69285
Compare
Choose a tag to compare

Includes minor bugfixes and stability improvements.

Full Changelog: 2.3.0...2.3.1

Lightning v2.3: Tensor Parallelism and 2D Parallelism

13 Jun 21:30
a42484c
Compare
Choose a tag to compare

Lightning AI is excited to announce the release of Lightning 2.3 ⚡

Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.

This release introduces experimental support for Tensor Parallelism and 2D Parallelism, PyTorch 2.3 support, and several bugfixes and stability improvements.

Highlights

Tensor Parallelism (beta)

Tensor parallelism (TP) is a technique that splits up the computation of selected layers across GPUs to save memory and speed up distributed models. To enable TP as well as other forms of parallelism, we introduce a ModelParallelStrategy for both Lightning Trainer and Fabric. Under the hood, TP is enabled through new experimental PyTorch APIs like DTensor and torch.distributed.tensor.parallel.

PyTorch Lightning

Enabling TP in a model with PyTorch Lightning requires you to implement the LightningModule.configure_model() method where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.

Open In Studio

 

import lightning as L
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module


# 1. Implement the `configure_model()` method in LightningModule
class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = FeedForward(8192, 8192)

    def configure_model(self):
        # Lightning will set up a `self.device_mesh` for you
        tp_mesh = self.device_mesh["tensor_parallel"]
        # Use PyTorch's distributed tensor APIs to parallelize the model
        plan = {
            "w1": ColwiseParallel(),
            "w2": RowwiseParallel(),
            "w3": ColwiseParallel(),
        }
        parallelize_module(self.model, tp_mesh, plan)

    def training_step(self, batch):
        ...


# 2. Create the strategy
strategy = ModelParallelStrategy()

# 3. Configure devices and set the strategy in Trainer
trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy)
trainer.fit(...)
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module

import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.pytorch.strategies import ModelParallelStrategy


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = FeedForward(8192, 8192)

    def configure_model(self):
        if self.device_mesh is None:
            return

        # Lightning will set up a `self.device_mesh` for you
        tp_mesh = self.device_mesh["tensor_parallel"]
        # Use PyTorch's distributed tensor APIs to parallelize the model
        plan = {
            "w1": ColwiseParallel(),
            "w2": RowwiseParallel(),
            "w3": ColwiseParallel(),
        }
        parallelize_module(self.model, tp_mesh, plan)

    def training_step(self, batch):
        output = self.model(batch)
        loss = output.sum()
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=3e-3)

    def train_dataloader(self):
        # Trainer configures the sampler automatically for you such that
        # all batches in a tensor-parallel group are identical
        dataset = RandomDataset(8192, 64)
        return torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2)


strategy = ModelParallelStrategy()
trainer = L.Trainer(
    accelerator="cuda",
    devices=2,
    strategy=strategy,
    max_epochs=1,
)

model = LitModel()
trainer.fit(model)

trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

Lightning Fabric

Applying TP in a model with Fabric requires you to implement a special function where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.

Open In Studio

 

import lightning as L
from lightning.fabric.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module


# 1. Implement the parallelization function for your model
def parallelize_feedforward(model, device_mesh):
    # Lightning will set up a device mesh for you
    tp_mesh = device_mesh["tensor_parallel"]
    # Use PyTorch's distributed tensor APIs to parallelize the model
    plan = {
        "w1": ColwiseParallel(),
        "w2": RowwiseParallel(),
        "w3": ColwiseParallel(),
    }
    parallelize_module(model, tp_mesh, plan)
    return model


# 2. Pass the parallelization function to the strategy
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)

# 3. Configure devices and set the strategy in Fabric
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module

import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.fabric.strategies import ModelParallelStrategy


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


def parallelize_feedforward(model, device_mesh):
    # Lightning will set up a device mesh for you
    tp_mesh = device_mesh["tensor_parallel"]
    # Use PyTorch's distributed tensor APIs to parallelize the model
    plan = {
        "w1": ColwiseParallel(),
        "w2": RowwiseParallel(),
        "w3": ColwiseParallel(),
    }
    parallelize_module(model, tp_mesh, plan)
    return model


strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()

# Initialize the model
model = FeedForward(8192, 8192)
model = fabric.setup(model)

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
optimizer = fabric.setup_optimizers(optimizer)

# Define dataset/dataloader
dataset = RandomDataset(8192, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_si...
Read more

Patch release v2.2.5

22 May 17:28
ac3f1ee
Compare
Choose a tag to compare

PyTorch Lightning + Fabric

Fixed

  • Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) (#19886)

Full Changelog: 2.2.4...2.2.5

Patch release v2.2.4

01 May 22:50
Compare
Choose a tag to compare

App

Fixed

  • Fixed HTTPClient retry for flow/work queue (#19837)

PyTorch

No Changes.

Fabric

No Changes.

Full Changelog: 2.2.3...2.2.4