Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Introduce strategy flag to Trainer #9053

Closed
kaushikb11 opened this issue Aug 23, 2021 · 22 comments
Closed

[RFC] Introduce strategy flag to Trainer #9053

kaushikb11 opened this issue Aug 23, 2021 · 22 comments
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@kaushikb11
Copy link
Contributor

kaushikb11 commented Aug 23, 2021

🚀 Feature

Motivation

The motivation is to have a separate accelerator_strategy flag to support passing training type aliases (ddp, ddp_spawn, etc) and custom TrainingTypePlugin objects.

Trainer(strategy="ddp", accelerator="gpu", devices=4)
Trainer(strategy=DDPPlugin(find_unused_parameters=False), accelerator="gpu", devices=4)
Trainer(strategy="ddp_spawn", accelerator="cpu", devices=4)
Trainer(strategy="ddp_spawn", accelerator="tpu", devices=4)

xxxxxxxxxxxxxx

Background

At the moment, there’s a single flag accelerator tied for Accelerators as well as Training Type plugins. We wish to have them decoupled and would like to add a separate flag accelerator_strategy for Training Type plugins!

trainer = Trainer(accelerator=GPUaccelerator(..))
trainer = Trainer(accelerator='ddp_spawn')

Alternate flags to set Training Types

  • accelerator
    • type: Optional[Union[str, Accelerator]] = None
    • Supports training types and Accelerator Objects
  • distributed_backend
    • type: Optional[str] = None
    • Deprecated, should use accelerator instead
  • plugins
    • type: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None
    • Supports custom lightning plugins & environment

What's the difference between passing training type to accelerator, distributed_backend, or plugins?

  • accelerator and distributed_backend only support DistributedType (ddp, ddp_spawn, etc), whereas plugins support Custom Training Types (DDPPlugin(), ddp_find_unused_parameters_false, etc).

xxxxxxxxxxxxxxxxxxxxx

Proposed Solution

  • Introduce strategy flag to Trainer.
  • Support the exceptions and deprecations mentioned below

Exceptions:

  • Trainer(distributed_backend="ddp_cpu", strategy="ddp_spawn")
  • Trainer(accelerator="ddp", strategy="ddp_spawn")
  • Trainer(plugins="ddp_find_unused_parameters_false", strategy="ddp_spawn")

Deprecations: (Deprecated in v1.5 & will be removed in v1.6)

  • Passing training type to accelerator flag
  • Passing training type to plugins flag

xxxxxxxxxxxxxxxxxxxxx

Related PR: #8597
Related Issue: #6090

If you agree with this change, react with 🎉, if not then 🙅🏽 with comments.

Alternatives

  • Only deprecate passing the TrainingTypePlugin into the plugins argument not the accelerator argument.
  • Use simpler strategy argument instead of accelerator_strategy.

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

@kaushikb11 kaushikb11 added feature Is an improvement or enhancement help wanted Open to be worked on design Includes a design discussion labels Aug 23, 2021
@ananthsub
Copy link
Contributor

FYI @four4fish

@yifuwang
Copy link
Contributor

yifuwang commented Aug 24, 2021

Hi @kaushikb11, thanks for the proposal! It makes perfect sense.

Using accelerator to specify distributed training algorithms only made sense when TrainingTypePlugin-s were mostly about DistributedDataParallel. It started to feel off with the introduction of ddp_fully_sharded. In model parallel use cases, often the algorithm's responsibility is not to make training faster, but to overcome the resource limitation (i.e. memory) of a single device/host. Hence, I really like the proposal to use different options to specify hardware accelerator and distributed training algorithm.

