Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support Gym 0.26.0 #748

Merged
merged 21 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions .github/workflows/extra_sys.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@ jobs:
os: [macos-latest, windows-latest]
python-version: [3.7, 3.8]
steps:
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
python -m pip uninstall ray -y
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
run: |
pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
python -m pip uninstall ray -y
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
run: |
pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes
44 changes: 22 additions & 22 deletions .github/workflows/gputest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@ jobs:
runs-on: [self-hosted, Linux, X64]
if: "!contains(github.event.head_commit.message, 'ci skip')"
steps:
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
58 changes: 29 additions & 29 deletions .github/workflows/lint_and_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,32 @@ jobs:
check:
runs-on: ubuntu-latest
steps:
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
- name: Code formatter
run: |
yapf -r -d .
isort --check .
- name: Type check
run: |
mypy
- name: Documentation test
run: |
make check-docstyle
make spelling
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
- name: Code formatter
run: |
yapf -r -d .
isort --check .
- name: Type check
run: |
mypy
- name: Documentation test
run: |
make check-docstyle
make spelling
36 changes: 18 additions & 18 deletions .github/workflows/profile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ jobs:
runs-on: ubuntu-latest
if: "!contains(github.event.head_commit.message, 'ci skip')"
steps:
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: Test with pytest
run: |
pytest test/throughput --durations=0 -v --color=yes
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: Test with pytest
run: |
pytest test/throughput --durations=0 -v --color=yes
60 changes: 30 additions & 30 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,33 @@ jobs:
matrix:
python-version: [3.7, 3.8, 3.9]
steps:
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v --color=yes
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
token: ${{ secrets.CODECOV }}
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
- name: Cancel previous run
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v --color=yes
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
token: ${{ secrets.CODECOV }}
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ exclude =
dist
*.egg-info
max-line-length = 87
ignore = B305,W504,B006,B008
ignore = B305,W504,B006,B008,B024

[yapf]
based_on_style = pep8
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_extras_require() -> str:
"pygame>=2.1.0", # pettingzoo test cases pistonball
"pymunk>=6.2.1", # pettingzoo test cases pistonball
"nni>=2.3",
"pytorch_lightning",
],
"atari": ["atari_py", "opencv-python"],
"mujoco": ["mujoco_py"],
Expand Down
43 changes: 17 additions & 26 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,28 +71,19 @@ def __init__(
self.action_space = MultiDiscrete([2, 2])
else:
self.action_space = Discrete(2)
self.done = False
self.terminated = False
self.index = 0
self.seed()

def seed(self, seed=0):
self.rng = np.random.RandomState(seed)
return [seed]

def reset(self, state=0, seed=None, return_info=False):
if seed is not None:
self.rng = np.random.RandomState(seed)
self.done = False
def reset(self, state=0, seed=None):
super().reset(seed=seed)
self.terminated = False
self.do_sleep()
self.index = state
if return_info:
return self._get_state(), {'key': 1, 'env': self}
else:
return self._get_state()
return self._get_state(), {'key': 1, 'env': self}

def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True."""
end_flag = int(self.done)
end_flag = int(self.terminated)
if self.ma_rew > 0:
return [end_flag] * self.ma_rew
return end_flag
Expand All @@ -102,14 +93,14 @@ def _get_state(self):
if self.dict_state:
return {
'index': np.array([self.index], dtype=np.float32),
'rand': self.rng.rand(1)
'rand': self.np_random.random(1)
}
elif self.recurse_state:
return {
'index': np.array([self.index], dtype=np.float32),
'dict': {
"tuple": (np.array([1], dtype=int), self.rng.rand(2)),
"rand": self.rng.rand(1, 2)
"tuple": (np.array([1], dtype=int), self.np_random.random(2)),
"rand": self.np_random.random((1, 2))
}
}
elif self.array_state:
Expand All @@ -132,21 +123,21 @@ def step(self, action):
self.steps += 1
if self._md_action:
action = action[0]
if self.done:
if self.terminated:
raise ValueError('step after done !!!')
self.do_sleep()
if self.index == self.size:
self.done = True
return self._get_state(), self._get_reward(), self.done, {}
self.terminated = True
return self._get_state(), self._get_reward(), self.terminated, False, {}
if action == 0:
self.index = max(self.index - 1, 0)
return self._get_state(), self._get_reward(), self.done, \
return self._get_state(), self._get_reward(), self.terminated, False, \
{'key': 1, 'env': self} if self.dict_state else {}
elif action == 1:
self.index += 1
self.done = self.index == self.size
self.terminated = self.index == self.size
return self._get_state(), self._get_reward(), \
self.done, {'key': 1, 'env': self}
self.terminated, False, {'key': 1, 'env': self}


class NXEnv(gym.Env):
Expand All @@ -168,10 +159,10 @@ def reset(self):
graph_state = np.random.rand(self.size, self.feat_dim)
for i in range(self.size):
self.graph.nodes[i]["data"] = graph_state[i]
return self._encode_obs()
return self._encode_obs(), {}

def step(self, action):
next_graph_state = np.random.rand(self.size, self.feat_dim)
for i in range(self.size):
self.graph.nodes[i]["data"] = next_graph_state[i]
return self._encode_obs(), 1.0, 0, {}
return self._encode_obs(), 1.0, 0, 0, {}
Loading