From de0a282abdfff986bbccd79b5d93c77592aa4524 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 7 Dec 2023 10:23:02 +0800 Subject: [PATCH] [SOT]Fix Train/Eval Switch BUG in SOT (#59747) * [SOT]Fix Train/Eval Switch BUG in SOT * rm usless code --- python/paddle/jit/dy2static/program_translator.py | 2 +- python/paddle/jit/sot/symbolic/compile_cache.py | 2 +- test/sot/test_model_switch_training.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index c64c499c59abb..b3872a415abe4 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -731,7 +731,7 @@ def _perform_call(self, *args, **kwargs): traced_fun = symbolic_translate( self._dygraph_function, build_strategy=build_strategy, - training=self._training, + training=self._is_train_mode(), backend=backend, ) if self._class_instance is not None: diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py index c57b4a6a63c8a..465de3f6adf50 100644 --- a/python/paddle/jit/sot/symbolic/compile_cache.py +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -62,7 +62,6 @@ def __init__(self, compiled_fn, SIR, is_training: bool): self.concrete_program = None self.SIR = SIR # for debug self.is_training = is_training - self.compiled_fn.eval() if not is_training else self.compiled_fn.train() def amp_cast_inputs(self, args, kwargs): """Prepare inputs for amp, cast float16 into float32 if needed.""" @@ -111,6 +110,7 @@ def __call__(self, *args, **kwargs): self.concrete_program, self.partial_program, ) = self.compiled_fn.get_concrete_program(*args, **kwargs) + self.partial_program.training = self.is_training with EventGuard("FallbackWrapper: sot call partial_program"): outputs = self.partial_program.sot_call(*args, **kwargs) diff --git a/test/sot/test_model_switch_training.py b/test/sot/test_model_switch_training.py index e5bab3bf2db86..8d1e95b22f431 100644 --- a/test/sot/test_model_switch_training.py +++ b/test/sot/test_model_switch_training.py @@ -17,6 +17,7 @@ import numpy as np import paddle +from paddle.jit.sot.symbolic.compile_cache import CompileSIRCache class SimpleNet(paddle.nn.Layer): @@ -35,6 +36,15 @@ class TestModelSwitchTraining(unittest.TestCase): def setUp(self): self.seed = 1127 self.net = SimpleNet() + # singleton + self.compile_cache = CompileSIRCache() + + def check_mode(self, is_train): + self.assertEqual(len(self.compile_cache.cache), 1) + mode = list(self.compile_cache.cache.values())[ + 0 + ].partial_program.training + self.assertEqual(mode, is_train) def get_dygraph_out(self, input): paddle.seed(self.seed) @@ -46,11 +56,16 @@ def get_dygraph_out(self, input): def get_static_out(self, input): paddle.seed(self.seed) + self.compile_cache.clear() static_net = paddle.jit.to_static(self.net) static_net.eval() eval_result = static_net(input) + self.check_mode(is_train=False) + self.compile_cache.clear() + static_net.train() train_result = static_net(input) + self.check_mode(is_train=True) return eval_result, train_result def test_model_switch_training(self):