Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Action masking for Space.sample() #2906

Merged
merged 26 commits into from
Jun 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c390a6b
Allows a new RNG to be generated with seed=-1 and updated env_checker…
pseudo-rnd-thoughts Jun 8, 2022
654476e
Revert "fixed `gym.vector.make` where the checker was being applied i…
pseudo-rnd-thoughts Jun 8, 2022
ea110ad
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 11, 2022
2e5dc9c
Remove bad pushed commits
pseudo-rnd-thoughts Jun 13, 2022
717bc1f
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 13, 2022
7e73d04
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 15, 2022
5743cd9
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 16, 2022
400b1e9
Fixed spelling in core.py
pseudo-rnd-thoughts Jun 17, 2022
4281c76
Pins pytest to the last py 3.6 version
pseudo-rnd-thoughts Jun 17, 2022
5dee690
Add support for action masking in Space.sample(mask=...)
pseudo-rnd-thoughts Jun 17, 2022
bc6ab4a
Fix action mask
pseudo-rnd-thoughts Jun 17, 2022
1700e9d
Fix action_mask
pseudo-rnd-thoughts Jun 17, 2022
7f46df2
Fix action_mask
pseudo-rnd-thoughts Jun 17, 2022
cd91007
Added docstrings, fixed bugs and added taxi examples
pseudo-rnd-thoughts Jun 19, 2022
be4063e
Fixed bugs
pseudo-rnd-thoughts Jun 19, 2022
2f14eb7
Add tests for sample
pseudo-rnd-thoughts Jun 20, 2022
f52d5d5
Add docstrings and test space sample mask Discrete and MultiBinary
pseudo-rnd-thoughts Jun 20, 2022
5e699e1
Add MultiDiscrete sampling and tests
pseudo-rnd-thoughts Jun 21, 2022
634da12
Remove sample mask from graph
pseudo-rnd-thoughts Jun 21, 2022
f85055c
Update gym/spaces/multi_discrete.py
pseudo-rnd-thoughts Jun 23, 2022
4a4b166
Updates based on Marcus28 and jjshoots for Graph.py
pseudo-rnd-thoughts Jun 23, 2022
eb63c62
Updates based on Marcus28 and jjshoots for Graph.py
pseudo-rnd-thoughts Jun 23, 2022
8918914
jjshoot review
pseudo-rnd-thoughts Jun 24, 2022
a53f0e7
jjshoot review
pseudo-rnd-thoughts Jun 25, 2022
8e71e46
Update assert check
pseudo-rnd-thoughts Jun 25, 2022
875ab44
Update type hints
pseudo-rnd-thoughts Jun 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gym/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def seed(self, seed=None):
there aren't accidental correlations between multiple generators.

Args:
seed(Optional int): The seed value for the random number geneartor
seed(Optional int): The seed value for the random number generator
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved

Returns:
seeds (List[int]): Returns the list of seeds used in this environment's random
Expand Down
44 changes: 41 additions & 3 deletions gym/envs/toy_text/taxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ class TaxiEnv(Env):
- 2: Y(ellow)
- 3: B(lue)

### Info

``step`` and ``reset(return_info=True)`` will return an info dictionary that contains "p" and "action_mask" containing
the probability that the state is taken and a mask of what actions will result in a change of state to speed up training.

As Taxi's initial state is a stochastic, the "p" key represents the probability of the
transition however this value is currently bugged being 1.0, this will be fixed soon.
As the steps are deterministic, "p" represents the probability of the transition which is always 1.0

For some cases, taking an action will have no effect on the state of the agent.
In v0.25.0, ``info["action_mask"]`` contains a np.ndarray for each of the action specifying
if the action will change the state.

To sample a modifying action, use ``action = env.action_space.sample(info["action_mask"])``
Or with a Q-value based algorithm ``action = np.argmax(q_values[obs, np.where(info["action_mask"] == 1)[0]])``.

