Skip to content

Commit

Permalink
[gym/common] Fix support of non-batched mode for generic math.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Feb 29, 2024
1 parent 78c6dc2 commit 2fdf1c1
Show file tree
Hide file tree
Showing 16 changed files with 104 additions and 103 deletions.
40 changes: 20 additions & 20 deletions python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
# pylint: disable=missing-module-docstring

from .generic_bases import (DT_EPS,
ObsT,
ActT,
BaseObsT,
BaseActT,
InfoType,
SensorMeasurementStackMap,
EngineObsType,
InterfaceObserver,
InterfaceController,
InterfaceJiminyEnv)
from .block_bases import (BlockStateT,
InterfaceBlock,
BaseObserverBlock,
BaseControllerBlock)
from .pipeline_bases import (BasePipelineWrapper,
BaseTransformObservation,
BaseTransformAction,
ObservedJiminyEnv,
ControlledJiminyEnv)
from .generic import (DT_EPS,
ObsT,
ActT,
BaseObsT,
BaseActT,
InfoType,
SensorMeasurementStackMap,
EngineObsType,
InterfaceObserver,
InterfaceController,
InterfaceJiminyEnv)
from .block import (BlockStateT,
InterfaceBlock,
BaseObserverBlock,
BaseControllerBlock)
from .pipeline import (BasePipelineWrapper,
BaseTransformObservation,
BaseTransformAction,
ObservedJiminyEnv,
ControlledJiminyEnv)


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

from ..utils import FieldNested, DataNested, get_fieldnames, fill, zeros

from .generic_bases import (ObsT,
ActT,
BaseObsT,
BaseActT,
InterfaceController,
InterfaceObserver,
InterfaceJiminyEnv)
from .generic import (ObsT,
ActT,
BaseObsT,
BaseActT,
InterfaceController,
InterfaceObserver,
InterfaceJiminyEnv)


BlockStateT = TypeVar('BlockStateT', bound=Union[DataNested, None])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@

from ..utils import DataNested, is_breakpoint, zeros, build_copyto, copy

from .generic_bases import (DT_EPS,
ObsT,
ActT,
BaseObsT,
BaseActT,
InfoType,
EngineObsType,
InterfaceJiminyEnv)
from .block_bases import BaseControllerBlock, BaseObserverBlock
from .generic import (DT_EPS,
ObsT,
ActT,
BaseObsT,
BaseActT,
InfoType,
EngineObsType,
InterfaceJiminyEnv)
from .block import BaseControllerBlock, BaseObserverBlock


OtherObsT = TypeVar('OtherObsT', bound=DataNested)
Expand Down
18 changes: 11 additions & 7 deletions python/gym_jiminy/common/gym_jiminy/common/blocks/mahony_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
reinforcement learning pipeline environment design.
"""
import logging
from typing import List, Union, Optional, Tuple
from typing import List, Union, Optional, Tuple, no_type_check

import numpy as np
import numba as nb
Expand All @@ -23,7 +23,7 @@


@nb.jit(nopython=True, cache=True)
def compute_tilt(q: np.ndarray) -> None:
def compute_tilt(q: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute e_z in R(q) frame (Euler-Rodrigues Formula): R(q).T @ e_z.
:param q: Array whose rows are the 4 components of quaternions (x, y, z, w)
Expand Down Expand Up @@ -110,6 +110,7 @@ def mahony_filter(q: np.ndarray,


# FIXME: Enabling cache causes segfault on Apple Silicon
@no_type_check
@nb.jit(nopython=True, cache=False)
def quat_from_vector(
v_a: Tuple[ArrayOrScalar, ArrayOrScalar, ArrayOrScalar],
Expand Down Expand Up @@ -398,11 +399,14 @@ def _setup(self) -> None:
"provide a meaningful estimate of the IMU orientations. It "
"should not exceed 10ms.", self.observe_dt)

# Make sure that `mahony_filter` has been pre-compiled, otherwise the
# first simulation step may timeout because of it.
if not mahony_filter.signatures:
self._is_initialized = True
self.refresh_observation(self.env.observation)
# Call `mahony_filter` to make sure that it has been pre-compiled, to
# avoid raising a timeout exception during the first simulation step.
# Note that it is not reliable to check if `mahony_filter` has been
# compiled at least once, because it may have been compiled for a
# different environment, for which `mahony_filter` may be another
# signature and therefore trigger yet another compilation.
self._is_initialized = True
self.refresh_observation(self.env.observation)

# Consider that the observer is not initialized anymore
self._is_initialized = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,10 @@ def _setup(self) -> None:
# Reset the command state
fill(self._command_state, 0)

# Make sure that `pd_controller` has been pre-compiled, otherwise the
# first simulation step may timeout because of it.
if not pd_controller.signatures:
self.compute_command(self.env.action)
# Call `pd_controller` to make sure that it has been pre-compiled,
# otherwise the first simulation step will take much more time than
# expected, which is likely to raise a timeout exception.
self.compute_command(self.env.action)

@property
def fieldnames(self) -> List[str]:
Expand Down
4 changes: 2 additions & 2 deletions python/gym_jiminy/common/gym_jiminy/common/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=missing-module-docstring

from .env_generic import BaseJiminyEnv
from .env_locomotion import WalkerJiminyEnv
from .generic import BaseJiminyEnv
from .locomotion import WalkerJiminyEnv


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..utils import sample
from ..bases import InfoType
from .env_generic import BaseJiminyEnv
from .generic import BaseJiminyEnv


GROUND_FRICTION_RANGE = (0.2, 2.0)
Expand Down
84 changes: 42 additions & 42 deletions python/gym_jiminy/common/gym_jiminy/common/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def quat_to_rpy(quat: np.ndarray,
else:
assert out.shape == (3, *quat.shape[1:])
out_ = out
roll, pitch, yaw = out_

# Compute some intermediary quantities
q_xx, q_xy, q_xz, q_xw = quat[-4] * quat[-4:]
Expand All @@ -86,10 +85,11 @@ def quat_to_rpy(quat: np.ndarray,
q_xz *= norm_inv

# Compute Roll, Pitch and Yaw separately
roll[:] = np.arctan2(2 * (q_xw + q_yz), 1.0 - 2 * (q_xx + q_yy))
pitch[:] = - np.pi / 2 + 2 * np.arctan2(
# roll, pitch, yaw = out_
out_[0] = np.arctan2(2 * (q_xw + q_yz), 1.0 - 2 * (q_xx + q_yy))
out_[1] = - np.pi / 2 + 2 * np.arctan2(
np.sqrt(1.0 + 2 * (q_yw - q_xz)), np.sqrt(1.0 - 2 * (q_yw - q_xz)))
yaw[:] = np.arctan2(2 * (q_zw + q_xy), 1.0 - 2 * (q_yy + q_zz))
out_[2] = np.arctan2(2 * (q_zw + q_xy), 1.0 - 2 * (q_yy + q_zz))

return out_

Expand Down Expand Up @@ -143,11 +143,11 @@ def matrix_to_quat(mat: np.ndarray,
else:
assert out.shape == (4, *mat.shape[2:])
out_ = out
q_x, q_y, q_z, q_w = out_
q_x[:] = mat[2, 1] - mat[1, 2]
q_y[:] = mat[0, 2] - mat[2, 0]
q_z[:] = mat[1, 0] - mat[0, 1]
q_w[:] = 1.0 + mat[0, 0] + mat[1, 1] + mat[2, 2]
# q_x, q_y, q_z, q_w = out_
out_[0] = mat[2, 1] - mat[1, 2]
out_[1] = mat[0, 2] - mat[2, 0]
out_[2] = mat[1, 0] - mat[0, 1]
out_[3] = 1.0 + mat[0, 0] + mat[1, 1] + mat[2, 2]
out_ /= np.sqrt(np.sum(np.square(out_), 0))
return out_

Expand All @@ -170,35 +170,35 @@ def matrices_to_quat(mat_list: Tuple[np.ndarray],
else:
assert out.shape == (4, len(mat_list))
out_ = out
q_x, q_y, q_z, q_w = out_
# q_x, q_y, q_z, q_w = out_
t = np.empty((len(mat_list),))
for i, mat in enumerate(mat_list):
if mat[2, 2] < 0:
if mat[0, 0] > mat[1, 1]:
t[i] = 1 + mat[0, 0] - mat[1, 1] - mat[2, 2]
q_x[i] = t[i]
q_y[i] = mat[1, 0] + mat[0, 1]
q_z[i] = mat[0, 2] + mat[2, 0]
q_w[i] = mat[2, 1] - mat[1, 2]
out_[0][i] = t[i]
out_[1][i] = mat[1, 0] + mat[0, 1]
out_[2][i] = mat[0, 2] + mat[2, 0]
out_[3][i] = mat[2, 1] - mat[1, 2]
else:
t[i] = 1 - mat[0, 0] + mat[1, 1] - mat[2, 2]
q_x[i] = mat[1, 0] + mat[0, 1]
q_y[i] = t[i]
q_z[i] = mat[2, 1] + mat[1, 2]
q_w[i] = mat[0, 2] - mat[2, 0]
out_[0][i] = mat[1, 0] + mat[0, 1]
out_[1][i] = t[i]
out_[2][i] = mat[2, 1] + mat[1, 2]
out_[3][i] = mat[0, 2] - mat[2, 0]
else:
if mat[0, 0] < -mat[1, 1]:
t[i] = 1 - mat[0, 0] - mat[1, 1] + mat[2, 2]
q_x[i] = mat[0, 2] + mat[2, 0]
q_y[i] = mat[2, 1] + mat[1, 2]
q_z[i] = t[i]
q_w[i] = mat[1, 0] - mat[0, 1]
out_[0][i] = mat[0, 2] + mat[2, 0]
out_[1][i] = mat[2, 1] + mat[1, 2]
out_[2][i] = t[i]
out_[3][i] = mat[1, 0] - mat[0, 1]
else:
t[i] = 1 + mat[0, 0] + mat[1, 1] + mat[2, 2]
q_x[i] = mat[2, 1] - mat[1, 2]
q_y[i] = mat[0, 2] - mat[2, 0]
q_z[i] = mat[1, 0] - mat[0, 1]
q_w[i] = t[i]
out_[0][i] = mat[2, 1] - mat[1, 2]
out_[1][i] = mat[0, 2] - mat[2, 0]
out_[2][i] = mat[1, 0] - mat[0, 1]
out_[3][i] = t[i]
out_ /= 2 * np.sqrt(t)
return out_

Expand Down Expand Up @@ -251,13 +251,13 @@ def matrix_to_rpy(mat: np.ndarray,
else:
assert out.shape == (3, *mat.shape[2:])
out_ = out
roll, pitch, yaw = out_
yaw[:] = np.arctan2(mat[1, 0], mat[0, 0])
# roll, pitch, yaw = out_
out_[2] = np.arctan2(mat[1, 0], mat[0, 0])
cos_pitch = np.sqrt(mat[2, 2] ** 2 + mat[2, 1] ** 2)
pitch[:] = np.arctan2(- mat[2, 0], np.sign(yaw) * cos_pitch)
yaw[:] += np.pi * (yaw < 0.0)
sin_yaw, cos_yaw = np.sin(yaw), np.cos(yaw)
roll[:] = np.arctan2(
out_[1] = np.arctan2(- mat[2, 0], np.sign(out_[2]) * cos_pitch)
out_[2] += np.pi * (out_[2] < 0.0)
sin_yaw, cos_yaw = np.sin(out_[2]), np.cos(out_[2])
out_[0] = np.arctan2(
sin_yaw * mat[0, 2] - cos_yaw * mat[1, 2],
cos_yaw * mat[1, 1] - sin_yaw * mat[0, 1])
return out_
Expand All @@ -280,15 +280,15 @@ def rpy_to_quat(rpy: np.ndarray,
else:
assert out.shape == (4, *rpy.shape[1:])
out_ = out
q_x, q_y, q_z, q_w = out_
roll, pitch, yaw = rpy
cos_roll, sin_roll = np.cos(roll / 2), np.sin(roll / 2)
cos_pitch, sin_pitch = np.cos(pitch / 2), np.sin(pitch / 2)
cos_yaw, sin_yaw = np.cos(yaw / 2), np.sin(yaw / 2)
q_x[:] = sin_roll * cos_pitch * cos_yaw - cos_roll * sin_pitch * sin_yaw
q_y[:] = cos_roll * sin_pitch * cos_yaw + sin_roll * cos_pitch * sin_yaw
q_z[:] = cos_roll * cos_pitch * sin_yaw - sin_roll * sin_pitch * cos_yaw
q_w[:] = cos_roll * cos_pitch * cos_yaw + sin_roll * sin_pitch * sin_yaw
# q_x, q_y, q_z, q_w = out_
out_[0] = sin_roll * cos_pitch * cos_yaw - cos_roll * sin_pitch * sin_yaw
out_[1] = cos_roll * sin_pitch * cos_yaw + sin_roll * cos_pitch * sin_yaw
out_[2] = cos_roll * cos_pitch * sin_yaw - sin_roll * sin_pitch * cos_yaw
out_[3] = cos_roll * cos_pitch * cos_yaw + sin_roll * sin_pitch * sin_yaw
return out_


Expand Down Expand Up @@ -321,12 +321,12 @@ def quat_multiply(quat_left: np.ndarray,
else:
assert out.shape == quat_left.shape
out_ = out
qx_out, qy_out, qz_out, qw_out = out_
(qx_l, qy_l, qz_l, qw_l), (qx_r, qy_r, qz_r, qw_r) = quat_left, quat_right
qx_out[:] = qw_l * qx_r + qx_l * qw_r + qy_l * qz_r - qz_l * qy_r
qy_out[:] = qw_l * qy_r - qx_l * qz_r + qy_l * qw_r + qz_l * qx_r
qz_out[:] = qw_l * qz_r + qx_l * qy_r - qy_l * qx_r + qz_l * qw_r
qw_out[:] = qw_l * qw_r - qx_l * qx_r - qy_l * qy_r - qz_l * qz_r
# qx_out, qy_out, qz_out, qw_out = out_
out_[0] = qw_l * qx_r + qx_l * qw_r + qy_l * qz_r - qz_l * qy_r
out_[1] = qw_l * qy_r - qx_l * qz_r + qy_l * qw_r + qz_l * qx_r
out_[2] = qw_l * qz_r + qx_l * qy_r - qy_l * qx_r + qz_l * qw_r
out_[3] = qw_l * qw_r - qx_l * qx_r - qy_l * qy_r - qz_l * qz_r
return out_


Expand Down
4 changes: 2 additions & 2 deletions python/gym_jiminy/envs/gym_jiminy/envs/ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def __init__(self, debug: bool = False, **kwargs: Any) -> None:

# Define observation slices proxy for fast access.
# Note that they will be initialized in `_initialize_buffers`.
self._obs_slices: Tuple[np.ndarray] = ()
self._obs_slices: Tuple[np.ndarray, ...] = ()

# Define base orientation and external forces proxies for fast access.
# Note that they will be initialized in `_initialize_buffers`.
self._base_rot = np.array([])
self._f_external: Tuple[np.ndarray] = ()
self._f_external: Tuple[np.ndarray, ...] = ()

# Initialize base class
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion python/jiminy_py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def finalize_options(self) -> None:
# Check PEP8 conformance of Python native code
"flake8",
# Python linter
"pylint>=3.0",
"pylint>=3.1",
# Python static type checker
"mypy>=1.5.0",
# Dependency for documentation generation
Expand Down
2 changes: 1 addition & 1 deletion python/jiminy_py/src/jiminy_py/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def robot(self) -> jiminy.Robot:
return self.engine.robots[0]

@property
def robot_state(self) -> jiminy.Robot:
def robot_state(self) -> jiminy.RobotState:
"""Convenience proxy to get the state of the robot.
Internally, all it does is returning `self.engine.robot_states[0]`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class MeshcatVisualizer(BaseVisualizer):
Based on https://github.com/stack-of-tasks/pinocchio/blob/master/bindings/python/pinocchio/visualize/meshcat_visualizer.py
Copyright (c) 2014-2020, CNRS
Copyright (c) 2018-2020, INRIA
""" # noqa: E501 # pylint: disable=line-too-long
""" # noqa: E501
def initViewer(self, # pylint: disable=arguments-differ
viewer: meshcat.Visualizer = None,
loadModel: bool = False,
Expand Down
1 change: 0 additions & 1 deletion python/jiminy_py/src/jiminy_py/viewer/meshcat/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def _meshcat_server(info: Dict[str, str], verbose: bool) -> None:
"""Meshcat server daemon, using in/out argument to get the zmq url instead
of reading stdout as it was.
"""
# pylint: disable=consider-using-with
# Redirect both stdout and stderr to devnull if not verbose
if not verbose:
devnull = open(os.devnull, 'w')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,10 +851,8 @@ def move_orbital_camera_task(self,
self.longitude_deg = self.longitude_deg - 360.0
if self.longitude_deg < -180.0:
self.longitude_deg = self.longitude_deg + 360.0
if self.latitude_deg > (90.0 - 0.001):
self.latitude_deg = 90.0 - 0.001
if self.latitude_deg < (-90.0 + 0.001):
self.latitude_deg = -90.0 + 0.001
self.latitude_deg = min(max(
self.latitude_deg, -90.0 + 0.001), 90.0 - 0.001)

longitude = self.longitude_deg * np.pi / 180.0
latitude = self.latitude_deg * np.pi / 180.0
Expand Down Expand Up @@ -2022,7 +2020,7 @@ class Panda3dVisualizer(BaseVisualizer):
Based on https://github.com/stack-of-tasks/pinocchio/blob/master/bindings/python/pinocchio/visualize/panda3d_visualizer.py
Copyright (c) 2014-2020, CNRS
Copyright (c) 2018-2020, INRIA
""" # noqa: E501 # pylint: disable=line-too-long
""" # noqa: E501
def initViewer(self, # pylint: disable=arguments-differ
viewer: Optional[Union[Panda3dViewer, Panda3dApp]] = None,
loadModel: bool = False,
Expand Down

0 comments on commit 2fdf1c1

Please sign in to comment.