However, I don't think accelerator_strategy is a good name for specifying distributed training algorithms. IMHO, algorithms like ddp, ddp_fully_sharded, and potential algorithms for recsys and MoE are NOT strategies of accelerators. Today, TrainingTypePlugin-s are organized internally as strategies of Accelerator-s because Accelerator delegates most calls to TrainingTypePlugin. However, I don't think this delegation makes sense (I vaguely remember there's plan to change that @ananthsub).

I'd vote for parallelizer. IMHO, it feels very natural since it's directly borrowed from "data parallelism" and "model parallelism", and it covers both acceleration and scaling.

Edit: clarification

@ananthsub
Copy link
Contributor

ananthsub commented Aug 24, 2021

@kaushikb11 - I really like the proposal. Regarding accelerator, what happens to device-related related arguments on the Trainer constructor?

  • num_processes
  • gpus
  • auto_select_gpus
  • tpu_cores
  • ipus

Will these be deprecated? I would strongly prefer we have 1 way to specify different devices for distributed training, and the device_type/strategy/device_ids distinction is highly extensible compared to having separate flags for every new device. For instance, as a user I would find it very confusing if I can specify both accelerator="gpu" and gpus and tpu_cores all at the same time.

@yifuwang - what about accelerator as intended for hardware acceleration, like GPU or TPUs? In this light,' parallelizer="gpu" doesn't sound fully right either as we could still be doing single-device training

@kaushikb11
Copy link
Contributor Author

@ananthsub

what happens to device-related related arguments on the Trainer constructor?

I think it's a very strong user-facing API change for the Users to accept, unfortunately. Also, it's more intuitive for the user to set gpus or ipus with one flag rather than remembering two.

But the Trainer(accelerator="x", devices="x") combo really shines at two instances:

  • When the User does Trainer(gpus=8) and wishes to switch to TPU, they would need to make two changes. One to change the value of gpus and one to change tpu_cores. With the new API, you only had to change accelerator from gpu to tpu.

  • With Trainer(accelerator="auto", devices=4), auto-selecting of an accelerator is great functionality to have with precedence (TPUs > IPUs > GPUs > CPUs). We could also support devices=auto, seems to be a good enhancement as well.

@kaushikb11
Copy link
Contributor Author

@yifuwang

If possible, I think parallelizer/parallelization_strategy would be the better term. IMHO they feel very natural since they are directly borrowed from "data parallelism" and "model parallelism", and they cover both acceleration and scaling.

Initially, we had considered distributed_type/distributed_strategy naming for the argument. But it wouldn't make sense if SingleDevice plugins are passed.

@ananthsub
Copy link
Contributor

ananthsub commented Aug 24, 2021

I think it's a very strong user-facing API change for the Users to accept, unfortunately. Also, it's more intuitive for the user to set gpus or ipus with one flag rather than remembering two.

As a user, this is confusing to me: with the new arguments, there are now double the ways to configure training. What is the framework's recommendation for what I do? I'd have to learn the precedence of accelerator vs setting gpus directly - Do I set both? Only one? If only one, which one? What happens if I set both gpus=X and tpu_cores=Y ? what happens if I set both gpus=X and num_processes=Z ? Either we document this full cross product (which adds to onboarding costs, and risks documentation falling out of sync with the implementation) or users end up inspecting the implementation details themselves. Neither is great.
cc @awaelchli @justusschock since we were discussing this in #9006

This optionality will also reflect in more framework complexity. You know best how complex the accelerator connector class has become :/ . The bloat has all sorts of side effects on dev velocity too (harder to work in the framework -> need for more tests across options to cover all the possible inputs -> dealing with slow and/or flaky tests, etc)

All of this is a result of one more thing to remember: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md#main-core-value-one-less-thing-to-remember

@yifuwang
Copy link
Contributor

yifuwang commented Aug 24, 2021

@yifuwang - what about accelerator as intended for hardware acceleration, like GPU or TPUs? In this light,' parallelizer="gpu" doesn't sound fully right either as we could still be doing single-device training

@ananthsub, @kaushikb11 - sorry for the confusion. I wasn't suggesting to rename accelerator to parallelizer. What I meant to say was that comparing to specifying ddp through accelerator="ddp", something like parallelizer="ddp" makes more sense. Thus I totally agree with this proposal that accelerator should be used for specifying the hardware accelerator types (will update my comment to make it more clear).

Initially, we had considered distributed_type/distributed_strategy naming for the argument. But it wouldn't make sense if SingleDevice plugins are passed.

IMHO distributed training algorithms like ddp, ddp_fully_sharded, and potential algorithms for recsys and MoE are NOT strategies of accelerators. Today, TrainingTypePlugin-s are organized internally as strategies of Accelerator-s because Accelerator delegates most calls to TrainingTypePlugin. However, I don't think this delegation makes sense (I vaguely remember there's plan to change that @ananthsub).

I'd even argue that Accelerator-s should be strategies of TrainingTypePlugin-s (instead of the other way around), because most distributed training algorithms apply to different hardware accelerators. From Lightning's perspective, maybe Accelerator should just be abstracting away data movement and accelerator specific communicators (accelerator specific kernels are handled transparently by PyTorch).

Thus I don't think accelerator_strategy is an ideal name, because distributed training algorithms should not be strategies of accelerators.

@ananthsub
Copy link
Contributor

ananthsub commented Aug 24, 2021

IMHO distributed training algorithms like dpp, dpp_fully_sharded, and potential algorithms for recsys and MoE are NOT strategies of accelerators. Today, TrainingTypePlugin-s are organized internally as strategies of Accelerator-s because Accelerator delegates most calls to TrainingTypePlugin. However, I don't think this delegation makes sense (I vaguely remember there's plan to change that @ananthsub).

I'd even argue that Accelerator-s should be strategies of TrainingTypePlugin-s (instead of the other way around), because most distributed training algorithms apply to different hardware accelerators. From Lightning's perspective, maybe Accelerator should just be abstracting away data movement and accelerator specific communicators (accelerator specific kernels are handled transparently by PyTorch).

Yes, this is the idea being formulated here: https://docs.google.com/document/d/1xHU7-iQSpp9KJTjI3As2EM0mfNHHr37WZYpDpwLkivA/edit#heading=h.vv5zw27fkkpe (@four4fish )

