Skip to content

Commit

Permalink
Merge branch 'master' into fix-docs-tictactoe-dummy-vector-env
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Oct 2, 2022
2 parents 9b3cabf + 128feb6 commit 95eccaf
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 48 deletions.
35 changes: 35 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
---
repos:
# - repo: local
# hooks:
# - id: mypy
# name: mypy
# entry: mypy
# language: python
# pass_filenames: false
# args: [--config-file=setup.cfg, tianshou]

- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
hooks:
- id: yapf
args: [-r]

- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
- id: isort
name: isort

- repo: https://gitlab.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
args: [--config=setup.cfg, --count, --show-source, --statistics]
additional_dependencies: ["flake8_bugbear"]

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
hooks:
- id: pydocstyle
exclude: ^(test/)|(docs/)|(examples/)|(setup.py)
75 changes: 38 additions & 37 deletions README.md

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions docs/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ in the main directory. This installation is removable by
PEP8 Code Style Check and Code Formatter
----------------------------------------

Please set up pre-commit by running

.. code-block:: bash
$ pre-commit install
in the main directory. This should make sure that your contribution is properly
formatted before every commit.

We follow PEP8 python code style with flake8. To check, in the main directory, run:

.. code-block:: bash
Expand Down
2 changes: 0 additions & 2 deletions test/pettingzoo/test_pistonball.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pprint

import pytest
from pistonball import get_args, train_agent, watch


@pytest.mark.skip(reason="TODO(Markus28): fix later")
def test_piston_ball(args=get_args()):
if args.watch:
watch(args)
Expand Down
2 changes: 0 additions & 2 deletions test/pettingzoo/test_tic_tac_toe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pprint

import pytest
from tic_tac_toe import get_args, train_agent, watch


@pytest.mark.skip(reason="TODO(Markus28): fix later")
def test_tic_tac_toe(args=get_args()):
if args.watch:
watch(args)
Expand Down
40 changes: 35 additions & 5 deletions tianshou/env/pettingzoo_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import warnings
from abc import ABC
from typing import Any, Dict, List, Tuple, Union

import gym.spaces
import pettingzoo
from packaging import version
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import BaseWrapper

if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
warnings.warn(
f"You are using PettingZoo {pettingzoo.__version__}. "
f"Future tianshou versions may not support PettingZoo<1.21.0. "
f"Consider upgrading your PettingZoo version.", DeprecationWarning
)


class PettingZooEnv(AECEnv, ABC):
"""The interface for petting zoo environments.
Expand Down Expand Up @@ -57,7 +67,20 @@ def __init__(self, env: BaseWrapper):

def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
self.env.reset(*args, **kwargs)
observation, _, _, info = self.env.last(self)

# Here, we do not label the return values explicitly to keep compatibility with
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
last_return = self.env.last(self)

if len(last_return) == 4:
warnings.warn(
"The PettingZoo environment is using the old step API. "
"This API may not be supported in future versions of tianshou. "
"We recommend that you update the environment code or apply a "
"compatibility wrapper.", DeprecationWarning
)

observation, info = last_return[0], last_return[-1]
if isinstance(observation, dict) and 'action_mask' in observation:
observation_dict = {
'agent_id': self.env.agent_selection,
Expand All @@ -83,9 +106,16 @@ def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
else:
return observation_dict

def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
def step(
self, action: Any
) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool,
Dict]]:
self.env.step(action)
observation, rew, done, info = self.env.last()

# Here, we do not label the return values explicitly to keep compatibility with
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
last_return = self.env.last()
observation = last_return[0]
if isinstance(observation, dict) and 'action_mask' in observation:
obs = {
'agent_id': self.env.agent_selection,
Expand All @@ -105,15 +135,15 @@ def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:

for agent_id, reward in self.env.rewards.items():
self.rewards[self.agent_idx[agent_id]] = reward
return obs, self.rewards, done, info
return (obs, self.rewards, *last_return[2:]) # type: ignore

def close(self) -> None:
self.env.close()

def seed(self, seed: Any = None) -> None:
try:
self.env.seed(seed)
except NotImplementedError:
except (NotImplementedError, AttributeError):
self.env.reset(seed=seed)

def render(self, mode: str = "human") -> Any:
Expand Down
5 changes: 3 additions & 2 deletions tianshou/env/worker/subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
def save(self, ndarray: np.ndarray) -> None:
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
dst_np = np.frombuffer(dst,
dtype=self.dtype).reshape(self.shape) # type: ignore
np.copyto(dst_np, ndarray)

def get(self) -> np.ndarray:
obj = self.arr.get_obj()
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore


def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
Expand Down

0 comments on commit 95eccaf

Please sign in to comment.