Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 2.3.0 #19954

Merged
merged 6 commits into from
Jun 11, 2024
Merged

Release 2.3.0 #19954

merged 6 commits into from
Jun 11, 2024

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Jun 6, 2024

Below is the draft for the release notes:

Lightning v2.3: Tensor Parallelism and 2D Parallelism

Lightning AI is excited to announce the release of Lightning 2.3 ⚡

Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.

This release introduces experimental support for Tensor Parallelism and 2D Parallelism, PyTorch 2.3 support, and several bugfixes and stability improvements.

Highlights

Tensor Parallelism (beta)

Tensor parallelism (TP) is a technique that splits up the computation of selected layers across GPUs to save memory and speed up distributed models. To enable TP as well as other forms of parallelism, we introduce a ModelParallelStrategy for both Lightning Trainer and Fabric. Under the hood, TP is enabled through new experimental PyTorch APIs like DTensor and torch.distributed.tensor.parallel.

PyTorch Lightning

Enabling TP in a model with PyTorch Lightning requires you to implement the LightningModule.configure_model() method where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.

Open In Studio

 

import lightning as L
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module


# 1. Implement the `configure_model()` method in LightningModule
class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = FeedForward(8192, 8192)

    def configure_model(self):
        # Lightning will set up a `self.device_mesh` for you
        tp_mesh = self.device_mesh["tensor_parallel"]
        # Use PyTorch's distributed tensor APIs to parallelize the model
        plan = {
            "w1": ColwiseParallel(),
            "w2": RowwiseParallel(),
            "w3": ColwiseParallel(),
        }
        parallelize_module(self.model, tp_mesh, plan)

    def training_step(self, batch):
        ...


# 2. Create the strategy
strategy = ModelParallelStrategy()

# 3. Configure devices and set the strategy in Trainer
trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy)
trainer.fit(...)
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module

import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.pytorch.strategies import ModelParallelStrategy


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = FeedForward(8192, 8192)

    def configure_model(self):
        if self.device_mesh is None:
            return

        # Lightning will set up a `self.device_mesh` for you
        tp_mesh = self.device_mesh["tensor_parallel"]
        # Use PyTorch's distributed tensor APIs to parallelize the model
        plan = {
            "w1": ColwiseParallel(),
            "w2": RowwiseParallel(),
            "w3": ColwiseParallel(),
        }
        parallelize_module(self.model, tp_mesh, plan)

    def training_step(self, batch):
        output = self.model(batch)
        loss = output.sum()
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=3e-3)

    def train_dataloader(self):
        # Trainer configures the sampler automatically for you such that
        # all batches in a tensor-parallel group are identical
        dataset = RandomDataset(8192, 64)
        return torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2)


strategy = ModelParallelStrategy()
trainer = L.Trainer(
    accelerator="cuda",
    devices=2,
    strategy=strategy,
    max_epochs=1,
)

model = LitModel()
trainer.fit(model)

trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

Lightning Fabric

Applying TP in a model with Fabric requires you to implement a special function where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.

Open In Studio

 

import lightning as L
from lightning.fabric.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module


# 1. Implement the parallelization function for your model
def parallelize_feedforward(model, device_mesh):
    # Lightning will set up a device mesh for you
    tp_mesh = device_mesh["tensor_parallel"]
    # Use PyTorch's distributed tensor APIs to parallelize the model
    plan = {
        "w1": ColwiseParallel(),
        "w2": RowwiseParallel(),
        "w3": ColwiseParallel(),
    }
    parallelize_module(model, tp_mesh, plan)
    return model


# 2. Pass the parallelization function to the strategy
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)

# 3. Configure devices and set the strategy in Fabric
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module

import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.fabric.strategies import ModelParallelStrategy


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


def parallelize_feedforward(model, device_mesh):
    # Lightning will set up a device mesh for you
    tp_mesh = device_mesh["tensor_parallel"]
    # Use PyTorch's distributed tensor APIs to parallelize the model
    plan = {
        "w1": ColwiseParallel(),
        "w2": RowwiseParallel(),
        "w3": ColwiseParallel(),
    }
    parallelize_module(model, tp_mesh, plan)
    return model


strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()

# Initialize the model
model = FeedForward(8192, 8192)
model = fabric.setup(model)

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
optimizer = fabric.setup_optimizers(optimizer)

# Define dataset/dataloader
dataset = RandomDataset(8192, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)

# Simplified training loop
for i, batch in enumerate(dataloader):
    output = model(batch)
    loss = output.sum()
    fabric.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
    fabric.print(f"Iteration {i} complete")

fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

2D Parallelism (beta)

Tensor Parallelism by itself can be very effective for efficient inference of very large models. For training, TP is typically combined with other forms of parallelism, such as FSDP, to increase throughput and scalability on large clusters with 100s of GPUs. The new ModelParallelStrategy in this release supports the combination of TP + FSDP, which is referred to as 2D parallelism.

