Skip to content

Commit

Permalink
adding support for gym to fix #379 and finish previous work
Browse files Browse the repository at this point in the history
  • Loading branch information
BDonnot committed Dec 1, 2022
1 parent 2a668e7 commit 721c471
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 21 deletions.
54 changes: 45 additions & 9 deletions grid2op/gym_compat/gymenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
# SPDX-License-Identifier: MPL-2.0
# This file is part of Grid2Op, Grid2Op a testbed platform to model sequential decision making in power systems.

from os import truncate
import warnings
import gym
from grid2op.Chronics import Multifolder
from grid2op.gym_compat.gym_obs_space import GymObservationSpace
from grid2op.gym_compat.gym_act_space import GymActionSpace
from grid2op.gym_compat.utils import check_gym_version
from grid2op.gym_compat.utils import (check_gym_version, sample_seed,
_MAX_GYM_VERSION_RANDINT, GYM_VERSION)


class GymEnv(gym.Env):
Expand Down Expand Up @@ -49,27 +52,54 @@ def __init__(self, env_init, shuffle_chronics=True):
self.reward_range = self.init_env.reward_range
self.metadata = self.init_env.metadata
self._shuffle_chronics = shuffle_chronics

def step(self, gym_action):

if GYM_VERSION <= _MAX_GYM_VERSION_RANDINT:
self.seed = self._aux_seed
self.reset = self._aux_reset
self.step = self._aux_step
else:
self.reset = self._aux_reset_new
self.step = self._aux_step_new

def _aux_step(self, gym_action):
# used for gym < 0.26
g2op_act = self.action_space.from_gym(gym_action)
g2op_obs, reward, done, info = self.init_env.step(g2op_act)
gym_obs = self.observation_space.to_gym(g2op_obs)
return gym_obs, float(reward), done, info

def _aux_step_new(self, gym_action):
# used for gym >= 0.26
# TODO refacto with _aux_step
g2op_act = self.action_space.from_gym(gym_action)
g2op_obs, reward, done, info = self.init_env.step(g2op_act)
gym_obs = self.observation_space.to_gym(g2op_obs)
truncated = g2op_obs.current_step == g2op_obs.max_step
return gym_obs, float(reward), done, truncated, info

def reset(self, seed=None, return_info=False, options=None):
def _aux_reset(self, seed=None, return_info=None, options=None):
# used for gym < 0.26
if self._shuffle_chronics and isinstance(
self.init_env.chronics_handler.real_data, Multifolder
):
self.init_env.chronics_handler.sample_next_chronics()

if seed is not None:
self.init_env.seed(seed)
self._aux_seed(seed)

g2op_obs = self.init_env.reset()
gym_obs = self.observation_space.to_gym(g2op_obs)

if return_info:
return gym_obs, {}
chron_id = self.init_env.chronics_handler.get_id()
return gym_obs, {"time serie id": chron_id}
else:
return gym_obs

def _aux_reset_new(self, seed=None, options=None):
# used for gym > 0.26
return self._aux_reset(seed, True, options)

def render(self, mode="human"):
"""for compatibility with open ai gym render function"""
super(GymEnv, self).render(mode=mode)
Expand All @@ -87,9 +117,15 @@ def close(self):
self.observation_space.close()
self.observation_space = None

def seed(self, seed=None):
self.init_env.seed(seed)
# TODO seed also env space and observation space
def _aux_seed(self, seed=None):
# deprecated in gym >=0.26
if seed is not None:
# seed the gym env
super().reset(seed=seed)
# then seed the underlying grid2op env
max_ = 2**32-1
next_seed = sample_seed(max_, self._np_random)
self.init_env.seed(next_seed)

def __del__(self):
# delete possible dangling reference
Expand Down
22 changes: 11 additions & 11 deletions grid2op/tests/test_gym_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def test_convert_togym(self):
), f"Size should be {size_th} but is {dim_obs_space}"

# test that i can do basic stuff there
obs = env_gym.reset()
obs, info = env_gym.reset()
for k in env_gym.observation_space.spaces.keys():
assert obs[k] in env_gym.observation_space[k], f"error for key: {k}"
act = env_gym.action_space.sample()
obs2, reward2, done2, info2 = env_gym.step(act)
obs2, reward2, done2, truncated, info2 = env_gym.step(act)
assert obs2 in env_gym.observation_space

# test for the __str__ method
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_scale_attr_converter(self):
ScalerAttrConverter(substract=0.0, divide=self.env.gen_pmax),
)
env_gym.observation_space = ob_space
obs = env_gym.reset()
obs, info = env_gym.reset()
assert key in env_gym.observation_space.spaces
low = np.zeros(self.env.n_gen) - 1
high = np.zeros(self.env.n_gen) + 1
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_add_key(self):
)

