Skip to content

Commit

Permalink
[gym_jiminy/common] Train/eval mode API of InterfaceJiminyEnv is now …
Browse files Browse the repository at this point in the history
…consistent with PyTorch.
  • Loading branch information
duburcqa committed Jan 8, 2025
1 parent d949a46 commit 608e440
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 145 deletions.
26 changes: 13 additions & 13 deletions python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def __init__(self,
grace_period: float = 0.0,
*,
is_truncation: bool = False,
is_training_only: bool = False) -> None:
training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the termination condition. This name will
Expand All @@ -451,16 +451,16 @@ def __init__(self,
terminated or truncated whenever the termination
condition is triggered.
Optional: False by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
:param training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
self.env = env
self._name = name
self.grace_period = grace_period
self.is_truncation = is_truncation
self.is_training_only = is_training_only
self.training_only = training_only

@property
def name(self) -> str:
Expand Down Expand Up @@ -505,7 +505,7 @@ def __call__(self, info: InfoType) -> Tuple[bool, bool]:
"""
# Skip termination condition in eval mode or during grace period
termination_info: InfoType = {}
if (self.is_training_only and not self.env.is_training) or (
if (self.training_only and not self.env.training) or (
self.env.stepper_state.t < self.grace_period):
# Always continue
is_terminated, is_truncated = False, False
Expand Down Expand Up @@ -555,7 +555,7 @@ def __init__(self,
grace_period: float = 0.0,
*,
is_truncation: bool = False,
is_training_only: bool = False) -> None:
training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the termination condition. This name will
Expand All @@ -578,10 +578,10 @@ def __init__(self,
terminated or truncated whenever the termination
condition is triggered.
Optional: False by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
:param training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Backup user argument(s)
self.low = np.asarray(low) if isinstance(low, Sequence) else low
Expand All @@ -593,7 +593,7 @@ def __init__(self,
name,
grace_period,
is_truncation=is_truncation,
is_training_only=is_training_only)
training_only=training_only)

# Add quantity to the set of quantities managed by the environment
self.env.quantities[self.name] = quantity
Expand Down
27 changes: 18 additions & 9 deletions python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,13 @@ def has_terminated(self, info: InfoType) -> Tuple[bool, bool]:
"""

@abstractmethod
def train(self) -> None:
"""Sets the environment in training mode.
def train(self, mode: bool = True) -> None:
"""Sets the environment in training or evaluation mode.
:param mode: Whether to set training (True) or evaluation mode (False).
Optional: `True` by default.
"""

@abstractmethod
def eval(self) -> None:
"""Sets the environment in evaluation mode.
Expand All @@ -423,6 +425,19 @@ def eval(self) -> None:
time specifically. See documentations of a given environment for
details about their behaviors in training and evaluation modes.
"""
self.train(False)

@property
@abstractmethod
def training(self) -> bool:
"""Check whether the environment is in training or evaluation mode.
"""

@training.setter
def training(self, mode: bool) -> None:
"""Sets the environment in training or evaluation mode.
"""
self.train(mode)

@property
@abstractmethod
Expand All @@ -436,9 +451,3 @@ def unwrapped(self) -> "BaseJiminyEnv":
def step_dt(self) -> float:
"""Get timestep of a single 'step'.
"""

@property
@abstractmethod
def is_training(self) -> bool:
"""Check whether the environment is in 'train' or 'eval' mode.
"""
15 changes: 6 additions & 9 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __getattr__(self, name: str) -> Any:
try:
# Make sure that no simulaton is running
if (self.__getattribute__('is_simulation_running') and
self.env.is_training and not hasattr(sys, 'ps1')):
self.env.training and not hasattr(sys, 'ps1')):
# `hasattr(sys, 'ps1')` is used to detect whether the method
# was called from an interpreter or within a script. For
# details, see: https://stackoverflow.com/a/64523765/4820605
Expand Down Expand Up @@ -294,15 +294,12 @@ def unwrapped(self) -> "BaseJiminyEnv":
def step_dt(self) -> float:
return self.env.step_dt

@property
def is_training(self) -> bool:
return self.env.is_training

def train(self) -> None:
self.env.train()
@InterfaceJiminyEnv.training.getter # type: ignore[attr-defined]
def training(self) -> bool:
return self.env.training

def eval(self) -> None:
self.env.eval()
def train(self, mode: bool = True) -> None:
self.env.train(mode)

def update_pipeline(self, derived: Optional[InterfaceJiminyEnv]) -> None:
if derived is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _setup(self) -> None:

# Unbounded effort limits in evaluation mode.
# Note that training/evaluation cannot be changed at this point.
self._enable_limit_soft = self.env.is_training
self._enable_limit_soft = self.env.training

def compute_command(self,
action: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def _setup(self) -> None:

# Enable deadband in evaluation mode only.
# Note that training/evaluation cannot be changed at this point.
self._enable_deadband = not self.env.is_training
self._enable_deadband = not self.env.training

@property
def fieldnames(self) -> List[str]:
Expand Down
60 changes: 30 additions & 30 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self,
post_fn: Optional[Callable[
[ArrayOrScalar], ArrayOrScalar]] = None,
is_truncation: bool = False,
is_training_only: bool = False) -> None:
training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the termination condition. This name will
Expand Down Expand Up @@ -214,10 +214,10 @@ def __init__(self,
terminated or truncated whenever the termination
condition is triggered.
Optional: False by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
:param training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# pylint: disable=unnecessary-lambda-assignment

Expand Down Expand Up @@ -256,7 +256,7 @@ def __init__(self,
high,
grace_period,
is_truncation=is_truncation,
is_training_only=is_training_only)
training_only=training_only)

def _compute_drift_error(self,
left: np.ndarray,
Expand Down Expand Up @@ -298,7 +298,7 @@ def __init__(self,
*,
op: Callable[[np.ndarray, np.ndarray], np.ndarray] = sub,
is_truncation: bool = False,
is_training_only: bool = False) -> None:
training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the termination condition. This name will
Expand Down Expand Up @@ -335,10 +335,10 @@ def __init__(self,
terminated or truncated whenever the termination
condition is triggered.
Optional: False by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
:param training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# pylint: disable=unnecessary-lambda-assignment

Expand Down Expand Up @@ -388,7 +388,7 @@ def min_norm(values: np.ndarray) -> float:
np.array(thr),
max(grace_period, horizon),
is_truncation=is_truncation,
is_training_only=is_training_only)
training_only=training_only)

def _compute_min_distance(self,
left: np.ndarray,
Expand Down Expand Up @@ -483,7 +483,7 @@ def __init__(self,
velocity_max: float,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param position_margin: Distance of actuated joints from their
Expand All @@ -496,10 +496,10 @@ def __init__(self,
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
:param training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Backup user argument(s)
self.position_margin = position_margin
Expand All @@ -511,7 +511,7 @@ def __init__(self,
"termination_mechanical_safety",
grace_period,
is_truncation=False,
is_training_only=is_training_only)
training_only=training_only)

# Add quantity to the set of quantities managed by the environment
self.env.quantities["_".join((self.name, "position_delta"))] = (
Expand Down Expand Up @@ -577,7 +577,7 @@ def __init__(
generator_mode: EnergyGenerationMode = EnergyGenerationMode.CHARGE,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param max_power: Maximum average mechanical power consumption applied
Expand All @@ -589,10 +589,10 @@ def __init__(
Optional: 0.0 by default.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the average.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
:param training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Backup user argument(s)
self.max_power = max_power
Expand All @@ -610,7 +610,7 @@ def __init__(
self.max_power,
grace_period,
is_truncation=False,
is_training_only=is_training_only)
training_only=training_only)


class ShiftTrackingMotorPositionsTermination(ShiftTrackingQuantityTermination):
Expand Down Expand Up @@ -638,7 +638,7 @@ def __init__(self,
horizon: float,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param thr: Maximum shift above which termination is triggered.
Expand All @@ -648,10 +648,10 @@ def __init__(self,
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
:param training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Call base implementation
super().__init__(
Expand All @@ -665,4 +665,4 @@ def __init__(self,
horizon,
grace_period,
is_truncation=False,
is_training_only=is_training_only)
training_only=training_only)
Loading

0 comments on commit 608e440

Please sign in to comment.