Skip to content

Commit

Permalink
[SOT]Fix Train/Eval Switch BUG in SOT (#59747)
Browse files Browse the repository at this point in the history
* [SOT]Fix Train/Eval Switch BUG in SOT

* rm usless code
  • Loading branch information
Aurelius84 authored Dec 7, 2023
1 parent 3a0aeaa commit de0a282
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def _perform_call(self, *args, **kwargs):
traced_fun = symbolic_translate(
self._dygraph_function,
build_strategy=build_strategy,
training=self._training,
training=self._is_train_mode(),
backend=backend,
)
if self._class_instance is not None:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def __init__(self, compiled_fn, SIR, is_training: bool):
self.concrete_program = None
self.SIR = SIR # for debug
self.is_training = is_training
self.compiled_fn.eval() if not is_training else self.compiled_fn.train()

def amp_cast_inputs(self, args, kwargs):
"""Prepare inputs for amp, cast float16 into float32 if needed."""
Expand Down Expand Up @@ -111,6 +110,7 @@ def __call__(self, *args, **kwargs):
self.concrete_program,
self.partial_program,
) = self.compiled_fn.get_concrete_program(*args, **kwargs)
self.partial_program.training = self.is_training
with EventGuard("FallbackWrapper: sot call partial_program"):
outputs = self.partial_program.sot_call(*args, **kwargs)

Expand Down
15 changes: 15 additions & 0 deletions test/sot/test_model_switch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

import paddle
from paddle.jit.sot.symbolic.compile_cache import CompileSIRCache


class SimpleNet(paddle.nn.Layer):
Expand All @@ -35,6 +36,15 @@ class TestModelSwitchTraining(unittest.TestCase):
def setUp(self):
self.seed = 1127
self.net = SimpleNet()
# singleton
self.compile_cache = CompileSIRCache()

def check_mode(self, is_train):
self.assertEqual(len(self.compile_cache.cache), 1)
mode = list(self.compile_cache.cache.values())[
0
].partial_program.training
self.assertEqual(mode, is_train)

def get_dygraph_out(self, input):
paddle.seed(self.seed)
Expand All @@ -46,11 +56,16 @@ def get_dygraph_out(self, input):

def get_static_out(self, input):
paddle.seed(self.seed)
self.compile_cache.clear()
static_net = paddle.jit.to_static(self.net)
static_net.eval()
eval_result = static_net(input)
self.check_mode(is_train=False)
self.compile_cache.clear()

static_net.train()
train_result = static_net(input)
self.check_mode(is_train=True)
return eval_result, train_result

def test_model_switch_training(self):
Expand Down

0 comments on commit de0a282

Please sign in to comment.