Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.188】 Migrate paddle.regularizer.L2Decay into pir #59313

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions test/legacy_test/test_regularizer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def bow_net(
Expand Down Expand Up @@ -102,8 +103,8 @@ def run_program(self, place, feed_list):
def check_l2decay_regularizer(self, place, model):
paddle.seed(1)
paddle.framework.random._manual_program_seed(1)
main_prog = base.framework.Program()
startup_prog = base.framework.Program()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with self.scope_prog_guard(
main_prog=main_prog, startup_prog=startup_prog
):
Expand Down Expand Up @@ -175,16 +176,20 @@ def test_l2(self):
rtol=5e-5,
)

@test_with_pir_api
def test_repeated_regularization(self):
paddle.enable_static()
l1 = paddle.regularizer.L1Decay(0.1)
l2 = paddle.regularizer.L2Decay(0.01)
fc_param_attr = paddle.ParamAttr(
regularizer=paddle.regularizer.L1Decay()
)
with base.program_guard(base.Program(), base.Program()):
with base.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.uniform([2, 2, 3])
out = paddle.static.nn.fc(x, 5, weight_attr=fc_param_attr)
linear = paddle.nn.Linear(3, 5, weight_attr=fc_param_attr)
out = linear(x)
loss = paddle.sum(out)
sgd = paddle.optimizer.SGD(learning_rate=0.1, weight_decay=l2)
sgd.minimize(loss)
Expand Down