Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unroll dict input before call Accelerator X_steps and update function type #10907

Closed
four4fish opened this issue Dec 3, 2021 · 0 comments · Fixed by #10908
Closed

Unroll dict input before call Accelerator X_steps and update function type #10907

four4fish opened this issue Dec 3, 2021 · 0 comments · Fixed by #10908

Comments

@four4fish
Copy link
Contributor

four4fish commented Dec 3, 2021

Proposed refactor

From @justusschock's comments in #10890
Refactor accelerator X_steps()'s signature, to support both positional and keyword arguments
Instead of unrolling step_kwarg dictionary in accelerator steps, unroll them in caller side.

This will unblock #10648 move functions from accelerator to strategies part

Motivation

Keep flexibilities and improving function typing

Pitch

Current Accelerator X_step typing signature is
https://github.com/four4fish/pytorch-lightning/blob/master/pytorch_lightning/accelerators/accelerator.py#L124-L148

Update to :
def X_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
Which not align with what torch model training step and model() signature.

In caller side, eg: optimizer_loop.py. Instead of passing dictionary step_kwargs, pass *step_kwargs.values()

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @justusschock @awaelchli @akihironitta

@four4fish four4fish changed the title Unroll dict input before call Accelerator X_steps and update function typing Unroll dict input before call Accelerator X_steps and update function type Dec 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant