Skip to content

Commit

Permalink
Support for using random agents, improvements in CollectStats (#1207)
Browse files Browse the repository at this point in the history
Random action agents can be useful for tests, but also as a baseline.

Also (independent change):

- improved CollectStats for action_std
- Batch: extension for changing shape to atleast_2d, including "casting"
torch distributions to atleast_2d

One thing is a bit weird in this implementation: in high level
interfaces the optim_factory is obligatory. For random actions it
doesn't make sense, of course, but making it optional would require
changes and additional type checks in several places. @opcode81 what do
you think about this - should we make it optional?

@Trinkle23897 a quick glance should be enough, nothing controversial.
@opcode81 is on vacation and won't have time to review
  • Loading branch information
MischaPanch authored Aug 26, 2024
2 parents 002ffd9 + 4e03191 commit 8eb2795
Show file tree
Hide file tree
Showing 16 changed files with 372 additions and 31 deletions.
6 changes: 3 additions & 3 deletions docs/01_tutorials/04_tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ Two Random Agents

.. Figure:: ../_static/images/marl.png

Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.policy.RandomPolicy` and :class:`~tianshou.policy.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation.
Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.policy.MARLRandomPolicy` and :class:`~tianshou.policy.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation.

::

Expand Down Expand Up @@ -202,7 +202,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
BasePolicy,
DQNPolicy,
MultiAgentPolicyManager,
RandomPolicy,
MARLRandomPolicy,
)
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -286,7 +286,7 @@ The following ``get_agents`` function returns agents and their optimizers from e

- The action model we use is an instance of :class:`~tianshou.utils.net.common.Net`, essentially a multi-layer perceptron with the ReLU activation function;
- The network model is passed to a :class:`~tianshou.policy.DQNPolicy`, where actions are selected according to both the action mask and their Q-values;
- The opponent can be either a random agent :class:`~tianshou.policy.RandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves.
- The opponent can be either a random agent :class:`~tianshou.policy.MARLRandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves.

Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.policy.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment.

Expand Down
2 changes: 2 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,5 @@ monte
carlo
subclass
subclassing
dist
dists
60 changes: 59 additions & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import pytest
import torch
from deepdiff import DeepDiff
from torch.distributions import Distribution, Independent, Normal
from torch.distributions.categorical import Categorical

from tianshou.data import Batch, to_numpy, to_torch
from tianshou.data.batch import IndexType, get_sliced_dist
from tianshou.data.batch import IndexType, dist_to_atleast_2d, get_sliced_dist


def test_batch() -> None:
Expand Down Expand Up @@ -766,6 +767,63 @@ def test_batch_over_batch_to_torch() -> None:
assert batch.b.d.dtype == torch.float32
assert batch.b.e.dtype == torch.float32

@staticmethod
@pytest.mark.parametrize(
"dist, expected_batch_shape",
[
(Categorical(probs=torch.tensor([0.3, 0.7])), (1,)),
(Categorical(probs=torch.tensor([[0.3, 0.7], [0.4, 0.6]])), (2,)),
(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), (1,)),
(Normal(loc=torch.tensor([0.0, 1.0]), scale=torch.tensor([1.0, 2.0])), (2,)),
(Independent(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), 0), (1,)),
(
Independent(
Normal(loc=torch.tensor([0.0, 1.0]), scale=torch.tensor([1.0, 2.0])),
0,
),
(2,),
),
],
)
def test_dist_to_atleast_2d(dist: Distribution, expected_batch_shape: tuple[int]) -> None:
result = dist_to_atleast_2d(dist)
assert result.batch_shape == expected_batch_shape

# Additionally check that the parameters are correctly transformed
if isinstance(dist, Categorical):
assert isinstance(result, Categorical)
assert result.probs.shape[:-1] == expected_batch_shape
elif isinstance(dist, Normal):
assert isinstance(result, Normal)
assert result.loc.shape == expected_batch_shape
assert result.scale.shape == expected_batch_shape
elif isinstance(dist, Independent):
assert isinstance(result, Independent)
assert result.base_dist.batch_shape == expected_batch_shape