Where we flip the accelerator & training type, and make training type (or whatever the new name is) own:

  • The hardware accelerator for data movement & device-specific information (e.g. Revamp Device Stats Logging #9032)
  • Collectives, which can be influenced by the hardware used: Consolidate collective functions #7534
  • The checkpointing agent (Introduce CheckpointIO Plugin #8743)
  • The rank information (we could elevate ClusterEnvironment as an abstraction, or some reduced interface of this, to be a top-level property of the base training type plugin)
  • The optimizer & LR scheduler initialization (and potential rewrapping for things like PostLocalSGD, ZeRO v2, Horovod, or others)
  • The precision (since this is highly hardware & comms dependent)

@ananthsub
Copy link
Contributor

Thus I don't think accelerator_strategy is an ideal name, because distributed training algorithms should not be strategies of accelerators.

Ah @yifuwang I misunderstood, my apologies. I think including parallel is a great call out as I assume many people toggling this setting would be familiar with data/model/pipeline parallel terminologies.

@kaushikb11 what do you think of one of parallel_type, parallelizer, parallel_strategy ? this would also reinforces the core assumption of sync-sgd

@tchaton
Copy link
Contributor

tchaton commented Sep 1, 2021

Hey everyone,

Personally, I am not a huge fan of parallel and this is a new terminology to adapt to.

distributed_backend was good to me and I believe the confusion was introduced when we started to enable accelerator to select the distributed_backend too. Users already know about this Trainer argument and we could stop the depreciation to re-purpose it.

Trainer(distributed_backend="deepspeed", accelerator="auto", num_devices=4, precision=16)

Best,
T.C

@justusschock
Copy link
Member

Imo we should not use the term distributed, since this is not all distributed training (thinking of DP and single device stuff). To me a term like accelerator_strategy or device_strategy makes more sense.

@ananthsub regarding the deprecation of arguments like gpus, cpus, tpu_cores and ipus: I think we have to stick with them for some time. The best we could do is to revisit the docs to not mention them anymore and instead recommend the use of num_devices or something similar. But removing them is such a hard change, that we probably cannot do this without having a new major version (e.g. PL 2.x.x).

@tchaton
Copy link
Contributor

tchaton commented Sep 2, 2021

Personally, I don't like strategy as it doesn't inform the users this is meant to be extended for their own use-case.

Would device_manager_plugin makes more sense ?

Best,
T.C

@justusschock
Copy link
Member

I think device_manager_plugin only makes sense if we move the accelerator into this plugin, since the one responsible for devices is actually the accelerator...

@yifuwang
Copy link
Contributor

yifuwang commented Sep 2, 2021

@tchaton distributed_backend sounds good to me too.

To me a term like accelerator_strategy or device_strategy makes more sense.

@justusschock with the encapsulation of collectives, distributed algorithms in Lightning are close to being device agnostic. IMO the algorithms are not strategies of device usage (as @ananthsub mentioned above, there's plan to eliminate the Accelerator -> TrainingTypePlugin delegation).

@justusschock
Copy link
Member

@yifuwang Yes that's true. That's why it's only the strategy on how to use the given devices (the device/hardware handling is implemented by the accelerator).

Just the term device_manager is a bit confusing for me, since there it actually indicates the plugin to take over the accelerators responsibilities (even though the Accelerator -> TrainingTypePlugin delegation might be removed, there still has to be a proper split of responsibilities)

But please let's not use distributed here. We had this in the past and it has been proven not to cover everything Lightning offers. Especially single device and DP fall out of this. Let's use something more general here to avoid that we have to change it again...

@tchaton
Copy link
Contributor

tchaton commented Sep 3, 2021

Makes sense !

What about accelerator_orchestrator ?

@Tshimanga
Copy link
Contributor

Tshimanga commented Sep 3, 2021

İ think maybe we can draw inspiration from similar ideas in the Hadoop/Spark ecosystem. İt sounds like we have two different concerns here: responsibility of interfacing with hardware (CPUs, GPUs, TPUs,...) and a separate Hadoop YARN-like concern of orchestration/resource management. İ would suggest the former to be called a HardwareConnector and the latter a ResourceManager or HardwareOrchestrator?

Spark 3.x introduced GPU acceleration so maybe there are good examples in their source code to check out

@kaushikb11
Copy link
Contributor Author

We (Lightning team) had a discussion about this issue at the offsite. We have decided to go ahead with the strategy flag.

@four4fish
Copy link
Contributor

Few folks mentioned about that we are planning on eliminate the Accelerator -> TrainingTypePlugin delegation. With move precision into TTP and call TTP directly instead of going through Accelerator, Accelerator and subclass could be the device_manager_plugin @justusschock mentioned above.

I feel the flags will make more sense after the tasks above finishes. we can have
Trainer(strategy="ddp", device_plugin="gpu", devices=4), the flag to logic mapping and the ownership for each flag is clear. After the tasks finish : Strategy mapping to current TTP, any TTP logic calls strategy directly. Accelerator rename/restructure to device_manager_plugin, user define device_plugin which specific to device related logic instead of having a hidden layer of current accelerator delegation.

More details here: https://docs.google.com/document/d/1E5t8auWf5DrNHzutvMmrJC_KqBqVyZuYX0thra69Ad8/edit#heading=h.nehmb825f40s

Are we planning on change the flag in 1.5?

@SeanNaren
Copy link
Contributor

Thanks @four4fish I think that could be a followup decision outside this PR, there were a lot of ideas introduced during this PR so it's important to focus on just one. As @kaushikb11 said we spent some time going through the potential API choices and came to the conclusion that the strategy flag + devices flag that allows agnostic distributed strategies (such as DDP) to work irrespective of underlying hardware (CPU/GPU etc).

# current API
Trainer(accelerator='ddp', gpus=2)
Trainer(plugin=DDPPlugin(), gpus=2)

# introduce strategy API 
Trainer(strategy='ddp', accelerator='cpu', gpus=2, num_processes=2)

# introduce device API for agnostic devices
Trainer(strategy='ddp', accelerator='cpu', devices=2, num_nodes=2)
Trainer(strategy='ddp', accelerator='gpu', devices=2, num_nodes=2)

# introduce error handling
Trainer(strategy='ddp', accelerator='ipu', devices=2, num_nodes=2) # crashes as ddp not supported on IPUs
Trainer(strategy='ddp_find_unused_parameters_false', accelerator='ipu', devices=2, num_nodes=2)

@kaushikb11 kaushikb11 changed the title [RFC] Introduce accelerator_strategy flag to Trainer [RFC] Introduce strategy flag to Trainer Oct 11, 2021
@carmocca
Copy link
Contributor

What's left to do here?

@kaushikb11
Copy link
Contributor Author

Closing this issue, will create a new issue to keep track of this with additional pointers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

9 participants