Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Wrapping model again in FSDP doesn't contain root parameters #648

Open
SeanNaren opened this issue May 4, 2021 · 11 comments
Open
Labels
FSDP FullyShardedDataParallel (zero-3)

Comments

@SeanNaren
Copy link

SeanNaren commented May 4, 2021

🐛 Bug

Related Lightning-AI/pytorch-lightning#6152

When wrapping the module twice in FSDP, because we introduce a FlattenParamsWrapper that contains all the parameters, this means the second wrapping does not contain the parameters for the model. This is required for Lightning where we wrap the LightningModule in training, and return this LightningModule back to the user who may call trainer.test(module) independently.

This potentially could be solved by removing FlattenParamsWrapper and returning sharded weights back to the correct references permanently after training. Is this doable?

I'm not 100% on a solution here (even from the Lightning side) so if you have ideas please let me know!

To Reproduce

import os
import unittest
from unittest import mock

import torch
import torch.nn as nn
from fairscale.nn import FullyShardedDataParallel
import torch.nn.functional as F


@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "1337"}, clear=True)
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
def test_wrapping_module():
    """
    This test simulates wrapping the module after training to run inference.
    This is required in cases where later in a session, the model is wrapped again in FSDP but
    contains nested FSDP wrappers within the module.
    """
    device = torch.device("cuda")
    torch.cuda.set_device(0)

    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    module = nn.Sequential(
        nn.Linear(5, 5),
        FullyShardedDataParallel(nn.Linear(5, 5)),
    )

    training_model = FullyShardedDataParallel(module).to(device)

    input = torch.rand((1, 5), dtype=torch.float).to(device)
    output = training_model(input)
    loss = F.mse_loss(input, output)
    loss.backward()

    inference_model = FullyShardedDataParallel(module).to(device)
    second_output = inference_model(input)

    assert torch.allclose(output, second_output)

    # Fails as we are missing parameters in the highest root level FSDP wrap
    assert len(list(inference_model.parameters())) == len(list(training_model.parameters()))

    torch.distributed.destroy_process_group()

Expected behavior

Able to wrap the model in an FSDP wrapper again after model is trained.

cc @ananthsub @min-xu-ai @shuyingsunshine21

@min-xu-ai min-xu-ai added the FSDP FullyShardedDataParallel (zero-3) label May 4, 2021
@min-xu-ai
Copy link
Contributor

Thanks, @SeanNaren! What's the motivation to wrap it again for inference? Can you use the same (wrapped) model for it? I must be missing something obvious.

@SeanNaren
Copy link
Author

Thanks @min-xu-ai Lighting context will help decide if this is something we should look into fixing in Lightning!

Currently we treat train and test as two separate stages. So if the model is wrapped in train, we assume when we finish training, the module that the user passed to our Trainer is now a trained model. The motivation for this is it keeps each stage independent, allowing each stage to be run independently, i.e if test was ran without train we'd wrap the model using the same logic we'd wrap train with. With FSDP this line is slightly blurred however, since the final module is a sharded version, not a full weight version!

Using the same wrapped model would break this independency; we'd have to keep a reference to the wrapped model, ensuring all stages use this. This may have lots of unintended side effects in Lightning! However if it is deemed difficult to provide a un-flattened version that can be wrapped with FSDP again, this might be a required change to Lightning to support this case of keeping a reference.

@min-xu-ai
Copy link
Contributor

I see. I think the decoupling makes sense. It is very interesting that DDP in this case wrap the model twice:

In [21]: ddp = torch.nn.parallel.DistributedDataParallel(ddp)
/home/owen/e/py39_clip/lib/python3.9/site-packages/torch/nn/parallel/distributed.py:487: UserWarning: Single-Process Multi-GPU is not the recommended mode for DDP. In this mode, each DDP instance operates on multiple devices and creates multiple module replicas within one process. The overhead of scatter/gather and GIL contention in every forward pass can slow down training. Please consider using one DDP instance per device or per module replica by explicitly setting device_ids or CUDA_VISIBLE_DEVICES.
  warnings.warn(

In [22]: ddp.module
Out[22]:
DistributedDataParallel(
  (module): Linear(in_features=1, out_features=1, bias=True)
)

In [23]: ddp.module.module
Out[23]: Linear(in_features=1, out_features=1, bias=True)

I am surprised that this doesn't cause problems. Maybe it is because your second wrap is for inference only and DDP is basically a no-op in that case?

I need think more about what's the proper thing do this in this case TBH.

@min-xu-ai
Copy link
Contributor

@SeanNaren, maybe this is related to: #649?

What if we have the following:

FSDP(FSDP()) --> throw an error since wrapping consecutively multiple times doesn't seems to be useful

but

with enable_wrap():
    auto_wrap(FSDP()) or wrap(FSDP()) --> will return the already wrapped model without creating a new wrapper

Will the above be a good solution for your use cases?

@SeanNaren
Copy link
Author

Thanks @min-xu-ai! Unfortunately this would still be an issue for Lightning as we reference the model that is stored within the FSDP wrapper. If we had access to the model which has been wrapped with FSDP initially there would be no issues! I'm starting to suspect this is less of something FairScale should try to support as you mentioned for DDP, and a limitation in Lightning we need to address.

Our current workaround is to only support user wrapping, rather than lightning automatically wrapping the module for you for now.

@shuyingsunshine21 is working on this, and may be a stop gap as we discuss solutions here (but is blocked by #658 atm).

@min-xu-ai
Copy link
Contributor

In case you guys have update on your side, please share.

@myleott, @QuentinDuval, check out comment #648 (comment) if you haven't. Do you think it makes sense for FSDP to wrap another FSDP instance?

@rohitgr7
Copy link

hey guys! any update on this? With PyTorch FSDP it's now possible to check if the raw model has already been wrapped or not so that we can avoid this, but with fairscale FSDP it's not.

@min-xu-ai
Copy link
Contributor

hey guys! any update on this? With PyTorch FSDP it's now possible to check if the raw model has already been wrapped or not so that we can avoid this, but with fairscale FSDP it's not.

Any reason not using pytorch FSDP? fairscale FSDP is in maintenance mode is unlikely to be changed much.

@rohitgr7
Copy link

It's just that in pytorch-lightning, we support both.
Also, I read this on PyTorch FSDP release blog:

In the near future, FairScale FSDP will stay in the FairScale repository for research projects, while generic and widely adopted features will be upstreamed to PyTorch incrementally and hardened accordingly.

so I am wondering, if new/experimental features will be added to fairscale and then upstreamed to PyTorch or new/experimental features will be added to PyTorch itself?

@min-xu-ai
Copy link
Contributor

new/experimental ones will be added to fairscale if pytorch team is not sure whether or not those new feature may end up being useful for general user base or not. If they think a new feature will be useful, I don't see why not add to it directly.

@rohitgr7
Copy link

then, we might need to keep the fairscale FSDP support in pytorch lightning, so is it possible to resolve this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
Development

No branches or pull requests

3 participants