@staticmethod
@pytest.mark.parametrize(
"dist",
[
Categorical(probs=torch.tensor([0.3, 0.7])),
Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)),
Independent(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), 0),
],
)
def test_dist_to_atleast_2d_idempotent(dist: Distribution) -> None:
result1 = dist_to_atleast_2d(dist)
result2 = dist_to_atleast_2d(result1)
assert result1 == result2

@staticmethod
def test_batch_to_atleast_2d() -> None:
scalar_batch = Batch(a=1, b=2, dist=Categorical(probs=torch.ones(3)))
assert scalar_batch.dist.batch_shape == ()
assert scalar_batch.a.shape == scalar_batch.b.shape == ()
scalar_batch_2d = scalar_batch.to_at_least_2d()
assert scalar_batch_2d.dist.batch_shape == (1,)
assert scalar_batch_2d.a.shape == scalar_batch_2d.b.shape == (1, 1)


class TestAssignment:
@staticmethod
Expand Down
45 changes: 44 additions & 1 deletion test/base/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import torch
from torch.distributions import Categorical, Distribution, Independent, Normal

from tianshou.data import Batch
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.policy.base import episode_mc_return_to_go
from tianshou.policy.base import RandomActionPolicy, episode_mc_return_to_go
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor
Expand Down Expand Up @@ -85,3 +86,45 @@ def test_get_action(self, policy: PPOPolicy) -> None:
actions = [policy.compute_action(sample_obs) for _ in range(10)]
# check that the actions are the same in deterministic mode
assert len(set(map(_to_hashable, actions))) == 1

@staticmethod
def test_random_policy_discrete_actions() -> None:
action_space = gym.spaces.Discrete(3)
policy = RandomActionPolicy(action_space=action_space)

# forward of actor returns discrete probabilities, in compliance with the overall discrete actor
action_probs = policy.actor(np.zeros((10, 2)))[0]
assert np.allclose(action_probs, 1 / 3 * np.ones((10, 3)))

actions = []
for _ in range(10):
action = policy.compute_action(np.array([0]))
assert action_space.contains(action)
actions.append(action)

# not all actions are the same
assert len(set(actions)) > 1

# test batched forward
action_batch = policy(Batch(obs=np.zeros((10, 2))))
assert action_batch.act.shape == (10,)
assert len(set(action_batch.act.tolist())) > 1

@staticmethod
def test_random_policy_continuous_actions() -> None:
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
policy = RandomActionPolicy(action_space=action_space)

actions = []
for _ in range(10):
action = policy.compute_action(np.array([0]))
assert action_space.contains(action)
actions.append(action)

# not all actions are the same
assert len(set(map(_to_hashable, actions))) > 1

# test batched forward
action_batch = policy(Batch(obs=np.zeros((10, 2))))
assert action_batch.act.shape == (10, 3)
assert len(set(map(_to_hashable, action_batch.act))) > 1
42 changes: 42 additions & 0 deletions test/base/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from typing import cast

import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Normal

from tianshou.data import Batch, CollectStats
from tianshou.data.collector import CollectStepBatchProtocol, get_stddev_from_dist
from tianshou.policy.base import TrainingStats, TrainingStatsWrapper


Expand Down Expand Up @@ -47,3 +54,38 @@ def test_training_stats_wrapper() -> None:
"loss_field",
), "Attribute `loss_field` not found in `wrapped_train_stats`."
assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13

