Skip to content

Commit

Permalink
remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed Nov 9, 2023
1 parent 26762cd commit 13332f2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
1 change: 0 additions & 1 deletion python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ def _create_program(self, is_infer_mode=False):
infer_program = PirPassContext.apply(
infer_program, self._build_strategy
)
# TODO(Aurelius84): Support this later.
if self._hooker:
self._hooker.after_infer(infer_program)
return infer_program
Expand Down
2 changes: 0 additions & 2 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,9 +1497,7 @@ def after_append_backward(self, whole_program, src_vars, forward_end_idx):
new_start_index = (
len(whole_program.global_block().ops) - backward_length
)
# print(whole_program)
return whole_program, new_start_index, dst_vars
# print(whole_program)
return whole_program, forward_end_idx, src_vars

def after_infer(self, infer_program):
Expand Down
41 changes: 24 additions & 17 deletions test/prim/pir_prim/test_prim_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from paddle.framework import core


def func(x):
x1 = paddle.mean(x)
out = paddle.nn.functional.gelu(x1, False)
return out


class TestDy2staticPir(unittest.TestCase):
def test_basic_network_backward(self):
core._set_prim_all_enabled(True)

def func(x):
x1 = paddle.mean(x)
# out = paddle.nn.functional.gelu(x1, False)
return x1

# ==== dygraph computation ====
static_func = paddle.jit.to_static(func, full_graph=True)
x = paddle.randn((8, 16, 64))
Expand All @@ -43,7 +44,16 @@ def func(x):
actual_out = out * 2
actual_out.backward()
actual_grad = x.grad

core._set_prim_all_enabled(False)
ops = [
op.name()
for op in static_func.program_cache.last()[-1][-1]
.train_program.program.global_block()
.ops
]
assert "pd_op.erf" in ops
assert "pd_op.gelu" not in ops

np.testing.assert_allclose(
ref_out, actual_out.numpy(), atol=1e-6, rtol=1e-6
Expand All @@ -58,11 +68,6 @@ class TestDy2staticPirEval(unittest.TestCase):
def test_basic_network_backward_(self):
core._set_prim_all_enabled(True)

def func(x):
x1 = paddle.mean(x)
out = paddle.nn.functional.gelu(x1, False)
return out

# ==== dygraph computation ====
static_func = paddle.jit.to_static(func, full_graph=True)
static_func.eval()
Expand All @@ -74,15 +79,17 @@ def func(x):
out = static_func(x)
actual_out = out * 2

# ops = [
# op.name()
# for op in static_func.program_cache.last()[-1][-1]
# .infer_program.program.global_block()
# .ops
# ]
# print(ops)
ops = [
op.name()
for op in static_func.program_cache.last()[-1][-1]
.infer_program.program.global_block()
.ops
]
core._set_prim_all_enabled(False)

assert "pd_op.erf" in ops
assert "pd_op.gelu" not in ops

np.testing.assert_allclose(
ref_out, actual_out.numpy(), atol=1e-6, rtol=1e-6
)
Expand Down

0 comments on commit 13332f2

Please sign in to comment.