diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 6080d8637..177217b98 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -405,16 +405,21 @@ class _set_dispatch_td_nn_modules(_DecoratorContextManager): def __init__(self, mode): self.mode = mode + self._saved_mode = None def clone(self): return type(self)(self.mode) def __enter__(self): global DISPATCH_TDNN_MODULES - self._saved_mode = DISPATCH_TDNN_MODULES - DISPATCH_TDNN_MODULES = self.mode + # We want to avoid changing global variables because compile puts guards on them + if DISPATCH_TDNN_MODULES != self.mode: + self._saved_mode = DISPATCH_TDNN_MODULES + DISPATCH_TDNN_MODULES = self.mode def __exit__(self, exc_type, exc_val, exc_tb): + if self._saved_mode is None: + return global DISPATCH_TDNN_MODULES DISPATCH_TDNN_MODULES = self._saved_mode