Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【AutoParallelism】Add refined recompute support #58421

Merged
merged 4 commits into from
Oct 31, 2023

Conversation

heavyrain-lzy
Copy link
Contributor

@heavyrain-lzy heavyrain-lzy commented Oct 26, 2023

PR types

New features

PR changes

Others

Description

PCard-71568
Add refined recompute support in automatic parallelism of static graph.

  • How to use
import paddle
import paddle.distributed.auto_parallel as auto
paddle.enable_static()

class LeNet(nn.Layer):
    def __init__(self, num_classes=10):

        super().__init__()
        self.num_classes = num_classes
        
        self.conv2d_0 = nn.Conv2D(1, 6, 3, stride=1, padding=1)
            self.relu_0 = nn.ReLU(),
            self.maxpool_0 = nn.MaxPool2D(2, 2),
            self.conv2d_1 = nn.Conv2D(6, 16, 5, stride=1, padding=0)
            self.relu_1 = nn.ReLU(),
            self.maxpool_1 = nn.MaxPool2D(2, 2)
 
     def fun_0(self, x):
            tmp0 = self.maxpool_0(x)
            tmp1 = self.conv2d_1(tmp0)
            out = self.relu_1(tmp1)
            return out

    def forward(self, inputs):
        x = self.conv2d_0(inputs)
        x = self.maxpool_0(x)
       # 调用auto.exclude_ops_in_recompute
        x = auto.recompute(auto.exclude_ops_in_recompute(self.fun_0(x)))
        return x

You can also refer to PaddlePaddle/PaddleNLP#7317
The test results of on single compute with 8 GPU devices:
image

@paddle-bot
Copy link

paddle-bot bot commented Oct 26, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Oct 26, 2023
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@heavyrain-lzy heavyrain-lzy merged commit 721f834 into PaddlePaddle:develop Oct 31, 2023
28 checks passed
@paddle-bot paddle-bot bot removed the contributor External developers label Nov 3, 2023
zeroRains pushed a commit to zeroRains/Paddle that referenced this pull request Nov 8, 2023
* add refined-recompute support

* fix bug in recompute_pass

* fix coverage
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* add refined-recompute support

* fix bug in recompute_pass

* fix coverage
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants