Skip to content

Commit

Permalink
[misc] Fix python typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexis Duburcq committed Dec 13, 2021
1 parent 96bfad7 commit b6bd204
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 30 deletions.
13 changes: 7 additions & 6 deletions .github/workflows/ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ jobs:
matrix:
include:
- os: ubuntu-18.04
BUILD_TYPE: 'Release'
- os: ubuntu-20.04
BUILD_TYPE: 'Debug'
- os: ubuntu-20.04
BUILD_TYPE: 'Release'

defaults:
run:
Expand Down Expand Up @@ -58,7 +58,7 @@ jobs:
#####################################################################################

- name: PEP8 Code Style Check
if: matrix.os == 'ubuntu-18.04'
if: matrix.os == 'ubuntu-20.04'
run: |
flake8 --ignore=E121,E126,E123,E226,E241,E266,E402,F405,W504 --count --show-source --statistics "$RootDir/python"
Expand Down Expand Up @@ -109,7 +109,7 @@ jobs:
# Ubuntu 18 is distributed with Python3.6, which is not supported by Numpy>=1.20.
# The new type check support of Numpy is raising pylint and mypy errors, so Ubuntu 18
# is used to do type checking for now.
if: matrix.os == 'ubuntu-18.04'
if: matrix.os == 'ubuntu-20.04'
run: |
gym_modules=(
"common"
Expand All @@ -119,12 +119,13 @@ jobs:
for name in "${gym_modules[@]}"; do
cd "$RootDir/python/gym_jiminy/$name"
pylint --unsafe-load-any-extension=y --ignore-imports=y --min-similarity-lines=7 --max-nested-blocks=7 \
pylint --unsafe-load-any-extension=y --ignore-imports=y --min-similarity-lines=20 --max-nested-blocks=7 \
--good-names=i,j,k,t,q,v,x,e,u,s,v,b,c,f,M,dt,rg,fd,lo,hi,tb,_ \
--disable=fixme,abstract-method,protected-access,useless-super-delegation \
--disable=too-many-instance-attributes,too-many-arguments,too-few-public-methods,too-many-lines \
--disable=too-many-locals,too-many-branches,too-many-statements \
--disable=unspecified-encoding,logging-fstring-interpolation \
--disable=misplaced-comparison-constant \
--generated-members=numpy.*,torch.* "gym_jiminy/"
mypy --allow-redefinition --check-untyped-defs --disallow-incomplete-defs --disallow-untyped-defs \
Expand All @@ -142,7 +143,7 @@ jobs:
cmake . -DCOMPONENT=docs -P ./cmake_install.cmake
- name: Deploy to GitHub Pages
if: >-
matrix.os == 'ubuntu-18.04' && success() &&
matrix.os == 'ubuntu-20.04' && success() &&
github.repository == 'duburcqa/jiminy' && github.event_name == 'push' && github.ref == 'refs/heads/master'
uses: crazy-max/ghaction-github-pages@v2
with:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ def get_fieldnames(self) -> FieldNested:
This method is not supposed to be called before `reset`, so that
the controller should be already initialized at this point.
"""
# Assertion(s) for type checker
assert self.action_space is not None

return get_fieldnames(self.action_space)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def get_observation(self) -> DataNested:
In most cases, it is not necessary to overloaded this method, and
doing so may lead to unexpected behavior if not done carefully.
"""
# Assertion(s) for type checker
assert self._observation is not None
return self._observation

# methods to override:
Expand Down
24 changes: 22 additions & 2 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def _controller_handle(self,
.. warning::
This method is not supposed to be called manually nor overloaded.
"""
# Assertion(s) for type checker
assert self._action is not None

command[:] = self.compute_command(
self.env.get_observation(), self._action)

Expand All @@ -113,6 +116,9 @@ def get_observation(self) -> DataNested:
.. warning::
This method is not supposed to be called manually nor overloaded.
"""
# Assertion(s) for type checker
assert self._observation is not None

return copy(self._observation)

def reset(self,
Expand Down Expand Up @@ -176,6 +182,9 @@ def step(self,
:returns: Next observation, reward, status of the episode (done or
not), and a dictionary of extra information.
"""
# Assertion(s) for type checker
assert self._action is not None

# Backup the action to perform, if any
if action is not None:
set_value(self._action, action)
Expand Down Expand Up @@ -206,6 +215,9 @@ def _setup(self) -> None:
# Call base implementation
super()._setup()

# Assertion(s) for type checker
assert self._action is not None

# Reset some internal buffers
fill(self._action, 0.0)
fill(self._command, 0.0)
Expand Down Expand Up @@ -234,6 +246,10 @@ def compute_command(self,
:param measure: Observation of the environment.
:param action: Target to achieve.
"""
# Assertion(s) for type checker
assert self._action is not None
assert self.env._action is not None

set_value(self._action, action)
set_value(self.env._action, action)
return self.env.compute_command(measure, action)
Expand Down Expand Up @@ -380,7 +396,7 @@ def refresh_observation(self) -> None: # type: ignore[override]
if not self.simulator.is_simulation_running:
features = self.observer.get_observation()
if self.augment_observation:
# Assertion for type checker
# Assertion(s) for type checker
assert isinstance(self._observation, dict)
# Make sure to store references
if isinstance(obs, gym.spaces.Dict):
Expand Down Expand Up @@ -563,6 +579,10 @@ def compute_command(self,
:param measure: Observation of the environment.
:param action: High-level target to achieve.
"""
# Assertion(s) for type checker
assert self._observation is not None
assert self.env._action is not None

# Update the target to send to the subsequent block if necessary.
# Note that `_observation` buffer has already been updated right before
# calling this method by `_controller_handle`, so it can be used as
Expand Down Expand Up @@ -609,7 +629,7 @@ def refresh_observation(self) -> None: # type: ignore[override]
if not self.simulator.is_simulation_running:
obs = self.env.get_observation()
if self.augment_observation:
# Assertion for type checker
# Assertion(s) for type checker
assert isinstance(self._observation, dict)
# Make sure to store references
if isinstance(obs, dict):
Expand Down
12 changes: 11 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _controller_handle(self,
.. warning::
This method is not supposed to be called manually nor overloaded.
"""
assert self._action is not None
command[:] = self.compute_command(
self.get_observation(), self._action)

Expand Down Expand Up @@ -546,6 +547,7 @@ def reset(self,

# Assertion(s) for type checker
assert self.observation_space is not None
assert self._action is not None

# Stop the simulator
self.simulator.stop()
Expand Down Expand Up @@ -736,6 +738,9 @@ def step(self,
:returns: Next observation, reward, status of the episode (done or
not), and a dictionary of extra information
"""
# Assertion(s) for type checker
assert self._action is not None

# Make sure a simulation is already running
if not self.simulator.is_simulation_running:
raise RuntimeError(
Expand Down Expand Up @@ -1357,7 +1362,7 @@ def refresh_observation(self) -> None: # type: ignore[override]

def compute_command(self,
measure: DataNested,
action: np.ndarray
action: DataNested
) -> np.ndarray:
"""Compute the motors efforts to apply on the robot.
Expand All @@ -1378,6 +1383,11 @@ def compute_command(self,
if self.debug and not self.action_space.contains(action):
logger.warn("The action is out-of-bounds.")

if not isinstance(action, np.ndarray):
raise RuntimeError(
"`BaseJiminyEnv.compute_command` must be overloaded unless "
"the action space has type `gym.spaces.Box`.")

return action

def is_done(self, *args: Any, **kwargs: Any) -> bool:
Expand Down
9 changes: 5 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
"""
from typing import Optional, Union, Dict, Sequence, TypeVar

import numpy as np
import tree
import gym
import tree
import numpy as np
from numpy.random.mtrand import _rand as global_randstate


ValueType = TypeVar('ValueType')
Expand Down Expand Up @@ -93,7 +94,7 @@ def sample(low: Union[float, np.ndarray] = -1.0,
# Sample from normalized distribution.
# Note that some distributions are not normalized by default
if rg is None:
rg = np.random
rg = global_randstate
distrib_fn = getattr(rg, dist)
if dist == 'uniform':
value = distrib_fn(low=-1.0, high=1.0, size=shape)
Expand Down Expand Up @@ -123,7 +124,7 @@ def is_bounded(space_nested: gym.Space) -> bool:


def zeros(space_nested: gym.Space,
dtype: Optional[type] = None) -> Union[DataNested, int]:
dtype: Optional[type] = None) -> DataNested:
"""Allocate data structure from `gym.space.Space` and initialize it to zero.
:param space: Gym.Space on which to operate.
Expand Down
20 changes: 12 additions & 8 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" TODO: Write documentation.
"""
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
Expand All @@ -21,15 +21,15 @@ def on_episode_step(self,
*,
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: Episode,
**kwargs) -> None:
**kwargs: Any) -> None:
""" TODO: Write documentation.
"""
super().on_episode_step(worker=worker,
base_env=base_env,
policies=policies,
pisode=episode,
episode=episode,
**kwargs)
info = episode.last_info_for()
if info is not None:
Expand All @@ -43,13 +43,13 @@ def on_episode_end(self,
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
**kwargs) -> None:
**kwargs: Any) -> None:
""" TODO: Write documentation.
"""
super().on_episode_end(worker=worker,
base_env=base_env,
policies=policies,
pisode=episode,
episode=episode,
**kwargs)
episode.custom_metrics["episode_duration"] = \
base_env.get_sub_environments()[0].step_dt * episode.length
Expand All @@ -65,9 +65,13 @@ def on_train_result(self,
**kwargs: Any) -> None:
""" TODO: Write documentation.
"""
assert isinstance(trainer.workers, WorkerSet)
super().on_train_result(trainer=trainer, result=result, **kwargs)
trainer.workers.foreach_worker(

# Assertion(s) for type checker
workers = trainer.workers
assert isinstance(workers, WorkerSet)

workers.foreach_worker(
lambda worker: worker.foreach_env(
lambda env: env.update(result)))

Expand Down
23 changes: 18 additions & 5 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/curriculum.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from ray.tune.logger import TBXLogger
from ray.tune.result import TRAINING_ITERATION, TIMESTEPS_TOTAL
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.policy import Policy
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.utils.typing import PolicyID

from gym_jiminy.toolbox.wrappers.meta_envs import DataTreeT

Expand All @@ -35,14 +39,19 @@ def __init__(self) -> None:

def on_episode_end(self,
*,
worker: RolloutWorker,
base_env: BaseEnv,
episode: MultiAgentEpisode,
policies: Dict[PolicyID, Policy],
episode: Episode,
**kwargs: Any) -> None:
""" TODO: Write documentation.
"""
# Call base implementation
super().on_episode_end(
base_env=base_env, episode=episode, **kwargs)
super().on_episode_end(worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
**kwargs)

# Monitor episode duration for each gait
for env in base_env.get_sub_environments():
Expand Down Expand Up @@ -196,8 +205,12 @@ def on_train_result(self,
if task_branch_next:
task_branches.append(task_branch_next)

# Assertion(s) for type checker
workers = trainer.workers
assert isinstance(workers, WorkerSet)

# Update envs accordingly
trainer.workers.foreach_worker(
workers.foreach_worker(
lambda worker: worker.foreach_env(
lambda env: env.task_tree_probas.update(task_tree_probas)))

Expand Down
1 change: 1 addition & 0 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def value_function(
if self.config["l2_reg"] > 0.0:
# Add actor l2-regularization loss
l2_reg = torch.tensor(0.0)
assert isinstance(model, torch.nn.Module)
for name, params in model.named_parameters():
if not name.endswith("bias"):
l2_reg += l2_loss(params)
Expand Down
8 changes: 5 additions & 3 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ray.rllib.policy import Policy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils.filter import NoFilter
from ray.rllib.utils.filter import NoFilter, MeanStdFilter
from ray.rllib.agents.trainer import Trainer
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.env.env_context import EnvContext
Expand Down Expand Up @@ -275,7 +275,7 @@ def compute_action(policy: Policy,
feed_dict = {policy._input_dict[key]: value
for key, value in input_dict.items()
if key in policy._input_dict.keys()}
feed_dict[policy._is_exploring] = explore
feed_dict[policy._is_exploring] = np.array(True)
action, *state = policy._sess.run(
[policy._sampled_action] + policy._state_outputs,
feed_dict=feed_dict)
Expand Down Expand Up @@ -651,10 +651,12 @@ def test(test_agent: Trainer,
obs_filter = test_agent.workers.local_worker().filters["default_policy"]
if isinstance(obs_filter, NoFilter):
obs_filter_fn = None
else:
elif isinstance(obs_filter, MeanStdFilter):
obs_mean, obs_std = obs_filter.rs.mean, obs_filter.rs.std
obs_filter_fn = \
lambda obs: (obs - obs_mean) / (obs_std + 1.0e-8) # noqa: E731
else:
raise RuntimeError(f"Filter '{obs_filter.__class__}' not supported.")

# Forward viewer keyword arguments
if viewer_kwargs is not None:
Expand Down
4 changes: 3 additions & 1 deletion python/gym_jiminy/toolbox/gym_jiminy/toolbox/math/qhull.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
""" TODO: Write documentation.
"""
from typing import Optional

import numpy as np
import numba as nb
from numba.np.extensions import cross2d
Expand Down Expand Up @@ -128,7 +130,7 @@ def __init__(self, points: np.ndarray) -> None:
self._hull = None

# Buffer to cache center computation
self._center = None
self._center: Optional[np.ndarray] = None

@property
def center(self) -> np.ndarray:
Expand Down

0 comments on commit b6bd204

Please sign in to comment.