Skip to content

Commit

Permalink
feat(new-ir): support L2Decay
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyewww committed Nov 23, 2023
1 parent ea0ef21 commit 183f636
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
5 changes: 5 additions & 0 deletions test/legacy_test/test_regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from paddle import base, regularizer
from paddle.base import core, framework
from paddle.base.backward import append_backward
from paddle.pir_utils import test_with_pir_api


class TestL2Decay(unittest.TestCase):
@test_with_pir_api
def test_l2decay_regularizer(self):
paddle.enable_static()
program = framework.Program()
Expand Down Expand Up @@ -69,6 +71,7 @@ def test_l2decay_regularizer(self):


class TestL1Decay(unittest.TestCase):
@test_with_pir_api
def test_l2decay_regularizer(self):
paddle.enable_static()
program = framework.Program()
Expand Down Expand Up @@ -241,6 +244,7 @@ def check_l2decay(self, place, model):
param_sum = self.run_program(place, [data, label])
return param_sum

@test_with_pir_api
def test_l2(self):
for place in self.get_places():
dense_sparse_p_sum = []
Expand All @@ -261,6 +265,7 @@ def test_l2(self):
rtol=5e-5,
)

@test_with_pir_api
def test_repeated_regularization(self):
l1 = paddle.regularizer.L1Decay(coeff=0.1)
l2 = paddle.regularizer.L2Decay(coeff=0.01)
Expand Down
14 changes: 10 additions & 4 deletions test/legacy_test/test_regularizer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def bow_net(
Expand Down Expand Up @@ -99,11 +100,12 @@ def run_program(self, place, feed_list):
param_sum.append(p_sum)
return param_sum

@test_with_pir_api
def check_l2decay_regularizer(self, place, model):
paddle.seed(1)
paddle.framework.random._manual_program_seed(1)
main_prog = base.framework.Program()
startup_prog = base.framework.Program()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with self.scope_prog_guard(
main_prog=main_prog, startup_prog=startup_prog
):
Expand Down Expand Up @@ -175,16 +177,20 @@ def test_l2(self):
rtol=5e-5,
)

@test_with_pir_api
def test_repeated_regularization(self):
paddle.enable_static()
l1 = paddle.regularizer.L1Decay(0.1)
l2 = paddle.regularizer.L2Decay(0.01)
fc_param_attr = paddle.ParamAttr(
regularizer=paddle.regularizer.L1Decay()
)
with base.program_guard(base.Program(), base.Program()):
with base.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.uniform([2, 2, 3])
out = paddle.static.nn.fc(x, 5, weight_attr=fc_param_attr)
linear = paddle.nn.Linear(1, 5, weight_attr=fc_param_attr)
out = linear(x)
loss = paddle.sum(out)
sgd = paddle.optimizer.SGD(learning_rate=0.1, weight_decay=l2)
sgd.minimize(loss)
Expand Down

0 comments on commit 183f636

Please sign in to comment.