@staticmethod
@pytest.mark.parametrize(
"act,dist",
(
(np.array(1), Categorical(probs=torch.tensor([0.5, 0.5]))),
(np.array([1, 2, 3]), Normal(torch.zeros(3), torch.ones(3))),
),
)
def test_collect_stats_update_at_step(
act: np.ndarray,
dist: torch.distributions.Distribution,
) -> None:
step_batch = cast(
CollectStepBatchProtocol,
Batch(
info={},
obs=np.array([1, 2, 3]),
obs_next=np.array([4, 5, 6]),
act=act,
rew=np.array(1.0),
done=np.array(False),
terminated=np.array(False),
dist=dist,
).to_at_least_2d(),
)
stats = CollectStats()
for _ in range(10):
stats.update_at_step_batch(step_batch)
stats.refresh_all_sequence_stats()
assert stats.n_collected_steps == 10
assert stats.pred_dist_std_array is not None
assert np.allclose(stats.pred_dist_std_array, get_stddev_from_dist(dist))
assert stats.pred_dist_std_array_stat is not None
assert stats.pred_dist_std_array_stat[0].mean == get_stddev_from_dist(dist)[0].item()
9 changes: 7 additions & 2 deletions test/pettingzoo/tic_tac_toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from tianshou.data.stats import InfoStats
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.policy import (
BasePolicy,
DQNPolicy,
MARLRandomPolicy,
MultiAgentPolicyManager,
)
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
Expand Down Expand Up @@ -131,7 +136,7 @@ def get_agents(
agent_opponent = deepcopy(agent_learn)
agent_opponent.load_state_dict(torch.load(args.opponent_path))
else:
agent_opponent = RandomPolicy(action_space=env.action_space)
agent_opponent = MARLRandomPolicy(action_space=env.action_space)

if args.agent_id == 1:
agents = [agent_learn, agent_opponent]
Expand Down
39 changes: 38 additions & 1 deletion tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ def get_len_of_dist(dist: Distribution) -> int:
return dist.batch_shape[0]


def dist_to_atleast_2d(dist: TDistribution) -> TDistribution:
"""Convert a distribution to at least 2D, such that the `batch_shape` attribute has a len of at least 1."""
if len(dist.batch_shape) > 0:
return dist
if isinstance(dist, Categorical):
return Categorical(probs=dist.probs.unsqueeze(0)) # type: ignore[return-value]
elif isinstance(dist, Normal):
return Normal(loc=dist.loc.unsqueeze(0), scale=dist.scale.unsqueeze(0)) # type: ignore[return-value]
elif isinstance(dist, Independent):
return Independent(
dist_to_atleast_2d(dist.base_dist),
dist.reinterpreted_batch_ndims,
) # type: ignore[return-value]
else:
raise NotImplementedError(f"Unsupported distribution for conversion to 2D: {type(dist)}")


# Note: This is implemented as a protocol because the interface
# of Batch is always extended by adding new fields. Having a hierarchy of
# protocols building off this one allows for type safety and IDE support despite
Expand Down Expand Up @@ -602,6 +619,14 @@ def get(self, key: str, default: Any | None = None) -> Any:
def pop(self, key: str, default: Any | None = None) -> Any:
raise ProtocolCalledException

def to_at_least_2d(self) -> Self:
"""Ensures that all arrays and dists in the batch have at least 2 dimensions.
This is useful for ensuring that all arrays in the batch can be concatenated
along a new axis.
"""
raise ProtocolCalledException


class Batch(BatchProtocol):
"""See :class:`~tianshou.data.batch.BatchProtocol`."""
Expand Down Expand Up @@ -1160,7 +1185,7 @@ def __len__(self) -> int:
if isinstance(obj, Distribution):
lens.append(get_len_of_dist(obj))
continue
raise TypeError(f"Entry for {key} in {self} is {obj}has no len()")
raise TypeError(f"Entry for {key} in {self} is {obj} has no len()")
if not lens:
return 0
return min(lens)
Expand Down Expand Up @@ -1326,6 +1351,18 @@ def replace_empty_batches_by_none(self) -> None:
else:
val.replace_empty_batches_by_none()

def to_at_least_2d(self) -> Self:
"""Ensures that all arrays and dists in the batch have at least 2 dimensions.
This is useful for ensuring that all arrays in the batch can be concatenated
along a new axis.
"""
result = self.apply_values_transform(np.atleast_2d, inplace=False)
for key, val in self.items():
if isinstance(val, Distribution):
result[key] = dist_to_atleast_2d(val)
return result


def _apply_batch_values_func_recursively(
batch: TBatch,
Expand Down
Loading

0 comments on commit 8eb2795

Please sign in to comment.