For an introduction to this feature, please also refer to the tutorial Studios (PyTorch Lightning, Lightning Fabric). At the moment, the PyTorch team is reimplementing FSDP under the name FSDP2 with the aim to make it compose well with other parallelisms such as TP. Therefore, for the experimental 2D parallelism support, you'll need to switch to using FSDP2 with the new ModelParallelStrategy. Please refer to our docs (PyTorch Lightning, Lightning Fabric) and stay tuned for future releases as these APIs mature.

Training Mode in Model Summary

The model summary table that gets displayed when you run Trainer.fit() now contains a new column "Mode" that shows the training mode each layer is in (#19468).

  | Name                 | Type            | Params | Mode 
-----------------------------------------------------------------
0 | model                | Sam             | 93.7 M | train
1 | model.image_encoder  | ImageEncoderViT | 89.7 M | eval 
2 | model.prompt_encoder | PromptEncoder   | 6.2 K  | train
3 | model.mask_decoder   | MaskDecoder     | 4.1 M  | train
-----------------------------------------------------------------
93.7 M    Trainable params
0         Non-trainable params
93.7 M    Total params
374.942   Total estimated model params size (MB)

A module in PyTorch is always either in train (default) or eval mode.
This improvement should give users more visibility into the state of their model and help debug issues, for example when you need to make sure certain layers of the model are frozen.

Special Forward Methods in Fabric

Until now, Lightning Fabric warned the user in case the forward pass of the model or a subset of its modules was conducted through methods other than the dedicated forward method of the PyTorch module. The reason for this is that PyTorch needs to run special hooks in case of DDP/FSDP and other strategies to function properly, and not running through the real forward method would skip these hooks and lead to correctness issues.

In Lightning Fabric 2.3, we added a feature to explicitly mark alternative forward methods so that Fabric can add the necessary rerouting behind the scenes:

import lightning as L

fabric = L.Fabric(devices=2, strategy="ddp")
fabric.launch()

model = MyModel()
model = fabric.setup(model)

# OK: Calling the model directly
output = model(input)

# ERROR: Calling another method that calls forward indirectly
prediction = model.generate(input)

# New: Mark special forward methods explicitly before using them
model.mark_forward_method(model.generate)

# OK: Now can use `model.generate()` in DDP/FSDP without issues
prediction = model.generate(input)

Find the full example and more details in our docs.

Notable Changes

The 2.0 series of Lightning releases guarantees core API stability: No name changes, argument renaming, hook removals etc. on core interfaces (Trainer, LightningModule, etc.) unless a feature is specifically marked experimental. Here we list a few behavioral changes made in places where the change was justified if it significantly improves the user experience, improves performance, or fixes the correctness of a feature. These changes will likely not impact most users.

Skipping the training step in DDP

It is no longer allowed to skip training_step() by returning None in distributed training (#19918). The following usage was previously possible but would result in unpredictable hangs and timeouts in distributed training:

def training_step(self, batch):
    loss = ...
    if loss.isnan():
        # No longer allowed in multi-GPU!
        # Raises error in Lightning >= 2.3
        return None
    return loss

We decided to raise an error if the user attempts to return None when running in a multi-GPU setting.

Miscellaneous Changes

  • Dropped support for PyTorch 1.13 (#19300). With every new Lightning release, we add official support for the latest PyTorch stable version and drop the oldest version in our support window.
  • The prepare_data() hook in LightningModule and LightningDataModule is now subject to a barrier without timeout to avoid long-running tasks to be interrupted (#19448).

CHANGELOG

PyTorch Lightning

Added
Changed
Deprecated
Removed
Fixed

Lightning Fabric

Added
Changed
Removed
Fixed

Full commit list: 2.2.0 -> 2.3.0

Contributors

We thank all our contributors who submitted pull requests for features, bug fixes and documentation updates.

New Contributors

TODO

cc @Borda @carmocca @justusschock @awaelchli

@awaelchli awaelchli added this to the 2.3 milestone Jun 6, 2024
@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package package data labels Jun 6, 2024
@awaelchli awaelchli marked this pull request as ready for review June 10, 2024 23:56
Copy link
Contributor

github-actions bot commented Jun 10, 2024

⚡ Required checks status: All passing 🟢

Groups summary

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/version.info.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/version.info.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, would it make sense to add a mention to ModelParallelStrategy to the README?

@mergify mergify bot added the ready PRs ready to be merged label Jun 11, 2024
@awaelchli
Copy link
Contributor Author

Thanks for the suggestion. I'll look for a good spot to mention it.

@awaelchli awaelchli merged commit f6fd046 into master Jun 11, 2024
64 of 67 checks passed
@awaelchli awaelchli deleted the draft-2.3.0 branch June 11, 2024 16:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric package pl Generic label for PyTorch Lightning package ready PRs ready to be merged release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants