Skip to content

Commit

Permalink
[gym_jiminy/common] Stop relying on 'true' sensor measurements in com…
Browse files Browse the repository at this point in the history
…putations. (#845)

* [gym_jiminy/common] Fix 'BaseJiminyEnv.play_interactive'.
* [gym_jiminy/common] Share base engine measurement through all the layers of the pipeline.
* [gym_jiminy/common] Stop relying on 'true' sensor measurements in computations.
* [gym_jiminy/common] Stop relying on stepper_state to get current time.
  • Loading branch information
duburcqa authored Dec 4, 2024
1 parent 0abf0f1 commit 27c3bb0
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 128 deletions.
39 changes: 26 additions & 13 deletions python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
import gymnasium as gym

import jiminy_py.core as jiminy
from jiminy_py.core import ( # pylint: disable=no-name-in-module
multi_array_copyto)
from jiminy_py.simulator import Simulator
from jiminy_py.viewer.viewer import is_display_available

import pinocchio as pin

from ..utils import DataNested
if TYPE_CHECKING:
from ..envs.generic import BaseJiminyEnv
Expand Down Expand Up @@ -192,7 +196,7 @@ class InterfaceJiminyEnv(
robot: jiminy.Robot
stepper_state: jiminy.StepperState
robot_state: jiminy.RobotState
sensor_measurements: SensorMeasurementStackMap
measurements: EngineObsType
is_simulation_running: npt.NDArray[np.bool_]

num_steps: npt.NDArray[np.int64]
Expand All @@ -217,13 +221,25 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.__is_observation_refreshed = True

# Store latest engine measurement for efficiency
self.__measurement = EngineObsType(
self.measurement = EngineObsType(
t=np.array(0.0),
states=OrderedDict(
agent=OrderedDict(q=np.array([]), v=np.array([]))),
measurements=OrderedDict(self.robot.sensor_measurements))
agent=OrderedDict(
q=pin.neutral(self.robot.pinocchio_model),
v=np.zeros(self.robot.pinocchio_model.nv))),
measurements=OrderedDict(zip(
self.robot.sensor_measurements.keys(),
map(np.copy, self.robot.sensor_measurements.values()))))
self._sensors_types = tuple(self.robot.sensor_measurements.keys())

# Define flattened engine measurement for efficiency
agent_state = self.measurement['states']['agent']
assert isinstance(agent_state, dict)
self._measurement_flat = (self.measurement['t'],
agent_state['q'],
agent_state['v'],
*self.measurement['measurements'].values())

# Call super to allow mixing interfaces through multiple inheritance
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -264,6 +280,10 @@ def _observer_handle(self,
:param v: Current extended velocity vector of the robot.
:param sensor_measurements: Current sensor data.
"""
# Update engine measurement
measurement_flat = (t, q, v, *sensor_measurements.values())
multi_array_copyto(self._measurement_flat, measurement_flat)

# Early return if no simulation is running
if not self.is_simulation_running:
return
Expand All @@ -288,16 +308,9 @@ def _observer_handle(self,
# that is supposed to be executed before `refresh_observation` is being
# called for the first time of an episode.
if not self.__is_observation_refreshed:
measurement = self.__measurement
measurement["t"][()] = t
measurement["states"]["agent"]["q"] = q
measurement["states"]["agent"]["v"] = v
measurement_sensors = measurement["measurements"]
sensor_measurements_it = iter(sensor_measurements.values())
for sensor_type in self._sensors_types:
measurement_sensors[sensor_type] = next(sensor_measurements_it)
# Refresh observation
try:
self.refresh_observation(measurement)
self.refresh_observation(self.measurement)
except RuntimeError as e:
raise RuntimeError(
"The observation space must be invariant.") from e
Expand Down
20 changes: 11 additions & 9 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
from gymnasium.core import RenderFrame
from gymnasium.envs.registration import EnvSpec

from jiminy_py.core import ( # pylint: disable=no-name-in-module
is_breakpoint, array_copyto)
from jiminy_py.core import array_copyto # pylint: disable=no-name-in-module
from jiminy_py.dynamics import Trajectory
from jiminy_py.tree import issubclass_mapping

Expand All @@ -46,7 +45,8 @@
zeros,
build_copyto,
copy,
get_robot_state_space)
get_robot_state_space,
is_breakpoint)
if TYPE_CHECKING:
from ..envs.generic import BaseJiminyEnv

Expand Down Expand Up @@ -180,7 +180,6 @@ def __init__(self,
self.stepper_state = env.stepper_state
self.robot = env.robot
self.robot_state = env.robot_state
self.sensor_measurements = env.sensor_measurements
self.is_simulation_running = env.is_simulation_running
self.num_steps = env.num_steps

Expand All @@ -193,6 +192,10 @@ def __init__(self,
# Call base implementation
super().__init__() # Do not forward any argument

# Bind engine measurement
self.measurement = env.measurement
self._measurement_flat = env._measurement_flat

# Enable direct forwarding by default for efficiency
if BasePipelineWrapper.has_terminated is type(self).has_terminated:
self.has_terminated = ( # type: ignore[method-assign]
Expand Down Expand Up @@ -433,7 +436,6 @@ def _setup(self) -> None:
# Refresh some proxies for fast lookup
self.robot = self.env.robot
self.robot_state = self.env.robot_state
self.sensor_measurements = self.env.sensor_measurements

# Initialize specialized operator(s) for efficiency
self._copyto_action = build_copyto(self.action)
Expand Down Expand Up @@ -926,7 +928,7 @@ def refresh_observation(self, measurement: EngineObsType) -> None:
self.env.refresh_observation(measurement)

# Update observed features if necessary
if is_breakpoint(self.stepper_state, self.observe_dt, DT_EPS):
if is_breakpoint(measurement["t"], self.observe_dt, DT_EPS):
self.observer.refresh_observation(self.env.observation)

def compute_command(self, action: Act, command: np.ndarray) -> None:
Expand Down Expand Up @@ -1148,7 +1150,7 @@ def compute_command(self, action: Act, command: np.ndarray) -> None:
# Note that `observation` buffer has already been updated right before
# calling this method by `_controller_handle`, so it can be used as
# measure argument without issue.
if is_breakpoint(self.stepper_state, self.control_dt, DT_EPS):
if is_breakpoint(self.measurement["t"], self.control_dt, DT_EPS):
self.controller.compute_command(action, self.env.action)

# Update the command to send to the actuators of the robot.
Expand Down Expand Up @@ -1254,7 +1256,7 @@ def refresh_observation(self, measurement: EngineObsType) -> None:
self.env.refresh_observation(measurement)

# Transform observation at the end of the step only
if is_breakpoint(self.stepper_state, self._step_dt, DT_EPS):
if is_breakpoint(measurement["t"], self._step_dt, DT_EPS):
self.transform_observation()

@abstractmethod
Expand Down Expand Up @@ -1364,7 +1366,7 @@ def compute_command(self,
:param command: Lower-level command to update in-place.
"""
# Transform action at the beginning of the step only
if is_breakpoint(self.stepper_state, self._step_dt, DT_EPS):
if is_breakpoint(self.measurement["t"], self._step_dt, DT_EPS):
self.transform_action(action)

# Delegate command computation to wrapped environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,9 @@ def __init__(self,
"Revolute unbounded joints are not supported for now.")
self.encoder_to_position_map[sensor.index] = joint.idx_q

# Extract measured motor / joint positions for fast access.
# Note that they will be initialized in `_setup` method.
self.encoder_data = np.array([])
# Extract measured motor / joint positions for fast access
self.encoder_data, _ = (
env.measurement["measurements"][EncoderSensor.type])

# Ratio to translate encoder data to joint side
self.encoder_to_joint_ratio = np.array([])
Expand Down Expand Up @@ -783,9 +783,6 @@ def _setup(self) -> None:
self._kin_flex_rots.append(kin_flex_rots)
self._kin_imu_rots.append(kin_imu_rots)

# Refresh measured motor position proxy
self.encoder_data, _ = self.env.sensor_measurements[EncoderSensor.type]

# Refresh mechanical reduction ratio
encoder_to_joint_ratio = []
for sensor in self.env.robot.sensors[EncoderSensor.type]:
Expand Down
17 changes: 7 additions & 10 deletions python/gym_jiminy/common/gym_jiminy/common/blocks/mahony_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ def __init__(self,
# triggering yet another compilation.
self._is_compiled = False

# Define gyroscope and accelerometer proxies for fast access.
# Note that they will be initialized in `_setup` method.
self.gyro, self.acc = np.array([]), np.array([])
# Define gyroscope and accelerometer proxies for fast access
self.gyro, self.acc = np.split(
env.measurement["measurements"][ImuSensor.type], 2)

# Allocate gyroscope bias estimate
self._bias = np.zeros((3, num_imu_sensors))
Expand Down Expand Up @@ -358,10 +358,6 @@ def _setup(self) -> None:
raise ValueError(
"This block does not support time-continuous update.")

# Refresh gyroscope and accelerometer proxies
self.gyro, self.acc = np.split(
self.env.sensor_measurements[ImuSensor.type], 2)

# Reset the sensor bias estimate
fill(self._bias, 0)

Expand Down Expand Up @@ -411,9 +407,10 @@ def refresh_observation(self, measurement: BaseObs) -> None:
if not self._is_initialized:
if not self.exact_init:
if (np.abs(self.acc) < 0.1 * EARTH_SURFACE_GRAVITY).all():
LOGGER.warning(
"The robot is free-falling. Impossible to initialize "
"Mahony filter for 'exact_init=False'.")
if self._is_compiled:
LOGGER.warning(
"The robot is free-falling. Impossible to "
"initialize Mahony filter for 'exact_init=False'.")
else:
# Try to determine the orientation of the IMU from its
# measured acceleration at initialization. This approach is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def __init__(self,
"Consider using the same ordering for encoders and motors for "
"optimal performance.")

# Extract measured motor positions and velocities for fast access.
# Note that they will be initialized in `_setup` method.
self.q_measured, self.v_measured = np.array([]), np.array([])
# Extract measured motor positions and velocities for fast access
self.q_measured, self.v_measured = (
env.measurement["measurements"][EncoderSensor.type])

# Initialize the controller
super().__init__(name, env, update_ratio=1)
Expand All @@ -186,14 +186,6 @@ def _initialize_action_space(self) -> None:
"""
self.action_space = self.env.action_space

def _setup(self) -> None:
# Call base implementation
super()._setup()

# Refresh measured motor positions and velocities proxies
self.q_measured, self.v_measured = (
self.env.sensor_measurements[EncoderSensor.type])

@property
def fieldnames(self) -> List[str]:
return [f"currentMotorTorque{motor.name}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,9 @@ def __init__(self,
motors_velocity_limit,
acceleration_limit], axis=0)

# Extract measured motor positions and velocities for fast access.
# Note that they will be initialized in `_setup` method.
self.q_measured, self.v_measured = np.array([]), np.array([])
# Extract measured motor positions and velocities for fast access
self.q_measured, self.v_measured = (
env.measurement["measurements"][EncoderSensor.type])

# Allocate memory for the command state
self._command_state = np.zeros((3, env.robot.nmotors))
Expand Down Expand Up @@ -430,10 +430,6 @@ def _setup(self) -> None:
raise ValueError(
"This block does not support time-continuous update.")

# Refresh measured motor positions and velocities proxies
self.q_measured, self.v_measured = (
self.env.sensor_measurements[EncoderSensor.type])

# Reset the command state
fill(self._command_state, 0)

Expand Down
Loading

0 comments on commit 27c3bb0

Please sign in to comment.