# we highly recommend to "reset" the environment after setting up the observation space
obs_gym = env_gym.reset()
obs_gym, info = env_gym.reset()
assert key in env_gym.observation_space.spaces
assert obs_gym in env_gym.observation_space

Expand Down Expand Up @@ -756,7 +756,7 @@ def test_can_create(self):
)
},
)
obs_gym = self.env_gym.reset()
obs_gym, info = self.env_gym.reset()
assert obs_gym in self.env_gym.observation_space
assert self.env_gym.observation_space._attr_to_keep == sorted(kept_attr)
assert len(obs_gym) == 3583
Expand All @@ -766,7 +766,7 @@ def test_can_create_int(self):
self.env_gym.observation_space = BoxGymObsSpace(
self.env.observation_space, attr_to_keep=kept_attr
)
obs_gym = self.env_gym.reset()
obs_gym, info = self.env_gym.reset()
assert obs_gym in self.env_gym.observation_space
assert self.env_gym.observation_space._attr_to_keep == sorted(kept_attr)
assert len(obs_gym) == 79
Expand All @@ -779,7 +779,7 @@ def test_scaling(self):
self.env.observation_space, attr_to_keep=kept_attr
)
self.env_gym.observation_space = observation_space
obs_gym = self.env_gym.reset()
obs_gym, info = self.env_gym.reset()
assert obs_gym in observation_space
assert observation_space._attr_to_keep == kept_attr
assert len(obs_gym) == 17
Expand All @@ -792,7 +792,7 @@ def test_scaling(self):
divide={"gen_p": self.env.gen_pmax, "load_p": self.obs_env.load_p},
)
self.env_gym.observation_space = observation_space
obs_gym = self.env_gym.reset()
obs_gym, info = self.env_gym.reset()
assert obs_gym in observation_space
assert observation_space._attr_to_keep == kept_attr
assert len(obs_gym) == 17
Expand All @@ -807,7 +807,7 @@ def test_scaling(self):
subtract={"gen_p": 100.0, "load_p": 100.0},
)
self.env_gym.observation_space = observation_space
obs_gym = self.env_gym.reset()
obs_gym, info = self.env_gym.reset()
assert obs_gym in observation_space
assert observation_space._attr_to_keep == kept_attr
assert len(obs_gym) == 17
Expand Down Expand Up @@ -844,7 +844,7 @@ def test_functs(self):
)
},
)
obs_gym = self.env_gym.reset()
obs_gym, info = self.env_gym.reset()
assert obs_gym in self.env_gym.observation_space
assert self.env_gym.observation_space._attr_to_keep == sorted(kept_attr)
assert len(obs_gym) == 3583
Expand Down Expand Up @@ -1846,7 +1846,7 @@ def test_all_attr_in_obs(self):
env = grid2op.make("educ_case14_storage", test=True,
action_class=PlayableAction)
gym_env = GymEnv(env)
obs = gym_env.reset()
obs, info = gym_env.reset()
all_attrs = ["year",
"month",
"day",
Expand Down
45 changes: 45 additions & 0 deletions grid2op/tests/test_issue_379.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2019-2022, RTE (https://www.rte-france.com)
# See AUTHORS.txt and https://github.com/rte-france/Grid2Op/pull/319
# This Source Code Form is subject to the terms of the Mozilla Public License, version 2.0.
# If a copy of the Mozilla Public License, version 2.0 was not distributed with this file,
# you can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
# This file is part of Grid2Op, Grid2Op a testbed platform to model sequential decision making in power systems.

import grid2op
import unittest
import warnings

from grid2op.gym_compat import GymEnv
import grid2op
from gym import Env
from gym.utils.env_checker import check_env
try:
from gym.utils.env_checker import check_reset_return_type, check_reset_options, check_reset_seed
CAN_TEST_ALL = True
except ImportError:
CAN_TEST_ALL = False


class Issue379Tester(unittest.TestCase):
def setUp(self) -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
self.env = grid2op.make("l2rpn_case14_sandbox", test=True)
self.gym_env = GymEnv(self.env)

def tearDown(self) -> None:
self.env.close()
self.gym_env.close()
return super().tearDown()

def test_check_env(self):
if CAN_TEST_ALL:
check_reset_return_type(self.gym_env)
check_reset_seed(self.gym_env)
check_reset_options(self.gym_env)
check_env(self.gym_env)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def my_test_suite():
],
"plot": ["imageio"],
"test": ["lightsim2grid",
"numba"
"numba",
"gym>=0.26"
],
"chronix2grid": [
"ChroniX2Grid>=1.1.0.post1"
Expand Down

0 comments on commit 721c471

Please sign in to comment.