### Rewards
- -1 per step unless other reward is triggered.
- +20 delivering passenger.
Expand All @@ -99,7 +115,7 @@ class TaxiEnv(Env):
```

### Version History
* v3: Map Correction + Cleaner Domain Description
* v3: Map Correction + Cleaner Domain Description, v0.25.0 action masking added to the reset and step information
* v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold.
* v1: Remove (3,2) from locs, add passidx<4 check
* v0: Initial versions release
Expand Down Expand Up @@ -214,14 +230,36 @@ def decode(self, i):
assert 0 <= i < 5
return reversed(out)

def action_mask(self, state: int):
"""Computes an action mask for the action space using the state information."""
mask = np.zeros(6, dtype=np.int8)
taxi_row, taxi_col, pass_loc, dest_idx = self.decode(state)
if taxi_row < 4:
mask[0] = 1
if taxi_row > 0:
mask[1] = 1
if taxi_col < 4 and self.desc[taxi_row + 1, 2 * taxi_col + 2] == b":":
mask[2] = 1
if taxi_col > 0 and self.desc[taxi_row + 1, 2 * taxi_col] == b":":
mask[3] = 1
if pass_loc < 4 and (taxi_row, taxi_col) == self.locs[pass_loc]:
mask[4] = 1
if pass_loc == 4 and (
(taxi_row, taxi_col) == self.locs[dest_idx]
or (taxi_row, taxi_col) in self.locs
):
mask[5] = 1
return mask

def step(self, a):
transitions = self.P[self.s][a]
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, d = transitions[i]
self.s = s
self.lastaction = a
self.renderer.render_step()
return (int(s), r, d, {"prob": p})

return int(s), r, d, {"prob": p, "action_mask": self.action_mask(s)}

def reset(
self,
Expand All @@ -239,7 +277,7 @@ def reset(
if not return_info:
return int(self.s)
else:
return int(self.s), {"prob": 1}
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}

def render(self, mode="human"):
if self.render_mode is not None:
Expand Down
11 changes: 10 additions & 1 deletion gym/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

import gym.error
from gym import logger
from gym.spaces.space import Space
from gym.utils import seeding
Expand Down Expand Up @@ -146,7 +147,7 @@ def is_bounded(self, manner: str = "both") -> bool:
else:
raise ValueError("manner is not in {'below', 'above', 'both'}")

def sample(self) -> np.ndarray:
def sample(self, mask: None = None) -> np.ndarray:
r"""Generates a single random sample inside the Box.

In creating a sample of the box, each coordinate is sampled (independently) from a distribution
Expand All @@ -157,9 +158,17 @@ def sample(self) -> np.ndarray:
* :math:`(-\infty, b]` : shifted negative exponential distribution
* :math:`(-\infty, \infty)` : normal distribution

Args:
mask: A mask for sampling values from the Box space, currently unsupported.

Returns:
A sampled value from the Box
"""
if mask is not None:
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
raise gym.error.Error(
f"Box.sample cannot be provided a mask, actual value: {mask}"
)

high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
sample = np.empty(self.shape)

Expand Down
17 changes: 16 additions & 1 deletion gym/spaces/dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Dict as TypingDict
from typing import Optional, Union

Expand Down Expand Up @@ -137,14 +138,28 @@ def seed(self, seed: Optional[Union[dict, int]] = None) -> list:

return seeds

def sample(self) -> dict:
def sample(self, mask: Optional[TypingDict[str, Any]] = None) -> dict:
"""Generates a single random sample from this space.

The sample is an ordered dictionary of independent samples from the constituent spaces.

Args:
mask: An optional mask for each of the subspaces, expects the same keys as the space

Returns:
A dictionary with the same key and sampled values from :attr:`self.spaces`
"""
if mask is not None:
assert isinstance(
mask, dict
), f"Expects mask to be a dict, actual type: {type(mask)}"
assert (
mask.keys() == self.spaces.keys()
), f"Expect mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"
return OrderedDict(
[(k, space.sample(mask[k])) for k, space in self.spaces.items()]
)

return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])

def contains(self, x) -> bool:
Expand Down
30 changes: 28 additions & 2 deletions gym/spaces/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,40 @@ def __init__(
self.start = int(start)
super().__init__((), np.int64, seed)

def sample(self) -> int:
def sample(self, mask: Optional[np.ndarray] = None) -> int:
"""Generates a single random sample from this space.

A sample will be chosen uniformly at random.
A sample will be chosen uniformly at random with the mask if provided

Args:
mask: An optional mask for if an action can be selected.
Expected `np.ndarray` of shape `(n,)` and dtype `np.int8` where `1` represents valid actions and `0` invalid / infeasible actions.
If there are no possible actions (i.e. `np.all(mask == 0)`) then `space.start` will be returned.

Returns:
A sampled integer from the space
"""
if mask is not None:
assert isinstance(
mask, np.ndarray
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
assert (
mask.dtype == np.int8
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
assert mask.shape == (
self.n,
), f"The expected shape of the mask is {(self.n,)}, actual shape: {mask.shape}"
valid_action_mask = mask == 1
assert np.all(
np.logical_or(mask == 0, valid_action_mask)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
if np.any(valid_action_mask):
return int(
self.start + self.np_random.choice(np.where(valid_action_mask)[0])
)
else:
return self.start

return int(self.start + self.np_random.integers(self.n))

def contains(self, x) -> bool:
Expand Down
85 changes: 56 additions & 29 deletions gym/spaces/graph.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
from collections import namedtuple
from typing import NamedTuple, Optional, Sequence, Union
from typing import NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np

from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.multi_discrete import SAMPLE_MASK_TYPE, MultiDiscrete
from gym.spaces.space import Space
from gym.utils import seeding

Expand Down Expand Up @@ -70,53 +70,80 @@ def __init__(

def _generate_sample_space(
self, base_space: Union[None, Box, Discrete], num: int
) -> Optional[Union[Box, Discrete]]:
# the possibility of this space , got {type(base_space)}aving nothing
if num == 0:
) -> Optional[Union[Box, MultiDiscrete]]:
if num == 0 or base_space is None:
return None

if isinstance(base_space, Box):
return Box(
low=np.array(max(1, num) * [base_space.low]),
high=np.array(max(1, num) * [base_space.high]),
shape=(num, *base_space.shape),
shape=(num,) + base_space.shape,
dtype=base_space.dtype,
seed=self._np_random,
seed=self.np_random,
)
elif isinstance(base_space, Discrete):
return MultiDiscrete(nvec=[base_space.n] * num, seed=self._np_random)
elif base_space is None:
return None
return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random)
else:
raise AssertionError(
f"Only Box and Discrete can be accepted as a base_space, got {type(base_space)}, you should not have gotten this error."
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
)

def _sample_sample_space(self, sample_space) -> Optional[np.ndarray]:
if sample_space is not None:
return sample_space.sample()
else:
return None

def sample(self) -> NamedTuple:
def sample(
self,
mask: Optional[
Tuple[
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]],
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]],
]
] = None,
num_nodes: int = 10,
num_edges: Optional[int] = None,
) -> NamedTuple:
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.

Args:
mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces
(Box spaces don't support sample masks).
If no `num_edges` is provided then the `edge_mask` is multiplied by the number of edges
num_nodes: The number of nodes that will be sampled, the default is 10 nodes
num_edges: An optional number of edges, otherwise, a random number between 0 and `num_nodes`^2

Returns:
A NamedTuple representing a graph with attributes .nodes, .edges, and .edge_links.
"""
num_nodes = self.np_random.integers(low=1, high=10)
assert (
num_nodes > 0
), f"The number of nodes is expected to be greater than 0, actual value: {num_nodes}"

# we only have edges when we have at least 2 nodes
num_edges = 0
if num_nodes > 1:
# maximal number of edges is (n*n) allowing self connections and two way is allowed
num_edges = self.np_random.integers(num_nodes * num_nodes)

node_sample_space = self._generate_sample_space(self.node_space, num_nodes)
edge_sample_space = self._generate_sample_space(self.edge_space, num_edges)
if mask is not None:
node_space_mask, edge_space_mask = mask
else:
node_space_mask, edge_space_mask = None, None

sampled_nodes = self._sample_sample_space(node_sample_space)
sampled_edges = self._sample_sample_space(edge_sample_space)
# we only have edges when we have at least 2 nodes
if num_edges is None:
if num_nodes > 1:
# maximal number of edges is `n*(n-1)` allowing self connections and two-way is allowed
num_edges = self.np_random.integers(num_nodes * (num_nodes - 1))
else:
num_edges = 0
if edge_space_mask is not None:
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
else:
assert (
num_edges >= 0
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}"

sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)

sampled_nodes = sampled_node_space.sample(node_space_mask)
sampled_edges = (
sampled_edge_space.sample(edge_space_mask)
if sampled_edge_space is not None
else None
)

sampled_edge_links = None
if sampled_edges is not None and num_edges > 0:
Expand Down
24 changes: 23 additions & 1 deletion gym/spaces/multi_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,36 @@ def shape(self) -> Tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore

def sample(self) -> np.ndarray:
def sample(self, mask: Optional[np.ndarray] = None) -> np.ndarray:
"""Generates a single random sample from this space.

A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).

Args:
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
Where mask == 0 then the samples will be 0.

Returns:
Sampled values from space
"""
if mask is not None:
assert isinstance(
mask, np.ndarray
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
assert (
mask.dtype == np.int8
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
assert (
mask.shape == self.shape
), f"The expected shape of the mask is {self.shape}, actual shape: {mask.shape}"
assert np.all(
np.logical_or(mask == 0, mask == 1)
), f"All values of a mask should be 0 or 1, actual values: {mask}"

return mask * self.np_random.integers(
low=0, high=2, size=self.n, dtype=self.dtype
)

return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)

def contains(self, x) -> bool:
Expand Down
Loading